From 3879a8d0160e1536a7bb6c8f3c6b6cd0d00ef8ce Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Tue, 23 Jun 2026 08:33:01 +0000 Subject: [PATCH 01/11] fix: updating security fixes --- .dockerignore | 47 ++++++++ .gitignore | 3 - app/config.py | 20 +++- app/core/security_headers_middleware.py | 37 +++++++ docker/Dockerfile.api | 46 ++++---- .../docs/getting-started/authentication.mdx | 9 +- .../test_security_headers_middleware.py | 104 ++++++++++++++++++ 7 files changed, 234 insertions(+), 32 deletions(-) create mode 100644 .dockerignore create mode 100644 tests/test_core/test_security_headers_middleware.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..8c597a4f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,47 @@ +# Git and IDE +.git +.gitignore +.vscode +.idea + +# Python artifacts +**/__pycache__ +**/*.py[cod] +**/.pytest_cache +**/.mypy_cache +**/.ruff_cache +.venv +venv +env +newenv +*.egg-info +dist +build + +# Local config and secrets +.env +.env.* +config.yml +config.docker.yml +.encryption_key +.data +secrets +uploads + +# Frontend (built in Docker; avoid shipping host node_modules) +frontend/node_modules +frontend/dist + +# Docs and dev-only content +docs-fumadocs +*.md +!README.md + +# Tests +tests + +# Logs and temp files +*.log +logs +tmp +temp diff --git a/.gitignore b/.gitignore index a502d377..713908a1 100644 --- a/.gitignore +++ b/.gitignore @@ -54,9 +54,6 @@ uv.lock .encryption_key .data/ -# Docker -.dockerignore - # Uploads uploads/ *.wav diff --git a/app/config.py b/app/config.py index 5a1d7968..5747c69d 100644 --- a/app/config.py +++ b/app/config.py @@ -83,7 +83,25 @@ class Settings(BaseSettings): # Frontend FRONTEND_DIR: str = "./frontend/dist" - + + # Content Security Policy (Report-Only by default; set CSP_REPORT_ONLY=false to enforce) + CSP_ENABLED: bool = True + CSP_REPORT_ONLY: bool = True + CSP_POLICY: str = ( + "default-src 'self'; " + "script-src 'self'; " + "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " + "font-src 'self' https://fonts.gstatic.com; " + "img-src 'self' data: blob: https:; " + "connect-src 'self' wss: ws:; " + "media-src 'self' blob: https:; " + "frame-src 'self' blob:; " + "object-src 'none'; " + "base-uri 'self'; " + "form-action 'self'; " + "frame-ancestors 'self'" + ) + # SMTP / Email Notifications (for Alerts) SMTP_HOST: Optional[str] = None # e.g., "smtp.gmail.com" SMTP_PORT: int = 587 diff --git a/app/core/security_headers_middleware.py b/app/core/security_headers_middleware.py index 1d115019..caa726a9 100644 --- a/app/core/security_headers_middleware.py +++ b/app/core/security_headers_middleware.py @@ -6,6 +6,41 @@ from starlette.requests import Request from starlette.responses import Response +from app.config import settings + +_NO_STORE_CACHE = "no-cache, no-store, must-revalidate" +_ASSET_CACHE = "public, max-age=31536000, immutable" + + +def _apply_cache_control(request: Request, response: Response) -> None: + if "cache-control" in response.headers: + return + + path = request.url.path + if path.startswith("/assets/"): + response.headers["Cache-Control"] = _ASSET_CACHE + return + + cache_value = _NO_STORE_CACHE + if path.startswith("/api/"): + cache_value = f"{_NO_STORE_CACHE}, private" + + response.headers["Cache-Control"] = cache_value + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + + +def _apply_csp(response: Response) -> None: + if not settings.CSP_ENABLED: + return + + header_name = ( + "Content-Security-Policy-Report-Only" + if settings.CSP_REPORT_ONLY + else "Content-Security-Policy" + ) + response.headers[header_name] = settings.CSP_POLICY + class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add baseline security headers to every response.""" @@ -15,4 +50,6 @@ async def dispatch(self, request: Request, call_next) -> Response: response.headers["X-Frame-Options"] = "SAMEORIGIN" response.headers["X-Content-Type-Options"] = "nosniff" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + _apply_cache_control(request, response) + _apply_csp(response) return response diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api index 3a50fcf1..fb79cc57 100644 --- a/docker/Dockerfile.api +++ b/docker/Dockerfile.api @@ -1,10 +1,26 @@ +# ============================================================ +# Stage 1: Frontend build (Node only — not shipped to runtime) +# ============================================================ +FROM node:18-bookworm-slim AS frontend-builder + +WORKDIR /app/frontend + +COPY frontend/package.json frontend/package-lock.json* ./ +RUN --mount=type=cache,target=/root/.npm \ + npm ci --legacy-peer-deps || npm install --legacy-peer-deps + +COPY frontend/ ./ +RUN npm run build + +# ============================================================ +# Stage 2: API runtime (Python only — no node_modules / esbuild) +# ============================================================ FROM python:3.11-slim -# Install Node.js, npm, build tools, and WeasyPrint system libraries. +# Install build tools and WeasyPrint system libraries. # build-essential + cmake: packages like praat-parselmouth on arm64/Apple Silicon # libgobject/libpango/libcairo/libgdk-pixbuf/libffi/shared-mime-info: PDF rendering RUN apt-get update && apt-get install -y \ - curl \ build-essential \ cmake \ libgobject-2.0-0 \ @@ -14,15 +30,12 @@ RUN apt-get update && apt-get install -y \ libgdk-pixbuf-2.0-0 \ libffi-dev \ shared-mime-info \ - && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ - && apt-get install -y nodejs \ && rm -rf /var/lib/apt/lists/* # Install uv via pip to get the glibc-linked binary (the musl-linked binary # from ghcr.io/astral-sh/uv has DNS resolution issues inside Docker) RUN pip install --no-cache-dir uv -# Set working directory WORKDIR /app # ============================================================ @@ -43,28 +56,13 @@ RUN if find /usr/local/lib/ -name "litellm_init.pth" 2>/dev/null | grep -q .; th fi # ============================================================ -# LAYER 2: Frontend dependencies (cached unless package.json changes) -# ============================================================ -COPY frontend/package.json frontend/package-lock.json* ./frontend/ -WORKDIR /app/frontend - -# Install all dependencies (including dev dependencies needed for build) -# Use --legacy-peer-deps to resolve peer dependency conflicts with HeroUI and Tailwind -RUN --mount=type=cache,target=/root/.npm \ - npm ci --legacy-peer-deps || npm install --legacy-peer-deps - -# ============================================================ -# LAYER 3: Frontend build (rebuilds on frontend code changes) +# LAYER 2: Frontend static assets (built in stage 1, dist only) # ============================================================ -COPY frontend/ ./ - -# Build frontend for production -RUN npm run build +COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist # ============================================================ -# LAYER 4: Backend code (rebuilds on backend code changes) +# LAYER 3: Backend code (rebuilds on backend code changes) # ============================================================ -WORKDIR /app COPY src/ ./src/ COPY app/ ./app/ COPY scripts/ ./scripts/ @@ -74,8 +72,6 @@ RUN uv pip install --system -e . # Create uploads directory RUN mkdir -p /app/uploads -# Expose port EXPOSE 8000 -# Run the application CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/docs-fumadocs/content/docs/getting-started/authentication.mdx b/docs-fumadocs/content/docs/getting-started/authentication.mdx index 857c2896..3ee9bdc7 100644 --- a/docs-fumadocs/content/docs/getting-started/authentication.mdx +++ b/docs-fumadocs/content/docs/getting-started/authentication.mdx @@ -121,9 +121,12 @@ Rules the UI enforces: - Rotate `SECRET_KEY` to invalidate existing sessions. - Put the app behind a reverse proxy (Nginx, Caddy, Cloudflare) that terminates TLS and enforces HSTS. -- The bundled FastAPI server sends `X-Frame-Options: SAMEORIGIN` on all - responses. If you terminate traffic at an external reverse proxy, keep - that header (or stricter CSP `frame-ancestors`) enabled there too. +- The bundled FastAPI server sends baseline security headers on all + responses (`X-Frame-Options`, `X-Content-Type-Options`, `Referrer-Policy`, + `Cache-Control`, and `Content-Security-Policy-Report-Only` by default). If + you terminate traffic at an external reverse proxy, keep those headers (or + stricter CSP `frame-ancestors`) enabled there too. After reviewing CSP + violation reports, set `CSP_REPORT_ONLY=false` to enforce the policy. - Pin third-party observability container images to fixed tags and rebuild external reverse-proxy images on patched runtimes. If a scanner reports a Go stdlib CVE in a binary this repo does not build, identify the diff --git a/tests/test_core/test_security_headers_middleware.py b/tests/test_core/test_security_headers_middleware.py new file mode 100644 index 00000000..320e2163 --- /dev/null +++ b/tests/test_core/test_security_headers_middleware.py @@ -0,0 +1,104 @@ +"""Tests for HTTP security response headers.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from fastapi.testclient import TestClient + +from app.config import settings +from app.core.security_headers_middleware import SecurityHeadersMiddleware + + +@pytest.fixture +def security_client(tmp_path): + """Minimal app with SecurityHeadersMiddleware (avoids heavy conftest client).""" + app = FastAPI() + app.add_middleware(SecurityHeadersMiddleware) + + @app.get("/health") + def health(): + return {"status": "healthy"} + + @app.get("/api/v1/auth/config") + def auth_config(): + return {"providers": []} + + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + (assets_dir / "app.js").write_text("console.log('ok');", encoding="utf-8") + app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets") + + with TestClient(app) as client: + yield client + + +def test_health_includes_baseline_security_headers(security_client): + response = security_client.get("/health") + + assert response.status_code == 200 + assert response.headers["X-Content-Type-Options"] == "nosniff" + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + assert response.headers["X-Frame-Options"] == "SAMEORIGIN" + assert "no-store" in response.headers["Cache-Control"] + assert response.headers["Pragma"] == "no-cache" + assert response.headers["Expires"] == "0" + + +def test_api_routes_include_private_no_store_cache(security_client): + response = security_client.get("/api/v1/auth/config") + + assert response.status_code == 200 + assert "no-store" in response.headers["Cache-Control"] + assert "private" in response.headers["Cache-Control"] + + +def test_csp_report_only_header_when_enabled(security_client, monkeypatch): + monkeypatch.setattr(settings, "CSP_ENABLED", True) + monkeypatch.setattr(settings, "CSP_REPORT_ONLY", True) + + response = security_client.get("/health") + + assert "Content-Security-Policy-Report-Only" in response.headers + assert "Content-Security-Policy" not in response.headers + assert "default-src 'self'" in response.headers["Content-Security-Policy-Report-Only"] + + +def test_csp_enforcing_header_when_report_only_disabled(security_client, monkeypatch): + monkeypatch.setattr(settings, "CSP_ENABLED", True) + monkeypatch.setattr(settings, "CSP_REPORT_ONLY", False) + + response = security_client.get("/health") + + assert "Content-Security-Policy" in response.headers + assert "Content-Security-Policy-Report-Only" not in response.headers + + +def test_asset_routes_use_long_cache(security_client): + response = security_client.get("/assets/app.js") + + assert response.status_code == 200 + assert response.headers["Cache-Control"] == "public, max-age=31536000, immutable" + assert "Pragma" not in response.headers + + +@pytest.mark.skipif( + not (Path(settings.FRONTEND_DIR) / "assets").is_dir(), + reason="frontend dist assets not built in test environment", +) +def test_built_frontend_assets_use_long_cache(): + """Integration check when frontend/dist exists locally.""" + from app.main import create_app + + app = create_app() + assets_dir = Path(settings.FRONTEND_DIR) / "assets" + asset_file = next(assets_dir.iterdir()) + + with TestClient(app) as client: + response = client.get(f"/assets/{asset_file.name}") + + assert response.status_code == 200 + assert response.headers["Cache-Control"] == "public, max-age=31536000, immutable" From 461545190547d5b41ef7bee0d04a3ab96e33999c Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Tue, 23 Jun 2026 10:47:32 +0000 Subject: [PATCH 02/11] fix: updating public health check access --- app/config.py | 11 + app/core/health.py | 25 +++ app/core/migration_middleware.py | 33 ++- app/core/operational_access_middleware.py | 115 ++++++++++ app/main.py | 37 ++-- config.yml.example | 8 + .../docs/getting-started/authentication.mdx | 8 + tests/test_core/test_operational_endpoints.py | 200 ++++++++++++++++++ 8 files changed, 402 insertions(+), 35 deletions(-) create mode 100644 app/core/health.py create mode 100644 app/core/operational_access_middleware.py create mode 100644 tests/test_core/test_operational_endpoints.py diff --git a/app/config.py b/app/config.py index 5747c69d..2c6e401b 100644 --- a/app/config.py +++ b/app/config.py @@ -102,6 +102,10 @@ class Settings(BaseSettings): "frame-ancestors 'self'" ) + # Operational endpoints (/health, /metrics) + OPERATIONAL_PUBLIC: bool = False + OPERATIONAL_TRUSTED_IPS: List[str] = [] + # SMTP / Email Notifications (for Alerts) SMTP_HOST: Optional[str] = None # e.g., "smtp.gmail.com" SMTP_PORT: int = 587 @@ -563,6 +567,13 @@ def load_config_from_file(config_path: str) -> None: if "csv_max_rows" in ja_cfg: settings.JUDGE_ALIGNMENT_CSV_MAX_ROWS = int(ja_cfg["csv_max_rows"]) + if "operational" in config_data: + operational_config = config_data["operational"] + if "public" in operational_config: + settings.OPERATIONAL_PUBLIC = bool(operational_config["public"]) + if "trusted_ips" in operational_config: + settings.OPERATIONAL_TRUSTED_IPS = operational_config["trusted_ips"] + # Update Celery URLs if they weren't explicitly set if not settings.CELERY_BROKER_URL: settings.CELERY_BROKER_URL = settings.REDIS_URL diff --git a/app/core/health.py b/app/core/health.py new file mode 100644 index 00000000..a8284778 --- /dev/null +++ b/app/core/health.py @@ -0,0 +1,25 @@ +"""Health check response helpers.""" + +from __future__ import annotations + +from app.core.migrations import check_migrations_status + + +def build_health_status(*, detailed: bool) -> tuple[dict, int]: + """Build health payload and HTTP status code.""" + is_up_to_date, pending = check_migrations_status() + + if is_up_to_date: + if detailed: + return {"status": "healthy", "migrations": "up_to_date"}, 200 + return {"status": "healthy"}, 200 + + if detailed: + return { + "status": "degraded", + "migrations": "pending", + "pending_migrations": pending, + "message": f"{len(pending)} migration(s) pending: {', '.join(pending)}", + }, 503 + + return {"status": "degraded"}, 503 diff --git a/app/core/migration_middleware.py b/app/core/migration_middleware.py index da3441cd..8873e838 100644 --- a/app/core/migration_middleware.py +++ b/app/core/migration_middleware.py @@ -3,9 +3,11 @@ Blocks API requests if migrations are pending. """ -from fastapi import Request, HTTPException, status +from fastapi import Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware + +from app.config import settings from app.core.migrations import check_migrations_status import logging @@ -16,32 +18,30 @@ _migrations_up_to_date = False +def _migration_bypass_paths() -> list[str]: + paths = ["/health"] + if settings.DEBUG: + paths.extend(["/docs", "/redoc", "/openapi.json"]) + return paths + + class MigrationCheckMiddleware(BaseHTTPMiddleware): """ Middleware that blocks API requests if database migrations are pending. Allows health checks and migration-related endpoints to pass through. """ - - # Endpoints that should be allowed even if migrations are pending - ALLOWED_PATHS = [ - "/health", - "/docs", - "/redoc", - "/openapi.json", - ] - + async def dispatch(self, request: Request, call_next): - # Allow health checks and docs - if any(request.url.path.startswith(path) for path in self.ALLOWED_PATHS): + if any(request.url.path.startswith(path) for path in _migration_bypass_paths()): return await call_next(request) - + # Allow static assets (frontend) if request.url.path.startswith("/assets/"): return await call_next(request) - + # Check migration status is_up_to_date, pending = check_migrations_status() - + if not is_up_to_date: # Block API requests if migrations are pending if request.url.path.startswith("/api/"): @@ -56,6 +56,5 @@ async def dispatch(self, request: Request, call_next): "message": f"Application is starting up. {len(pending)} migration(s) need to be applied: {', '.join(pending)}" } ) - - return await call_next(request) + return await call_next(request) diff --git a/app/core/operational_access_middleware.py b/app/core/operational_access_middleware.py new file mode 100644 index 00000000..78b10bc3 --- /dev/null +++ b/app/core/operational_access_middleware.py @@ -0,0 +1,115 @@ +"""Restrict operational endpoints (/health, /metrics) from the public internet.""" + +from __future__ import annotations + +import ipaddress +import logging +from typing import Iterable + +from fastapi import HTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from app.config import settings +from app.core.auth.dependency import _resolve +from app.database import SessionLocal + +logger = logging.getLogger(__name__) + +_PROBE_USER_AGENTS = ("ELB-HealthChecker", "kube-probe") +_HEALTH_PATH = "/health" +_METRICS_PREFIX = "/metrics" + + +def _client_ip(request: Request) -> str | None: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + if request.client: + return request.client.host + return None + + +def _ip_in_trusted(ip: str, trusted: Iterable[str]) -> bool: + try: + address = ipaddress.ip_address(ip) + except ValueError: + return False + + for entry in trusted: + entry = entry.strip() + if not entry: + continue + try: + if "/" in entry: + if address in ipaddress.ip_network(entry, strict=False): + return True + elif address == ipaddress.ip_address(entry): + return True + except ValueError: + logger.warning("Ignoring invalid operational trusted IP entry: %s", entry) + return False + + +def _has_authenticated_caller(request: Request) -> bool: + db = SessionLocal() + try: + principal = _resolve( + request.headers.get("authorization"), + request.headers.get("x-api-key"), + request.headers.get("x-efficientai-api-key"), + db, + ) + return principal is not None + except HTTPException: + return False + finally: + db.close() + + +def _is_health_probe(request: Request) -> bool: + user_agent = request.headers.get("user-agent", "") + if any(marker in user_agent for marker in _PROBE_USER_AGENTS): + return True + + if request.headers.get("x-forwarded-for"): + return False + + direct_ip = request.client.host if request.client else None + if direct_ip and _ip_in_trusted(direct_ip, settings.OPERATIONAL_TRUSTED_IPS): + return True + return False + + +def is_operational_access_allowed(request: Request) -> bool: + """Return True when the caller may access a protected operational endpoint.""" + if settings.OPERATIONAL_PUBLIC: + return True + + path = request.url.path + if path == _HEALTH_PATH and _is_health_probe(request): + return True + + client_ip = _client_ip(request) + if client_ip and _ip_in_trusted(client_ip, settings.OPERATIONAL_TRUSTED_IPS): + return True + + if _has_authenticated_caller(request): + return True + + return False + + +def is_operational_path(path: str) -> bool: + return path == _HEALTH_PATH or path.startswith(_METRICS_PREFIX) + + +class OperationalAccessMiddleware(BaseHTTPMiddleware): + """Block anonymous public access to /health and /metrics.""" + + async def dispatch(self, request: Request, call_next) -> Response: + if is_operational_path(request.url.path) and not is_operational_access_allowed(request): + return JSONResponse(status_code=404, content={"detail": "Not found"}) + + return await call_next(request) diff --git a/app/main.py b/app/main.py index c1477972..ee29a224 100644 --- a/app/main.py +++ b/app/main.py @@ -4,14 +4,17 @@ from contextlib import asynccontextmanager from pathlib import Path -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from app.config import load_config_from_file, settings, validate_auth_configuration +from app.core.auth.rbac import require_admin +from app.core.health import build_health_status from app.core.migration_middleware import MigrationCheckMiddleware from app.core.migrations import check_migrations_status, ensure_migrations_directory, run_migrations +from app.core.operational_access_middleware import OperationalAccessMiddleware from app.core.rbac_middleware import ReaderReadOnlyMiddleware from app.core.security_headers_middleware import SecurityHeadersMiddleware from app.database import init_db @@ -99,8 +102,9 @@ def create_app() -> FastAPI: title=settings.APP_NAME, version=settings.APP_VERSION, description="EfficientAI Voice AI Evaluation Platform API", - docs_url="/docs", - redoc_url="/redoc", + docs_url="/docs" if settings.DEBUG else None, + redoc_url="/redoc" if settings.DEBUG else None, + openapi_url="/openapi.json" if settings.DEBUG else None, lifespan=lifespan, ) @@ -120,6 +124,7 @@ def create_app() -> FastAPI: allow_headers=["*"], ) app.add_middleware(SecurityHeadersMiddleware) + app.add_middleware(OperationalAccessMiddleware) if settings.OBSERVABILITY_ENABLED: from prometheus_fastapi_instrumentator import Instrumentator @@ -137,20 +142,15 @@ def create_app() -> FastAPI: @app.get("/health") async def health_check(): - """ - Health check endpoint. - Returns migration status to help diagnose issues. - """ - is_up_to_date, pending = check_migrations_status() - - if is_up_to_date: - return {"status": "healthy", "migrations": "up_to_date"} - return { - "status": "degraded", - "migrations": "pending", - "pending_migrations": pending, - "message": f"{len(pending)} migration(s) pending: {', '.join(pending)}", - } + """Minimal health probe for load balancers (see OperationalAccessMiddleware).""" + payload, status_code = build_health_status(detailed=False) + return JSONResponse(content=payload, status_code=status_code) + + @app.get("/health/detail") + async def health_detail(_admin=Depends(require_admin)): + """Authenticated migration diagnostics for operators.""" + payload, status_code = build_health_status(detailed=True) + return JSONResponse(content=payload, status_code=status_code) frontend_dist = Path(settings.FRONTEND_DIR) if frontend_dist.exists() and frontend_dist.is_dir(): @@ -167,6 +167,7 @@ async def serve_frontend(full_path: str, request: Request): or full_path.startswith("redoc") or full_path.startswith("assets/") or full_path == "health" + or full_path == "health/detail" or full_path == "metrics" ): return {"detail": "Not found"} diff --git a/config.yml.example b/config.yml.example index fe461ae4..7cf5e492 100644 --- a/config.yml.example +++ b/config.yml.example @@ -12,6 +12,14 @@ server: host: "0.0.0.0" port: 8000 +# Operational endpoints (/health, /metrics) +# Keep public: false in production/sandbox. ALB health checks use the +# ELB-HealthChecker user-agent; Prometheus scrapers should use trusted_ips. +operational: + public: false + trusted_ips: + - "10.0.0.0/8" + # Database Configuration database: url: "postgresql://efficientai:password@localhost:5432/efficientai" diff --git a/docs-fumadocs/content/docs/getting-started/authentication.mdx b/docs-fumadocs/content/docs/getting-started/authentication.mdx index 3ee9bdc7..f45d45e9 100644 --- a/docs-fumadocs/content/docs/getting-started/authentication.mdx +++ b/docs-fumadocs/content/docs/getting-started/authentication.mdx @@ -132,6 +132,14 @@ Rules the UI enforces: a Go stdlib CVE in a binary this repo does not build, identify the flagged container or proxy artifact and upgrade it separately. - Restrict `cors.origins` to the exact domain(s) serving the SPA. +- Set `app.debug: false` in production so `/docs`, `/redoc`, and `/openapi.json` + are not served on the public hostname. +- Keep `operational.public: false` so `/health` and `/metrics` return **404** + to anonymous public clients (including vulnerability scanners hitting your ALB + hostname). AWS ALB target health checks still succeed via the + `ELB-HealthChecker` user-agent. Add your VPC CIDRs to + `operational.trusted_ips` for internal Prometheus scrapers. Full migration + diagnostics are available at `GET /health/detail` for org admins. --- diff --git a/tests/test_core/test_operational_endpoints.py b/tests/test_core/test_operational_endpoints.py new file mode 100644 index 00000000..3e766344 --- /dev/null +++ b/tests/test_core/test_operational_endpoints.py @@ -0,0 +1,200 @@ +"""Tests for operational endpoint access controls.""" + +from __future__ import annotations + +import sys +from types import SimpleNamespace + +import pytest +from fastapi import APIRouter, FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + +from app.config import settings +from app.core.health import build_health_status +from app.core.migration_middleware import MigrationCheckMiddleware +from app.core.operational_access_middleware import OperationalAccessMiddleware + + +@pytest.fixture +def operational_client(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + monkeypatch.setattr(settings, "DEBUG", True) + + app = FastAPI() + app.add_middleware(MigrationCheckMiddleware) + app.add_middleware(OperationalAccessMiddleware) + + @app.get("/health") + def health(): + payload, status_code = build_health_status(detailed=False) + return JSONResponse(content=payload, status_code=status_code) + + @app.get("/metrics") + def metrics(): + return "metrics" + + with TestClient(app) as client: + yield client + + +def test_public_health_via_alb_is_blocked(operational_client): + response = operational_client.get( + "/health", + headers={"X-Forwarded-For": "203.0.113.1", "User-Agent": "SecurityScanner/1.0"}, + ) + + assert response.status_code == 404 + + +def test_alb_health_probe_is_allowed(operational_client): + response = operational_client.get( + "/health", + headers={"User-Agent": "ELB-HealthChecker/2.0"}, + ) + + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_health_allowed_for_trusted_client_ip(operational_client): + response = operational_client.get( + "/health", + headers={"X-Forwarded-For": "10.1.2.3", "User-Agent": "Mozilla/5.0"}, + ) + + assert response.status_code == 200 + assert set(response.json().keys()) == {"status"} + + +def test_metrics_blocked_for_public_client(operational_client): + response = operational_client.get( + "/metrics", + headers={"X-Forwarded-For": "203.0.113.1"}, + ) + + assert response.status_code == 404 + + +def test_metrics_allowed_for_trusted_client_ip(operational_client): + response = operational_client.get( + "/metrics", + headers={"X-Forwarded-For": "10.1.2.3"}, + ) + + assert response.status_code == 200 + assert response.json() == "metrics" + + +def test_operational_public_allows_anonymous_health(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", True) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", []) + + app = FastAPI() + app.add_middleware(OperationalAccessMiddleware) + + @app.get("/health") + def health(): + return {"status": "healthy"} + + with TestClient(app) as client: + response = client.get( + "/health", + headers={"X-Forwarded-For": "203.0.113.1"}, + ) + + assert response.status_code == 200 + + +def test_docs_not_registered_when_debug_disabled(monkeypatch): + monkeypatch.setitem( + sys.modules, + "app.api.v1.api", + SimpleNamespace(api_router=APIRouter()), + ) + monkeypatch.setattr(settings, "FRONTEND_DIR", "__missing_frontend__") + monkeypatch.setattr(settings, "OBSERVABILITY_ENABLED", False) + + from app.main import create_app + + monkeypatch.setattr(settings, "DEBUG", False) + app = create_app() + route_paths = {getattr(route, "path", None) for route in app.routes} + + assert "/docs" not in route_paths + assert "/redoc" not in route_paths + assert "/openapi.json" not in route_paths + + +def test_build_health_status_minimal_excludes_migration_details(monkeypatch): + monkeypatch.setattr( + "app.core.health.check_migrations_status", + lambda: (False, ["033_add_workspaces.sql"]), + ) + + payload, status_code = build_health_status(detailed=False) + + assert status_code == 503 + assert payload == {"status": "degraded"} + assert "pending_migrations" not in payload + + +def test_build_health_status_detailed_includes_migration_details(monkeypatch): + monkeypatch.setattr( + "app.core.health.check_migrations_status", + lambda: (False, ["033_add_workspaces.sql"]), + ) + + payload, status_code = build_health_status(detailed=True) + + assert status_code == 503 + assert payload["status"] == "degraded" + assert payload["pending_migrations"] == ["033_add_workspaces.sql"] + + +def test_health_detail_returns_migration_info_for_admin(monkeypatch): + monkeypatch.setitem( + sys.modules, + "app.api.v1.api", + SimpleNamespace(api_router=APIRouter()), + ) + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", True) + monkeypatch.setattr(settings, "DEBUG", False) + monkeypatch.setattr(settings, "FRONTEND_DIR", "__missing_frontend__") + monkeypatch.setattr(settings, "OBSERVABILITY_ENABLED", False) + monkeypatch.setattr( + "app.core.health.check_migrations_status", + lambda: (True, []), + ) + + from app.core.auth.rbac import require_admin + from app.main import create_app + + app = create_app() + app.dependency_overrides[require_admin] = lambda: object() + + with TestClient(app) as client: + response = client.get("/health/detail") + + assert response.status_code == 200 + assert response.json()["migrations"] == "up_to_date" + + +def test_health_detail_requires_authentication_via_create_app(monkeypatch): + monkeypatch.setitem( + sys.modules, + "app.api.v1.api", + SimpleNamespace(api_router=APIRouter()), + ) + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", True) + monkeypatch.setattr(settings, "DEBUG", False) + monkeypatch.setattr(settings, "FRONTEND_DIR", "__missing_frontend__") + monkeypatch.setattr(settings, "OBSERVABILITY_ENABLED", False) + + from app.main import create_app + + with TestClient(create_app()) as client: + response = client.get("/health/detail") + + assert response.status_code == 401 From 54758611882326f5c981ca0bd353c7bd44a49087 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Tue, 23 Jun 2026 11:18:23 +0000 Subject: [PATCH 03/11] fix: failing test cases --- tests/test_core/test_operational_endpoints.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_core/test_operational_endpoints.py b/tests/test_core/test_operational_endpoints.py index 3e766344..911aa68b 100644 --- a/tests/test_core/test_operational_endpoints.py +++ b/tests/test_core/test_operational_endpoints.py @@ -21,6 +21,10 @@ def operational_client(monkeypatch): monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) monkeypatch.setattr(settings, "DEBUG", True) + monkeypatch.setattr( + "app.core.health.check_migrations_status", + lambda: (True, []), + ) app = FastAPI() app.add_middleware(MigrationCheckMiddleware) From fb86bd2c5ac96ebe64908092ec00f42205100fc2 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Tue, 23 Jun 2026 12:13:42 +0000 Subject: [PATCH 04/11] feat: updating IP security --- app/core/operational_access_middleware.py | 55 +++++----- config.yml.example | 5 +- .../docs/getting-started/authentication.mdx | 8 +- tests/test_core/test_operational_endpoints.py | 101 ++++++++++++++---- 4 files changed, 119 insertions(+), 50 deletions(-) diff --git a/app/core/operational_access_middleware.py b/app/core/operational_access_middleware.py index 78b10bc3..1d53c9ad 100644 --- a/app/core/operational_access_middleware.py +++ b/app/core/operational_access_middleware.py @@ -17,15 +17,11 @@ logger = logging.getLogger(__name__) -_PROBE_USER_AGENTS = ("ELB-HealthChecker", "kube-probe") _HEALTH_PATH = "/health" _METRICS_PREFIX = "/metrics" -def _client_ip(request: Request) -> str | None: - forwarded = request.headers.get("x-forwarded-for") - if forwarded: - return forwarded.split(",")[0].strip() +def _peer_ip(request: Request) -> str | None: if request.client: return request.client.host return None @@ -52,6 +48,33 @@ def _ip_in_trusted(ip: str, trusted: Iterable[str]) -> bool: return False +def _resolved_trusted_ip(request: Request) -> str | None: + """Resolve the client IP that may be matched against OPERATIONAL_TRUSTED_IPS. + + - Direct connections (no X-Forwarded-For): use the TCP peer address. This + covers ALB/kube health checks that connect from a VPC address. + - Proxied connections: only honor X-Forwarded-For when the TCP peer is + itself in the trusted list (known load balancer). Use the rightmost hop, + which is the address appended by that proxy — not the leftmost hop, which + can be client-supplied and spoofed. + """ + peer = _peer_ip(request) + if peer is None: + return None + + forwarded = request.headers.get("x-forwarded-for") + if not forwarded: + return peer + + if not _ip_in_trusted(peer, settings.OPERATIONAL_TRUSTED_IPS): + return None + + hops = [hop.strip() for hop in forwarded.split(",") if hop.strip()] + if not hops: + return peer + return hops[-1] + + def _has_authenticated_caller(request: Request) -> bool: db = SessionLocal() try: @@ -68,31 +91,13 @@ def _has_authenticated_caller(request: Request) -> bool: db.close() -def _is_health_probe(request: Request) -> bool: - user_agent = request.headers.get("user-agent", "") - if any(marker in user_agent for marker in _PROBE_USER_AGENTS): - return True - - if request.headers.get("x-forwarded-for"): - return False - - direct_ip = request.client.host if request.client else None - if direct_ip and _ip_in_trusted(direct_ip, settings.OPERATIONAL_TRUSTED_IPS): - return True - return False - - def is_operational_access_allowed(request: Request) -> bool: """Return True when the caller may access a protected operational endpoint.""" if settings.OPERATIONAL_PUBLIC: return True - path = request.url.path - if path == _HEALTH_PATH and _is_health_probe(request): - return True - - client_ip = _client_ip(request) - if client_ip and _ip_in_trusted(client_ip, settings.OPERATIONAL_TRUSTED_IPS): + resolved_ip = _resolved_trusted_ip(request) + if resolved_ip and _ip_in_trusted(resolved_ip, settings.OPERATIONAL_TRUSTED_IPS): return True if _has_authenticated_caller(request): diff --git a/config.yml.example b/config.yml.example index 7cf5e492..3d763a9d 100644 --- a/config.yml.example +++ b/config.yml.example @@ -13,8 +13,9 @@ server: port: 8000 # Operational endpoints (/health, /metrics) -# Keep public: false in production/sandbox. ALB health checks use the -# ELB-HealthChecker user-agent; Prometheus scrapers should use trusted_ips. +# Keep public: false in production/sandbox. ALB health checks connect +# directly from VPC addresses (no X-Forwarded-For). Include your VPC/LB CIDRs +# in trusted_ips so probes and internal scrapers are allowed. operational: public: false trusted_ips: diff --git a/docs-fumadocs/content/docs/getting-started/authentication.mdx b/docs-fumadocs/content/docs/getting-started/authentication.mdx index f45d45e9..f0c90605 100644 --- a/docs-fumadocs/content/docs/getting-started/authentication.mdx +++ b/docs-fumadocs/content/docs/getting-started/authentication.mdx @@ -136,10 +136,10 @@ Rules the UI enforces: are not served on the public hostname. - Keep `operational.public: false` so `/health` and `/metrics` return **404** to anonymous public clients (including vulnerability scanners hitting your ALB - hostname). AWS ALB target health checks still succeed via the - `ELB-HealthChecker` user-agent. Add your VPC CIDRs to - `operational.trusted_ips` for internal Prometheus scrapers. Full migration - diagnostics are available at `GET /health/detail` for org admins. + hostname). AWS ALB target health checks connect directly from VPC addresses + (no `X-Forwarded-For`); include your VPC/LB CIDRs in + `operational.trusted_ips`. Full migration diagnostics are available at + `GET /health/detail` for org admins. --- diff --git a/tests/test_core/test_operational_endpoints.py b/tests/test_core/test_operational_endpoints.py index 911aa68b..4fd116f6 100644 --- a/tests/test_core/test_operational_endpoints.py +++ b/tests/test_core/test_operational_endpoints.py @@ -9,11 +9,34 @@ from fastapi import APIRouter, FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient +from starlette.requests import Request from app.config import settings from app.core.health import build_health_status from app.core.migration_middleware import MigrationCheckMiddleware -from app.core.operational_access_middleware import OperationalAccessMiddleware +from app.core.operational_access_middleware import ( + OperationalAccessMiddleware, + is_operational_access_allowed, +) + + +def _request( + *, + client_host: str, + headers: list[tuple[bytes, bytes]] | None = None, + path: str = "/health", +) -> Request: + scope = { + "type": "http", + "http_version": "1.1", + "method": "GET", + "path": path, + "headers": headers or [], + "client": (client_host, 12345), + "scheme": "http", + "server": ("testserver", 80), + } + return Request(scope) @pytest.fixture @@ -52,24 +75,61 @@ def test_public_health_via_alb_is_blocked(operational_client): assert response.status_code == 404 -def test_alb_health_probe_is_allowed(operational_client): - response = operational_client.get( - "/health", - headers={"User-Agent": "ELB-HealthChecker/2.0"}, +def test_spoofed_user_agent_does_not_bypass_gate(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request( + client_host="203.0.113.1", + headers=[(b"user-agent", b"ELB-HealthChecker/2.0")], ) - assert response.status_code == 200 - assert response.json() == {"status": "healthy"} + assert is_operational_access_allowed(request) is False -def test_health_allowed_for_trusted_client_ip(operational_client): - response = operational_client.get( - "/health", - headers={"X-Forwarded-For": "10.1.2.3", "User-Agent": "Mozilla/5.0"}, +def test_direct_trusted_peer_without_xff_allowed(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request(client_host="10.1.2.3") + + assert is_operational_access_allowed(request) is True + + +def test_health_allowed_for_trusted_client_ip_via_proxy(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request( + client_host="10.0.0.50", + headers=[(b"x-forwarded-for", b"10.1.2.3")], ) - assert response.status_code == 200 - assert set(response.json().keys()) == {"status"} + assert is_operational_access_allowed(request) is True + + +def test_public_client_via_trusted_proxy_is_blocked(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request( + client_host="10.0.0.50", + headers=[(b"x-forwarded-for", b"203.0.113.1")], + ) + + assert is_operational_access_allowed(request) is False + + +def test_spoofed_leftmost_xff_without_trusted_peer_is_blocked(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request( + client_host="203.0.113.1", + headers=[(b"x-forwarded-for", b"10.0.0.1")], + ) + + assert is_operational_access_allowed(request) is False def test_metrics_blocked_for_public_client(operational_client): @@ -81,14 +141,17 @@ def test_metrics_blocked_for_public_client(operational_client): assert response.status_code == 404 -def test_metrics_allowed_for_trusted_client_ip(operational_client): - response = operational_client.get( - "/metrics", - headers={"X-Forwarded-For": "10.1.2.3"}, +def test_metrics_allowed_for_trusted_client_ip(operational_client, monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", ["10.0.0.0/8"]) + + request = _request( + client_host="10.0.0.50", + path="/metrics", + headers=[(b"x-forwarded-for", b"10.1.2.3")], ) - assert response.status_code == 200 - assert response.json() == "metrics" + assert is_operational_access_allowed(request) is True def test_operational_public_allows_anonymous_health(monkeypatch): From ed83afe50c27a9e0b896dce51fdc72bc9d5b0123 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Fri, 26 Jun 2026 10:55:18 +0000 Subject: [PATCH 05/11] feat: updating flex price features --- app/api/v1/routes/auth.py | 10 +- app/api/v1/routes/call_import_evaluations.py | 14 +- app/api/v1/routes/call_imports.py | 48 +- app/api/v1/routes/chat.py | 17 +- app/api/v1/routes/evaluations.py | 13 +- app/api/v1/routes/evaluators.py | 16 +- app/api/v1/routes/judge_alignment.py | 29 +- app/api/v1/routes/metrics.py | 25 +- app/api/v1/routes/observability.py | 21 +- app/api/v1/routes/playground.py | 41 +- app/api/v1/routes/prompt_optimization.py | 16 +- app/api/v1/routes/public_blind_test.py | 15 +- app/api/v1/routes/test_agents.py | 24 +- app/api/v1/routes/voice_playground.py | 33 +- app/config.py | 14 + app/core/auth/oidc_common.py | 10 +- app/services/billing/__init__.py | 1 + app/services/billing/flexprice_service.py | 761 ++++++++++++++++++ app/services/evaluation/evaluation_service.py | 8 + app/services/organization_provisioning.py | 11 + app/workers/tasks/evaluate_call_import_row.py | 14 + app/workers/tasks/process_call_import_row.py | 9 + app/workers/tasks/process_evaluator_result.py | 21 + app/workers/tasks/run_evaluator.py | 9 + app/workers/tasks/run_judge_alignment.py | 12 + app/workers/tasks/run_prompt_optimization.py | 12 + app/workers/tasks/tts_comparison.py | 12 + app/workers/tasks/tts_report.py | 9 + config.yml.example | 8 + pyproject.toml | 1 + scripts/create_api_key.py | 9 +- .../test_api/test_voice_playground_routes.py | 96 +++ tests/test_services/test_flexprice_service.py | 229 ++++++ 33 files changed, 1543 insertions(+), 25 deletions(-) create mode 100644 app/services/billing/__init__.py create mode 100644 app/services/billing/flexprice_service.py create mode 100644 tests/test_services/test_flexprice_service.py diff --git a/app/api/v1/routes/auth.py b/app/api/v1/routes/auth.py index b813e33f..97dbc838 100644 --- a/app/api/v1/routes/auth.py +++ b/app/api/v1/routes/auth.py @@ -50,7 +50,10 @@ RoleEnum, User, ) -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) router = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -324,6 +327,11 @@ def signup(payload: SignupRequest, db: Session = Depends(get_db)) -> TokenRespon organization_id=organization.id, created_by_user_id=user.id, ) + provision_billing_customer( + organization_id=organization.id, + name=org_name, + email=payload.email, + ) user.last_login_at = datetime.now(timezone.utc) db.commit() db.refresh(user) diff --git a/app/api/v1/routes/call_import_evaluations.py b/app/api/v1/routes/call_import_evaluations.py index 33562c96..5e3b9504 100644 --- a/app/api/v1/routes/call_import_evaluations.py +++ b/app/api/v1/routes/call_import_evaluations.py @@ -15,7 +15,7 @@ from datetime import date, datetime, timedelta, timezone -from fastapi import APIRouter, Body, Depends, HTTPException, Query, Response, status +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Query, Response, status from fastapi.responses import StreamingResponse from loguru import logger from pydantic import BaseModel, Field, field_validator @@ -32,6 +32,7 @@ get_workspace_id, require_enterprise_feature, ) +from app.services.billing.flexprice_service import record_call_import_evaluation_started from app.services.workspace_rbac import resolve_workspace_capabilities from app.models.database import ( AIProvider, @@ -557,6 +558,7 @@ def _rollup_evaluation_status(evaluation: CallImportEvaluation, db: Session) -> async def create_call_import_evaluation( call_import_id: UUID, payload: CallImportEvaluationCreate, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), db: Session = Depends(get_db), @@ -1090,6 +1092,16 @@ def _name_for_source(source: str) -> Optional[str]: db.commit() for evaluation in created_evaluations: db.refresh(evaluation) + if evaluation.status == "running": + background_tasks.add_task( + record_call_import_evaluation_started, + organization_id, + evaluation.id, + workspace_id=call_import.workspace_id, + call_import_id=call_import.id, + row_count=evaluation.total_rows, + metric_count=len(leaf_metric_ids), + ) return _serialize_eval( db, primary_evaluation, sibling_evaluation_ids=sibling_ids ) diff --git a/app/api/v1/routes/call_imports.py b/app/api/v1/routes/call_imports.py index 8be373a8..538ae494 100644 --- a/app/api/v1/routes/call_imports.py +++ b/app/api/v1/routes/call_imports.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from uuid import UUID -from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, Response, UploadFile, status +from fastapi import APIRouter, Body, BackgroundTasks, Depends, File, Form, HTTPException, Query, Response, UploadFile, status from loguru import logger from sqlalchemy import desc, func, or_ from sqlalchemy.orm import Session @@ -32,6 +32,10 @@ get_workspace_id, require_enterprise_feature, ) +from app.services.billing.flexprice_service import ( + record_call_import_batch_created, + record_call_import_row_imported, +) from app.models.database import ( CallImport, CallImportRow, @@ -1556,6 +1560,7 @@ async def update_call_import_mapping( async def start_call_import( call_import_id: UUID, payload: CallImportStartRequest, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -1682,6 +1687,16 @@ async def start_call_import( db.commit() db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="csv", + provider=call_import.provider, + ) + _enqueue_row_tasks(db, call_import, row_models) return CallImportUploadResponse( @@ -1705,6 +1720,7 @@ async def start_call_import( deprecated=True, ) async def upload_call_import_csv( + background_tasks: BackgroundTasks, file: UploadFile = File(...), provider: Optional[str] = Form( None, @@ -1886,6 +1902,16 @@ async def upload_call_import_csv( db.commit() db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="csv", + provider=call_import.provider, + ) + _enqueue_row_tasks(db, call_import, row_models) return CallImportUploadResponse( @@ -1908,6 +1934,7 @@ async def upload_call_import_csv( operation_id="uploadCallImportAudio", ) async def upload_call_import_audio( + background_tasks: BackgroundTasks, files: List[UploadFile] = File( ..., description="One or more manual call recording audio files.", @@ -1999,6 +2026,7 @@ async def upload_call_import_audio( ) total_size = sum(len(item["contents"]) for item in prepared) uploaded_keys: List[str] = [] + imported_row_ids: List[UUID] = [] from app.services.storage.s3_service import s3_service @@ -2058,6 +2086,7 @@ async def upload_call_import_audio( row.recording_s3_key = key row.recording_content_type = item["content_type"] row.recording_size_bytes = len(item["contents"]) + imported_row_ids.append(row.id) db.commit() except Exception as exc: @@ -2076,6 +2105,23 @@ async def upload_call_import_audio( ) from exc db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="audio", + provider=None, + ) + for row_id in imported_row_ids: + background_tasks.add_task( + record_call_import_row_imported, + organization_id, + row_id, + workspace_id=workspace_id, + call_import_id=call_import.id, + ) return CallImportUploadResponse( id=call_import.id, total_rows=call_import.total_rows, diff --git a/app/api/v1/routes/chat.py b/app/api/v1/routes/chat.py index b25d731e..a8317844 100644 --- a/app/api/v1/routes/chat.py +++ b/app/api/v1/routes/chat.py @@ -2,14 +2,15 @@ Chat/Inference API Routes For generating responses from AI models """ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from sqlalchemy.orm import Session from typing import List, Dict, Optional, Any -from uuid import UUID +from uuid import UUID, uuid4 from pydantic import BaseModel -from app.dependencies import get_db, get_organization_id +from app.dependencies import get_db, get_organization_id, get_workspace_id from app.services.ai.llm_service import llm_service +from app.services.billing.flexprice_service import record_chat_completion from app.models.schemas import ModelProvider router = APIRouter(prefix="/chat", tags=["chat"]) @@ -39,7 +40,9 @@ class ChatResponse(BaseModel): @router.post("/completion", response_model=ChatResponse) async def chat_completion( request: ChatRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) ): """Generate a chat completion using the specified AI provider and model.""" @@ -59,6 +62,14 @@ async def chat_completion( task_defaults={"temperature": 0.7}, ) + background_tasks.add_task( + record_chat_completion, + organization_id, + uuid4(), + workspace_id=workspace_id, + model=result.get("model", request.model), + ) + return ChatResponse( text=result.get("text", ""), model=result.get("model", request.model), diff --git a/app/api/v1/routes/evaluations.py b/app/api/v1/routes/evaluations.py index 840d3803..10b5ea3d 100644 --- a/app/api/v1/routes/evaluations.py +++ b/app/api/v1/routes/evaluations.py @@ -1,6 +1,6 @@ """Evaluation routes.""" -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from sqlalchemy.orm import Session from typing import List from uuid import UUID @@ -13,6 +13,7 @@ EvaluationStatusResponse, MessageResponse, ) +from app.services.billing.flexprice_service import record_evaluation_created from app.workers.celery_app import process_evaluation_task from app.core.exceptions import EvaluationNotFoundError @@ -22,6 +23,7 @@ @router.post("/create", response_model=EvaluationResponse, status_code=201) def create_evaluation( evaluation_data: EvaluationCreate, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -55,6 +57,15 @@ def create_evaluation( db.commit() db.refresh(evaluation) + background_tasks.add_task( + record_evaluation_created, + organization_id, + evaluation.id, + workspace_id=workspace_id, + audio_id=evaluation.audio_id, + metrics_requested=len(evaluation.metrics_requested or []), + ) + # Queue async task process_evaluation_task.delay(str(evaluation.id)) diff --git a/app/api/v1/routes/evaluators.py b/app/api/v1/routes/evaluators.py index 610ba16d..0e13b1f8 100644 --- a/app/api/v1/routes/evaluators.py +++ b/app/api/v1/routes/evaluators.py @@ -1,10 +1,10 @@ """Evaluator routes.""" -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status, Query from fastapi.responses import JSONResponse, Response from sqlalchemy.orm import Session from sqlalchemy import and_ -from uuid import UUID +from uuid import UUID, uuid4 import random from typing import List from pydantic import BaseModel @@ -12,6 +12,7 @@ from app.database import get_db from app.dependencies import get_organization_id, get_workspace_id, get_api_key +from app.services.billing.flexprice_service import record_evaluator_run_requested from app.models.database import Evaluator, Agent, Persona, Scenario, EvaluatorResult, EvaluatorResultStatus, VoiceBundle, Metric from app.models.schemas import ( EvaluatorCreate, @@ -548,6 +549,7 @@ def delete_evaluator( @router.post("/run", response_model=RunEvaluatorsResponse, status_code=200) def run_evaluators( request: RunEvaluatorsRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), @@ -649,7 +651,15 @@ def run_evaluators( if not task_ids: raise HTTPException(status_code=500, detail="Failed to create any tasks") - + + background_tasks.add_task( + record_evaluator_run_requested, + organization_id, + uuid4(), + workspace_id=workspace_id, + quantity=len(request.evaluator_ids), + ) + return RunEvaluatorsResponse( task_ids=task_ids, evaluator_results=evaluator_results diff --git a/app/api/v1/routes/judge_alignment.py b/app/api/v1/routes/judge_alignment.py index 9fab1ee6..1bb96006 100644 --- a/app/api/v1/routes/judge_alignment.py +++ b/app/api/v1/routes/judge_alignment.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile, status from loguru import logger from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.orm import Session @@ -29,6 +29,7 @@ get_workspace_id, Principal, ) +from app.services.billing.flexprice_service import record_judge_alignment_run_started from app.models.database import ( Agent, Evaluator, @@ -617,6 +618,7 @@ def list_runs( def trigger_judge_run( dataset_id: UUID, body: JudgeRunCreate, + background_tasks: BackgroundTasks, principal: Principal = Depends(get_principal), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), @@ -662,6 +664,19 @@ def trigger_judge_run( ), ) + from app.services.judge_alignment.judge_runner import select_samples_for_split + + sample_id_strs = ( + [str(s) for s in (body.sample_ids or [])] if body.split != "all" else None + ) + selected_samples = select_samples_for_split( + dataset_id=dataset.id, + split=body.split, + db=db, + sample_ids=sample_id_strs, + ) + sample_count = len(selected_samples) + judge_run = JudgeRun( dataset_id=dataset.id, organization_id=organization_id, @@ -679,10 +694,6 @@ def trigger_judge_run( db.commit() db.refresh(judge_run) - sample_id_strs = ( - [str(s) for s in (body.sample_ids or [])] if body.split != "all" else None - ) - try: from app.workers.tasks.run_judge_alignment import run_judge_alignment_task @@ -692,6 +703,14 @@ def trigger_judge_run( judge_run.celery_task_id = async_result.id db.commit() db.refresh(judge_run) + background_tasks.add_task( + record_judge_alignment_run_started, + organization_id, + judge_run.id, + workspace_id=workspace_id, + dataset_id=dataset.id, + sample_count=sample_count, + ) except Exception as exc: logger.error(f"[JudgeAlignment] Failed to enqueue judge task: {exc}") judge_run.status = "failed" diff --git a/app/api/v1/routes/metrics.py b/app/api/v1/routes/metrics.py index 8e7be9aa..f1aea458 100644 --- a/app/api/v1/routes/metrics.py +++ b/app/api/v1/routes/metrics.py @@ -2,16 +2,17 @@ import json import re -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from sqlalchemy import and_, or_ -from uuid import UUID +from uuid import UUID, uuid4 from typing import List, Optional, Dict, Any, Literal from pydantic import BaseModel, Field from loguru import logger from app.database import get_db from app.dependencies import get_organization_id, get_api_key, get_workspace_id +from app.services.billing.flexprice_service import record_metrics_llm_assist from app.models.database import Metric, MetricCategory, MetricType, MetricTrigger, ModelProvider from app.models.schemas import ( MetricCreate, @@ -1539,7 +1540,9 @@ def _parse_metric_generation_response(text: str) -> Dict[str, Any]: @router.post("/generate", response_model=MetricGenerateResponse) def generate_metric( req: MetricGenerateRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), ): """Use an LLM to suggest a metric definition. Does NOT persist anything.""" @@ -1631,6 +1634,14 @@ def generate_metric( suffix += 1 name = f"{name} ({suffix})" + background_tasks.add_task( + record_metrics_llm_assist, + organization_id, + uuid4(), + workspace_id=workspace_id, + mode=req.mode, + ) + return MetricGenerateResponse( name=name, description=(parsed.get("description") or "").strip()[:1000], @@ -1953,7 +1964,9 @@ def _taken(name: str) -> bool: @router.post("/parse-bulk", response_model=MetricParseBulkResponse) def parse_bulk_metric( req: MetricParseBulkRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), ): """Parse a multi-label rubric into a *list* of independent metric drafts. @@ -2077,6 +2090,14 @@ def parse_bulk_metric( ) ) + background_tasks.add_task( + record_metrics_llm_assist, + organization_id, + uuid4(), + workspace_id=workspace_id, + mode="parse-bulk", + ) + return MetricParseBulkResponse(metrics=drafts, parent=parent_payload) diff --git a/app/api/v1/routes/observability.py b/app/api/v1/routes/observability.py index 61770404..94ac5d0a 100644 --- a/app/api/v1/routes/observability.py +++ b/app/api/v1/routes/observability.py @@ -5,11 +5,15 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from pydantic import BaseModel, ConfigDict from sqlalchemy.orm import Session from app.dependencies import get_api_key, get_db, get_organization_id, get_workspace_id +from app.services.billing.flexprice_service import ( + record_observability_call_evaluated, + record_observability_call_ingested, +) from app.models.database import ( Agent, APIKey, CallRecording, CallRecordingStatus, CallRecordingSource, Evaluator, EvaluatorResult, EvaluatorResultStatus, Scenario, Workspace, @@ -188,6 +192,13 @@ def _upsert_call_recording( response = _serialize_call_recording(call_recording, include_data=True) response["action"] = action + if action == "created": + record_observability_call_ingested( + organization_id, + call_recording.call_short_id, + workspace_id=workspace_id, + provider=provider_platform, + ) return response @@ -466,6 +477,7 @@ def _messages_to_speaker_segments(messages: List[Dict[str, Any]]) -> List[Dict[s async def evaluate_call( call_short_id: str, payload: EvaluateCallPayload, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -575,6 +587,13 @@ async def evaluate_call( except Exception: pass + background_tasks.add_task( + record_observability_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + ) + return { "evaluator_result_id": str(evaluator_result.id), "result_id": evaluator_result.result_id, diff --git a/app/api/v1/routes/playground.py b/app/api/v1/routes/playground.py index 02651ab1..ee39eec8 100644 --- a/app/api/v1/routes/playground.py +++ b/app/api/v1/routes/playground.py @@ -14,7 +14,13 @@ import uuid as _uuid from datetime import datetime -from app.dependencies import get_db, get_organization_id, get_workspace_id, get_api_key +from app.services.billing.flexprice_service import ( + record_playground_call_evaluated, + record_playground_web_call_started, + record_playground_websocket_session_started, +) +from app.database import get_db +from app.dependencies import get_organization_id, get_workspace_id, get_api_key from app.models.database import ( Agent, Integration, @@ -671,6 +677,14 @@ async def create_web_call( db.add(call_recording) db.commit() db.refresh(call_recording) + + background_tasks.add_task( + record_playground_web_call_started, + organization_id, + call_short_id, + workspace_id=workspace_id, + agent_id=agent.id, + ) # Start background task to poll for call metrics # Note: We need to pass the decrypted API key, but we should be careful with security @@ -777,6 +791,7 @@ async def list_call_recordings( @router.post("/custom-websocket-sessions", response_model=Dict[str, Any]) async def create_custom_websocket_session( + background_tasks: BackgroundTasks, agent_id: str = Form(...), websocket_url: str = Form(...), transcript_entries: str = Form("[]"), @@ -902,6 +917,13 @@ async def create_custom_websocket_session( db.commit() db.refresh(call_recording) + background_tasks.add_task( + record_playground_websocket_session_started, + organization_id, + call_short_id, + workspace_id=workspace_id, + ) + return { "message": "Custom websocket session saved", "call_short_id": call_short_id, @@ -913,6 +935,7 @@ async def create_custom_websocket_session( @router.post("/custom-websocket-sessions/{call_short_id}/evaluate", response_model=Dict[str, Any]) async def evaluate_custom_websocket_session( call_short_id: str, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -999,6 +1022,14 @@ async def evaluate_custom_websocket_session( logger.error(f"[Custom WebSocket] Failed to trigger evaluation worker: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to trigger evaluation worker") + background_tasks.add_task( + record_playground_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + metric_count=0, + ) + return { "message": "Evaluation queued", "evaluator_result_id": str(evaluator_result.id), @@ -1367,6 +1398,14 @@ def _download_audio_from_payload(payload: Dict[str, Any]): logger.error(f"[Re-evaluate] Failed to trigger Celery task: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to trigger evaluation worker") + background_tasks.add_task( + record_playground_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + metric_count=0, + ) + return { "message": "Re-evaluation started", "evaluator_result_id": str(evaluator_result.id), diff --git a/app/api/v1/routes/prompt_optimization.py b/app/api/v1/routes/prompt_optimization.py index f0805797..cae668e2 100644 --- a/app/api/v1/routes/prompt_optimization.py +++ b/app/api/v1/routes/prompt_optimization.py @@ -9,7 +9,7 @@ from typing import Dict, Any, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.orm import Session from loguru import logger @@ -21,6 +21,7 @@ get_api_key, require_enterprise_feature, ) +from app.services.billing.flexprice_service import record_prompt_optimization_run_started from app.models.database import ( Agent, Evaluator, @@ -99,6 +100,7 @@ class CandidateResponse(BaseModel): @router.post("/runs", response_model=OptimizationRunResponse, status_code=201) def create_optimization_run( data: OptimizationRunCreate, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -156,6 +158,18 @@ def create_optimization_run( run.celery_task_id = task.id db.commit() + config = run.config if isinstance(run.config, dict) else {} + max_metric_calls = config.get("max_metric_calls") + + background_tasks.add_task( + record_prompt_optimization_run_started, + organization_id, + run.id, + workspace_id=workspace_id, + agent_id=agent.id, + max_metric_calls=max_metric_calls, + ) + logger.info(f"[GEPA] Created optimization run {run.id} for agent {agent.name}") return run diff --git a/app/api/v1/routes/public_blind_test.py b/app/api/v1/routes/public_blind_test.py index 07982c7e..79bdef77 100644 --- a/app/api/v1/routes/public_blind_test.py +++ b/app/api/v1/routes/public_blind_test.py @@ -25,7 +25,7 @@ import time from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from pydantic import BaseModel from sqlalchemy.orm import Session @@ -40,6 +40,7 @@ TTSSampleStatus, ) from app.services.storage.s3_service import s3_service +from app.services.billing.flexprice_service import record_blind_test_response_submitted router = APIRouter( @@ -225,6 +226,7 @@ async def submit_public_blind_test_response( share_token: str, data: PublicBlindResponseSubmit, request: Request, + background_tasks: BackgroundTasks, db: Session = Depends(get_db), ): """Accept a rater's submission and merge into the comparison summary.""" @@ -351,8 +353,19 @@ def _clean_ratings(raw: Dict[str, float]) -> Dict[str, float]: detail="This email has already submitted a response for this blind test.", ) + db.refresh(record) + # Re-aggregate into the comparison's evaluation_summary from app.api.v1.routes.voice_playground import _recompute_summary _recompute_summary(comparison, db) + background_tasks.add_task( + record_blind_test_response_submitted, + share.organization_id, + record.id, + share_id=share.id, + workspace_id=share.workspace_id, + response_count=len(cleaned_entries), + ) + return {"message": "Thanks for your response!"} diff --git a/app/api/v1/routes/test_agents.py b/app/api/v1/routes/test_agents.py index 0c402156..f6b6ea1b 100644 --- a/app/api/v1/routes/test_agents.py +++ b/app/api/v1/routes/test_agents.py @@ -3,7 +3,7 @@ API endpoints for managing test agent conversations. """ -from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status, UploadFile, File from sqlalchemy.orm import Session from uuid import UUID from typing import List, Optional @@ -16,6 +16,10 @@ TestAgentConversationUpdate, TestAgentConversationResponse ) +from app.services.billing.flexprice_service import ( + record_test_agent_conversation_ended, + record_test_agent_conversation_started, +) from app.services.testing.test_agent_service import test_agent_service router = APIRouter(prefix="/test-agents", tags=["test-agents"]) @@ -86,6 +90,7 @@ async def get_conversation( @router.post("/conversations/{conversation_id}/start", response_model=TestAgentConversationResponse, operation_id="startTestAgentConversation") async def start_conversation( conversation_id: UUID, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) @@ -97,6 +102,12 @@ async def start_conversation( organization_id=organization_id, db=db ) + background_tasks.add_task( + record_test_agent_conversation_started, + organization_id, + conversation_id, + workspace_id=workspace_id, + ) return conversation except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -209,8 +220,10 @@ async def get_response_audio( @router.post("/conversations/{conversation_id}/end", response_model=TestAgentConversationResponse, operation_id="endTestAgentConversation") async def end_conversation( conversation_id: UUID, + background_tasks: BackgroundTasks, final_audio_key: Optional[str] = None, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) ): """End a test agent conversation.""" @@ -221,6 +234,15 @@ async def end_conversation( db=db, final_audio_key=final_audio_key ) + turn_count = len(conversation.live_transcription or []) + background_tasks.add_task( + record_test_agent_conversation_ended, + organization_id, + conversation_id, + workspace_id=workspace_id, + duration_seconds=conversation.duration_seconds, + turn_count=turn_count, + ) return conversation except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/app/api/v1/routes/voice_playground.py b/app/api/v1/routes/voice_playground.py index 85d79baa..8cc4519d 100644 --- a/app/api/v1/routes/voice_playground.py +++ b/app/api/v1/routes/voice_playground.py @@ -8,7 +8,7 @@ import uuid as _uuid -from fastapi import APIRouter, Depends, HTTPException, Response, UploadFile, File, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Response, UploadFile, File, status from sqlalchemy.orm import Session from sqlalchemy import func from typing import Dict, Any, List, Literal, Optional @@ -17,6 +17,11 @@ from loguru import logger from app.dependencies import get_db, get_organization_id, get_workspace_id, get_api_key, require_enterprise_feature +from app.services.billing.flexprice_service import ( + record_blind_test_share_created, + record_tts_generation_started, + record_tts_report_requested, +) from app.models.database import ( AIProvider, CallImport, @@ -1371,6 +1376,7 @@ async def get_comparison( @router.post("/comparisons/{comparison_id}/generate", operation_id="generateTTSComparison") async def generate_comparison( comparison_id: UUID, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -1432,6 +1438,13 @@ async def generate_comparison( comparison.celery_task_id = task.id db.commit() logger.info(f"[VoicePlayground] Dispatched generation task {task.id} for comparison {comparison.id}") + background_tasks.add_task( + record_tts_generation_started, + organization_id, + comparison.id, + workspace_id=workspace_id, + sample_count=pending_tts_samples, + ) except Exception as e: comparison.status = TTSComparisonStatus.FAILED.value comparison.error_message = str(e) @@ -1468,6 +1481,7 @@ async def submit_blind_test( async def create_blind_test_share( comparison_id: UUID, data: BlindTestShareCreate, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -1540,6 +1554,15 @@ async def create_blind_test_share( db.commit() db.refresh(share) + if not existing: + background_tasks.add_task( + record_blind_test_share_created, + organization_id, + share.id, + workspace_id=workspace_id, + comparison_id=comparison.id, + ) + _recompute_summary(comparison, db) return _serialize_share(share, db, include_aggregates=True) @@ -1891,6 +1914,7 @@ async def download_tts_comparison_report( @router.post("/comparisons/{comparison_id}/reports", operation_id="createTTSComparisonReportJob") async def create_tts_comparison_report_job( comparison_id: UUID, + background_tasks: BackgroundTasks, data: Optional[TTSReportJobCreate] = None, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -1924,6 +1948,13 @@ async def create_tts_comparison_report_job( task = generate_tts_report_pdf_task.delay(str(report_job.id), options_dict) report_job.celery_task_id = task.id db.commit() + background_tasks.add_task( + record_tts_report_requested, + organization_id, + report_job.id, + workspace_id=workspace_id, + comparison_id=comparison.id, + ) except Exception as e: report_job.status = TTSReportJobStatus.FAILED.value report_job.error_message = f"Failed to queue report task: {str(e)}" diff --git a/app/config.py b/app/config.py index 2c6e401b..191da0e6 100644 --- a/app/config.py +++ b/app/config.py @@ -155,6 +155,11 @@ class Settings(BaseSettings): JUDGE_ALIGNMENT_ENABLED: bool = True JUDGE_ALIGNMENT_CSV_MAX_ROWS: int = 5000 + # Flexprice usage-based billing (optional; disabled when unset) + FLEXPRICE_ENABLED: bool = False + FLEXPRICE_API_KEY: Optional[str] = None + FLEXPRICE_API_HOST: str = "https://us.api.flexprice.io/v1" + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", @@ -574,6 +579,15 @@ def load_config_from_file(config_path: str) -> None: if "trusted_ips" in operational_config: settings.OPERATIONAL_TRUSTED_IPS = operational_config["trusted_ips"] + if "flexprice" in config_data: + flexprice_config = config_data["flexprice"] + if "enabled" in flexprice_config: + settings.FLEXPRICE_ENABLED = bool(flexprice_config["enabled"]) + if flexprice_config.get("api_key"): + settings.FLEXPRICE_API_KEY = flexprice_config["api_key"] + if flexprice_config.get("api_host"): + settings.FLEXPRICE_API_HOST = flexprice_config["api_host"] + # Update Celery URLs if they weren't explicitly set if not settings.CELERY_BROKER_URL: settings.CELERY_BROKER_URL = settings.REDIS_URL diff --git a/app/core/auth/oidc_common.py b/app/core/auth/oidc_common.py index a455db23..6c2542c7 100644 --- a/app/core/auth/oidc_common.py +++ b/app/core/auth/oidc_common.py @@ -24,7 +24,10 @@ from app.core.auth.providers import AuthError from app.models.database import Organization, OrganizationMember, User from app.models.enums import RoleEnum -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) # -- JWKS cache --------------------------------------------------------------- @@ -196,6 +199,11 @@ def upsert_user_and_membership( organization_id=organization.id, created_by_user_id=user.id, ) + provision_billing_customer( + organization_id=organization.id, + name=organization_name, + email=email, + ) logger.info(f"Provisioned new organization '{organization_name}' for {email}") member = OrganizationMember( diff --git a/app/services/billing/__init__.py b/app/services/billing/__init__.py new file mode 100644 index 00000000..dd930582 --- /dev/null +++ b/app/services/billing/__init__.py @@ -0,0 +1 @@ +"""Billing and usage metering integrations.""" diff --git a/app/services/billing/flexprice_service.py b/app/services/billing/flexprice_service.py new file mode 100644 index 00000000..4091fdda --- /dev/null +++ b/app/services/billing/flexprice_service.py @@ -0,0 +1,761 @@ +"""Flexprice usage metering (optional; no-op when disabled). + +Event naming: ``product.action`` snake_case (e.g. ``call_import.batch_created``). +Every event uses ``external_customer_id=str(organization.id)`` and a stable +``event_id`` for idempotency. ``properties`` should include ``workspace_id`` and +``feature`` (license key) when the surface is gated. +""" + +from __future__ import annotations + +from typing import Any, Optional, Union +from uuid import UUID + +from loguru import logger + +from app.config import settings + +EVENT_SOURCE = "efficientai" +FEATURE_CALL_IMPORTS = "call_imports" +FEATURE_VOICE_PLAYGROUND = "voice_playground" +FEATURE_GEPA = "gepa_optimization" + +# Event names +BLIND_TEST_SHARE_CREATED = "blind_test.share_created" +BLIND_TEST_RESPONSE_SUBMITTED = "blind_test.response_submitted" +TTS_GENERATION_STARTED = "tts.generation_started" +TTS_SAMPLE_SYNTHESIZED = "tts.sample_synthesized" +TTS_REPORT_REQUESTED = "tts.report_requested" +TTS_REPORT_COMPLETED = "tts.report_completed" +CALL_IMPORT_BATCH_CREATED = "call_import.batch_created" +CALL_IMPORT_ROW_IMPORTED = "call_import.row_imported" +CALL_IMPORT_EVALUATION_STARTED = "call_import.evaluation_started" +CALL_IMPORT_EVALUATION_ROW_COMPLETED = "call_import.evaluation_row_completed" +PLAYGROUND_WEB_CALL_STARTED = "playground.web_call_started" +PLAYGROUND_WEBSOCKET_SESSION_STARTED = "playground.websocket_session_started" +PLAYGROUND_CALL_EVALUATED = "playground.call_evaluated" +PLAYGROUND_EVALUATION_COMPLETED = "playground.evaluation_completed" +EVALUATOR_RUN_REQUESTED = "evaluator.run_requested" +EVALUATOR_RUN_COMPLETED = "evaluator.run_completed" +EVALUATION_CREATED = "evaluation.created" +EVALUATION_COMPLETED = "evaluation.completed" +PROMPT_OPTIMIZATION_RUN_STARTED = "prompt_optimization.run_started" +PROMPT_OPTIMIZATION_RUN_COMPLETED = "prompt_optimization.run_completed" +JUDGE_ALIGNMENT_RUN_STARTED = "judge_alignment.run_started" +JUDGE_ALIGNMENT_RUN_COMPLETED = "judge_alignment.run_completed" +OBSERVABILITY_CALL_INGESTED = "observability.call_ingested" +OBSERVABILITY_CALL_EVALUATED = "observability.call_evaluated" +TEST_AGENT_CONVERSATION_STARTED = "test_agent.conversation_started" +TEST_AGENT_CONVERSATION_ENDED = "test_agent.conversation_ended" +METRICS_LLM_ASSIST = "metrics.llm_assist" +CHAT_COMPLETION = "chat.completion" + + +def is_enabled() -> bool: + """Return True only when Flexprice is explicitly enabled with an API key.""" + return bool(settings.FLEXPRICE_ENABLED and settings.FLEXPRICE_API_KEY) + + +def _is_customer_already_exists(exc: Exception) -> bool: + message = str(exc).lower() + if "already exist" in message or "duplicate" in message: + return True + status_code = getattr(exc, "status_code", None) + return status_code == 409 + + +def _ingest_usage_event(client, payload: dict) -> None: + """Call Flexprice event ingest across SDK versions (flat kwargs vs request=).""" + events = client.events + ingest = getattr(events, "ingest_event", None) or getattr(events, "ingest", None) + if ingest is None: + raise AttributeError("Flexprice SDK has no events.ingest_event or events.ingest") + + try: + ingest(**payload) + except TypeError: + ingest(request=payload) + + +def _coerce_properties(properties: Optional[dict[str, Any]]) -> dict[str, Any]: + if not properties: + return {} + out: dict[str, Any] = {} + for key, value in properties.items(): + if value is None: + continue + if isinstance(value, UUID): + out[key] = str(value) + else: + out[key] = value + return out + + +def record_event( + event_name: str, + organization_id: UUID, + event_id: Union[str, UUID], + *, + properties: Optional[dict[str, Any]] = None, +) -> None: + """Ingest a usage event. No-op when Flexprice is disabled; never raises.""" + if not is_enabled(): + return + + try: + from flexprice import Flexprice + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as client: + _ingest_usage_event( + client, + { + "event_name": event_name, + "external_customer_id": str(organization_id), + "event_id": str(event_id), + "source": EVENT_SOURCE, + "properties": _coerce_properties(properties), + }, + ) + except Exception as exc: + logger.warning( + "Flexprice {} ingest failed (event_id={}): {}", + event_name, + event_id, + exc, + ) + + +def ensure_customer( + organization_id: UUID, + *, + name: str, + email: Optional[str] = None, +) -> None: + """Register an organization as a Flexprice customer. No-op when disabled.""" + if not is_enabled(): + return + + try: + from flexprice import Flexprice + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as client: + client.customers.create_customer( + external_id=str(organization_id), + name=name, + email=email, + ) + except Exception as exc: + if _is_customer_already_exists(exc): + return + logger.warning( + "Flexprice ensure_customer failed for org {}: {}", + organization_id, + exc, + ) + + +# --- Voice playground --- + + +def record_blind_test_share_created( + organization_id: UUID, + share_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + BLIND_TEST_SHARE_CREATED, + organization_id, + share_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "share_id": share_id, + "comparison_id": comparison_id, + }, + ) + + +def record_blind_test_response_submitted( + organization_id: UUID, + response_id: UUID, + *, + share_id: UUID, + workspace_id: UUID, + response_count: int, +) -> None: + record_event( + BLIND_TEST_RESPONSE_SUBMITTED, + organization_id, + response_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "share_id": share_id, + "response_count": response_count, + "quantity": response_count, + }, + ) + + +def record_tts_generation_started( + organization_id: UUID, + comparison_id: UUID, + *, + workspace_id: UUID, + sample_count: int, +) -> None: + record_event( + TTS_GENERATION_STARTED, + organization_id, + comparison_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "sample_count": sample_count, + }, + ) + + +def record_tts_sample_synthesized( + organization_id: UUID, + sample_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, + provider: Optional[str] = None, + side: Optional[str] = None, + duration_seconds: Optional[float] = None, +) -> None: + record_event( + TTS_SAMPLE_SYNTHESIZED, + organization_id, + sample_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "sample_id": sample_id, + "provider": provider, + "side": side, + "duration_seconds": duration_seconds, + "quantity": 1, + }, + ) + + +def record_tts_report_requested( + organization_id: UUID, + report_job_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + TTS_REPORT_REQUESTED, + organization_id, + report_job_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "report_job_id": report_job_id, + }, + ) + + +def record_tts_report_completed( + organization_id: UUID, + report_job_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + TTS_REPORT_COMPLETED, + organization_id, + report_job_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "report_job_id": report_job_id, + }, + ) + + +# --- Call imports --- + + +def record_call_import_batch_created( + organization_id: UUID, + call_import_id: UUID, + *, + workspace_id: UUID, + total_rows: int, + source: str, + provider: Optional[str] = None, +) -> None: + record_event( + CALL_IMPORT_BATCH_CREATED, + organization_id, + call_import_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "total_rows": total_rows, + "source": source, + "provider": provider, + }, + ) + + +def record_call_import_row_imported( + organization_id: UUID, + row_id: UUID, + *, + workspace_id: UUID, + call_import_id: UUID, + duration_seconds: Optional[float] = None, +) -> None: + record_event( + CALL_IMPORT_ROW_IMPORTED, + organization_id, + row_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "row_id": row_id, + "duration_seconds": duration_seconds, + "quantity": 1, + }, + ) + + +def record_call_import_evaluation_started( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, + call_import_id: UUID, + row_count: int, + metric_count: int, +) -> None: + record_event( + CALL_IMPORT_EVALUATION_STARTED, + organization_id, + evaluation_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "evaluation_id": evaluation_id, + "row_count": row_count, + "metric_count": metric_count, + "quantity": row_count * metric_count, + }, + ) + + +def record_call_import_evaluation_row_completed( + organization_id: UUID, + evaluation_id: UUID, + row_id: UUID, + *, + workspace_id: UUID, + call_import_id: UUID, + metrics_scored: int, +) -> None: + record_event( + CALL_IMPORT_EVALUATION_ROW_COMPLETED, + organization_id, + f"{evaluation_id}:{row_id}", + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "evaluation_id": evaluation_id, + "row_id": row_id, + "metrics_scored": metrics_scored, + "quantity": 1, + }, + ) + + +# --- Agent playground --- + + +def record_playground_web_call_started( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + agent_id: UUID, +) -> None: + record_event( + PLAYGROUND_WEB_CALL_STARTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "agent_id": agent_id, + "call_short_id": call_short_id, + }, + ) + + +def record_playground_websocket_session_started( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, +) -> None: + record_event( + PLAYGROUND_WEBSOCKET_SESSION_STARTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + }, + ) + + +def record_playground_call_evaluated( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + metric_count: int, +) -> None: + record_event( + PLAYGROUND_CALL_EVALUATED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "metric_count": metric_count, + }, + ) + + +def record_playground_evaluation_completed( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + duration_seconds: Optional[float] = None, +) -> None: + record_event( + PLAYGROUND_EVALUATION_COMPLETED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "duration_seconds": duration_seconds, + }, + ) + + +# --- Evaluators --- + + +def record_evaluator_run_requested( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: UUID, + quantity: int, +) -> None: + record_event( + EVALUATOR_RUN_REQUESTED, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "quantity": quantity, + }, + ) + + +def record_evaluator_run_completed( + organization_id: UUID, + result_id: str, + *, + workspace_id: UUID, + evaluator_id: UUID, + call_count: int = 1, +) -> None: + record_event( + EVALUATOR_RUN_COMPLETED, + organization_id, + result_id, + properties={ + "workspace_id": workspace_id, + "evaluator_id": evaluator_id, + "result_id": result_id, + "call_count": call_count, + }, + ) + + +# --- Legacy evaluations --- + + +def record_evaluation_created( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, + audio_id: UUID, + metrics_requested: int, +) -> None: + record_event( + EVALUATION_CREATED, + organization_id, + evaluation_id, + properties={ + "workspace_id": workspace_id, + "audio_id": audio_id, + "metrics_requested": metrics_requested, + }, + ) + + +def record_evaluation_completed( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, +) -> None: + record_event( + EVALUATION_COMPLETED, + organization_id, + evaluation_id, + properties={"workspace_id": workspace_id}, + ) + + +# --- Prompt optimization --- + + +def record_prompt_optimization_run_started( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + agent_id: UUID, + max_metric_calls: Optional[int] = None, +) -> None: + record_event( + PROMPT_OPTIMIZATION_RUN_STARTED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_GEPA, + "run_id": run_id, + "agent_id": agent_id, + "max_metric_calls": max_metric_calls, + }, + ) + + +def record_prompt_optimization_run_completed( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + agent_id: UUID, + candidates_count: int = 0, +) -> None: + record_event( + PROMPT_OPTIMIZATION_RUN_COMPLETED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_GEPA, + "run_id": run_id, + "agent_id": agent_id, + "candidates_count": candidates_count, + }, + ) + + +# --- Judge alignment --- + + +def record_judge_alignment_run_started( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + dataset_id: UUID, + sample_count: int, +) -> None: + record_event( + JUDGE_ALIGNMENT_RUN_STARTED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "run_id": run_id, + "dataset_id": dataset_id, + "sample_count": sample_count, + }, + ) + + +def record_judge_alignment_run_completed( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + dataset_id: UUID, + samples_scored: int, +) -> None: + record_event( + JUDGE_ALIGNMENT_RUN_COMPLETED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "run_id": run_id, + "dataset_id": dataset_id, + "samples_scored": samples_scored, + }, + ) + + +# --- Observability --- + + +def record_observability_call_ingested( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + provider: Optional[str] = None, +) -> None: + record_event( + OBSERVABILITY_CALL_INGESTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "provider": provider, + }, + ) + + +def record_observability_call_evaluated( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, +) -> None: + record_event( + OBSERVABILITY_CALL_EVALUATED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + }, + ) + + +# --- Test agents --- + + +def record_test_agent_conversation_started( + organization_id: UUID, + conversation_id: UUID, + *, + workspace_id: UUID, +) -> None: + record_event( + TEST_AGENT_CONVERSATION_STARTED, + organization_id, + conversation_id, + properties={ + "workspace_id": workspace_id, + "conversation_id": conversation_id, + }, + ) + + +def record_test_agent_conversation_ended( + organization_id: UUID, + conversation_id: UUID, + *, + workspace_id: UUID, + duration_seconds: Optional[float] = None, + turn_count: int = 0, +) -> None: + record_event( + TEST_AGENT_CONVERSATION_ENDED, + organization_id, + conversation_id, + properties={ + "workspace_id": workspace_id, + "conversation_id": conversation_id, + "duration_seconds": duration_seconds, + "turn_count": turn_count, + "quantity": duration_seconds or 1, + }, + ) + + +# --- LLM assist --- + + +def record_metrics_llm_assist( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: Optional[UUID], + mode: str, +) -> None: + record_event( + METRICS_LLM_ASSIST, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "mode": mode, + }, + ) + + +def record_chat_completion( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: Optional[UUID], + model: Optional[str] = None, +) -> None: + record_event( + CHAT_COMPLETION, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "model": model, + "quantity": 1, + }, + ) diff --git a/app/services/evaluation/evaluation_service.py b/app/services/evaluation/evaluation_service.py index c5a07ffb..dd3c5760 100644 --- a/app/services/evaluation/evaluation_service.py +++ b/app/services/evaluation/evaluation_service.py @@ -135,6 +135,14 @@ def process_evaluation( evaluation.completed_at = datetime.now(UTC) db.commit() + from app.services.billing.flexprice_service import record_evaluation_completed + + record_evaluation_completed( + evaluation.organization_id, + evaluation.id, + workspace_id=evaluation.workspace_id, + ) + return { "evaluation_id": str(evaluation.id), "status": "completed", diff --git a/app/services/organization_provisioning.py b/app/services/organization_provisioning.py index 54981021..a839ca67 100644 --- a/app/services/organization_provisioning.py +++ b/app/services/organization_provisioning.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from app.models.database import Workspace +from app.services.billing.flexprice_service import ensure_customer from app.services.workspace_rbac import ( backfill_org_workspace_memberships, ensure_creator_workspace_admin, @@ -51,3 +52,13 @@ def provision_default_workspace( ) backfill_org_workspace_memberships(db, organization_id=organization_id) return workspace + + +def provision_billing_customer( + *, + organization_id: UUID, + name: str, + email: str | None = None, +) -> None: + """Register the org with Flexprice when billing is enabled (no-op otherwise).""" + ensure_customer(organization_id, name=name, email=email) diff --git a/app/workers/tasks/evaluate_call_import_row.py b/app/workers/tasks/evaluate_call_import_row.py index 6653e48a..7641feb9 100644 --- a/app/workers/tasks/evaluate_call_import_row.py +++ b/app/workers/tasks/evaluate_call_import_row.py @@ -1059,6 +1059,20 @@ def _resolve_pm( _rollup_parent(db, evaluation) db.commit() + if eval_row.status == "completed": + from app.services.billing.flexprice_service import ( + record_call_import_evaluation_row_completed, + ) + + record_call_import_evaluation_row_completed( + evaluation.organization_id, + evaluation.id, + source_row.id, + workspace_id=evaluation.workspace_id, + call_import_id=evaluation.call_import_id, + metrics_scored=len(eval_row.metric_scores or {}), + ) + return { "status": eval_row.status, "eval_row_id": eval_row_id, diff --git a/app/workers/tasks/process_call_import_row.py b/app/workers/tasks/process_call_import_row.py index c02d2fad..d30490b5 100644 --- a/app/workers/tasks/process_call_import_row.py +++ b/app/workers/tasks/process_call_import_row.py @@ -433,6 +433,15 @@ def process_call_import_row_task(self, row_id: str): _rollup_parent_status(db, call_import) db.commit() + from app.services.billing.flexprice_service import record_call_import_row_imported + + record_call_import_row_imported( + row.organization_id, + row.id, + workspace_id=call_import.workspace_id, + call_import_id=call_import.id, + ) + return { "status": "completed", "row_id": row_id, diff --git a/app/workers/tasks/process_evaluator_result.py b/app/workers/tasks/process_evaluator_result.py index 51036586..71b7c486 100644 --- a/app/workers/tasks/process_evaluator_result.py +++ b/app/workers/tasks/process_evaluator_result.py @@ -591,6 +591,27 @@ def process_evaluator_result_task(self, result_id: str): result.status = EvaluatorResultStatus.COMPLETED.value db.commit() + from app.models.database import CallRecording, CallRecordingSource + from app.services.billing.flexprice_service import ( + record_playground_evaluation_completed, + ) + + call_recording = ( + db.query(CallRecording) + .filter( + CallRecording.evaluator_result_id == result.id, + CallRecording.source == CallRecordingSource.PLAYGROUND, + ) + .first() + ) + if call_recording: + record_playground_evaluation_completed( + result.organization_id, + call_recording.call_short_id, + workspace_id=result.workspace_id, + duration_seconds=result.duration_seconds, + ) + total_time = time.time() - task_start_time logger.info( f"[EvaluatorResult {result.result_id}] Completed in {total_time:.2f}s, " diff --git a/app/workers/tasks/run_evaluator.py b/app/workers/tasks/run_evaluator.py index 536e8c53..a82f0199 100644 --- a/app/workers/tasks/run_evaluator.py +++ b/app/workers/tasks/run_evaluator.py @@ -91,6 +91,15 @@ def run_evaluator_task(self, evaluator_id: str, evaluator_result_id: str): db.commit() + from app.services.billing.flexprice_service import record_evaluator_run_completed + + record_evaluator_run_completed( + evaluator.organization_id, + result.result_id, + workspace_id=evaluator.workspace_id, + evaluator_id=evaluator.id, + ) + return { "evaluator_id": evaluator_id, "result_id": evaluator_result_id, diff --git a/app/workers/tasks/run_judge_alignment.py b/app/workers/tasks/run_judge_alignment.py index cf39652b..6ffbb655 100644 --- a/app/workers/tasks/run_judge_alignment.py +++ b/app/workers/tasks/run_judge_alignment.py @@ -88,6 +88,18 @@ def run_judge_alignment_task( _fail(db, run, str(exc)) return {"error": str(exc)} + from app.services.billing.flexprice_service import ( + record_judge_alignment_run_completed, + ) + + record_judge_alignment_run_completed( + run.organization_id, + run.id, + workspace_id=run.workspace_id, + dataset_id=run.dataset_id, + samples_scored=len(samples), + ) + return {"judge_run_id": judge_run_id, "metrics": metrics} finally: diff --git a/app/workers/tasks/run_prompt_optimization.py b/app/workers/tasks/run_prompt_optimization.py index 1b388465..dbd9c869 100644 --- a/app/workers/tasks/run_prompt_optimization.py +++ b/app/workers/tasks/run_prompt_optimization.py @@ -141,6 +141,18 @@ def run_prompt_optimization_task(self, optimization_run_id: str): f"Best score: {run.best_score}" ) + from app.services.billing.flexprice_service import ( + record_prompt_optimization_run_completed, + ) + + record_prompt_optimization_run_completed( + run.organization_id, + run.id, + workspace_id=run.workspace_id, + agent_id=run.agent_id, + candidates_count=len(result.get("candidates", [])), + ) + except Exception as e: logger.error(f"[GEPA] Optimization run {optimization_run_id} failed: {e}", exc_info=True) try: diff --git a/app/workers/tasks/tts_comparison.py b/app/workers/tasks/tts_comparison.py index 76092147..8ae42ea9 100644 --- a/app/workers/tasks/tts_comparison.py +++ b/app/workers/tasks/tts_comparison.py @@ -291,6 +291,18 @@ def _resolve_voice_meta(sample_obj): sample.status = TTSSampleStatus.COMPLETED.value db.commit() + from app.services.billing.flexprice_service import record_tts_sample_synthesized + + record_tts_sample_synthesized( + comp.organization_id, + sample.id, + workspace_id=comp.workspace_id, + comparison_id=comp.id, + provider=sample.provider, + side=sample.side, + duration_seconds=sample.duration_seconds, + ) + logger.info( f"[TTS Generate] Sample {sample.id} done – " f"{sample.provider}/{sample.voice_name} ttfb={ttfb_ms:.0f}ms total={latency_ms:.0f}ms" diff --git a/app/workers/tasks/tts_report.py b/app/workers/tasks/tts_report.py index c738244f..f61d5e1e 100644 --- a/app/workers/tasks/tts_report.py +++ b/app/workers/tasks/tts_report.py @@ -79,6 +79,15 @@ def generate_tts_report_pdf_task(self, report_job_id: str, report_options: dict report_job.error_message = None db.commit() + from app.services.billing.flexprice_service import record_tts_report_completed + + record_tts_report_completed( + report_job.organization_id, + report_job.id, + workspace_id=report_job.workspace_id, + comparison_id=comparison.id, + ) + return {"status": "completed", "s3_key": s3_key} except Exception as exc: logger.error(f"[TTS Report] Task failed: {exc}", exc_info=True) diff --git a/config.yml.example b/config.yml.example index 3d763a9d..ca83c979 100644 --- a/config.yml.example +++ b/config.yml.example @@ -163,3 +163,11 @@ judge_alignment: # license: # key: "eyJhbGciOi..." +# Flexprice usage-based billing (optional). When disabled or api_key is unset, +# no SDK calls are made and the app behaves as today. Cloud SaaS: set enabled +# true and provide FLEXPRICE_API_KEY (env var overrides api_key below). +# flexprice: +# enabled: false +# api_key: null +# api_host: "https://us.api.flexprice.io/v1" + diff --git a/pyproject.toml b/pyproject.toml index 8379d846..9eea5839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ "plivo>=4.47.0", "openpyxl>=3.1.0", "weasyprint>=62.0", + "flexprice>=2.1.0", ] [project.optional-dependencies] diff --git a/scripts/create_api_key.py b/scripts/create_api_key.py index 631015ee..a2b0e33b 100644 --- a/scripts/create_api_key.py +++ b/scripts/create_api_key.py @@ -23,7 +23,10 @@ from app.database import SessionLocal, init_db from app.models.database import APIKey, Organization -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) def main() -> int: @@ -61,6 +64,10 @@ def main() -> int: db.add(org) db.flush() provision_default_workspace(db, organization_id=org.id) + provision_billing_customer( + organization_id=org.id, + name=args.new_org, + ) api_key = secrets.token_urlsafe(32) db_key = APIKey( diff --git a/tests/test_api/test_voice_playground_routes.py b/tests/test_api/test_voice_playground_routes.py index 126da011..c627a686 100644 --- a/tests/test_api/test_voice_playground_routes.py +++ b/tests/test_api/test_voice_playground_routes.py @@ -301,3 +301,99 @@ def test_blind_test_only_rejects_cross_workspace_tts_sample( assert response.status_code == 404 + +def _seed_blind_test_ready_comparison(db_session, org_id, workspace_id): + comparison = TTSComparison( + id=uuid4(), + organization_id=org_id, + workspace_id=workspace_id, + simulation_id="sim-blind-share", + name="Blind share metering test", + status=TTSComparisonStatus.COMPLETED.value, + mode="benchmark", + provider_a="openai", + provider_b="elevenlabs", + model_a="gpt-4o-mini-tts", + model_b="eleven_multilingual_v2", + voices_a=[{"id": "alloy", "name": "Alloy"}], + voices_b=[{"id": "rachel", "name": "Rachel"}], + sample_texts=["hello"], + num_runs=1, + ) + db_session.add(comparison) + db_session.flush() + + db_session.add_all( + [ + TTSSample( + id=uuid4(), + comparison_id=comparison.id, + organization_id=org_id, + workspace_id=workspace_id, + provider="openai", + model="gpt-4o-mini-tts", + voice_id="alloy", + voice_name="Alloy", + side="A", + sample_index=0, + run_index=0, + text="hello", + audio_s3_key="organizations/test/a.wav", + status=TTSSampleStatus.COMPLETED.value, + source_type="tts", + ), + TTSSample( + id=uuid4(), + comparison_id=comparison.id, + organization_id=org_id, + workspace_id=workspace_id, + provider="elevenlabs", + model="eleven_multilingual_v2", + voice_id="rachel", + voice_name="Rachel", + side="B", + sample_index=0, + run_index=0, + text="hello", + audio_s3_key="organizations/test/b.wav", + status=TTSSampleStatus.COMPLETED.value, + source_type="tts", + ), + ] + ) + db_session.commit() + return comparison + + +def test_create_blind_test_share_meters_only_on_first_create( + authenticated_client, db_session, org_id, default_workspace, monkeypatch +): + from app.api.v1.routes import voice_playground as vp_routes + + comparison = _seed_blind_test_ready_comparison(db_session, org_id, default_workspace.id) + calls = [] + + def _record_metering(*args, **kwargs): + calls.append((args, kwargs)) + + monkeypatch.setattr(vp_routes, "record_blind_test_share_created", _record_metering) + + payload = {"title": "Public blind test", "custom_metrics": []} + first = authenticated_client.post( + f"/api/v1/voice-playground/comparisons/{comparison.id}/share", + json=payload, + ) + assert first.status_code == 200 + + second = authenticated_client.post( + f"/api/v1/voice-playground/comparisons/{comparison.id}/share", + json={"title": "Updated title", "custom_metrics": []}, + ) + assert second.status_code == 200 + + assert len(calls) == 1 + args, kwargs = calls[0] + assert args[0] == org_id + assert kwargs["comparison_id"] == comparison.id + assert kwargs["workspace_id"] == default_workspace.id + diff --git a/tests/test_services/test_flexprice_service.py b/tests/test_services/test_flexprice_service.py new file mode 100644 index 00000000..78c60eb9 --- /dev/null +++ b/tests/test_services/test_flexprice_service.py @@ -0,0 +1,229 @@ +"""Unit tests for Flexprice billing service (optional metering).""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.config import settings +from app.services.billing import flexprice_service as svc + + +@pytest.fixture(autouse=True) +def reset_flexprice_settings(): + previous = ( + settings.FLEXPRICE_ENABLED, + settings.FLEXPRICE_API_KEY, + settings.FLEXPRICE_API_HOST, + ) + yield + ( + settings.FLEXPRICE_ENABLED, + settings.FLEXPRICE_API_KEY, + settings.FLEXPRICE_API_HOST, + ) = previous + + +def test_is_enabled_false_when_disabled(): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + assert svc.is_enabled() is False + + +def test_is_enabled_false_when_api_key_missing(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = None + assert svc.is_enabled() is False + + +def test_is_enabled_true_when_configured(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + assert svc.is_enabled() is True + + +@patch("flexprice.Flexprice") +def test_ensure_customer_no_op_when_disabled(mock_flexprice): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + + svc.ensure_customer(uuid4(), name="Acme") + + mock_flexprice.assert_not_called() + + +@patch("flexprice.Flexprice") +def test_ensure_customer_calls_create_customer(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + settings.FLEXPRICE_API_HOST = "https://us.api.flexprice.io/v1" + + org_id = uuid4() + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.ensure_customer(org_id, name="Acme Inc", email="admin@acme.com") + + mock_flexprice.assert_called_once_with( + server_url="https://us.api.flexprice.io/v1", + api_key_auth="test-key", + ) + mock_client.customers.create_customer.assert_called_once_with( + external_id=str(org_id), + name="Acme Inc", + email="admin@acme.com", + ) + + +@patch("flexprice.Flexprice") +def test_ensure_customer_swallows_already_exists(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.customers.create_customer.side_effect = Exception("Customer already exists") + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.ensure_customer(uuid4(), name="Acme Inc") + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_no_op_when_disabled(mock_flexprice): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + mock_flexprice.assert_not_called() + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_ingests_event(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + share_id = uuid4() + workspace_id = uuid4() + comparison_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created( + org_id, + share_id, + workspace_id=workspace_id, + comparison_id=comparison_id, + ) + + mock_client.events.ingest_event.assert_called_once_with( + event_name="blind_test.share_created", + external_customer_id=str(org_id), + event_id=str(share_id), + source="efficientai", + properties={ + "share_id": str(share_id), + "workspace_id": str(workspace_id), + "comparison_id": str(comparison_id), + "feature": "voice_playground", + }, + ) + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_logs_and_swallows_errors(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.events.ingest_event.side_effect = RuntimeError("network down") + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + +@patch("flexprice.Flexprice") +def test_ingest_usage_event_falls_back_to_request_dict(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.events.ingest_event.side_effect = [ + TypeError("ingest_event() got an unexpected keyword argument 'event_name'"), + None, + ] + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + assert mock_client.events.ingest_event.call_count == 2 + assert "request" in mock_client.events.ingest_event.call_args_list[1].kwargs + + +@patch("flexprice.Flexprice") +def test_record_call_import_batch_created_includes_volume_properties(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + call_import_id = uuid4() + workspace_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_call_import_batch_created( + org_id, + call_import_id, + workspace_id=workspace_id, + total_rows=42, + source="csv", + provider="exotel", + ) + + mock_client.events.ingest_event.assert_called_once() + payload = mock_client.events.ingest_event.call_args.kwargs + assert payload["event_name"] == "call_import.batch_created" + assert payload["properties"]["total_rows"] == 42 + assert payload["properties"]["feature"] == "call_imports" + + +@patch("flexprice.Flexprice") +def test_record_call_import_evaluation_row_completed_uses_composite_event_id(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + evaluation_id = uuid4() + row_id = uuid4() + workspace_id = uuid4() + call_import_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_call_import_evaluation_row_completed( + org_id, + evaluation_id, + row_id, + workspace_id=workspace_id, + call_import_id=call_import_id, + metrics_scored=3, + ) + + mock_client.events.ingest_event.assert_called_once_with( + event_name="call_import.evaluation_row_completed", + external_customer_id=str(org_id), + event_id=f"{evaluation_id}:{row_id}", + source="efficientai", + properties={ + "workspace_id": str(workspace_id), + "feature": "call_imports", + "call_import_id": str(call_import_id), + "evaluation_id": str(evaluation_id), + "row_id": str(row_id), + "metrics_scored": 3, + "quantity": 1, + }, + ) From be29d97b2c6ee7984a199222d304729f5d9414c6 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Mon, 29 Jun 2026 12:44:46 +0000 Subject: [PATCH 06/11] feat: updating some details --- app/api/v1/routes/call_import_evaluations.py | 11 - app/api/v1/routes/call_imports.py | 15 +- .../052_call_import_evaluation_billed_rows.py | 59 +++++ app/models/database.py | 5 + app/services/billing/flexprice_service.py | 77 ++---- app/workers/tasks/evaluate_call_import_row.py | 35 ++- app/workers/tasks/process_call_import_row.py | 9 - scripts/remove_legacy_flexprice_meter.py | 95 +++++++ scripts/setup_flexprice_meters.py | 242 ++++++++++++++++++ tests/test_services/test_flexprice_service.py | 24 +- .../test_evaluate_call_import_row.py | 94 +++++++ 11 files changed, 550 insertions(+), 116 deletions(-) create mode 100644 app/migrations/052_call_import_evaluation_billed_rows.py create mode 100644 scripts/remove_legacy_flexprice_meter.py create mode 100644 scripts/setup_flexprice_meters.py diff --git a/app/api/v1/routes/call_import_evaluations.py b/app/api/v1/routes/call_import_evaluations.py index 5e3b9504..53b0a11e 100644 --- a/app/api/v1/routes/call_import_evaluations.py +++ b/app/api/v1/routes/call_import_evaluations.py @@ -32,7 +32,6 @@ get_workspace_id, require_enterprise_feature, ) -from app.services.billing.flexprice_service import record_call_import_evaluation_started from app.services.workspace_rbac import resolve_workspace_capabilities from app.models.database import ( AIProvider, @@ -1092,16 +1091,6 @@ def _name_for_source(source: str) -> Optional[str]: db.commit() for evaluation in created_evaluations: db.refresh(evaluation) - if evaluation.status == "running": - background_tasks.add_task( - record_call_import_evaluation_started, - organization_id, - evaluation.id, - workspace_id=call_import.workspace_id, - call_import_id=call_import.id, - row_count=evaluation.total_rows, - metric_count=len(leaf_metric_ids), - ) return _serialize_eval( db, primary_evaluation, sibling_evaluation_ids=sibling_ids ) diff --git a/app/api/v1/routes/call_imports.py b/app/api/v1/routes/call_imports.py index 538ae494..6c061b37 100644 --- a/app/api/v1/routes/call_imports.py +++ b/app/api/v1/routes/call_imports.py @@ -32,10 +32,7 @@ get_workspace_id, require_enterprise_feature, ) -from app.services.billing.flexprice_service import ( - record_call_import_batch_created, - record_call_import_row_imported, -) +from app.services.billing.flexprice_service import record_call_import_batch_created from app.models.database import ( CallImport, CallImportRow, @@ -2026,7 +2023,6 @@ async def upload_call_import_audio( ) total_size = sum(len(item["contents"]) for item in prepared) uploaded_keys: List[str] = [] - imported_row_ids: List[UUID] = [] from app.services.storage.s3_service import s3_service @@ -2086,7 +2082,6 @@ async def upload_call_import_audio( row.recording_s3_key = key row.recording_content_type = item["content_type"] row.recording_size_bytes = len(item["contents"]) - imported_row_ids.append(row.id) db.commit() except Exception as exc: @@ -2114,14 +2109,6 @@ async def upload_call_import_audio( source="audio", provider=None, ) - for row_id in imported_row_ids: - background_tasks.add_task( - record_call_import_row_imported, - organization_id, - row_id, - workspace_id=workspace_id, - call_import_id=call_import.id, - ) return CallImportUploadResponse( id=call_import.id, total_rows=call_import.total_rows, diff --git a/app/migrations/052_call_import_evaluation_billed_rows.py b/app/migrations/052_call_import_evaluation_billed_rows.py new file mode 100644 index 00000000..b4b3ede2 --- /dev/null +++ b/app/migrations/052_call_import_evaluation_billed_rows.py @@ -0,0 +1,59 @@ +""" +Migration: Flexprice billing watermark on call_import_evaluations. + +Adds ``billed_completed_rows`` so evaluation billing can emit one +``call_import.evaluation_completed`` event per pass with a delta +``quantity`` (newly completed rows since the last pass). +""" + +from sqlalchemy import text +from sqlalchemy.orm import Session + +description = ( + "Add billed_completed_rows on call_import_evaluations for pass-level " + "delta usage metering." +) + + +def _column_exists(db: Session, table_name: str, column_name: str) -> bool: + row = db.execute( + text( + """ + SELECT 1 + FROM information_schema.columns + WHERE table_name = :table_name AND column_name = :column_name + """ + ), + {"table_name": table_name, "column_name": column_name}, + ).first() + return row is not None + + +def upgrade(db: Session): + if _column_exists(db, "call_import_evaluations", "billed_completed_rows"): + print( + "call_import_evaluations.billed_completed_rows already exists, skipping..." + ) + return + + db.execute( + text( + """ + ALTER TABLE call_import_evaluations + ADD COLUMN billed_completed_rows INTEGER NOT NULL DEFAULT 0 + """ + ) + ) + print("Added call_import_evaluations.billed_completed_rows (integer, default 0)") + + +def downgrade(db: Session): + if not _column_exists(db, "call_import_evaluations", "billed_completed_rows"): + print("call_import_evaluations.billed_completed_rows missing, skipping...") + return + db.execute( + text( + "ALTER TABLE call_import_evaluations DROP COLUMN billed_completed_rows" + ) + ) + print("Dropped call_import_evaluations.billed_completed_rows") diff --git a/app/models/database.py b/app/models/database.py index 476fb7ed..dcc5ea85 100644 --- a/app/models/database.py +++ b/app/models/database.py @@ -2248,6 +2248,11 @@ class CallImportEvaluation(Base): total_rows = Column(Integer, nullable=False, default=0) completed_rows = Column(Integer, nullable=False, default=0) failed_rows = Column(Integer, nullable=False, default=0) + # Flexprice pass-level delta billing watermark: rows already emitted + # on ``call_import.evaluation_completed`` for this evaluation run. + billed_completed_rows = Column( + Integer, nullable=False, default=0, server_default="0" + ) error_message = Column(Text, nullable=True) celery_group_id = Column(String(255), nullable=True) diff --git a/app/services/billing/flexprice_service.py b/app/services/billing/flexprice_service.py index 4091fdda..adfd6b04 100644 --- a/app/services/billing/flexprice_service.py +++ b/app/services/billing/flexprice_service.py @@ -30,6 +30,7 @@ CALL_IMPORT_BATCH_CREATED = "call_import.batch_created" CALL_IMPORT_ROW_IMPORTED = "call_import.row_imported" CALL_IMPORT_EVALUATION_STARTED = "call_import.evaluation_started" +CALL_IMPORT_EVALUATION_COMPLETED = "call_import.evaluation_completed" CALL_IMPORT_EVALUATION_ROW_COMPLETED = "call_import.evaluation_row_completed" PLAYGROUND_WEB_CALL_STARTED = "playground.web_call_started" PLAYGROUND_WEBSOCKET_SESSION_STARTED = "playground.websocket_session_started" @@ -77,17 +78,20 @@ def _ingest_usage_event(client, payload: dict) -> None: ingest(request=payload) -def _coerce_properties(properties: Optional[dict[str, Any]]) -> dict[str, Any]: +def _coerce_properties(properties: Optional[dict[str, Any]]) -> dict[str, str]: + """Normalize event properties for Flexprice ingest (SDK expects string values).""" if not properties: return {} - out: dict[str, Any] = {} + out: dict[str, str] = {} for key, value in properties.items(): if value is None: continue - if isinstance(value, UUID): + if isinstance(value, bool): + out[key] = "true" if value else "false" + elif isinstance(value, (int, float, UUID)): out[key] = str(value) else: - out[key] = value + out[key] = str(value) return out @@ -313,81 +317,40 @@ def record_call_import_batch_created( "feature": FEATURE_CALL_IMPORTS, "call_import_id": call_import_id, "total_rows": total_rows, + "quantity": total_rows, "source": source, "provider": provider, }, ) -def record_call_import_row_imported( - organization_id: UUID, - row_id: UUID, - *, - workspace_id: UUID, - call_import_id: UUID, - duration_seconds: Optional[float] = None, -) -> None: - record_event( - CALL_IMPORT_ROW_IMPORTED, - organization_id, - row_id, - properties={ - "workspace_id": workspace_id, - "feature": FEATURE_CALL_IMPORTS, - "call_import_id": call_import_id, - "row_id": row_id, - "duration_seconds": duration_seconds, - "quantity": 1, - }, - ) +# --- Call imports (evaluations) --- -def record_call_import_evaluation_started( +def record_call_import_evaluation_completed( organization_id: UUID, evaluation_id: UUID, *, workspace_id: UUID, call_import_id: UUID, - row_count: int, - metric_count: int, + rows_billed: int, + completed_total: int, + metric_count: int = 0, ) -> None: + """Bill one pass of an evaluation run for newly completed rows.""" record_event( - CALL_IMPORT_EVALUATION_STARTED, + CALL_IMPORT_EVALUATION_COMPLETED, organization_id, - evaluation_id, + f"{evaluation_id}:{completed_total}", properties={ "workspace_id": workspace_id, "feature": FEATURE_CALL_IMPORTS, "call_import_id": call_import_id, "evaluation_id": evaluation_id, - "row_count": row_count, + "rows_billed": rows_billed, + "completed_total": completed_total, "metric_count": metric_count, - "quantity": row_count * metric_count, - }, - ) - - -def record_call_import_evaluation_row_completed( - organization_id: UUID, - evaluation_id: UUID, - row_id: UUID, - *, - workspace_id: UUID, - call_import_id: UUID, - metrics_scored: int, -) -> None: - record_event( - CALL_IMPORT_EVALUATION_ROW_COMPLETED, - organization_id, - f"{evaluation_id}:{row_id}", - properties={ - "workspace_id": workspace_id, - "feature": FEATURE_CALL_IMPORTS, - "call_import_id": call_import_id, - "evaluation_id": evaluation_id, - "row_id": row_id, - "metrics_scored": metrics_scored, - "quantity": 1, + "quantity": rows_billed, }, ) diff --git a/app/workers/tasks/evaluate_call_import_row.py b/app/workers/tasks/evaluate_call_import_row.py index 7641feb9..8d2706ee 100644 --- a/app/workers/tasks/evaluate_call_import_row.py +++ b/app/workers/tasks/evaluate_call_import_row.py @@ -364,6 +364,7 @@ def _build_parent_groups( def _rollup_parent(db, evaluation: CallImportEvaluation) -> None: + previous_status = evaluation.status rows = ( db.query(CallImportEvaluationRow.status) .filter(CallImportEvaluationRow.evaluation_id == evaluation.id) @@ -394,6 +395,26 @@ def _rollup_parent(db, evaluation: CallImportEvaluation) -> None: else: evaluation.status = "partial" + if previous_status == "running": + already_billed = int(getattr(evaluation, "billed_completed_rows", 0) or 0) + delta = completed - already_billed + if delta > 0: + from app.services.billing.flexprice_service import ( + record_call_import_evaluation_completed, + ) + + metric_count = len(evaluation.selected_metric_ids or []) + record_call_import_evaluation_completed( + evaluation.organization_id, + evaluation.id, + workspace_id=evaluation.workspace_id, + call_import_id=evaluation.call_import_id, + rows_billed=delta, + completed_total=completed, + metric_count=metric_count, + ) + evaluation.billed_completed_rows = completed + # Per-task time limits keep a wedged audio evaluation (e.g. torch.hub UTMOS # download stuck on a network hiccup, or libgomp deadlock in a prefork child) @@ -1059,20 +1080,6 @@ def _resolve_pm( _rollup_parent(db, evaluation) db.commit() - if eval_row.status == "completed": - from app.services.billing.flexprice_service import ( - record_call_import_evaluation_row_completed, - ) - - record_call_import_evaluation_row_completed( - evaluation.organization_id, - evaluation.id, - source_row.id, - workspace_id=evaluation.workspace_id, - call_import_id=evaluation.call_import_id, - metrics_scored=len(eval_row.metric_scores or {}), - ) - return { "status": eval_row.status, "eval_row_id": eval_row_id, diff --git a/app/workers/tasks/process_call_import_row.py b/app/workers/tasks/process_call_import_row.py index d30490b5..c02d2fad 100644 --- a/app/workers/tasks/process_call_import_row.py +++ b/app/workers/tasks/process_call_import_row.py @@ -433,15 +433,6 @@ def process_call_import_row_task(self, row_id: str): _rollup_parent_status(db, call_import) db.commit() - from app.services.billing.flexprice_service import record_call_import_row_imported - - record_call_import_row_imported( - row.organization_id, - row.id, - workspace_id=call_import.workspace_id, - call_import_id=call_import.id, - ) - return { "status": "completed", "row_id": row_id, diff --git a/scripts/remove_legacy_flexprice_meter.py b/scripts/remove_legacy_flexprice_meter.py new file mode 100644 index 00000000..a11656f6 --- /dev/null +++ b/scripts/remove_legacy_flexprice_meter.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""Remove legacy Flexprice meters/features superseded by the metering catalog.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import httpx +from flexprice import Flexprice + +from app.config import load_config_from_file, settings + +CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.yml" + +LEGACY_EVENT_NAMES = { + "BLIND_TEST_SHARE_CREATED_EVENT", +} + + +def _headers() -> dict[str, str]: + return {"x-api-key": settings.FLEXPRICE_API_KEY or ""} + + +def _base_url() -> str: + return (settings.FLEXPRICE_API_HOST or "").rstrip("/") + + +def _delete_meter(client: httpx.Client, meter_id: str) -> dict: + for method, suffix in (("DELETE", ""), ("POST", "/archive")): + url = f"{_base_url()}/meters/{meter_id}{suffix}" + if method == "DELETE": + resp = client.delete(url, headers=_headers()) + else: + resp = client.post(url, headers=_headers()) + if resp.status_code in (200, 204, 404): + return {"meter_id": meter_id, "method": method, "status": resp.status_code} + resp.raise_for_status() + return {"meter_id": meter_id, "status": resp.status_code} + + +def main() -> int: + if CONFIG_PATH.exists(): + load_config_from_file(str(CONFIG_PATH)) + + if not settings.FLEXPRICE_API_KEY: + print("FLEXPRICE_API_KEY is missing.", file=sys.stderr) + return 1 + + results: dict[str, list] = {"meters_removed": [], "features_removed": [], "errors": []} + + with httpx.Client(timeout=60.0) as client: + resp = client.get(f"{_base_url()}/meters", headers=_headers(), params={"limit": 200}) + resp.raise_for_status() + for meter in resp.json().get("items") or []: + event_name = meter.get("event_name") or "" + if event_name not in LEGACY_EVENT_NAMES: + continue + try: + results["meters_removed"].append( + {"event_name": event_name, **_delete_meter(client, meter["id"])} + ) + except Exception as exc: + results["errors"].append(f"meter {event_name}: {exc}") + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as sdk: + resp = sdk.features.query_feature(limit=200) + for item in resp.items or []: + lookup = (item.lookup_key or "").lower() + name = (item.name or "").upper() + meter = item.meter + meter_event = (meter.event_name if meter else "") or "" + if ( + meter_event in LEGACY_EVENT_NAMES + or lookup == "feat-blind_test_share_created_event" + or name in LEGACY_EVENT_NAMES + ): + try: + sdk.features.delete_feature(id=item.id) + results["features_removed"].append( + {"id": item.id, "lookup_key": item.lookup_key, "name": item.name} + ) + except Exception as exc: + results["errors"].append(f"feature {item.lookup_key or item.name}: {exc}") + + print(json.dumps(results, indent=2)) + return 1 if results["errors"] else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/setup_flexprice_meters.py b/scripts/setup_flexprice_meters.py new file mode 100644 index 00000000..533526c7 --- /dev/null +++ b/scripts/setup_flexprice_meters.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +"""Create Flexprice features and meters for EfficientAI usage metering catalog.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any + +import httpx +from flexprice import Flexprice + +from app.config import load_config_from_file, settings + +CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.yml" + +# (event_name, display_name, aggregation_type, aggregation_field|None) +METERS: list[tuple[str, str, str, str | None]] = [ + # Voice playground + ("blind_test.share_created", "Blind Test Share Created", "COUNT", None), + ("blind_test.response_submitted", "Blind Test Response Submitted", "COUNT", None), + ("tts.generation_started", "TTS Generation Started", "COUNT", None), + ("tts.sample_synthesized", "TTS Sample Synthesized", "SUM", "quantity"), + ("tts.report_requested", "TTS Report Requested", "COUNT", None), + ("tts.report_completed", "TTS Report Completed", "COUNT", None), + # Call imports + ("call_import.batch_created", "Call Import Batch Created", "SUM", "quantity"), + ("call_import.evaluation_completed", "Call Import Evaluation Completed", "SUM", "quantity"), + # Agent playground + ("playground.web_call_started", "Playground Web Call Started", "COUNT", None), + ("playground.websocket_session_started", "Playground Websocket Session Started", "COUNT", None), + ("playground.call_evaluated", "Playground Call Evaluated", "COUNT", None), + ("playground.evaluation_completed", "Playground Evaluation Completed", "COUNT", None), + # Evaluators + ("evaluator.run_requested", "Evaluator Run Requested", "SUM", "quantity"), + ("evaluator.run_completed", "Evaluator Run Completed", "COUNT", None), + # Legacy evaluations + ("evaluation.created", "Evaluation Created", "COUNT", None), + ("evaluation.completed", "Evaluation Completed", "COUNT", None), + # Prompt optimization + ("prompt_optimization.run_started", "Prompt Optimization Run Started", "COUNT", None), + ("prompt_optimization.run_completed", "Prompt Optimization Run Completed", "COUNT", None), + # Judge alignment + ("judge_alignment.run_started", "Judge Alignment Run Started", "COUNT", None), + ("judge_alignment.run_completed", "Judge Alignment Run Completed", "COUNT", None), + # Observability + ("observability.call_ingested", "Observability Call Ingested", "COUNT", None), + ("observability.call_evaluated", "Observability Call Evaluated", "COUNT", None), + # Test agents + ("test_agent.conversation_started", "Test Agent Conversation Started", "COUNT", None), + ("test_agent.conversation_ended", "Test Agent Conversation Ended", "SUM", "quantity"), + # LLM assist + ("metrics.llm_assist", "Metrics LLM Assist", "COUNT", None), + ("chat.completion", "Chat Completion", "SUM", "quantity"), +] + +LICENSE_FEATURES: list[dict[str, Any]] = [ + { + "name": "Call Imports", + "lookup_key": "call_imports", + "description": "CSV/audio call import batches, row processing, and evaluations", + "unit_singular": "batch", + "unit_plural": "batches", + "event_name": "call_import.batch_created", + "aggregation": {"type": "COUNT"}, + }, + { + "name": "Voice Playground", + "lookup_key": "voice_playground", + "description": "TTS comparisons, blind tests, and voice quality reports", + "unit_singular": "share", + "unit_plural": "shares", + "event_name": "blind_test.share_created", + "aggregation": {"type": "COUNT"}, + }, + { + "name": "GEPA Optimization", + "lookup_key": "gepa_optimization", + "description": "Prompt optimization (GEPA) runs", + "unit_singular": "run", + "unit_plural": "runs", + "event_name": "prompt_optimization.run_started", + "aggregation": {"type": "COUNT"}, + }, +] + + +def _headers() -> dict[str, str]: + return {"x-api-key": settings.FLEXPRICE_API_KEY or "", "Content-Type": "application/json"} + + +def _base_url() -> str: + return (settings.FLEXPRICE_API_HOST or "").rstrip("/") + + +def _meter_payload(name: str, event_name: str, agg_type: str, field: str | None) -> dict[str, Any]: + aggregation: dict[str, str] = {"type": agg_type} + if field and agg_type in {"SUM", "MAX", "LATEST", "COUNT_UNIQUE", "AVG"}: + aggregation["field"] = field + return { + "name": name, + "event_name": event_name, + "aggregation": aggregation, + "reset_usage": "BILLING_PERIOD", + } + + +def _list_meters(client: httpx.Client) -> dict[str, dict]: + existing: dict[str, dict] = {} + offset = 0 + while True: + resp = client.get( + f"{_base_url()}/meters", + headers=_headers(), + params={"limit": 200, "offset": offset}, + ) + resp.raise_for_status() + data = resp.json() + items = data.get("items") or [] + for item in items: + event_name = item.get("event_name") + if event_name: + existing[event_name] = item + pagination = data.get("pagination") or {} + total = pagination.get("total") + offset += len(items) + if not items or (total is not None and offset >= total): + break + return existing + + +def _create_meter(client: httpx.Client, event_name: str, name: str, agg_type: str, field: str | None) -> dict: + payload = _meter_payload(name, event_name, agg_type, field) + resp = client.post(f"{_base_url()}/meters", headers=_headers(), json=payload) + if resp.status_code == 409 or ( + resp.status_code == 400 and "exist" in resp.text.lower() + ): + return {"skipped": True, "event_name": event_name, "detail": resp.text} + resp.raise_for_status() + return resp.json() + + +def _list_feature_lookup_keys(sdk: Flexprice) -> set[str]: + keys: set[str] = set() + offset = 0 + while True: + resp = sdk.features.query_feature(limit=200, offset=offset) + items = resp.items or [] + for item in items: + if item.lookup_key: + keys.add(item.lookup_key) + offset += len(items) + if not items: + break + return keys + + +def _create_license_feature(sdk: Flexprice, spec: dict[str, Any]) -> dict: + meter = { + "name": spec["name"], + "event_name": spec["event_name"], + "aggregation": spec["aggregation"], + "reset_usage": "BILLING_PERIOD", + } + try: + result = sdk.features.create_feature( + name=spec["name"], + type_="metered", + lookup_key=spec["lookup_key"], + description=spec["description"], + unit_singular=spec.get("unit_singular"), + unit_plural=spec.get("unit_plural"), + meter=meter, + ) + return {"created": True, "lookup_key": spec["lookup_key"], "id": result.id} + except Exception as exc: + message = str(exc).lower() + if "already exist" in message or "duplicate" in message: + return {"skipped": True, "lookup_key": spec["lookup_key"], "detail": str(exc)} + raise + + +def main() -> int: + if CONFIG_PATH.exists(): + load_config_from_file(str(CONFIG_PATH)) + + if not settings.FLEXPRICE_ENABLED or not settings.FLEXPRICE_API_KEY: + print("Flexprice is not enabled or FLEXPRICE_API_KEY is missing.", file=sys.stderr) + return 1 + + created_meters: list[str] = [] + skipped_meters: list[str] = [] + failed_meters: list[str] = [] + created_features: list[str] = [] + skipped_features: list[str] = [] + failed_features: list[str] = [] + + with httpx.Client(timeout=60.0) as http_client: + existing_meters = _list_meters(http_client) + for event_name, name, agg_type, field in METERS: + if event_name in existing_meters: + skipped_meters.append(event_name) + continue + try: + result = _create_meter(http_client, event_name, name, agg_type, field) + if result.get("skipped"): + skipped_meters.append(event_name) + else: + created_meters.append(event_name) + except Exception as exc: + failed_meters.append(f"{event_name}: {exc}") + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as sdk: + existing_feature_keys = _list_feature_lookup_keys(sdk) + for spec in LICENSE_FEATURES: + key = spec["lookup_key"] + if key in existing_feature_keys: + skipped_features.append(key) + continue + try: + result = _create_license_feature(sdk, spec) + if result.get("skipped"): + skipped_features.append(key) + else: + created_features.append(key) + except Exception as exc: + failed_features.append(f"{key}: {exc}") + + summary = { + "meters": {"created": created_meters, "skipped": skipped_meters, "failed": failed_meters}, + "features": {"created": created_features, "skipped": skipped_features, "failed": failed_features}, + } + print(json.dumps(summary, indent=2)) + return 1 if (failed_meters or failed_features) else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_services/test_flexprice_service.py b/tests/test_services/test_flexprice_service.py index 78c60eb9..db028f9c 100644 --- a/tests/test_services/test_flexprice_service.py +++ b/tests/test_services/test_flexprice_service.py @@ -185,45 +185,47 @@ def test_record_call_import_batch_created_includes_volume_properties(mock_flexpr mock_client.events.ingest_event.assert_called_once() payload = mock_client.events.ingest_event.call_args.kwargs assert payload["event_name"] == "call_import.batch_created" - assert payload["properties"]["total_rows"] == 42 + assert payload["properties"]["total_rows"] == "42" + assert payload["properties"]["quantity"] == "42" assert payload["properties"]["feature"] == "call_imports" @patch("flexprice.Flexprice") -def test_record_call_import_evaluation_row_completed_uses_composite_event_id(mock_flexprice): +def test_record_call_import_evaluation_completed_meters_pass_delta(mock_flexprice): settings.FLEXPRICE_ENABLED = True settings.FLEXPRICE_API_KEY = "test-key" org_id = uuid4() evaluation_id = uuid4() - row_id = uuid4() workspace_id = uuid4() call_import_id = uuid4() mock_client = MagicMock() mock_flexprice.return_value.__enter__.return_value = mock_client - svc.record_call_import_evaluation_row_completed( + svc.record_call_import_evaluation_completed( org_id, evaluation_id, - row_id, workspace_id=workspace_id, call_import_id=call_import_id, - metrics_scored=3, + rows_billed=50, + completed_total=1950, + metric_count=5, ) mock_client.events.ingest_event.assert_called_once_with( - event_name="call_import.evaluation_row_completed", + event_name="call_import.evaluation_completed", external_customer_id=str(org_id), - event_id=f"{evaluation_id}:{row_id}", + event_id=f"{evaluation_id}:1950", source="efficientai", properties={ "workspace_id": str(workspace_id), "feature": "call_imports", "call_import_id": str(call_import_id), "evaluation_id": str(evaluation_id), - "row_id": str(row_id), - "metrics_scored": 3, - "quantity": 1, + "rows_billed": "50", + "completed_total": "1950", + "metric_count": "5", + "quantity": "50", }, ) diff --git a/tests/test_workers/test_evaluate_call_import_row.py b/tests/test_workers/test_evaluate_call_import_row.py index 8ea87e09..8ef04cde 100644 --- a/tests/test_workers/test_evaluate_call_import_row.py +++ b/tests/test_workers/test_evaluate_call_import_row.py @@ -840,3 +840,97 @@ def _simulate_cancel(*_args, **kwargs): refreshed_eval = db_session.get(CallImportEvaluation, evaluation.id) assert refreshed_eval.failed_rows == 1 assert refreshed_eval.completed_rows == 0 + + +def test_rollup_parent_emits_pass_delta_when_leaving_running(db_session, monkeypatch): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=2) + evaluation.status = "running" + eval_rows[0].status = "completed" + eval_rows[1].status = "failed" + db_session.commit() + + _rollup_parent(db_session, evaluation) + + assert evaluation.status == "partial" + assert evaluation.billed_completed_rows == 1 + assert len(recorded) == 1 + assert recorded[0]["rows_billed"] == 1 + assert recorded[0]["completed_total"] == 1 + + +def test_rollup_parent_bills_only_retry_delta(db_session, monkeypatch): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=10) + evaluation.status = "running" + evaluation.billed_completed_rows = 8 + for idx in range(8): + eval_rows[idx].status = "completed" + for idx in range(8, 10): + eval_rows[idx].status = "pending" + db_session.commit() + + _rollup_parent(db_session, evaluation) + assert evaluation.status == "running" + assert recorded == [] + + for idx in range(8, 10): + eval_rows[idx].status = "completed" + db_session.flush() + _rollup_parent(db_session, evaluation) + + assert evaluation.status == "completed" + assert evaluation.billed_completed_rows == 10 + assert len(recorded) == 1 + assert recorded[0]["rows_billed"] == 2 + assert recorded[0]["completed_total"] == 10 + + +def test_rollup_parent_skips_billing_when_metric_rerun_unchanged( + db_session, monkeypatch +): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=5) + evaluation.status = "running" + evaluation.billed_completed_rows = 5 + for eval_row in eval_rows: + eval_row.status = "completed" + db_session.commit() + + _rollup_parent(db_session, evaluation) + + assert evaluation.billed_completed_rows == 5 + assert recorded == [] From 69e802f2cf27ac9c21e5a6a8ba54e1cb37de08b1 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Mon, 29 Jun 2026 14:34:03 +0000 Subject: [PATCH 07/11] fix: testcases --- tests/test_api/test_voice_playground_routes.py | 2 +- .../test_evaluation/test_evaluation_service.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_api/test_voice_playground_routes.py b/tests/test_api/test_voice_playground_routes.py index c627a686..97397f3e 100644 --- a/tests/test_api/test_voice_playground_routes.py +++ b/tests/test_api/test_voice_playground_routes.py @@ -307,7 +307,7 @@ def _seed_blind_test_ready_comparison(db_session, org_id, workspace_id): id=uuid4(), organization_id=org_id, workspace_id=workspace_id, - simulation_id="sim-blind-share", + simulation_id="sim003", name="Blind share metering test", status=TTSComparisonStatus.COMPLETED.value, mode="benchmark", diff --git a/tests/test_services/test_evaluation/test_evaluation_service.py b/tests/test_services/test_evaluation/test_evaluation_service.py index 542e6588..8b71c747 100644 --- a/tests/test_services/test_evaluation/test_evaluation_service.py +++ b/tests/test_services/test_evaluation/test_evaluation_service.py @@ -49,9 +49,13 @@ def __init__(self, **kwargs): def test_process_evaluation_success(monkeypatch): evaluation_id = uuid4() audio_id = uuid4() + org_id = uuid4() + workspace_id = uuid4() evaluation = SimpleNamespace( id=evaluation_id, audio_id=audio_id, + organization_id=org_id, + workspace_id=workspace_id, model_name="base", metrics_requested=["wer", "latency"], reference_text="hello world", @@ -71,6 +75,10 @@ def test_process_evaluation_success(monkeypatch): lambda **_kwargs: {"wer": 0.0, "latency_s": 0.5, "latency_ms": 500.0}, ) monkeypatch.setattr(evaluation_module, "EvaluationResult", _FakeEvaluationResult) + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_evaluation_completed", + lambda *_a, **_k: None, + ) service = EvaluationService() result = service.process_evaluation(evaluation_id, db) From 248452fb0cdfcb5fa7a9a06958e7f0fe4ba696c2 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Mon, 29 Jun 2026 19:27:03 +0000 Subject: [PATCH 08/11] feat: updating logs --- app/main.py | 4 + app/services/billing/flexprice_service.py | 121 ++++++- app/workers/config.py | 8 + config.yml.example | 2 + docker-compose.yml | 296 +++++++++--------- docker/Dockerfile.api | 154 ++++----- tests/test_services/test_flexprice_service.py | 10 + 7 files changed, 355 insertions(+), 240 deletions(-) diff --git a/app/main.py b/app/main.py index ee29a224..a1747da4 100644 --- a/app/main.py +++ b/app/main.py @@ -87,6 +87,10 @@ async def lifespan(app: FastAPI): else: logger.info("All migrations are up to date") + from app.services.billing.flexprice_service import log_startup_status + + log_startup_status(component="api") + logger.info("=" * 60) logger.info("Application startup complete - Ready to serve requests") logger.info("=" * 60) diff --git a/app/services/billing/flexprice_service.py b/app/services/billing/flexprice_service.py index adfd6b04..665ea9b1 100644 --- a/app/services/billing/flexprice_service.py +++ b/app/services/billing/flexprice_service.py @@ -8,6 +8,7 @@ from __future__ import annotations +import os from typing import Any, Optional, Union from uuid import UUID @@ -20,6 +21,9 @@ FEATURE_VOICE_PLAYGROUND = "voice_playground" FEATURE_GEPA = "gepa_optimization" +# Log once when metering is inactive so AWS/worker misconfig is obvious. +_disabled_skip_logged = False + # Event names BLIND_TEST_SHARE_CREATED = "blind_test.share_created" BLIND_TEST_RESPONSE_SUBMITTED = "blind_test.response_submitted" @@ -52,9 +56,51 @@ CHAT_COMPLETION = "chat.completion" +def _verbose_logging() -> bool: + """Extra per-event logs when FLEXPRICE_VERBOSE=1 (useful on AWS).""" + return os.getenv("FLEXPRICE_VERBOSE", "").lower() in {"1", "true", "yes"} + + +def _mask_api_key(api_key: Optional[str]) -> str: + if not api_key: + return "(missing)" + if len(api_key) <= 8: + return "****" + return f"****{api_key[-4:]}" + + +def disabled_reason() -> Optional[str]: + """Human-readable reason metering is off, or None when active.""" + if not settings.FLEXPRICE_ENABLED: + return "flexprice.enabled is false (or FLEXPRICE_ENABLED unset)" + if not settings.FLEXPRICE_API_KEY: + return "flexprice.api_key is unset (or FLEXPRICE_API_KEY env missing)" + return None + + def is_enabled() -> bool: """Return True only when Flexprice is explicitly enabled with an API key.""" - return bool(settings.FLEXPRICE_ENABLED and settings.FLEXPRICE_API_KEY) + return disabled_reason() is None + + +def log_startup_status(*, component: str = "app") -> None: + """Log Flexprice config at process start (API + Celery worker).""" + reason = disabled_reason() + if reason: + logger.info( + "Flexprice metering INACTIVE for {} — {} (api_host={})", + component, + reason, + settings.FLEXPRICE_API_HOST, + ) + return + + logger.info( + "Flexprice metering ACTIVE for {} — api_host={} api_key={}", + component, + settings.FLEXPRICE_API_HOST, + _mask_api_key(settings.FLEXPRICE_API_KEY), + ) def _is_customer_already_exists(exc: Exception) -> bool: @@ -103,9 +149,29 @@ def record_event( properties: Optional[dict[str, Any]] = None, ) -> None: """Ingest a usage event. No-op when Flexprice is disabled; never raises.""" - if not is_enabled(): + global _disabled_skip_logged + + inactive_reason = disabled_reason() + if inactive_reason: + if not _disabled_skip_logged: + logger.warning( + "Flexprice metering inactive — {}. Usage events will be dropped until fixed.", + inactive_reason, + ) + _disabled_skip_logged = True + if _verbose_logging(): + logger.info( + "Flexprice SKIP {} org={} event_id={} ({})", + event_name, + organization_id, + event_id, + inactive_reason, + ) return + coerced = _coerce_properties(properties) + quantity = coerced.get("quantity") + try: from flexprice import Flexprice @@ -113,21 +179,29 @@ def record_event( server_url=settings.FLEXPRICE_API_HOST, api_key_auth=settings.FLEXPRICE_API_KEY, ) as client: - _ingest_usage_event( - client, - { - "event_name": event_name, - "external_customer_id": str(organization_id), - "event_id": str(event_id), - "source": EVENT_SOURCE, - "properties": _coerce_properties(properties), - }, - ) + payload = { + "event_name": event_name, + "external_customer_id": str(organization_id), + "event_id": str(event_id), + "source": EVENT_SOURCE, + "properties": coerced, + } + _ingest_usage_event(client, payload) + + logger.info( + "Flexprice ingested {} org={} event_id={} quantity={}", + event_name, + organization_id, + event_id, + quantity if quantity is not None else "n/a", + ) except Exception as exc: logger.warning( - "Flexprice {} ingest failed (event_id={}): {}", + "Flexprice {} ingest FAILED org={} event_id={} host={} error={}", event_name, + organization_id, event_id, + settings.FLEXPRICE_API_HOST, exc, ) @@ -139,7 +213,14 @@ def ensure_customer( email: Optional[str] = None, ) -> None: """Register an organization as a Flexprice customer. No-op when disabled.""" - if not is_enabled(): + inactive_reason = disabled_reason() + if inactive_reason: + if _verbose_logging(): + logger.info( + "Flexprice SKIP ensure_customer org={} ({})", + organization_id, + inactive_reason, + ) return try: @@ -154,12 +235,22 @@ def ensure_customer( name=name, email=email, ) + logger.info( + "Flexprice ensure_customer ok org={} name={}", + organization_id, + name, + ) except Exception as exc: if _is_customer_already_exists(exc): + logger.debug( + "Flexprice ensure_customer already exists org={}", + organization_id, + ) return logger.warning( - "Flexprice ensure_customer failed for org {}: {}", + "Flexprice ensure_customer FAILED org={} host={} error={}", organization_id, + settings.FLEXPRICE_API_HOST, exc, ) diff --git a/app/workers/config.py b/app/workers/config.py index 250840fe..bdc8db4b 100644 --- a/app/workers/config.py +++ b/app/workers/config.py @@ -116,3 +116,11 @@ "generate_agent_flowchart": {"queue": "imports"}, "map_agent_flowchart_prompt_sections": {"queue": "imports"}, } + + +@celery_app.on_after_configure.connect +def _log_flexprice_on_worker_configure(sender, **kwargs) -> None: + del sender, kwargs + from app.services.billing.flexprice_service import log_startup_status + + log_startup_status(component="celery-worker") diff --git a/config.yml.example b/config.yml.example index 7851a801..6b502edf 100644 --- a/config.yml.example +++ b/config.yml.example @@ -179,4 +179,6 @@ judge_alignment: # enabled: false # api_key: null # api_host: "https://us.api.flexprice.io/v1" +# Env override: FLEXPRICE_ENABLED, FLEXPRICE_API_KEY, FLEXPRICE_API_HOST +# Debug: FLEXPRICE_VERBOSE=1 logs every skipped event (when inactive or verbose) diff --git a/docker-compose.yml b/docker-compose.yml index cb0795c1..e0c6a7b6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,148 +1,148 @@ -# EfficientAI Docker Compose -# -# Usage: -# docker compose up # Uses latest images -# EFFICIENTAI_VERSION=1.0.0 docker compose up # Uses specific version -# -# For local development, uncomment the 'build' sections below. - -services: - db: - image: postgres:15 - container_name: efficientai_db - environment: - POSTGRES_USER: ${POSTGRES_USER:-efficientai} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} - POSTGRES_DB: ${POSTGRES_DB:-efficientai} - volumes: - - postgres_data:/var/lib/postgresql/data - ports: - - "5432:5432" - healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-efficientai}"] - interval: 10s - timeout: 5s - retries: 5 - - redis: - image: redis:7-alpine - container_name: efficientai_redis - ports: - - "6379:6379" - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s - timeout: 5s - retries: 5 - - api: - # Pre-built image from GitHub Container Registry - # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) - image: ghcr.io/efficientai-tech/efficientai-api:${EFFICIENTAI_VERSION:-latest} - # For local development, uncomment below and comment out the image line: - build: - context: . - dockerfile: docker/Dockerfile.api - container_name: efficientai_api - environment: - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - SECRET_KEY: ${SECRET_KEY:-your-secret-key-change-in-production} - UPLOAD_DIR: /app/uploads - DEBUG: ${DEBUG:-True} - FRONTEND_DIR: /app/frontend/dist - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (set storage.blob_provider: gcs in config.docker.yml): - # BLOB_STORAGE_PROVIDER: gcs - # GCS_BUCKET_NAME: your-gcs-bucket - # GCS_PROJECT_ID: your-gcp-project - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - # Optional Azure Blob Storage (set storage.blob_provider: azure in config.docker.yml): - # BLOB_STORAGE_PROVIDER: azure - # AZURE_BLOB_ENABLED: "true" - # AZURE_CONNECTION_STRING: DefaultEndpointsProtocol=https;AccountName=... - # AZURE_CONTAINER_NAME: your-container - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - ports: - - "8000:8000" - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - command: eai start --config /app/config.yml --host 0.0.0.0 --port 8000 --no-build-frontend --no-reload - - worker: - # Pre-built image from GitHub Container Registry - # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) - image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} - # For local development, uncomment below and comment out the image line: - build: - context: . - dockerfile: docker/Dockerfile.worker - args: - INSTALL_EXTRAS: "qualitative-voice" - container_name: efficientai_worker - environment: - # Worker uses Docker network, so it reaches DB/Redis via service names - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (match api service env when using gcs): - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - command: eai worker --config /app/config.yml --loglevel info - - # Dedicated worker for the call-import fan-out so large CSV imports - # don't starve synthetic-calling, audio generation, and evaluation jobs - # running on the default queue handled by the `worker` service above. - worker-imports: - image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} - build: - context: . - dockerfile: docker/Dockerfile.worker - args: - INSTALL_EXTRAS: "qualitative-voice" - container_name: efficientai_worker_imports - environment: - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (match api service env when using gcs): - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - command: eai worker --config /app/config.yml --loglevel info --queues imports --concurrency 4 - -volumes: - postgres_data: +# EfficientAI Docker Compose +# +# Usage: +# docker compose up # Uses latest images +# EFFICIENTAI_VERSION=1.0.0 docker compose up # Uses specific version +# +# For local development, uncomment the 'build' sections below. + +services: + db: + image: postgres:15 + container_name: efficientai_db + environment: + POSTGRES_USER: ${POSTGRES_USER:-efficientai} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} + POSTGRES_DB: ${POSTGRES_DB:-efficientai} + volumes: + - postgres_data:/var/lib/postgresql/data + ports: + - "5432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-efficientai}"] + interval: 10s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + container_name: efficientai_redis + ports: + - "6379:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + + api: + # Pre-built image from GitHub Container Registry + # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) + image: ghcr.io/efficientai-tech/efficientai-api:${EFFICIENTAI_VERSION:-latest} + # For local development, uncomment below and comment out the image line: + build: + context: . + dockerfile: docker/Dockerfile.api + container_name: efficientai_api + environment: + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + SECRET_KEY: ${SECRET_KEY:-your-secret-key-change-in-production} + UPLOAD_DIR: /app/uploads + DEBUG: ${DEBUG:-True} + FRONTEND_DIR: /app/frontend/dist + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (set storage.blob_provider: gcs in config.docker.yml): + # BLOB_STORAGE_PROVIDER: gcs + # GCS_BUCKET_NAME: your-gcs-bucket + # GCS_PROJECT_ID: your-gcp-project + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + # Optional Azure Blob Storage (set storage.blob_provider: azure in config.docker.yml): + # BLOB_STORAGE_PROVIDER: azure + # AZURE_BLOB_ENABLED: "true" + # AZURE_CONNECTION_STRING: DefaultEndpointsProtocol=https;AccountName=... + # AZURE_CONTAINER_NAME: your-container + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + ports: + - "8000:8000" + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + command: eai start --config /app/config.yml --host 0.0.0.0 --port 8000 --no-build-frontend --no-reload + + worker: + # Pre-built image from GitHub Container Registry + # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) + image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} + # For local development, uncomment below and comment out the image line: + build: + context: . + dockerfile: docker/Dockerfile.worker + args: + INSTALL_EXTRAS: "qualitative-voice" + container_name: efficientai_worker + environment: + # Worker uses Docker network, so it reaches DB/Redis via service names + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (match api service env when using gcs): + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + command: eai worker --config /app/config.yml --loglevel info + + # Dedicated worker for the call-import fan-out so large CSV imports + # don't starve synthetic-calling, audio generation, and evaluation jobs + # running on the default queue handled by the `worker` service above. + worker-imports: + image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} + build: + context: . + dockerfile: docker/Dockerfile.worker + args: + INSTALL_EXTRAS: "qualitative-voice" + container_name: efficientai_worker_imports + environment: + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (match api service env when using gcs): + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + command: eai worker --config /app/config.yml --loglevel info --queues imports --concurrency 4 + +volumes: + postgres_data: diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api index fb79cc57..22845648 100644 --- a/docker/Dockerfile.api +++ b/docker/Dockerfile.api @@ -1,77 +1,77 @@ -# ============================================================ -# Stage 1: Frontend build (Node only — not shipped to runtime) -# ============================================================ -FROM node:18-bookworm-slim AS frontend-builder - -WORKDIR /app/frontend - -COPY frontend/package.json frontend/package-lock.json* ./ -RUN --mount=type=cache,target=/root/.npm \ - npm ci --legacy-peer-deps || npm install --legacy-peer-deps - -COPY frontend/ ./ -RUN npm run build - -# ============================================================ -# Stage 2: API runtime (Python only — no node_modules / esbuild) -# ============================================================ -FROM python:3.11-slim - -# Install build tools and WeasyPrint system libraries. -# build-essential + cmake: packages like praat-parselmouth on arm64/Apple Silicon -# libgobject/libpango/libcairo/libgdk-pixbuf/libffi/shared-mime-info: PDF rendering -RUN apt-get update && apt-get install -y \ - build-essential \ - cmake \ - libgobject-2.0-0 \ - libpango-1.0-0 \ - libpangocairo-1.0-0 \ - libcairo2 \ - libgdk-pixbuf-2.0-0 \ - libffi-dev \ - shared-mime-info \ - && rm -rf /var/lib/apt/lists/* - -# Install uv via pip to get the glibc-linked binary (the musl-linked binary -# from ghcr.io/astral-sh/uv has DNS resolution issues inside Docker) -RUN pip install --no-cache-dir uv - -WORKDIR /app - -# ============================================================ -# LAYER 1: Python dependencies (cached unless pyproject.toml changes) -# ============================================================ -COPY pyproject.toml README.md ./ - -# Create minimal src structure for editable install -RUN mkdir -p src/efficientai && touch src/efficientai/__init__.py - -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=cache,target=/root/.cache/pip \ - uv pip install --system -e . - -# Security: fail build if compromised litellm .pth file is present (CVE: supply chain attack on litellm 1.82.8) -RUN if find /usr/local/lib/ -name "litellm_init.pth" 2>/dev/null | grep -q .; then \ - echo "SECURITY: Malicious litellm_init.pth detected! Aborting build." && exit 1; \ - fi - -# ============================================================ -# LAYER 2: Frontend static assets (built in stage 1, dist only) -# ============================================================ -COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist - -# ============================================================ -# LAYER 3: Backend code (rebuilds on backend code changes) -# ============================================================ -COPY src/ ./src/ -COPY app/ ./app/ -COPY scripts/ ./scripts/ - -RUN uv pip install --system -e . - -# Create uploads directory -RUN mkdir -p /app/uploads - -EXPOSE 8000 - -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +# ============================================================ +# Stage 1: Frontend build (Node only — not shipped to runtime) +# ============================================================ +FROM node:18-bookworm-slim AS frontend-builder + +WORKDIR /app/frontend + +COPY frontend/package.json frontend/package-lock.json* ./ +RUN --mount=type=cache,target=/root/.npm \ + npm ci --legacy-peer-deps || npm install --legacy-peer-deps + +COPY frontend/ ./ +RUN npm run build + +# ============================================================ +# Stage 2: API runtime (Python only — no node_modules / esbuild) +# ============================================================ +FROM python:3.11-slim + +# Install build tools and WeasyPrint system libraries. +# build-essential + cmake: packages like praat-parselmouth on arm64/Apple Silicon +# libgobject/libpango/libcairo/libgdk-pixbuf/libffi/shared-mime-info: PDF rendering +RUN apt-get update && apt-get install -y \ + build-essential \ + cmake \ + libgobject-2.0-0 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libcairo2 \ + libgdk-pixbuf-2.0-0 \ + libffi-dev \ + shared-mime-info \ + && rm -rf /var/lib/apt/lists/* + +# Install uv via pip to get the glibc-linked binary (the musl-linked binary +# from ghcr.io/astral-sh/uv has DNS resolution issues inside Docker) +RUN pip install --no-cache-dir uv + +WORKDIR /app + +# ============================================================ +# LAYER 1: Python dependencies (cached unless pyproject.toml changes) +# ============================================================ +COPY pyproject.toml README.md ./ + +# Create minimal src structure for editable install +RUN mkdir -p src/efficientai && touch src/efficientai/__init__.py + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=cache,target=/root/.cache/pip \ + uv pip install --system -e . + +# Security: fail build if compromised litellm .pth file is present (CVE: supply chain attack on litellm 1.82.8) +RUN if find /usr/local/lib/ -name "litellm_init.pth" 2>/dev/null | grep -q .; then \ + echo "SECURITY: Malicious litellm_init.pth detected! Aborting build." && exit 1; \ + fi + +# ============================================================ +# LAYER 2: Frontend static assets (built in stage 1, dist only) +# ============================================================ +COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist + +# ============================================================ +# LAYER 3: Backend code (rebuilds on backend code changes) +# ============================================================ +COPY src/ ./src/ +COPY app/ ./app/ +COPY scripts/ ./scripts/ + +RUN uv pip install --system -e . + +# Create uploads directory +RUN mkdir -p /app/uploads + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/tests/test_services/test_flexprice_service.py b/tests/test_services/test_flexprice_service.py index db028f9c..7c6778f1 100644 --- a/tests/test_services/test_flexprice_service.py +++ b/tests/test_services/test_flexprice_service.py @@ -15,12 +15,15 @@ def reset_flexprice_settings(): settings.FLEXPRICE_ENABLED, settings.FLEXPRICE_API_KEY, settings.FLEXPRICE_API_HOST, + svc._disabled_skip_logged, ) + svc._disabled_skip_logged = False yield ( settings.FLEXPRICE_ENABLED, settings.FLEXPRICE_API_KEY, settings.FLEXPRICE_API_HOST, + svc._disabled_skip_logged, ) = previous @@ -42,6 +45,13 @@ def test_is_enabled_true_when_configured(): assert svc.is_enabled() is True +def test_disabled_reason_when_api_key_missing(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = None + assert svc.disabled_reason() is not None + assert "api_key" in svc.disabled_reason() + + @patch("flexprice.Flexprice") def test_ensure_customer_no_op_when_disabled(mock_flexprice): settings.FLEXPRICE_ENABLED = False From 212075cf83a6ac42da343e31fd078bfa78bc11e8 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Mon, 29 Jun 2026 20:10:26 +0000 Subject: [PATCH 09/11] fix: adding more verbose logs --- app/services/billing/flexprice_service.py | 35 +++++++++++++++++++++++ app/workers/config.py | 12 +++----- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/app/services/billing/flexprice_service.py b/app/services/billing/flexprice_service.py index 665ea9b1..a0de32cb 100644 --- a/app/services/billing/flexprice_service.py +++ b/app/services/billing/flexprice_service.py @@ -102,6 +102,41 @@ def log_startup_status(*, component: str = "app") -> None: _mask_api_key(settings.FLEXPRICE_API_KEY), ) + connectivity_error = _verify_connectivity() + if connectivity_error: + logger.warning( + "Flexprice connectivity check FAILED for {} — host={} error={}. " + "Config looks valid but outbound calls may be blocked (check AWS egress/NAT/SG).", + component, + settings.FLEXPRICE_API_HOST, + connectivity_error, + ) + else: + logger.info( + "Flexprice connectivity check OK for {} — host={}", + component, + settings.FLEXPRICE_API_HOST, + ) + + +def _verify_connectivity() -> Optional[str]: + """Best-effort reachability probe; returns error text or None when OK.""" + try: + import httpx + + base = settings.FLEXPRICE_API_HOST.rstrip("/") + response = httpx.get( + f"{base}/customers", + headers={"x-api-key": settings.FLEXPRICE_API_KEY or ""}, + params={"limit": 1}, + timeout=10.0, + ) + if response.status_code < 400: + return None + return f"HTTP {response.status_code}: {response.text[:200]}" + except Exception as exc: + return str(exc) + def _is_customer_already_exists(exc: Exception) -> bool: message = str(exc).lower() diff --git a/app/workers/config.py b/app/workers/config.py index bdc8db4b..65b92c69 100644 --- a/app/workers/config.py +++ b/app/workers/config.py @@ -66,6 +66,10 @@ except Exception as e: logger.warning(f"⚠️ Celery worker: Could not load config.yml: {e}") +from app.services.billing.flexprice_service import log_startup_status # noqa: E402 + +log_startup_status(component="celery-worker") + # Create Celery app celery_app = Celery( "efficientai", @@ -116,11 +120,3 @@ "generate_agent_flowchart": {"queue": "imports"}, "map_agent_flowchart_prompt_sections": {"queue": "imports"}, } - - -@celery_app.on_after_configure.connect -def _log_flexprice_on_worker_configure(sender, **kwargs) -> None: - del sender, kwargs - from app.services.billing.flexprice_service import log_startup_status - - log_startup_status(component="celery-worker") From d9729b01243b5cee643ac3342294ff8d8c28f265 Mon Sep 17 00:00:00 2001 From: Tejas Narayan Date: Tue, 30 Jun 2026 08:48:35 +0000 Subject: [PATCH 10/11] feat: adding sarvam models --- app/config/models.json | 10 ++++ app/core/operational_access_middleware.py | 7 ++- app/services/ai/llm_service.py | 55 ++++++++++++++----- app/services/judge_alignment/model_catalog.py | 37 +++++++++++-- config.docker.yml | 7 +++ config.yml.example | 6 +- .../providers/ProviderModelPicker.tsx | 43 ++++++++++----- frontend/src/config/providers.ts | 4 +- .../src/pages/configurations/Integrations.tsx | 2 +- .../test_operational_access_middleware.py | 41 ++++++++++++++ .../test_services/test_ai/test_llm_service.py | 19 +++---- 11 files changed, 178 insertions(+), 53 deletions(-) create mode 100644 tests/test_core/test_operational_access_middleware.py diff --git a/app/config/models.json b/app/config/models.json index b2f928d4..23c0005c 100644 --- a/app/config/models.json +++ b/app/config/models.json @@ -510,6 +510,16 @@ "model_type": "tts", "description": "Smallest Lightning v3.1 low-latency neural TTS" }, + "sarvam-30b": { + "provider": "sarvam", + "model_type": "llm", + "description": "Sarvam 30B — balanced Indian-language + English chat model (64K context)" + }, + "sarvam-105b": { + "provider": "sarvam", + "model_type": "llm", + "description": "Sarvam 105B — flagship MoE for complex reasoning, coding, and agentic workflows (128K context)" + }, "saarika:v2.5": { "provider": "sarvam", "model_type": "stt", diff --git a/app/core/operational_access_middleware.py b/app/core/operational_access_middleware.py index 1d53c9ad..d329b572 100644 --- a/app/core/operational_access_middleware.py +++ b/app/core/operational_access_middleware.py @@ -1,4 +1,4 @@ -"""Restrict operational endpoints (/health, /metrics) from the public internet.""" +"""Restrict /metrics from the public internet (/health stays open for load balancers).""" from __future__ import annotations @@ -107,11 +107,12 @@ def is_operational_access_allowed(request: Request) -> bool: def is_operational_path(path: str) -> bool: - return path == _HEALTH_PATH or path.startswith(_METRICS_PREFIX) + # /health must stay reachable for load balancers (ALB/kube) without auth. + return path.startswith(_METRICS_PREFIX) class OperationalAccessMiddleware(BaseHTTPMiddleware): - """Block anonymous public access to /health and /metrics.""" + """Block anonymous public access to /metrics (/health is always allowed for LB probes).""" async def dispatch(self, request: Request, call_next) -> Response: if is_operational_path(request.url.path) and not is_operational_access_allowed(request): diff --git a/app/services/ai/llm_service.py b/app/services/ai/llm_service.py index 407ee51d..40b1c056 100644 --- a/app/services/ai/llm_service.py +++ b/app/services/ai/llm_service.py @@ -18,7 +18,7 @@ from sqlalchemy.orm import Session from app.models.database import ModelProvider, AIProvider -from app.services.credentials import resolve_ai_provider +from app.services.credentials import resolve_ai_provider, resolve_integration from app.services.ai.llm_generation_config import build_litellm_kwargs # LiteLLM will silently drop params the target provider doesn't support @@ -36,6 +36,7 @@ "groq": "groq", "xai": "xai", "fireworks": "fireworks_ai", + "sarvam": "sarvam", } # Matches the model-name half of the Gemini 2.5 family: ``gemini-2.5-pro``, @@ -160,6 +161,43 @@ def _get_ai_provider( provider, db, organization_id, credential_id=credential_id ) + def _resolve_api_key( + self, + provider: ModelProvider, + db: Session, + organization_id: UUID, + credential_id: Optional[UUID] = None, + ) -> str: + """Resolve and decrypt an API key from AIProvider or Integration tables.""" + from app.core.encryption import decrypt_api_key + + ai_provider = self._get_ai_provider( + provider, db, organization_id, credential_id=credential_id + ) + if ai_provider: + try: + return decrypt_api_key(ai_provider.api_key) + except Exception as e: + raise RuntimeError( + f"Failed to decrypt API key for provider {provider}: {e}" + ) + + integration = resolve_integration( + provider, db, organization_id, credential_id=credential_id + ) + if integration: + try: + return decrypt_api_key(integration.api_key) + except Exception as e: + raise RuntimeError( + f"Failed to decrypt API key for provider {provider}: {e}" + ) + + provider_label = provider.value if hasattr(provider, "value") else str(provider) + raise RuntimeError( + f"AI provider {provider_label} not configured for this organization." + ) + @staticmethod def _litellm_model_name(provider: ModelProvider, model: str) -> str: """Build the ``provider/model`` string that LiteLLM expects.""" @@ -208,22 +246,9 @@ def generate_response( start_time = time.time() # --- resolve API key from database -------------------------------- - ai_provider = self._get_ai_provider( + api_key = self._resolve_api_key( llm_provider, db, organization_id, credential_id=credential_id ) - if not ai_provider: - raise RuntimeError( - f"AI provider {llm_provider} not configured for this organization." - ) - - from app.core.encryption import decrypt_api_key - - try: - api_key = decrypt_api_key(ai_provider.api_key) - except Exception as e: - raise RuntimeError( - f"Failed to decrypt API key for provider {llm_provider}: {e}" - ) # --- call LiteLLM -------------------------------------------------- model_str = self._litellm_model_name(llm_provider, llm_model) diff --git a/app/services/judge_alignment/model_catalog.py b/app/services/judge_alignment/model_catalog.py index e7f4afda..e727bffe 100644 --- a/app/services/judge_alignment/model_catalog.py +++ b/app/services/judge_alignment/model_catalog.py @@ -20,14 +20,15 @@ from loguru import logger from sqlalchemy.orm import Session -from app.models.database import AIProvider, ModelProvider +from app.models.database import AIProvider, Integration, ModelProvider +from app.models.enums import IntegrationPlatform from app.services.ai.model_config_service import model_config_service # Providers that ship LLM models we can use as a judge. We deliberately # exclude STT/TTS-only vendors (Deepgram, Cartesia, ElevenLabs, Murf, -# Sarvam, Voicemaker, Smallest) because LiteLLM has no LLM completion -# route for them. +# Voicemaker, Smallest) because LiteLLM has no LLM completion route for +# them. Sarvam is included because it exposes chat models via LiteLLM. _LLM_CAPABLE_PROVIDERS = { ModelProvider.OPENAI.value, ModelProvider.ANTHROPIC.value, @@ -43,6 +44,13 @@ ModelProvider.AWS.value, ModelProvider.OPENROUTER.value, ModelProvider.CUSTOM.value, + ModelProvider.SARVAM.value, +} + +# Voice-platform integrations that also expose LLM models (credentials +# live in the Integration table rather than AIProvider). +_INTEGRATION_LLM_PLATFORMS = { + IntegrationPlatform.SARVAM.value: ModelProvider.SARVAM.value, } @@ -63,6 +71,7 @@ def _provider_label(provider_value: str) -> str: "aws": "AWS", "openrouter": "OpenRouter", "custom": "Custom", + "sarvam": "Sarvam", } return overrides.get(provider_value.lower(), provider_value.title()) @@ -98,11 +107,29 @@ def list_judge_capable_models( .all() ) + configured_provider_values: set[str] = { + (ai_provider.provider or "").lower() + for ai_provider in providers + } + + integrations: List[Integration] = ( + db.query(Integration) + .filter( + Integration.organization_id == organization_id, + Integration.is_active == True, # noqa: E712 - SQLAlchemy comparison + ) + .all() + ) + for integration in integrations: + platform = (integration.platform or "").lower() + mapped = _INTEGRATION_LLM_PLATFORMS.get(platform) + if mapped: + configured_provider_values.add(mapped) + catalog: List[Dict[str, str]] = [] seen: set = set() - for ai_provider in providers: - provider_value = (ai_provider.provider or "").lower() + for provider_value in configured_provider_values: if provider_value not in _LLM_CAPABLE_PROVIDERS: continue diff --git a/config.docker.yml b/config.docker.yml index a1a7e964..0438d86e 100644 --- a/config.docker.yml +++ b/config.docker.yml @@ -13,6 +13,13 @@ server: host: "0.0.0.0" port: 8000 +# Operational endpoints (/metrics). /health is always open for load balancers. +# Add VPC CIDRs here if Prometheus scrapes /metrics from inside the VPC. +operational: + public: false + trusted_ips: + - "10.0.0.0/8" + # Database Configuration (Docker service name) database: url: "postgresql://efficientai:password@db:5432/efficientai" diff --git a/config.yml.example b/config.yml.example index 6b502edf..1dad08dc 100644 --- a/config.yml.example +++ b/config.yml.example @@ -12,10 +12,8 @@ server: host: "0.0.0.0" port: 8000 -# Operational endpoints (/health, /metrics) -# Keep public: false in production/sandbox. ALB health checks connect -# directly from VPC addresses (no X-Forwarded-For). Include your VPC/LB CIDRs -# in trusted_ips so probes and internal scrapers are allowed. +# Operational endpoints (/metrics). /health is always open for ALB/kube probes. +# Include VPC CIDRs in trusted_ips if Prometheus scrapes /metrics from inside the VPC. operational: public: false trusted_ips: diff --git a/frontend/src/components/providers/ProviderModelPicker.tsx b/frontend/src/components/providers/ProviderModelPicker.tsx index c2f26032..8e364288 100644 --- a/frontend/src/components/providers/ProviderModelPicker.tsx +++ b/frontend/src/components/providers/ProviderModelPicker.tsx @@ -131,6 +131,11 @@ const AUDIO_CAPABLE_MODEL_MATCHERS: Record = { } const AUDIO_CAPABLE_PROVIDERS = Object.keys(AUDIO_CAPABLE_MODEL_MATCHERS) +// Voice-platform integrations that also expose LLM models. Credentials +// for these providers live in the Integration table (Configurations → +// Integrations → Voice Platform), not AIProvider. +const INTEGRATION_LLM_PLATFORMS = new Set(['sarvam']) + function isAudioCapableModel(provider: string, model: string): boolean { const matchers = AUDIO_CAPABLE_MODEL_MATCHERS[provider] if (!matchers) return false @@ -161,9 +166,32 @@ export default function ProviderModelPicker({ const { data: integrations = [] } = useQuery({ queryKey: ['integrations'], queryFn: () => apiClient.listIntegrations(), - enabled: kind === 'stt', + enabled: kind === 'stt' || kind === 'llm', }) + const integrationCredentialRows: CredentialRow[] = + kind === 'stt' + ? integrations.map((i) => ({ + id: i.id, + provider: (i.platform || '').toLowerCase(), + is_active: i.is_active, + is_default: i.is_default, + name: i.name ?? null, + source: 'integration', + })) + : integrations + .filter((i) => + INTEGRATION_LLM_PLATFORMS.has((i.platform || '').toLowerCase()), + ) + .map((i) => ({ + id: i.id, + provider: (i.platform || '').toLowerCase(), + is_active: i.is_active, + is_default: i.is_default, + name: i.name ?? null, + source: 'integration', + })) + const allCredentials: CredentialRow[] = [ ...aiProviders.map((p) => ({ id: p.id, @@ -173,18 +201,7 @@ export default function ProviderModelPicker({ name: p.name ?? null, source: 'aiprovider', })), - // Only merge integrations for STT — LLM credentials never live in - // that table, so skipping it keeps the LLM picker unchanged. - ...(kind === 'stt' - ? integrations.map((i) => ({ - id: i.id, - provider: (i.platform || '').toLowerCase(), - is_active: i.is_active, - is_default: i.is_default, - name: i.name ?? null, - source: 'integration', - })) - : []), + ...integrationCredentialRows, ] const allowSet = providerAllowList diff --git a/frontend/src/config/providers.ts b/frontend/src/config/providers.ts index 6f948b80..50e4d04e 100644 --- a/frontend/src/config/providers.ts +++ b/frontend/src/config/providers.ts @@ -100,7 +100,7 @@ export const MODEL_PROVIDER_CONFIG: Record = { [ModelProvider.SARVAM]: { label: 'Sarvam', logo: '/sarvam.png', - description: 'Sarvam STT & TTS', + description: 'Sarvam STT, TTS & LLM', }, [ModelProvider.VOICEMAKER]: { label: 'VoiceMaker', @@ -162,7 +162,7 @@ export const INTEGRATION_PLATFORM_CONFIG: Record Date: Tue, 30 Jun 2026 09:08:13 +0000 Subject: [PATCH 11/11] feat: updating test cases --- app/workers/tasks/helpers/llm_diarisation.py | 83 ++++++++++++++++++- ...st_call_import_diarization_and_eval_llm.py | 80 ++++++++++++++++++ tests/test_core/test_operational_endpoints.py | 5 +- 3 files changed, 162 insertions(+), 6 deletions(-) diff --git a/app/workers/tasks/helpers/llm_diarisation.py b/app/workers/tasks/helpers/llm_diarisation.py index ef9c3673..0768b919 100644 --- a/app/workers/tasks/helpers/llm_diarisation.py +++ b/app/workers/tasks/helpers/llm_diarisation.py @@ -128,6 +128,79 @@ # transcript so a 1 MB CSV cell can't blow up the prompt window. _MAX_TRANSCRIPT_CHARS = 60_000 +# Output token budget for the diariser LLM. The JSON mirrors the input +# transcript (split into agent/user turns) plus structural overhead per +# turn. Without an explicit ``max_tokens`` many providers default to a +# low cap (e.g. 4096) and truncate mid-JSON on longer calls. +_DIARISATION_MIN_MAX_TOKENS = 4096 +_DIARISATION_MAX_MAX_TOKENS = 16_384 + + +def _estimate_diarisation_max_tokens( + *, text_length: int = 0, audio_bytes: int = 0 +) -> int: + """Scale the diariser output budget to the input payload size.""" + if text_length > 0: + # ~0.45 tokens/char covers JSON keys + per-turn wrapping. + estimated = int(text_length * 0.45) + 512 + elif audio_bytes > 0: + # Rough proxy: ~150 spoken chars/sec, ~320 KiB/min mono 16 kHz. + estimated_chars = (audio_bytes / 320_000) * 150 * 60 + estimated = int(estimated_chars * 0.45) + 512 + else: + estimated = _DIARISATION_MIN_MAX_TOKENS + return min( + _DIARISATION_MAX_MAX_TOKENS, + max(_DIARISATION_MIN_MAX_TOKENS, estimated), + ) + + +def _generate_diarisation_response( + *, + messages: List[Dict[str, Any]], + provider_enum: ModelProvider, + llm_model: str, + organization_id: UUID, + db: Session, + temperature: float, + credential_id: Optional[UUID], + config: Dict[str, Any], + content_length: int = 0, + audio_bytes: int = 0, +) -> Dict[str, Any]: + """Call the diariser LLM with a scaled ``max_tokens`` and one retry.""" + base_budget = _estimate_diarisation_max_tokens( + text_length=content_length, + audio_bytes=audio_bytes, + ) + retry_budget = min(_DIARISATION_MAX_MAX_TOKENS, base_budget * 2) + budgets = [base_budget] if retry_budget <= base_budget else [base_budget, retry_budget] + + response: Dict[str, Any] = {} + for attempt_idx, max_tokens in enumerate(budgets): + response = llm_service.generate_response( + messages=messages, + llm_provider=provider_enum, + llm_model=llm_model, + organization_id=organization_id, + db=db, + temperature=temperature, + credential_id=credential_id, + config=config, + task_defaults={"max_tokens": max_tokens}, + ) + if not response.get("truncated"): + return response + if attempt_idx < len(budgets) - 1: + logger.warning( + "Diariser LLM response truncated at max_tokens={} " + "(model={}); retrying with max_tokens={}.", + max_tokens, + llm_model, + budgets[attempt_idx + 1], + ) + return response + # Synonym keys we accept from over-creative LLMs that don't follow the # canonical ``{"speaker": ..., "text": ...}`` schema verbatim. Without @@ -460,15 +533,16 @@ def diarize_transcript_with_llm( # other providers. config = {"response_format": {"type": "json_object"}} - response = llm_service.generate_response( + response = _generate_diarisation_response( messages=messages, - llm_provider=provider_enum, + provider_enum=provider_enum, llm_model=llm_model, organization_id=organization_id, db=db, temperature=temperature, credential_id=credential_id, config=config, + content_length=len(cleaned), ) turns = _parse_turns_from_response(response) @@ -1042,15 +1116,16 @@ def diarize_audio_with_llm( config = {"response_format": {"type": "json_object"}} try: - response = llm_service.generate_response( + response = _generate_diarisation_response( messages=messages, - llm_provider=provider_enum, + provider_enum=provider_enum, llm_model=llm_model, organization_id=organization_id, db=db, temperature=temperature, credential_id=credential_id, config=config, + audio_bytes=len(audio_bytes), ) except Exception as exc: # LiteLLM surfaces "model does not support audio input" as a diff --git a/tests/test_api/test_call_import_diarization_and_eval_llm.py b/tests/test_api/test_call_import_diarization_and_eval_llm.py index e85e73fa..ffd32385 100644 --- a/tests/test_api/test_call_import_diarization_and_eval_llm.py +++ b/tests/test_api/test_call_import_diarization_and_eval_llm.py @@ -1313,3 +1313,83 @@ def test_compact_diarisation_error_strips_details_block(): assert _compact_diarisation_error(raw) == ( "Diarisation LLM provider/model not configured." ) + + +def test_estimate_diarisation_max_tokens_scales_with_transcript(): + from app.workers.tasks.helpers import llm_diarisation + + short = llm_diarisation._estimate_diarisation_max_tokens(text_length=500) + long = llm_diarisation._estimate_diarisation_max_tokens(text_length=40_000) + assert short >= llm_diarisation._DIARISATION_MIN_MAX_TOKENS + assert long > short + assert long <= llm_diarisation._DIARISATION_MAX_MAX_TOKENS + + +def test_diariser_passes_scaled_max_tokens(monkeypatch): + from app.workers.tasks.helpers import llm_diarisation + + captured: dict = {} + + def _fake_generate_response(**kwargs): + captured.update(kwargs) + return { + "text": '{"turns": [{"speaker": "agent", "text": "Hi"}]}', + "truncated": False, + } + + monkeypatch.setattr( + llm_diarisation.llm_service, + "generate_response", + _fake_generate_response, + ) + + llm_diarisation.diarize_transcript_with_llm( + "Hi there.", + llm_provider="openai", + llm_model="gpt-4o-mini", + organization_id=uuid4(), + db=SimpleNamespace(), + ) + + assert captured["task_defaults"]["max_tokens"] >= ( + llm_diarisation._DIARISATION_MIN_MAX_TOKENS + ) + + +def test_diariser_retries_when_first_response_truncated(monkeypatch): + from app.workers.tasks.helpers import llm_diarisation + + budgets: list[int] = [] + + def _fake_generate_response(**kwargs): + budgets.append(kwargs["task_defaults"]["max_tokens"]) + if len(budgets) == 1: + return { + "text": '{"turns": [{"speaker": "agent", "text": "partial', + "truncated": True, + } + return { + "text": '{"turns": [{"speaker": "agent", "text": "Done"}]}', + "truncated": False, + } + + monkeypatch.setattr( + llm_diarisation.llm_service, + "generate_response", + _fake_generate_response, + ) + + turns = llm_diarisation.diarize_transcript_with_llm( + "word " * 500, + llm_provider="openai", + llm_model="gpt-4o-mini", + organization_id=uuid4(), + db=SimpleNamespace(), + ) + + assert len(budgets) == 2 + assert budgets[1] == min( + llm_diarisation._DIARISATION_MAX_MAX_TOKENS, + budgets[0] * 2, + ) + assert turns[0]["text"] == "Done" diff --git a/tests/test_core/test_operational_endpoints.py b/tests/test_core/test_operational_endpoints.py index 4fd116f6..d15416c4 100644 --- a/tests/test_core/test_operational_endpoints.py +++ b/tests/test_core/test_operational_endpoints.py @@ -66,13 +66,14 @@ def metrics(): yield client -def test_public_health_via_alb_is_blocked(operational_client): +def test_public_health_via_alb_is_allowed_for_lb_probes(operational_client): + """/health stays open for ALB/kube probes; only /metrics is gated.""" response = operational_client.get( "/health", headers={"X-Forwarded-For": "203.0.113.1", "User-Agent": "SecurityScanner/1.0"}, ) - assert response.status_code == 404 + assert response.status_code == 200 def test_spoofed_user_agent_does_not_bypass_gate(monkeypatch):