diff --git a/app/api/v1/routes/auth.py b/app/api/v1/routes/auth.py index b813e33f..97dbc838 100644 --- a/app/api/v1/routes/auth.py +++ b/app/api/v1/routes/auth.py @@ -50,7 +50,10 @@ RoleEnum, User, ) -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) router = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -324,6 +327,11 @@ def signup(payload: SignupRequest, db: Session = Depends(get_db)) -> TokenRespon organization_id=organization.id, created_by_user_id=user.id, ) + provision_billing_customer( + organization_id=organization.id, + name=org_name, + email=payload.email, + ) user.last_login_at = datetime.now(timezone.utc) db.commit() db.refresh(user) diff --git a/app/api/v1/routes/call_import_evaluations.py b/app/api/v1/routes/call_import_evaluations.py index 33562c96..53b0a11e 100644 --- a/app/api/v1/routes/call_import_evaluations.py +++ b/app/api/v1/routes/call_import_evaluations.py @@ -15,7 +15,7 @@ from datetime import date, datetime, timedelta, timezone -from fastapi import APIRouter, Body, Depends, HTTPException, Query, Response, status +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Query, Response, status from fastapi.responses import StreamingResponse from loguru import logger from pydantic import BaseModel, Field, field_validator @@ -557,6 +557,7 @@ def _rollup_evaluation_status(evaluation: CallImportEvaluation, db: Session) -> async def create_call_import_evaluation( call_import_id: UUID, payload: CallImportEvaluationCreate, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), db: Session = Depends(get_db), diff --git a/app/api/v1/routes/call_imports.py b/app/api/v1/routes/call_imports.py index 184fad53..56daa19d 100644 --- a/app/api/v1/routes/call_imports.py +++ b/app/api/v1/routes/call_imports.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from uuid import UUID -from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, Response, UploadFile, status +from fastapi import APIRouter, Body, BackgroundTasks, Depends, File, Form, HTTPException, Query, Response, UploadFile, status from loguru import logger from sqlalchemy import desc, func, or_ from sqlalchemy.orm import Session @@ -32,6 +32,7 @@ get_workspace_id, require_enterprise_feature, ) +from app.services.billing.flexprice_service import record_call_import_batch_created from app.models.database import ( CallImport, CallImportRow, @@ -1556,6 +1557,7 @@ async def update_call_import_mapping( async def start_call_import( call_import_id: UUID, payload: CallImportStartRequest, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -1682,6 +1684,16 @@ async def start_call_import( db.commit() db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="csv", + provider=call_import.provider, + ) + _enqueue_row_tasks(db, call_import, row_models) return CallImportUploadResponse( @@ -1705,6 +1717,7 @@ async def start_call_import( deprecated=True, ) async def upload_call_import_csv( + background_tasks: BackgroundTasks, file: UploadFile = File(...), provider: Optional[str] = Form( None, @@ -1886,6 +1899,16 @@ async def upload_call_import_csv( db.commit() db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="csv", + provider=call_import.provider, + ) + _enqueue_row_tasks(db, call_import, row_models) return CallImportUploadResponse( @@ -1908,6 +1931,7 @@ async def upload_call_import_csv( operation_id="uploadCallImportAudio", ) async def upload_call_import_audio( + background_tasks: BackgroundTasks, files: List[UploadFile] = File( ..., description="One or more manual call recording audio files.", @@ -2076,6 +2100,15 @@ async def upload_call_import_audio( ) from exc db.refresh(call_import) + background_tasks.add_task( + record_call_import_batch_created, + organization_id, + call_import.id, + workspace_id=workspace_id, + total_rows=call_import.total_rows, + source="audio", + provider=None, + ) return CallImportUploadResponse( id=call_import.id, total_rows=call_import.total_rows, diff --git a/app/api/v1/routes/chat.py b/app/api/v1/routes/chat.py index b25d731e..a8317844 100644 --- a/app/api/v1/routes/chat.py +++ b/app/api/v1/routes/chat.py @@ -2,14 +2,15 @@ Chat/Inference API Routes For generating responses from AI models """ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from sqlalchemy.orm import Session from typing import List, Dict, Optional, Any -from uuid import UUID +from uuid import UUID, uuid4 from pydantic import BaseModel -from app.dependencies import get_db, get_organization_id +from app.dependencies import get_db, get_organization_id, get_workspace_id from app.services.ai.llm_service import llm_service +from app.services.billing.flexprice_service import record_chat_completion from app.models.schemas import ModelProvider router = APIRouter(prefix="/chat", tags=["chat"]) @@ -39,7 +40,9 @@ class ChatResponse(BaseModel): @router.post("/completion", response_model=ChatResponse) async def chat_completion( request: ChatRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) ): """Generate a chat completion using the specified AI provider and model.""" @@ -59,6 +62,14 @@ async def chat_completion( task_defaults={"temperature": 0.7}, ) + background_tasks.add_task( + record_chat_completion, + organization_id, + uuid4(), + workspace_id=workspace_id, + model=result.get("model", request.model), + ) + return ChatResponse( text=result.get("text", ""), model=result.get("model", request.model), diff --git a/app/api/v1/routes/evaluations.py b/app/api/v1/routes/evaluations.py index 840d3803..10b5ea3d 100644 --- a/app/api/v1/routes/evaluations.py +++ b/app/api/v1/routes/evaluations.py @@ -1,6 +1,6 @@ """Evaluation routes.""" -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from sqlalchemy.orm import Session from typing import List from uuid import UUID @@ -13,6 +13,7 @@ EvaluationStatusResponse, MessageResponse, ) +from app.services.billing.flexprice_service import record_evaluation_created from app.workers.celery_app import process_evaluation_task from app.core.exceptions import EvaluationNotFoundError @@ -22,6 +23,7 @@ @router.post("/create", response_model=EvaluationResponse, status_code=201) def create_evaluation( evaluation_data: EvaluationCreate, + background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key), organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -55,6 +57,15 @@ def create_evaluation( db.commit() db.refresh(evaluation) + background_tasks.add_task( + record_evaluation_created, + organization_id, + evaluation.id, + workspace_id=workspace_id, + audio_id=evaluation.audio_id, + metrics_requested=len(evaluation.metrics_requested or []), + ) + # Queue async task process_evaluation_task.delay(str(evaluation.id)) diff --git a/app/api/v1/routes/evaluators.py b/app/api/v1/routes/evaluators.py index 610ba16d..0e13b1f8 100644 --- a/app/api/v1/routes/evaluators.py +++ b/app/api/v1/routes/evaluators.py @@ -1,10 +1,10 @@ """Evaluator routes.""" -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status, Query from fastapi.responses import JSONResponse, Response from sqlalchemy.orm import Session from sqlalchemy import and_ -from uuid import UUID +from uuid import UUID, uuid4 import random from typing import List from pydantic import BaseModel @@ -12,6 +12,7 @@ from app.database import get_db from app.dependencies import get_organization_id, get_workspace_id, get_api_key +from app.services.billing.flexprice_service import record_evaluator_run_requested from app.models.database import Evaluator, Agent, Persona, Scenario, EvaluatorResult, EvaluatorResultStatus, VoiceBundle, Metric from app.models.schemas import ( EvaluatorCreate, @@ -548,6 +549,7 @@ def delete_evaluator( @router.post("/run", response_model=RunEvaluatorsResponse, status_code=200) def run_evaluators( request: RunEvaluatorsRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), @@ -649,7 +651,15 @@ def run_evaluators( if not task_ids: raise HTTPException(status_code=500, detail="Failed to create any tasks") - + + background_tasks.add_task( + record_evaluator_run_requested, + organization_id, + uuid4(), + workspace_id=workspace_id, + quantity=len(request.evaluator_ids), + ) + return RunEvaluatorsResponse( task_ids=task_ids, evaluator_results=evaluator_results diff --git a/app/api/v1/routes/judge_alignment.py b/app/api/v1/routes/judge_alignment.py index 9fab1ee6..1bb96006 100644 --- a/app/api/v1/routes/judge_alignment.py +++ b/app/api/v1/routes/judge_alignment.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile, status from loguru import logger from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.orm import Session @@ -29,6 +29,7 @@ get_workspace_id, Principal, ) +from app.services.billing.flexprice_service import record_judge_alignment_run_started from app.models.database import ( Agent, Evaluator, @@ -617,6 +618,7 @@ def list_runs( def trigger_judge_run( dataset_id: UUID, body: JudgeRunCreate, + background_tasks: BackgroundTasks, principal: Principal = Depends(get_principal), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), @@ -662,6 +664,19 @@ def trigger_judge_run( ), ) + from app.services.judge_alignment.judge_runner import select_samples_for_split + + sample_id_strs = ( + [str(s) for s in (body.sample_ids or [])] if body.split != "all" else None + ) + selected_samples = select_samples_for_split( + dataset_id=dataset.id, + split=body.split, + db=db, + sample_ids=sample_id_strs, + ) + sample_count = len(selected_samples) + judge_run = JudgeRun( dataset_id=dataset.id, organization_id=organization_id, @@ -679,10 +694,6 @@ def trigger_judge_run( db.commit() db.refresh(judge_run) - sample_id_strs = ( - [str(s) for s in (body.sample_ids or [])] if body.split != "all" else None - ) - try: from app.workers.tasks.run_judge_alignment import run_judge_alignment_task @@ -692,6 +703,14 @@ def trigger_judge_run( judge_run.celery_task_id = async_result.id db.commit() db.refresh(judge_run) + background_tasks.add_task( + record_judge_alignment_run_started, + organization_id, + judge_run.id, + workspace_id=workspace_id, + dataset_id=dataset.id, + sample_count=sample_count, + ) except Exception as exc: logger.error(f"[JudgeAlignment] Failed to enqueue judge task: {exc}") judge_run.status = "failed" diff --git a/app/api/v1/routes/metrics.py b/app/api/v1/routes/metrics.py index 8e7be9aa..f1aea458 100644 --- a/app/api/v1/routes/metrics.py +++ b/app/api/v1/routes/metrics.py @@ -2,16 +2,17 @@ import json import re -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from sqlalchemy import and_, or_ -from uuid import UUID +from uuid import UUID, uuid4 from typing import List, Optional, Dict, Any, Literal from pydantic import BaseModel, Field from loguru import logger from app.database import get_db from app.dependencies import get_organization_id, get_api_key, get_workspace_id +from app.services.billing.flexprice_service import record_metrics_llm_assist from app.models.database import Metric, MetricCategory, MetricType, MetricTrigger, ModelProvider from app.models.schemas import ( MetricCreate, @@ -1539,7 +1540,9 @@ def _parse_metric_generation_response(text: str) -> Dict[str, Any]: @router.post("/generate", response_model=MetricGenerateResponse) def generate_metric( req: MetricGenerateRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), ): """Use an LLM to suggest a metric definition. Does NOT persist anything.""" @@ -1631,6 +1634,14 @@ def generate_metric( suffix += 1 name = f"{name} ({suffix})" + background_tasks.add_task( + record_metrics_llm_assist, + organization_id, + uuid4(), + workspace_id=workspace_id, + mode=req.mode, + ) + return MetricGenerateResponse( name=name, description=(parsed.get("description") or "").strip()[:1000], @@ -1953,7 +1964,9 @@ def _taken(name: str) -> bool: @router.post("/parse-bulk", response_model=MetricParseBulkResponse) def parse_bulk_metric( req: MetricParseBulkRequest, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db), ): """Parse a multi-label rubric into a *list* of independent metric drafts. @@ -2077,6 +2090,14 @@ def parse_bulk_metric( ) ) + background_tasks.add_task( + record_metrics_llm_assist, + organization_id, + uuid4(), + workspace_id=workspace_id, + mode="parse-bulk", + ) + return MetricParseBulkResponse(metrics=drafts, parent=parent_payload) diff --git a/app/api/v1/routes/observability.py b/app/api/v1/routes/observability.py index 61770404..94ac5d0a 100644 --- a/app/api/v1/routes/observability.py +++ b/app/api/v1/routes/observability.py @@ -5,11 +5,15 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from pydantic import BaseModel, ConfigDict from sqlalchemy.orm import Session from app.dependencies import get_api_key, get_db, get_organization_id, get_workspace_id +from app.services.billing.flexprice_service import ( + record_observability_call_evaluated, + record_observability_call_ingested, +) from app.models.database import ( Agent, APIKey, CallRecording, CallRecordingStatus, CallRecordingSource, Evaluator, EvaluatorResult, EvaluatorResultStatus, Scenario, Workspace, @@ -188,6 +192,13 @@ def _upsert_call_recording( response = _serialize_call_recording(call_recording, include_data=True) response["action"] = action + if action == "created": + record_observability_call_ingested( + organization_id, + call_recording.call_short_id, + workspace_id=workspace_id, + provider=provider_platform, + ) return response @@ -466,6 +477,7 @@ def _messages_to_speaker_segments(messages: List[Dict[str, Any]]) -> List[Dict[s async def evaluate_call( call_short_id: str, payload: EvaluateCallPayload, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -575,6 +587,13 @@ async def evaluate_call( except Exception: pass + background_tasks.add_task( + record_observability_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + ) + return { "evaluator_result_id": str(evaluator_result.id), "result_id": evaluator_result.result_id, diff --git a/app/api/v1/routes/playground.py b/app/api/v1/routes/playground.py index 02651ab1..ee39eec8 100644 --- a/app/api/v1/routes/playground.py +++ b/app/api/v1/routes/playground.py @@ -14,7 +14,13 @@ import uuid as _uuid from datetime import datetime -from app.dependencies import get_db, get_organization_id, get_workspace_id, get_api_key +from app.services.billing.flexprice_service import ( + record_playground_call_evaluated, + record_playground_web_call_started, + record_playground_websocket_session_started, +) +from app.database import get_db +from app.dependencies import get_organization_id, get_workspace_id, get_api_key from app.models.database import ( Agent, Integration, @@ -671,6 +677,14 @@ async def create_web_call( db.add(call_recording) db.commit() db.refresh(call_recording) + + background_tasks.add_task( + record_playground_web_call_started, + organization_id, + call_short_id, + workspace_id=workspace_id, + agent_id=agent.id, + ) # Start background task to poll for call metrics # Note: We need to pass the decrypted API key, but we should be careful with security @@ -777,6 +791,7 @@ async def list_call_recordings( @router.post("/custom-websocket-sessions", response_model=Dict[str, Any]) async def create_custom_websocket_session( + background_tasks: BackgroundTasks, agent_id: str = Form(...), websocket_url: str = Form(...), transcript_entries: str = Form("[]"), @@ -902,6 +917,13 @@ async def create_custom_websocket_session( db.commit() db.refresh(call_recording) + background_tasks.add_task( + record_playground_websocket_session_started, + organization_id, + call_short_id, + workspace_id=workspace_id, + ) + return { "message": "Custom websocket session saved", "call_short_id": call_short_id, @@ -913,6 +935,7 @@ async def create_custom_websocket_session( @router.post("/custom-websocket-sessions/{call_short_id}/evaluate", response_model=Dict[str, Any]) async def evaluate_custom_websocket_session( call_short_id: str, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -999,6 +1022,14 @@ async def evaluate_custom_websocket_session( logger.error(f"[Custom WebSocket] Failed to trigger evaluation worker: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to trigger evaluation worker") + background_tasks.add_task( + record_playground_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + metric_count=0, + ) + return { "message": "Evaluation queued", "evaluator_result_id": str(evaluator_result.id), @@ -1367,6 +1398,14 @@ def _download_audio_from_payload(payload: Dict[str, Any]): logger.error(f"[Re-evaluate] Failed to trigger Celery task: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to trigger evaluation worker") + background_tasks.add_task( + record_playground_call_evaluated, + organization_id, + call_short_id, + workspace_id=workspace_id, + metric_count=0, + ) + return { "message": "Re-evaluation started", "evaluator_result_id": str(evaluator_result.id), diff --git a/app/api/v1/routes/prompt_optimization.py b/app/api/v1/routes/prompt_optimization.py index f0805797..cae668e2 100644 --- a/app/api/v1/routes/prompt_optimization.py +++ b/app/api/v1/routes/prompt_optimization.py @@ -9,7 +9,7 @@ from typing import Dict, Any, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.orm import Session from loguru import logger @@ -21,6 +21,7 @@ get_api_key, require_enterprise_feature, ) +from app.services.billing.flexprice_service import record_prompt_optimization_run_started from app.models.database import ( Agent, Evaluator, @@ -99,6 +100,7 @@ class CandidateResponse(BaseModel): @router.post("/runs", response_model=OptimizationRunResponse, status_code=201) def create_optimization_run( data: OptimizationRunCreate, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -156,6 +158,18 @@ def create_optimization_run( run.celery_task_id = task.id db.commit() + config = run.config if isinstance(run.config, dict) else {} + max_metric_calls = config.get("max_metric_calls") + + background_tasks.add_task( + record_prompt_optimization_run_started, + organization_id, + run.id, + workspace_id=workspace_id, + agent_id=agent.id, + max_metric_calls=max_metric_calls, + ) + logger.info(f"[GEPA] Created optimization run {run.id} for agent {agent.name}") return run diff --git a/app/api/v1/routes/public_blind_test.py b/app/api/v1/routes/public_blind_test.py index 07982c7e..79bdef77 100644 --- a/app/api/v1/routes/public_blind_test.py +++ b/app/api/v1/routes/public_blind_test.py @@ -25,7 +25,7 @@ import time from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from pydantic import BaseModel from sqlalchemy.orm import Session @@ -40,6 +40,7 @@ TTSSampleStatus, ) from app.services.storage.s3_service import s3_service +from app.services.billing.flexprice_service import record_blind_test_response_submitted router = APIRouter( @@ -225,6 +226,7 @@ async def submit_public_blind_test_response( share_token: str, data: PublicBlindResponseSubmit, request: Request, + background_tasks: BackgroundTasks, db: Session = Depends(get_db), ): """Accept a rater's submission and merge into the comparison summary.""" @@ -351,8 +353,19 @@ def _clean_ratings(raw: Dict[str, float]) -> Dict[str, float]: detail="This email has already submitted a response for this blind test.", ) + db.refresh(record) + # Re-aggregate into the comparison's evaluation_summary from app.api.v1.routes.voice_playground import _recompute_summary _recompute_summary(comparison, db) + background_tasks.add_task( + record_blind_test_response_submitted, + share.organization_id, + record.id, + share_id=share.id, + workspace_id=share.workspace_id, + response_count=len(cleaned_entries), + ) + return {"message": "Thanks for your response!"} diff --git a/app/api/v1/routes/test_agents.py b/app/api/v1/routes/test_agents.py index 0c402156..f6b6ea1b 100644 --- a/app/api/v1/routes/test_agents.py +++ b/app/api/v1/routes/test_agents.py @@ -3,7 +3,7 @@ API endpoints for managing test agent conversations. """ -from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status, UploadFile, File from sqlalchemy.orm import Session from uuid import UUID from typing import List, Optional @@ -16,6 +16,10 @@ TestAgentConversationUpdate, TestAgentConversationResponse ) +from app.services.billing.flexprice_service import ( + record_test_agent_conversation_ended, + record_test_agent_conversation_started, +) from app.services.testing.test_agent_service import test_agent_service router = APIRouter(prefix="/test-agents", tags=["test-agents"]) @@ -86,6 +90,7 @@ async def get_conversation( @router.post("/conversations/{conversation_id}/start", response_model=TestAgentConversationResponse, operation_id="startTestAgentConversation") async def start_conversation( conversation_id: UUID, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) @@ -97,6 +102,12 @@ async def start_conversation( organization_id=organization_id, db=db ) + background_tasks.add_task( + record_test_agent_conversation_started, + organization_id, + conversation_id, + workspace_id=workspace_id, + ) return conversation except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -209,8 +220,10 @@ async def get_response_audio( @router.post("/conversations/{conversation_id}/end", response_model=TestAgentConversationResponse, operation_id="endTestAgentConversation") async def end_conversation( conversation_id: UUID, + background_tasks: BackgroundTasks, final_audio_key: Optional[str] = None, organization_id: UUID = Depends(get_organization_id), + workspace_id: UUID = Depends(get_workspace_id), db: Session = Depends(get_db) ): """End a test agent conversation.""" @@ -221,6 +234,15 @@ async def end_conversation( db=db, final_audio_key=final_audio_key ) + turn_count = len(conversation.live_transcription or []) + background_tasks.add_task( + record_test_agent_conversation_ended, + organization_id, + conversation_id, + workspace_id=workspace_id, + duration_seconds=conversation.duration_seconds, + turn_count=turn_count, + ) return conversation except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/app/api/v1/routes/voice_playground.py b/app/api/v1/routes/voice_playground.py index 85d79baa..8cc4519d 100644 --- a/app/api/v1/routes/voice_playground.py +++ b/app/api/v1/routes/voice_playground.py @@ -8,7 +8,7 @@ import uuid as _uuid -from fastapi import APIRouter, Depends, HTTPException, Response, UploadFile, File, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Response, UploadFile, File, status from sqlalchemy.orm import Session from sqlalchemy import func from typing import Dict, Any, List, Literal, Optional @@ -17,6 +17,11 @@ from loguru import logger from app.dependencies import get_db, get_organization_id, get_workspace_id, get_api_key, require_enterprise_feature +from app.services.billing.flexprice_service import ( + record_blind_test_share_created, + record_tts_generation_started, + record_tts_report_requested, +) from app.models.database import ( AIProvider, CallImport, @@ -1371,6 +1376,7 @@ async def get_comparison( @router.post("/comparisons/{comparison_id}/generate", operation_id="generateTTSComparison") async def generate_comparison( comparison_id: UUID, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -1432,6 +1438,13 @@ async def generate_comparison( comparison.celery_task_id = task.id db.commit() logger.info(f"[VoicePlayground] Dispatched generation task {task.id} for comparison {comparison.id}") + background_tasks.add_task( + record_tts_generation_started, + organization_id, + comparison.id, + workspace_id=workspace_id, + sample_count=pending_tts_samples, + ) except Exception as e: comparison.status = TTSComparisonStatus.FAILED.value comparison.error_message = str(e) @@ -1468,6 +1481,7 @@ async def submit_blind_test( async def create_blind_test_share( comparison_id: UUID, data: BlindTestShareCreate, + background_tasks: BackgroundTasks, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), api_key: str = Depends(get_api_key), @@ -1540,6 +1554,15 @@ async def create_blind_test_share( db.commit() db.refresh(share) + if not existing: + background_tasks.add_task( + record_blind_test_share_created, + organization_id, + share.id, + workspace_id=workspace_id, + comparison_id=comparison.id, + ) + _recompute_summary(comparison, db) return _serialize_share(share, db, include_aggregates=True) @@ -1891,6 +1914,7 @@ async def download_tts_comparison_report( @router.post("/comparisons/{comparison_id}/reports", operation_id="createTTSComparisonReportJob") async def create_tts_comparison_report_job( comparison_id: UUID, + background_tasks: BackgroundTasks, data: Optional[TTSReportJobCreate] = None, organization_id: UUID = Depends(get_organization_id), workspace_id: UUID = Depends(get_workspace_id), @@ -1924,6 +1948,13 @@ async def create_tts_comparison_report_job( task = generate_tts_report_pdf_task.delay(str(report_job.id), options_dict) report_job.celery_task_id = task.id db.commit() + background_tasks.add_task( + record_tts_report_requested, + organization_id, + report_job.id, + workspace_id=workspace_id, + comparison_id=comparison.id, + ) except Exception as e: report_job.status = TTSReportJobStatus.FAILED.value report_job.error_message = f"Failed to queue report task: {str(e)}" diff --git a/app/config.py b/app/config.py index 6ec3ebd6..ff1772f1 100644 --- a/app/config.py +++ b/app/config.py @@ -163,6 +163,11 @@ class Settings(BaseSettings): JUDGE_ALIGNMENT_ENABLED: bool = True JUDGE_ALIGNMENT_CSV_MAX_ROWS: int = 5000 + # Flexprice usage-based billing (optional; disabled when unset) + FLEXPRICE_ENABLED: bool = False + FLEXPRICE_API_KEY: Optional[str] = None + FLEXPRICE_API_HOST: str = "https://us.api.flexprice.io/v1" + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", @@ -607,6 +612,15 @@ def load_config_from_file(config_path: str) -> None: if "trusted_ips" in operational_config: settings.OPERATIONAL_TRUSTED_IPS = operational_config["trusted_ips"] + if "flexprice" in config_data: + flexprice_config = config_data["flexprice"] + if "enabled" in flexprice_config: + settings.FLEXPRICE_ENABLED = bool(flexprice_config["enabled"]) + if flexprice_config.get("api_key"): + settings.FLEXPRICE_API_KEY = flexprice_config["api_key"] + if flexprice_config.get("api_host"): + settings.FLEXPRICE_API_HOST = flexprice_config["api_host"] + # Update Celery URLs if they weren't explicitly set if not settings.CELERY_BROKER_URL: settings.CELERY_BROKER_URL = settings.REDIS_URL diff --git a/app/config/models.json b/app/config/models.json index b2f928d4..23c0005c 100644 --- a/app/config/models.json +++ b/app/config/models.json @@ -510,6 +510,16 @@ "model_type": "tts", "description": "Smallest Lightning v3.1 low-latency neural TTS" }, + "sarvam-30b": { + "provider": "sarvam", + "model_type": "llm", + "description": "Sarvam 30B — balanced Indian-language + English chat model (64K context)" + }, + "sarvam-105b": { + "provider": "sarvam", + "model_type": "llm", + "description": "Sarvam 105B — flagship MoE for complex reasoning, coding, and agentic workflows (128K context)" + }, "saarika:v2.5": { "provider": "sarvam", "model_type": "stt", diff --git a/app/core/auth/oidc_common.py b/app/core/auth/oidc_common.py index a455db23..6c2542c7 100644 --- a/app/core/auth/oidc_common.py +++ b/app/core/auth/oidc_common.py @@ -24,7 +24,10 @@ from app.core.auth.providers import AuthError from app.models.database import Organization, OrganizationMember, User from app.models.enums import RoleEnum -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) # -- JWKS cache --------------------------------------------------------------- @@ -196,6 +199,11 @@ def upsert_user_and_membership( organization_id=organization.id, created_by_user_id=user.id, ) + provision_billing_customer( + organization_id=organization.id, + name=organization_name, + email=email, + ) logger.info(f"Provisioned new organization '{organization_name}' for {email}") member = OrganizationMember( diff --git a/app/core/operational_access_middleware.py b/app/core/operational_access_middleware.py index 1d53c9ad..d329b572 100644 --- a/app/core/operational_access_middleware.py +++ b/app/core/operational_access_middleware.py @@ -1,4 +1,4 @@ -"""Restrict operational endpoints (/health, /metrics) from the public internet.""" +"""Restrict /metrics from the public internet (/health stays open for load balancers).""" from __future__ import annotations @@ -107,11 +107,12 @@ def is_operational_access_allowed(request: Request) -> bool: def is_operational_path(path: str) -> bool: - return path == _HEALTH_PATH or path.startswith(_METRICS_PREFIX) + # /health must stay reachable for load balancers (ALB/kube) without auth. + return path.startswith(_METRICS_PREFIX) class OperationalAccessMiddleware(BaseHTTPMiddleware): - """Block anonymous public access to /health and /metrics.""" + """Block anonymous public access to /metrics (/health is always allowed for LB probes).""" async def dispatch(self, request: Request, call_next) -> Response: if is_operational_path(request.url.path) and not is_operational_access_allowed(request): diff --git a/app/main.py b/app/main.py index ee29a224..a1747da4 100644 --- a/app/main.py +++ b/app/main.py @@ -87,6 +87,10 @@ async def lifespan(app: FastAPI): else: logger.info("All migrations are up to date") + from app.services.billing.flexprice_service import log_startup_status + + log_startup_status(component="api") + logger.info("=" * 60) logger.info("Application startup complete - Ready to serve requests") logger.info("=" * 60) diff --git a/app/migrations/052_call_import_evaluation_billed_rows.py b/app/migrations/052_call_import_evaluation_billed_rows.py new file mode 100644 index 00000000..b4b3ede2 --- /dev/null +++ b/app/migrations/052_call_import_evaluation_billed_rows.py @@ -0,0 +1,59 @@ +""" +Migration: Flexprice billing watermark on call_import_evaluations. + +Adds ``billed_completed_rows`` so evaluation billing can emit one +``call_import.evaluation_completed`` event per pass with a delta +``quantity`` (newly completed rows since the last pass). +""" + +from sqlalchemy import text +from sqlalchemy.orm import Session + +description = ( + "Add billed_completed_rows on call_import_evaluations for pass-level " + "delta usage metering." +) + + +def _column_exists(db: Session, table_name: str, column_name: str) -> bool: + row = db.execute( + text( + """ + SELECT 1 + FROM information_schema.columns + WHERE table_name = :table_name AND column_name = :column_name + """ + ), + {"table_name": table_name, "column_name": column_name}, + ).first() + return row is not None + + +def upgrade(db: Session): + if _column_exists(db, "call_import_evaluations", "billed_completed_rows"): + print( + "call_import_evaluations.billed_completed_rows already exists, skipping..." + ) + return + + db.execute( + text( + """ + ALTER TABLE call_import_evaluations + ADD COLUMN billed_completed_rows INTEGER NOT NULL DEFAULT 0 + """ + ) + ) + print("Added call_import_evaluations.billed_completed_rows (integer, default 0)") + + +def downgrade(db: Session): + if not _column_exists(db, "call_import_evaluations", "billed_completed_rows"): + print("call_import_evaluations.billed_completed_rows missing, skipping...") + return + db.execute( + text( + "ALTER TABLE call_import_evaluations DROP COLUMN billed_completed_rows" + ) + ) + print("Dropped call_import_evaluations.billed_completed_rows") diff --git a/app/models/database.py b/app/models/database.py index 476fb7ed..dcc5ea85 100644 --- a/app/models/database.py +++ b/app/models/database.py @@ -2248,6 +2248,11 @@ class CallImportEvaluation(Base): total_rows = Column(Integer, nullable=False, default=0) completed_rows = Column(Integer, nullable=False, default=0) failed_rows = Column(Integer, nullable=False, default=0) + # Flexprice pass-level delta billing watermark: rows already emitted + # on ``call_import.evaluation_completed`` for this evaluation run. + billed_completed_rows = Column( + Integer, nullable=False, default=0, server_default="0" + ) error_message = Column(Text, nullable=True) celery_group_id = Column(String(255), nullable=True) diff --git a/app/services/ai/llm_service.py b/app/services/ai/llm_service.py index 407ee51d..40b1c056 100644 --- a/app/services/ai/llm_service.py +++ b/app/services/ai/llm_service.py @@ -18,7 +18,7 @@ from sqlalchemy.orm import Session from app.models.database import ModelProvider, AIProvider -from app.services.credentials import resolve_ai_provider +from app.services.credentials import resolve_ai_provider, resolve_integration from app.services.ai.llm_generation_config import build_litellm_kwargs # LiteLLM will silently drop params the target provider doesn't support @@ -36,6 +36,7 @@ "groq": "groq", "xai": "xai", "fireworks": "fireworks_ai", + "sarvam": "sarvam", } # Matches the model-name half of the Gemini 2.5 family: ``gemini-2.5-pro``, @@ -160,6 +161,43 @@ def _get_ai_provider( provider, db, organization_id, credential_id=credential_id ) + def _resolve_api_key( + self, + provider: ModelProvider, + db: Session, + organization_id: UUID, + credential_id: Optional[UUID] = None, + ) -> str: + """Resolve and decrypt an API key from AIProvider or Integration tables.""" + from app.core.encryption import decrypt_api_key + + ai_provider = self._get_ai_provider( + provider, db, organization_id, credential_id=credential_id + ) + if ai_provider: + try: + return decrypt_api_key(ai_provider.api_key) + except Exception as e: + raise RuntimeError( + f"Failed to decrypt API key for provider {provider}: {e}" + ) + + integration = resolve_integration( + provider, db, organization_id, credential_id=credential_id + ) + if integration: + try: + return decrypt_api_key(integration.api_key) + except Exception as e: + raise RuntimeError( + f"Failed to decrypt API key for provider {provider}: {e}" + ) + + provider_label = provider.value if hasattr(provider, "value") else str(provider) + raise RuntimeError( + f"AI provider {provider_label} not configured for this organization." + ) + @staticmethod def _litellm_model_name(provider: ModelProvider, model: str) -> str: """Build the ``provider/model`` string that LiteLLM expects.""" @@ -208,22 +246,9 @@ def generate_response( start_time = time.time() # --- resolve API key from database -------------------------------- - ai_provider = self._get_ai_provider( + api_key = self._resolve_api_key( llm_provider, db, organization_id, credential_id=credential_id ) - if not ai_provider: - raise RuntimeError( - f"AI provider {llm_provider} not configured for this organization." - ) - - from app.core.encryption import decrypt_api_key - - try: - api_key = decrypt_api_key(ai_provider.api_key) - except Exception as e: - raise RuntimeError( - f"Failed to decrypt API key for provider {llm_provider}: {e}" - ) # --- call LiteLLM -------------------------------------------------- model_str = self._litellm_model_name(llm_provider, llm_model) diff --git a/app/services/billing/__init__.py b/app/services/billing/__init__.py new file mode 100644 index 00000000..dd930582 --- /dev/null +++ b/app/services/billing/__init__.py @@ -0,0 +1 @@ +"""Billing and usage metering integrations.""" diff --git a/app/services/billing/flexprice_service.py b/app/services/billing/flexprice_service.py new file mode 100644 index 00000000..a0de32cb --- /dev/null +++ b/app/services/billing/flexprice_service.py @@ -0,0 +1,850 @@ +"""Flexprice usage metering (optional; no-op when disabled). + +Event naming: ``product.action`` snake_case (e.g. ``call_import.batch_created``). +Every event uses ``external_customer_id=str(organization.id)`` and a stable +``event_id`` for idempotency. ``properties`` should include ``workspace_id`` and +``feature`` (license key) when the surface is gated. +""" + +from __future__ import annotations + +import os +from typing import Any, Optional, Union +from uuid import UUID + +from loguru import logger + +from app.config import settings + +EVENT_SOURCE = "efficientai" +FEATURE_CALL_IMPORTS = "call_imports" +FEATURE_VOICE_PLAYGROUND = "voice_playground" +FEATURE_GEPA = "gepa_optimization" + +# Log once when metering is inactive so AWS/worker misconfig is obvious. +_disabled_skip_logged = False + +# Event names +BLIND_TEST_SHARE_CREATED = "blind_test.share_created" +BLIND_TEST_RESPONSE_SUBMITTED = "blind_test.response_submitted" +TTS_GENERATION_STARTED = "tts.generation_started" +TTS_SAMPLE_SYNTHESIZED = "tts.sample_synthesized" +TTS_REPORT_REQUESTED = "tts.report_requested" +TTS_REPORT_COMPLETED = "tts.report_completed" +CALL_IMPORT_BATCH_CREATED = "call_import.batch_created" +CALL_IMPORT_ROW_IMPORTED = "call_import.row_imported" +CALL_IMPORT_EVALUATION_STARTED = "call_import.evaluation_started" +CALL_IMPORT_EVALUATION_COMPLETED = "call_import.evaluation_completed" +CALL_IMPORT_EVALUATION_ROW_COMPLETED = "call_import.evaluation_row_completed" +PLAYGROUND_WEB_CALL_STARTED = "playground.web_call_started" +PLAYGROUND_WEBSOCKET_SESSION_STARTED = "playground.websocket_session_started" +PLAYGROUND_CALL_EVALUATED = "playground.call_evaluated" +PLAYGROUND_EVALUATION_COMPLETED = "playground.evaluation_completed" +EVALUATOR_RUN_REQUESTED = "evaluator.run_requested" +EVALUATOR_RUN_COMPLETED = "evaluator.run_completed" +EVALUATION_CREATED = "evaluation.created" +EVALUATION_COMPLETED = "evaluation.completed" +PROMPT_OPTIMIZATION_RUN_STARTED = "prompt_optimization.run_started" +PROMPT_OPTIMIZATION_RUN_COMPLETED = "prompt_optimization.run_completed" +JUDGE_ALIGNMENT_RUN_STARTED = "judge_alignment.run_started" +JUDGE_ALIGNMENT_RUN_COMPLETED = "judge_alignment.run_completed" +OBSERVABILITY_CALL_INGESTED = "observability.call_ingested" +OBSERVABILITY_CALL_EVALUATED = "observability.call_evaluated" +TEST_AGENT_CONVERSATION_STARTED = "test_agent.conversation_started" +TEST_AGENT_CONVERSATION_ENDED = "test_agent.conversation_ended" +METRICS_LLM_ASSIST = "metrics.llm_assist" +CHAT_COMPLETION = "chat.completion" + + +def _verbose_logging() -> bool: + """Extra per-event logs when FLEXPRICE_VERBOSE=1 (useful on AWS).""" + return os.getenv("FLEXPRICE_VERBOSE", "").lower() in {"1", "true", "yes"} + + +def _mask_api_key(api_key: Optional[str]) -> str: + if not api_key: + return "(missing)" + if len(api_key) <= 8: + return "****" + return f"****{api_key[-4:]}" + + +def disabled_reason() -> Optional[str]: + """Human-readable reason metering is off, or None when active.""" + if not settings.FLEXPRICE_ENABLED: + return "flexprice.enabled is false (or FLEXPRICE_ENABLED unset)" + if not settings.FLEXPRICE_API_KEY: + return "flexprice.api_key is unset (or FLEXPRICE_API_KEY env missing)" + return None + + +def is_enabled() -> bool: + """Return True only when Flexprice is explicitly enabled with an API key.""" + return disabled_reason() is None + + +def log_startup_status(*, component: str = "app") -> None: + """Log Flexprice config at process start (API + Celery worker).""" + reason = disabled_reason() + if reason: + logger.info( + "Flexprice metering INACTIVE for {} — {} (api_host={})", + component, + reason, + settings.FLEXPRICE_API_HOST, + ) + return + + logger.info( + "Flexprice metering ACTIVE for {} — api_host={} api_key={}", + component, + settings.FLEXPRICE_API_HOST, + _mask_api_key(settings.FLEXPRICE_API_KEY), + ) + + connectivity_error = _verify_connectivity() + if connectivity_error: + logger.warning( + "Flexprice connectivity check FAILED for {} — host={} error={}. " + "Config looks valid but outbound calls may be blocked (check AWS egress/NAT/SG).", + component, + settings.FLEXPRICE_API_HOST, + connectivity_error, + ) + else: + logger.info( + "Flexprice connectivity check OK for {} — host={}", + component, + settings.FLEXPRICE_API_HOST, + ) + + +def _verify_connectivity() -> Optional[str]: + """Best-effort reachability probe; returns error text or None when OK.""" + try: + import httpx + + base = settings.FLEXPRICE_API_HOST.rstrip("/") + response = httpx.get( + f"{base}/customers", + headers={"x-api-key": settings.FLEXPRICE_API_KEY or ""}, + params={"limit": 1}, + timeout=10.0, + ) + if response.status_code < 400: + return None + return f"HTTP {response.status_code}: {response.text[:200]}" + except Exception as exc: + return str(exc) + + +def _is_customer_already_exists(exc: Exception) -> bool: + message = str(exc).lower() + if "already exist" in message or "duplicate" in message: + return True + status_code = getattr(exc, "status_code", None) + return status_code == 409 + + +def _ingest_usage_event(client, payload: dict) -> None: + """Call Flexprice event ingest across SDK versions (flat kwargs vs request=).""" + events = client.events + ingest = getattr(events, "ingest_event", None) or getattr(events, "ingest", None) + if ingest is None: + raise AttributeError("Flexprice SDK has no events.ingest_event or events.ingest") + + try: + ingest(**payload) + except TypeError: + ingest(request=payload) + + +def _coerce_properties(properties: Optional[dict[str, Any]]) -> dict[str, str]: + """Normalize event properties for Flexprice ingest (SDK expects string values).""" + if not properties: + return {} + out: dict[str, str] = {} + for key, value in properties.items(): + if value is None: + continue + if isinstance(value, bool): + out[key] = "true" if value else "false" + elif isinstance(value, (int, float, UUID)): + out[key] = str(value) + else: + out[key] = str(value) + return out + + +def record_event( + event_name: str, + organization_id: UUID, + event_id: Union[str, UUID], + *, + properties: Optional[dict[str, Any]] = None, +) -> None: + """Ingest a usage event. No-op when Flexprice is disabled; never raises.""" + global _disabled_skip_logged + + inactive_reason = disabled_reason() + if inactive_reason: + if not _disabled_skip_logged: + logger.warning( + "Flexprice metering inactive — {}. Usage events will be dropped until fixed.", + inactive_reason, + ) + _disabled_skip_logged = True + if _verbose_logging(): + logger.info( + "Flexprice SKIP {} org={} event_id={} ({})", + event_name, + organization_id, + event_id, + inactive_reason, + ) + return + + coerced = _coerce_properties(properties) + quantity = coerced.get("quantity") + + try: + from flexprice import Flexprice + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as client: + payload = { + "event_name": event_name, + "external_customer_id": str(organization_id), + "event_id": str(event_id), + "source": EVENT_SOURCE, + "properties": coerced, + } + _ingest_usage_event(client, payload) + + logger.info( + "Flexprice ingested {} org={} event_id={} quantity={}", + event_name, + organization_id, + event_id, + quantity if quantity is not None else "n/a", + ) + except Exception as exc: + logger.warning( + "Flexprice {} ingest FAILED org={} event_id={} host={} error={}", + event_name, + organization_id, + event_id, + settings.FLEXPRICE_API_HOST, + exc, + ) + + +def ensure_customer( + organization_id: UUID, + *, + name: str, + email: Optional[str] = None, +) -> None: + """Register an organization as a Flexprice customer. No-op when disabled.""" + inactive_reason = disabled_reason() + if inactive_reason: + if _verbose_logging(): + logger.info( + "Flexprice SKIP ensure_customer org={} ({})", + organization_id, + inactive_reason, + ) + return + + try: + from flexprice import Flexprice + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as client: + client.customers.create_customer( + external_id=str(organization_id), + name=name, + email=email, + ) + logger.info( + "Flexprice ensure_customer ok org={} name={}", + organization_id, + name, + ) + except Exception as exc: + if _is_customer_already_exists(exc): + logger.debug( + "Flexprice ensure_customer already exists org={}", + organization_id, + ) + return + logger.warning( + "Flexprice ensure_customer FAILED org={} host={} error={}", + organization_id, + settings.FLEXPRICE_API_HOST, + exc, + ) + + +# --- Voice playground --- + + +def record_blind_test_share_created( + organization_id: UUID, + share_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + BLIND_TEST_SHARE_CREATED, + organization_id, + share_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "share_id": share_id, + "comparison_id": comparison_id, + }, + ) + + +def record_blind_test_response_submitted( + organization_id: UUID, + response_id: UUID, + *, + share_id: UUID, + workspace_id: UUID, + response_count: int, +) -> None: + record_event( + BLIND_TEST_RESPONSE_SUBMITTED, + organization_id, + response_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "share_id": share_id, + "response_count": response_count, + "quantity": response_count, + }, + ) + + +def record_tts_generation_started( + organization_id: UUID, + comparison_id: UUID, + *, + workspace_id: UUID, + sample_count: int, +) -> None: + record_event( + TTS_GENERATION_STARTED, + organization_id, + comparison_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "sample_count": sample_count, + }, + ) + + +def record_tts_sample_synthesized( + organization_id: UUID, + sample_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, + provider: Optional[str] = None, + side: Optional[str] = None, + duration_seconds: Optional[float] = None, +) -> None: + record_event( + TTS_SAMPLE_SYNTHESIZED, + organization_id, + sample_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "sample_id": sample_id, + "provider": provider, + "side": side, + "duration_seconds": duration_seconds, + "quantity": 1, + }, + ) + + +def record_tts_report_requested( + organization_id: UUID, + report_job_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + TTS_REPORT_REQUESTED, + organization_id, + report_job_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "report_job_id": report_job_id, + }, + ) + + +def record_tts_report_completed( + organization_id: UUID, + report_job_id: UUID, + *, + workspace_id: UUID, + comparison_id: UUID, +) -> None: + record_event( + TTS_REPORT_COMPLETED, + organization_id, + report_job_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_VOICE_PLAYGROUND, + "comparison_id": comparison_id, + "report_job_id": report_job_id, + }, + ) + + +# --- Call imports --- + + +def record_call_import_batch_created( + organization_id: UUID, + call_import_id: UUID, + *, + workspace_id: UUID, + total_rows: int, + source: str, + provider: Optional[str] = None, +) -> None: + record_event( + CALL_IMPORT_BATCH_CREATED, + organization_id, + call_import_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "total_rows": total_rows, + "quantity": total_rows, + "source": source, + "provider": provider, + }, + ) + + +# --- Call imports (evaluations) --- + + +def record_call_import_evaluation_completed( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, + call_import_id: UUID, + rows_billed: int, + completed_total: int, + metric_count: int = 0, +) -> None: + """Bill one pass of an evaluation run for newly completed rows.""" + record_event( + CALL_IMPORT_EVALUATION_COMPLETED, + organization_id, + f"{evaluation_id}:{completed_total}", + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_CALL_IMPORTS, + "call_import_id": call_import_id, + "evaluation_id": evaluation_id, + "rows_billed": rows_billed, + "completed_total": completed_total, + "metric_count": metric_count, + "quantity": rows_billed, + }, + ) + + +# --- Agent playground --- + + +def record_playground_web_call_started( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + agent_id: UUID, +) -> None: + record_event( + PLAYGROUND_WEB_CALL_STARTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "agent_id": agent_id, + "call_short_id": call_short_id, + }, + ) + + +def record_playground_websocket_session_started( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, +) -> None: + record_event( + PLAYGROUND_WEBSOCKET_SESSION_STARTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + }, + ) + + +def record_playground_call_evaluated( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + metric_count: int, +) -> None: + record_event( + PLAYGROUND_CALL_EVALUATED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "metric_count": metric_count, + }, + ) + + +def record_playground_evaluation_completed( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + duration_seconds: Optional[float] = None, +) -> None: + record_event( + PLAYGROUND_EVALUATION_COMPLETED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "duration_seconds": duration_seconds, + }, + ) + + +# --- Evaluators --- + + +def record_evaluator_run_requested( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: UUID, + quantity: int, +) -> None: + record_event( + EVALUATOR_RUN_REQUESTED, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "quantity": quantity, + }, + ) + + +def record_evaluator_run_completed( + organization_id: UUID, + result_id: str, + *, + workspace_id: UUID, + evaluator_id: UUID, + call_count: int = 1, +) -> None: + record_event( + EVALUATOR_RUN_COMPLETED, + organization_id, + result_id, + properties={ + "workspace_id": workspace_id, + "evaluator_id": evaluator_id, + "result_id": result_id, + "call_count": call_count, + }, + ) + + +# --- Legacy evaluations --- + + +def record_evaluation_created( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, + audio_id: UUID, + metrics_requested: int, +) -> None: + record_event( + EVALUATION_CREATED, + organization_id, + evaluation_id, + properties={ + "workspace_id": workspace_id, + "audio_id": audio_id, + "metrics_requested": metrics_requested, + }, + ) + + +def record_evaluation_completed( + organization_id: UUID, + evaluation_id: UUID, + *, + workspace_id: UUID, +) -> None: + record_event( + EVALUATION_COMPLETED, + organization_id, + evaluation_id, + properties={"workspace_id": workspace_id}, + ) + + +# --- Prompt optimization --- + + +def record_prompt_optimization_run_started( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + agent_id: UUID, + max_metric_calls: Optional[int] = None, +) -> None: + record_event( + PROMPT_OPTIMIZATION_RUN_STARTED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_GEPA, + "run_id": run_id, + "agent_id": agent_id, + "max_metric_calls": max_metric_calls, + }, + ) + + +def record_prompt_optimization_run_completed( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + agent_id: UUID, + candidates_count: int = 0, +) -> None: + record_event( + PROMPT_OPTIMIZATION_RUN_COMPLETED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "feature": FEATURE_GEPA, + "run_id": run_id, + "agent_id": agent_id, + "candidates_count": candidates_count, + }, + ) + + +# --- Judge alignment --- + + +def record_judge_alignment_run_started( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + dataset_id: UUID, + sample_count: int, +) -> None: + record_event( + JUDGE_ALIGNMENT_RUN_STARTED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "run_id": run_id, + "dataset_id": dataset_id, + "sample_count": sample_count, + }, + ) + + +def record_judge_alignment_run_completed( + organization_id: UUID, + run_id: UUID, + *, + workspace_id: UUID, + dataset_id: UUID, + samples_scored: int, +) -> None: + record_event( + JUDGE_ALIGNMENT_RUN_COMPLETED, + organization_id, + run_id, + properties={ + "workspace_id": workspace_id, + "run_id": run_id, + "dataset_id": dataset_id, + "samples_scored": samples_scored, + }, + ) + + +# --- Observability --- + + +def record_observability_call_ingested( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, + provider: Optional[str] = None, +) -> None: + record_event( + OBSERVABILITY_CALL_INGESTED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + "provider": provider, + }, + ) + + +def record_observability_call_evaluated( + organization_id: UUID, + call_short_id: str, + *, + workspace_id: UUID, +) -> None: + record_event( + OBSERVABILITY_CALL_EVALUATED, + organization_id, + call_short_id, + properties={ + "workspace_id": workspace_id, + "call_short_id": call_short_id, + }, + ) + + +# --- Test agents --- + + +def record_test_agent_conversation_started( + organization_id: UUID, + conversation_id: UUID, + *, + workspace_id: UUID, +) -> None: + record_event( + TEST_AGENT_CONVERSATION_STARTED, + organization_id, + conversation_id, + properties={ + "workspace_id": workspace_id, + "conversation_id": conversation_id, + }, + ) + + +def record_test_agent_conversation_ended( + organization_id: UUID, + conversation_id: UUID, + *, + workspace_id: UUID, + duration_seconds: Optional[float] = None, + turn_count: int = 0, +) -> None: + record_event( + TEST_AGENT_CONVERSATION_ENDED, + organization_id, + conversation_id, + properties={ + "workspace_id": workspace_id, + "conversation_id": conversation_id, + "duration_seconds": duration_seconds, + "turn_count": turn_count, + "quantity": duration_seconds or 1, + }, + ) + + +# --- LLM assist --- + + +def record_metrics_llm_assist( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: Optional[UUID], + mode: str, +) -> None: + record_event( + METRICS_LLM_ASSIST, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "mode": mode, + }, + ) + + +def record_chat_completion( + organization_id: UUID, + request_id: UUID, + *, + workspace_id: Optional[UUID], + model: Optional[str] = None, +) -> None: + record_event( + CHAT_COMPLETION, + organization_id, + request_id, + properties={ + "workspace_id": workspace_id, + "model": model, + "quantity": 1, + }, + ) diff --git a/app/services/evaluation/evaluation_service.py b/app/services/evaluation/evaluation_service.py index c5a07ffb..dd3c5760 100644 --- a/app/services/evaluation/evaluation_service.py +++ b/app/services/evaluation/evaluation_service.py @@ -135,6 +135,14 @@ def process_evaluation( evaluation.completed_at = datetime.now(UTC) db.commit() + from app.services.billing.flexprice_service import record_evaluation_completed + + record_evaluation_completed( + evaluation.organization_id, + evaluation.id, + workspace_id=evaluation.workspace_id, + ) + return { "evaluation_id": str(evaluation.id), "status": "completed", diff --git a/app/services/judge_alignment/model_catalog.py b/app/services/judge_alignment/model_catalog.py index e7f4afda..e727bffe 100644 --- a/app/services/judge_alignment/model_catalog.py +++ b/app/services/judge_alignment/model_catalog.py @@ -20,14 +20,15 @@ from loguru import logger from sqlalchemy.orm import Session -from app.models.database import AIProvider, ModelProvider +from app.models.database import AIProvider, Integration, ModelProvider +from app.models.enums import IntegrationPlatform from app.services.ai.model_config_service import model_config_service # Providers that ship LLM models we can use as a judge. We deliberately # exclude STT/TTS-only vendors (Deepgram, Cartesia, ElevenLabs, Murf, -# Sarvam, Voicemaker, Smallest) because LiteLLM has no LLM completion -# route for them. +# Voicemaker, Smallest) because LiteLLM has no LLM completion route for +# them. Sarvam is included because it exposes chat models via LiteLLM. _LLM_CAPABLE_PROVIDERS = { ModelProvider.OPENAI.value, ModelProvider.ANTHROPIC.value, @@ -43,6 +44,13 @@ ModelProvider.AWS.value, ModelProvider.OPENROUTER.value, ModelProvider.CUSTOM.value, + ModelProvider.SARVAM.value, +} + +# Voice-platform integrations that also expose LLM models (credentials +# live in the Integration table rather than AIProvider). +_INTEGRATION_LLM_PLATFORMS = { + IntegrationPlatform.SARVAM.value: ModelProvider.SARVAM.value, } @@ -63,6 +71,7 @@ def _provider_label(provider_value: str) -> str: "aws": "AWS", "openrouter": "OpenRouter", "custom": "Custom", + "sarvam": "Sarvam", } return overrides.get(provider_value.lower(), provider_value.title()) @@ -98,11 +107,29 @@ def list_judge_capable_models( .all() ) + configured_provider_values: set[str] = { + (ai_provider.provider or "").lower() + for ai_provider in providers + } + + integrations: List[Integration] = ( + db.query(Integration) + .filter( + Integration.organization_id == organization_id, + Integration.is_active == True, # noqa: E712 - SQLAlchemy comparison + ) + .all() + ) + for integration in integrations: + platform = (integration.platform or "").lower() + mapped = _INTEGRATION_LLM_PLATFORMS.get(platform) + if mapped: + configured_provider_values.add(mapped) + catalog: List[Dict[str, str]] = [] seen: set = set() - for ai_provider in providers: - provider_value = (ai_provider.provider or "").lower() + for provider_value in configured_provider_values: if provider_value not in _LLM_CAPABLE_PROVIDERS: continue diff --git a/app/services/organization_provisioning.py b/app/services/organization_provisioning.py index 54981021..a839ca67 100644 --- a/app/services/organization_provisioning.py +++ b/app/services/organization_provisioning.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from app.models.database import Workspace +from app.services.billing.flexprice_service import ensure_customer from app.services.workspace_rbac import ( backfill_org_workspace_memberships, ensure_creator_workspace_admin, @@ -51,3 +52,13 @@ def provision_default_workspace( ) backfill_org_workspace_memberships(db, organization_id=organization_id) return workspace + + +def provision_billing_customer( + *, + organization_id: UUID, + name: str, + email: str | None = None, +) -> None: + """Register the org with Flexprice when billing is enabled (no-op otherwise).""" + ensure_customer(organization_id, name=name, email=email) diff --git a/app/workers/config.py b/app/workers/config.py index 250840fe..65b92c69 100644 --- a/app/workers/config.py +++ b/app/workers/config.py @@ -66,6 +66,10 @@ except Exception as e: logger.warning(f"⚠️ Celery worker: Could not load config.yml: {e}") +from app.services.billing.flexprice_service import log_startup_status # noqa: E402 + +log_startup_status(component="celery-worker") + # Create Celery app celery_app = Celery( "efficientai", diff --git a/app/workers/tasks/evaluate_call_import_row.py b/app/workers/tasks/evaluate_call_import_row.py index 6653e48a..8d2706ee 100644 --- a/app/workers/tasks/evaluate_call_import_row.py +++ b/app/workers/tasks/evaluate_call_import_row.py @@ -364,6 +364,7 @@ def _build_parent_groups( def _rollup_parent(db, evaluation: CallImportEvaluation) -> None: + previous_status = evaluation.status rows = ( db.query(CallImportEvaluationRow.status) .filter(CallImportEvaluationRow.evaluation_id == evaluation.id) @@ -394,6 +395,26 @@ def _rollup_parent(db, evaluation: CallImportEvaluation) -> None: else: evaluation.status = "partial" + if previous_status == "running": + already_billed = int(getattr(evaluation, "billed_completed_rows", 0) or 0) + delta = completed - already_billed + if delta > 0: + from app.services.billing.flexprice_service import ( + record_call_import_evaluation_completed, + ) + + metric_count = len(evaluation.selected_metric_ids or []) + record_call_import_evaluation_completed( + evaluation.organization_id, + evaluation.id, + workspace_id=evaluation.workspace_id, + call_import_id=evaluation.call_import_id, + rows_billed=delta, + completed_total=completed, + metric_count=metric_count, + ) + evaluation.billed_completed_rows = completed + # Per-task time limits keep a wedged audio evaluation (e.g. torch.hub UTMOS # download stuck on a network hiccup, or libgomp deadlock in a prefork child) diff --git a/app/workers/tasks/helpers/llm_diarisation.py b/app/workers/tasks/helpers/llm_diarisation.py index ef9c3673..0768b919 100644 --- a/app/workers/tasks/helpers/llm_diarisation.py +++ b/app/workers/tasks/helpers/llm_diarisation.py @@ -128,6 +128,79 @@ # transcript so a 1 MB CSV cell can't blow up the prompt window. _MAX_TRANSCRIPT_CHARS = 60_000 +# Output token budget for the diariser LLM. The JSON mirrors the input +# transcript (split into agent/user turns) plus structural overhead per +# turn. Without an explicit ``max_tokens`` many providers default to a +# low cap (e.g. 4096) and truncate mid-JSON on longer calls. +_DIARISATION_MIN_MAX_TOKENS = 4096 +_DIARISATION_MAX_MAX_TOKENS = 16_384 + + +def _estimate_diarisation_max_tokens( + *, text_length: int = 0, audio_bytes: int = 0 +) -> int: + """Scale the diariser output budget to the input payload size.""" + if text_length > 0: + # ~0.45 tokens/char covers JSON keys + per-turn wrapping. + estimated = int(text_length * 0.45) + 512 + elif audio_bytes > 0: + # Rough proxy: ~150 spoken chars/sec, ~320 KiB/min mono 16 kHz. + estimated_chars = (audio_bytes / 320_000) * 150 * 60 + estimated = int(estimated_chars * 0.45) + 512 + else: + estimated = _DIARISATION_MIN_MAX_TOKENS + return min( + _DIARISATION_MAX_MAX_TOKENS, + max(_DIARISATION_MIN_MAX_TOKENS, estimated), + ) + + +def _generate_diarisation_response( + *, + messages: List[Dict[str, Any]], + provider_enum: ModelProvider, + llm_model: str, + organization_id: UUID, + db: Session, + temperature: float, + credential_id: Optional[UUID], + config: Dict[str, Any], + content_length: int = 0, + audio_bytes: int = 0, +) -> Dict[str, Any]: + """Call the diariser LLM with a scaled ``max_tokens`` and one retry.""" + base_budget = _estimate_diarisation_max_tokens( + text_length=content_length, + audio_bytes=audio_bytes, + ) + retry_budget = min(_DIARISATION_MAX_MAX_TOKENS, base_budget * 2) + budgets = [base_budget] if retry_budget <= base_budget else [base_budget, retry_budget] + + response: Dict[str, Any] = {} + for attempt_idx, max_tokens in enumerate(budgets): + response = llm_service.generate_response( + messages=messages, + llm_provider=provider_enum, + llm_model=llm_model, + organization_id=organization_id, + db=db, + temperature=temperature, + credential_id=credential_id, + config=config, + task_defaults={"max_tokens": max_tokens}, + ) + if not response.get("truncated"): + return response + if attempt_idx < len(budgets) - 1: + logger.warning( + "Diariser LLM response truncated at max_tokens={} " + "(model={}); retrying with max_tokens={}.", + max_tokens, + llm_model, + budgets[attempt_idx + 1], + ) + return response + # Synonym keys we accept from over-creative LLMs that don't follow the # canonical ``{"speaker": ..., "text": ...}`` schema verbatim. Without @@ -460,15 +533,16 @@ def diarize_transcript_with_llm( # other providers. config = {"response_format": {"type": "json_object"}} - response = llm_service.generate_response( + response = _generate_diarisation_response( messages=messages, - llm_provider=provider_enum, + provider_enum=provider_enum, llm_model=llm_model, organization_id=organization_id, db=db, temperature=temperature, credential_id=credential_id, config=config, + content_length=len(cleaned), ) turns = _parse_turns_from_response(response) @@ -1042,15 +1116,16 @@ def diarize_audio_with_llm( config = {"response_format": {"type": "json_object"}} try: - response = llm_service.generate_response( + response = _generate_diarisation_response( messages=messages, - llm_provider=provider_enum, + provider_enum=provider_enum, llm_model=llm_model, organization_id=organization_id, db=db, temperature=temperature, credential_id=credential_id, config=config, + audio_bytes=len(audio_bytes), ) except Exception as exc: # LiteLLM surfaces "model does not support audio input" as a diff --git a/app/workers/tasks/process_evaluator_result.py b/app/workers/tasks/process_evaluator_result.py index 51036586..71b7c486 100644 --- a/app/workers/tasks/process_evaluator_result.py +++ b/app/workers/tasks/process_evaluator_result.py @@ -591,6 +591,27 @@ def process_evaluator_result_task(self, result_id: str): result.status = EvaluatorResultStatus.COMPLETED.value db.commit() + from app.models.database import CallRecording, CallRecordingSource + from app.services.billing.flexprice_service import ( + record_playground_evaluation_completed, + ) + + call_recording = ( + db.query(CallRecording) + .filter( + CallRecording.evaluator_result_id == result.id, + CallRecording.source == CallRecordingSource.PLAYGROUND, + ) + .first() + ) + if call_recording: + record_playground_evaluation_completed( + result.organization_id, + call_recording.call_short_id, + workspace_id=result.workspace_id, + duration_seconds=result.duration_seconds, + ) + total_time = time.time() - task_start_time logger.info( f"[EvaluatorResult {result.result_id}] Completed in {total_time:.2f}s, " diff --git a/app/workers/tasks/run_evaluator.py b/app/workers/tasks/run_evaluator.py index 536e8c53..a82f0199 100644 --- a/app/workers/tasks/run_evaluator.py +++ b/app/workers/tasks/run_evaluator.py @@ -91,6 +91,15 @@ def run_evaluator_task(self, evaluator_id: str, evaluator_result_id: str): db.commit() + from app.services.billing.flexprice_service import record_evaluator_run_completed + + record_evaluator_run_completed( + evaluator.organization_id, + result.result_id, + workspace_id=evaluator.workspace_id, + evaluator_id=evaluator.id, + ) + return { "evaluator_id": evaluator_id, "result_id": evaluator_result_id, diff --git a/app/workers/tasks/run_judge_alignment.py b/app/workers/tasks/run_judge_alignment.py index cf39652b..6ffbb655 100644 --- a/app/workers/tasks/run_judge_alignment.py +++ b/app/workers/tasks/run_judge_alignment.py @@ -88,6 +88,18 @@ def run_judge_alignment_task( _fail(db, run, str(exc)) return {"error": str(exc)} + from app.services.billing.flexprice_service import ( + record_judge_alignment_run_completed, + ) + + record_judge_alignment_run_completed( + run.organization_id, + run.id, + workspace_id=run.workspace_id, + dataset_id=run.dataset_id, + samples_scored=len(samples), + ) + return {"judge_run_id": judge_run_id, "metrics": metrics} finally: diff --git a/app/workers/tasks/run_prompt_optimization.py b/app/workers/tasks/run_prompt_optimization.py index 1b388465..dbd9c869 100644 --- a/app/workers/tasks/run_prompt_optimization.py +++ b/app/workers/tasks/run_prompt_optimization.py @@ -141,6 +141,18 @@ def run_prompt_optimization_task(self, optimization_run_id: str): f"Best score: {run.best_score}" ) + from app.services.billing.flexprice_service import ( + record_prompt_optimization_run_completed, + ) + + record_prompt_optimization_run_completed( + run.organization_id, + run.id, + workspace_id=run.workspace_id, + agent_id=run.agent_id, + candidates_count=len(result.get("candidates", [])), + ) + except Exception as e: logger.error(f"[GEPA] Optimization run {optimization_run_id} failed: {e}", exc_info=True) try: diff --git a/app/workers/tasks/tts_comparison.py b/app/workers/tasks/tts_comparison.py index 76092147..8ae42ea9 100644 --- a/app/workers/tasks/tts_comparison.py +++ b/app/workers/tasks/tts_comparison.py @@ -291,6 +291,18 @@ def _resolve_voice_meta(sample_obj): sample.status = TTSSampleStatus.COMPLETED.value db.commit() + from app.services.billing.flexprice_service import record_tts_sample_synthesized + + record_tts_sample_synthesized( + comp.organization_id, + sample.id, + workspace_id=comp.workspace_id, + comparison_id=comp.id, + provider=sample.provider, + side=sample.side, + duration_seconds=sample.duration_seconds, + ) + logger.info( f"[TTS Generate] Sample {sample.id} done – " f"{sample.provider}/{sample.voice_name} ttfb={ttfb_ms:.0f}ms total={latency_ms:.0f}ms" diff --git a/app/workers/tasks/tts_report.py b/app/workers/tasks/tts_report.py index c738244f..f61d5e1e 100644 --- a/app/workers/tasks/tts_report.py +++ b/app/workers/tasks/tts_report.py @@ -79,6 +79,15 @@ def generate_tts_report_pdf_task(self, report_job_id: str, report_options: dict report_job.error_message = None db.commit() + from app.services.billing.flexprice_service import record_tts_report_completed + + record_tts_report_completed( + report_job.organization_id, + report_job.id, + workspace_id=report_job.workspace_id, + comparison_id=comparison.id, + ) + return {"status": "completed", "s3_key": s3_key} except Exception as exc: logger.error(f"[TTS Report] Task failed: {exc}", exc_info=True) diff --git a/config.docker.yml b/config.docker.yml index a1a7e964..0438d86e 100644 --- a/config.docker.yml +++ b/config.docker.yml @@ -13,6 +13,13 @@ server: host: "0.0.0.0" port: 8000 +# Operational endpoints (/metrics). /health is always open for load balancers. +# Add VPC CIDRs here if Prometheus scrapes /metrics from inside the VPC. +operational: + public: false + trusted_ips: + - "10.0.0.0/8" + # Database Configuration (Docker service name) database: url: "postgresql://efficientai:password@db:5432/efficientai" diff --git a/config.yml.example b/config.yml.example index 32a1dd02..1dad08dc 100644 --- a/config.yml.example +++ b/config.yml.example @@ -12,10 +12,8 @@ server: host: "0.0.0.0" port: 8000 -# Operational endpoints (/health, /metrics) -# Keep public: false in production/sandbox. ALB health checks connect -# directly from VPC addresses (no X-Forwarded-For). Include your VPC/LB CIDRs -# in trusted_ips so probes and internal scrapers are allowed. +# Operational endpoints (/metrics). /health is always open for ALB/kube probes. +# Include VPC CIDRs in trusted_ips if Prometheus scrapes /metrics from inside the VPC. operational: public: false trusted_ips: @@ -172,3 +170,13 @@ judge_alignment: # license: # key: "eyJhbGciOi..." +# Flexprice usage-based billing (optional). When disabled or api_key is unset, +# no SDK calls are made and the app behaves as today. Cloud SaaS: set enabled +# true and provide FLEXPRICE_API_KEY (env var overrides api_key below). +# flexprice: +# enabled: false +# api_key: null +# api_host: "https://us.api.flexprice.io/v1" +# Env override: FLEXPRICE_ENABLED, FLEXPRICE_API_KEY, FLEXPRICE_API_HOST +# Debug: FLEXPRICE_VERBOSE=1 logs every skipped event (when inactive or verbose) + diff --git a/docker-compose.yml b/docker-compose.yml index cb0795c1..e0c6a7b6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,148 +1,148 @@ -# EfficientAI Docker Compose -# -# Usage: -# docker compose up # Uses latest images -# EFFICIENTAI_VERSION=1.0.0 docker compose up # Uses specific version -# -# For local development, uncomment the 'build' sections below. - -services: - db: - image: postgres:15 - container_name: efficientai_db - environment: - POSTGRES_USER: ${POSTGRES_USER:-efficientai} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} - POSTGRES_DB: ${POSTGRES_DB:-efficientai} - volumes: - - postgres_data:/var/lib/postgresql/data - ports: - - "5432:5432" - healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-efficientai}"] - interval: 10s - timeout: 5s - retries: 5 - - redis: - image: redis:7-alpine - container_name: efficientai_redis - ports: - - "6379:6379" - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s - timeout: 5s - retries: 5 - - api: - # Pre-built image from GitHub Container Registry - # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) - image: ghcr.io/efficientai-tech/efficientai-api:${EFFICIENTAI_VERSION:-latest} - # For local development, uncomment below and comment out the image line: - build: - context: . - dockerfile: docker/Dockerfile.api - container_name: efficientai_api - environment: - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - SECRET_KEY: ${SECRET_KEY:-your-secret-key-change-in-production} - UPLOAD_DIR: /app/uploads - DEBUG: ${DEBUG:-True} - FRONTEND_DIR: /app/frontend/dist - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (set storage.blob_provider: gcs in config.docker.yml): - # BLOB_STORAGE_PROVIDER: gcs - # GCS_BUCKET_NAME: your-gcs-bucket - # GCS_PROJECT_ID: your-gcp-project - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - # Optional Azure Blob Storage (set storage.blob_provider: azure in config.docker.yml): - # BLOB_STORAGE_PROVIDER: azure - # AZURE_BLOB_ENABLED: "true" - # AZURE_CONNECTION_STRING: DefaultEndpointsProtocol=https;AccountName=... - # AZURE_CONTAINER_NAME: your-container - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - ports: - - "8000:8000" - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - command: eai start --config /app/config.yml --host 0.0.0.0 --port 8000 --no-build-frontend --no-reload - - worker: - # Pre-built image from GitHub Container Registry - # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) - image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} - # For local development, uncomment below and comment out the image line: - build: - context: . - dockerfile: docker/Dockerfile.worker - args: - INSTALL_EXTRAS: "qualitative-voice" - container_name: efficientai_worker - environment: - # Worker uses Docker network, so it reaches DB/Redis via service names - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (match api service env when using gcs): - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - command: eai worker --config /app/config.yml --loglevel info - - # Dedicated worker for the call-import fan-out so large CSV imports - # don't starve synthetic-calling, audio generation, and evaluation jobs - # running on the default queue handled by the `worker` service above. - worker-imports: - image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} - build: - context: . - dockerfile: docker/Dockerfile.worker - args: - INSTALL_EXTRAS: "qualitative-voice" - container_name: efficientai_worker_imports - environment: - DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} - REDIS_URL: redis://redis:6379/0 - CELERY_BROKER_URL: redis://redis:6379/0 - CELERY_RESULT_BACKEND: redis://redis:6379/0 - ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} - # Optional GCS blob storage (match api service env when using gcs): - # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json - depends_on: - db: - condition: service_healthy - redis: - condition: service_healthy - volumes: - - ./uploads:/app/uploads - - ./.data:/app/.data - - ./config.docker.yml:/app/config.yml:ro - # Optional: mount GCP service account for GCS auth - # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro - command: eai worker --config /app/config.yml --loglevel info --queues imports --concurrency 4 - -volumes: - postgres_data: +# EfficientAI Docker Compose +# +# Usage: +# docker compose up # Uses latest images +# EFFICIENTAI_VERSION=1.0.0 docker compose up # Uses specific version +# +# For local development, uncomment the 'build' sections below. + +services: + db: + image: postgres:15 + container_name: efficientai_db + environment: + POSTGRES_USER: ${POSTGRES_USER:-efficientai} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} + POSTGRES_DB: ${POSTGRES_DB:-efficientai} + volumes: + - postgres_data:/var/lib/postgresql/data + ports: + - "5432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-efficientai}"] + interval: 10s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + container_name: efficientai_redis + ports: + - "6379:6379" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + + api: + # Pre-built image from GitHub Container Registry + # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) + image: ghcr.io/efficientai-tech/efficientai-api:${EFFICIENTAI_VERSION:-latest} + # For local development, uncomment below and comment out the image line: + build: + context: . + dockerfile: docker/Dockerfile.api + container_name: efficientai_api + environment: + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + SECRET_KEY: ${SECRET_KEY:-your-secret-key-change-in-production} + UPLOAD_DIR: /app/uploads + DEBUG: ${DEBUG:-True} + FRONTEND_DIR: /app/frontend/dist + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (set storage.blob_provider: gcs in config.docker.yml): + # BLOB_STORAGE_PROVIDER: gcs + # GCS_BUCKET_NAME: your-gcs-bucket + # GCS_PROJECT_ID: your-gcp-project + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + # Optional Azure Blob Storage (set storage.blob_provider: azure in config.docker.yml): + # BLOB_STORAGE_PROVIDER: azure + # AZURE_BLOB_ENABLED: "true" + # AZURE_CONNECTION_STRING: DefaultEndpointsProtocol=https;AccountName=... + # AZURE_CONTAINER_NAME: your-container + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + ports: + - "8000:8000" + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + command: eai start --config /app/config.yml --host 0.0.0.0 --port 8000 --no-build-frontend --no-reload + + worker: + # Pre-built image from GitHub Container Registry + # Use EFFICIENTAI_VERSION env var to pin to a specific version (e.g., 1.0.0) + image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} + # For local development, uncomment below and comment out the image line: + build: + context: . + dockerfile: docker/Dockerfile.worker + args: + INSTALL_EXTRAS: "qualitative-voice" + container_name: efficientai_worker + environment: + # Worker uses Docker network, so it reaches DB/Redis via service names + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (match api service env when using gcs): + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + command: eai worker --config /app/config.yml --loglevel info + + # Dedicated worker for the call-import fan-out so large CSV imports + # don't starve synthetic-calling, audio generation, and evaluation jobs + # running on the default queue handled by the `worker` service above. + worker-imports: + image: ghcr.io/efficientai-tech/efficientai-worker:${EFFICIENTAI_VERSION:-latest} + build: + context: . + dockerfile: docker/Dockerfile.worker + args: + INSTALL_EXTRAS: "qualitative-voice" + container_name: efficientai_worker_imports + environment: + DATABASE_URL: postgresql://${POSTGRES_USER:-efficientai}:${POSTGRES_PASSWORD:-password}@db:5432/${POSTGRES_DB:-efficientai} + REDIS_URL: redis://redis:6379/0 + CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_RESULT_BACKEND: redis://redis:6379/0 + ENCRYPTION_KEY: ${ENCRYPTION_KEY:-} + # Optional GCS blob storage (match api service env when using gcs): + # GOOGLE_APPLICATION_CREDENTIALS: /app/secrets/gcp-sa.json + depends_on: + db: + condition: service_healthy + redis: + condition: service_healthy + volumes: + - ./uploads:/app/uploads + - ./.data:/app/.data + - ./config.docker.yml:/app/config.yml:ro + # Optional: mount GCP service account for GCS auth + # - ./secrets/gcp-sa.json:/app/secrets/gcp-sa.json:ro + command: eai worker --config /app/config.yml --loglevel info --queues imports --concurrency 4 + +volumes: + postgres_data: diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api index fb79cc57..22845648 100644 --- a/docker/Dockerfile.api +++ b/docker/Dockerfile.api @@ -1,77 +1,77 @@ -# ============================================================ -# Stage 1: Frontend build (Node only — not shipped to runtime) -# ============================================================ -FROM node:18-bookworm-slim AS frontend-builder - -WORKDIR /app/frontend - -COPY frontend/package.json frontend/package-lock.json* ./ -RUN --mount=type=cache,target=/root/.npm \ - npm ci --legacy-peer-deps || npm install --legacy-peer-deps - -COPY frontend/ ./ -RUN npm run build - -# ============================================================ -# Stage 2: API runtime (Python only — no node_modules / esbuild) -# ============================================================ -FROM python:3.11-slim - -# Install build tools and WeasyPrint system libraries. -# build-essential + cmake: packages like praat-parselmouth on arm64/Apple Silicon -# libgobject/libpango/libcairo/libgdk-pixbuf/libffi/shared-mime-info: PDF rendering -RUN apt-get update && apt-get install -y \ - build-essential \ - cmake \ - libgobject-2.0-0 \ - libpango-1.0-0 \ - libpangocairo-1.0-0 \ - libcairo2 \ - libgdk-pixbuf-2.0-0 \ - libffi-dev \ - shared-mime-info \ - && rm -rf /var/lib/apt/lists/* - -# Install uv via pip to get the glibc-linked binary (the musl-linked binary -# from ghcr.io/astral-sh/uv has DNS resolution issues inside Docker) -RUN pip install --no-cache-dir uv - -WORKDIR /app - -# ============================================================ -# LAYER 1: Python dependencies (cached unless pyproject.toml changes) -# ============================================================ -COPY pyproject.toml README.md ./ - -# Create minimal src structure for editable install -RUN mkdir -p src/efficientai && touch src/efficientai/__init__.py - -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=cache,target=/root/.cache/pip \ - uv pip install --system -e . - -# Security: fail build if compromised litellm .pth file is present (CVE: supply chain attack on litellm 1.82.8) -RUN if find /usr/local/lib/ -name "litellm_init.pth" 2>/dev/null | grep -q .; then \ - echo "SECURITY: Malicious litellm_init.pth detected! Aborting build." && exit 1; \ - fi - -# ============================================================ -# LAYER 2: Frontend static assets (built in stage 1, dist only) -# ============================================================ -COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist - -# ============================================================ -# LAYER 3: Backend code (rebuilds on backend code changes) -# ============================================================ -COPY src/ ./src/ -COPY app/ ./app/ -COPY scripts/ ./scripts/ - -RUN uv pip install --system -e . - -# Create uploads directory -RUN mkdir -p /app/uploads - -EXPOSE 8000 - -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +# ============================================================ +# Stage 1: Frontend build (Node only — not shipped to runtime) +# ============================================================ +FROM node:18-bookworm-slim AS frontend-builder + +WORKDIR /app/frontend + +COPY frontend/package.json frontend/package-lock.json* ./ +RUN --mount=type=cache,target=/root/.npm \ + npm ci --legacy-peer-deps || npm install --legacy-peer-deps + +COPY frontend/ ./ +RUN npm run build + +# ============================================================ +# Stage 2: API runtime (Python only — no node_modules / esbuild) +# ============================================================ +FROM python:3.11-slim + +# Install build tools and WeasyPrint system libraries. +# build-essential + cmake: packages like praat-parselmouth on arm64/Apple Silicon +# libgobject/libpango/libcairo/libgdk-pixbuf/libffi/shared-mime-info: PDF rendering +RUN apt-get update && apt-get install -y \ + build-essential \ + cmake \ + libgobject-2.0-0 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libcairo2 \ + libgdk-pixbuf-2.0-0 \ + libffi-dev \ + shared-mime-info \ + && rm -rf /var/lib/apt/lists/* + +# Install uv via pip to get the glibc-linked binary (the musl-linked binary +# from ghcr.io/astral-sh/uv has DNS resolution issues inside Docker) +RUN pip install --no-cache-dir uv + +WORKDIR /app + +# ============================================================ +# LAYER 1: Python dependencies (cached unless pyproject.toml changes) +# ============================================================ +COPY pyproject.toml README.md ./ + +# Create minimal src structure for editable install +RUN mkdir -p src/efficientai && touch src/efficientai/__init__.py + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=cache,target=/root/.cache/pip \ + uv pip install --system -e . + +# Security: fail build if compromised litellm .pth file is present (CVE: supply chain attack on litellm 1.82.8) +RUN if find /usr/local/lib/ -name "litellm_init.pth" 2>/dev/null | grep -q .; then \ + echo "SECURITY: Malicious litellm_init.pth detected! Aborting build." && exit 1; \ + fi + +# ============================================================ +# LAYER 2: Frontend static assets (built in stage 1, dist only) +# ============================================================ +COPY --from=frontend-builder /app/frontend/dist /app/frontend/dist + +# ============================================================ +# LAYER 3: Backend code (rebuilds on backend code changes) +# ============================================================ +COPY src/ ./src/ +COPY app/ ./app/ +COPY scripts/ ./scripts/ + +RUN uv pip install --system -e . + +# Create uploads directory +RUN mkdir -p /app/uploads + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/frontend/src/components/providers/ProviderModelPicker.tsx b/frontend/src/components/providers/ProviderModelPicker.tsx index c2f26032..8e364288 100644 --- a/frontend/src/components/providers/ProviderModelPicker.tsx +++ b/frontend/src/components/providers/ProviderModelPicker.tsx @@ -131,6 +131,11 @@ const AUDIO_CAPABLE_MODEL_MATCHERS: Record = { } const AUDIO_CAPABLE_PROVIDERS = Object.keys(AUDIO_CAPABLE_MODEL_MATCHERS) +// Voice-platform integrations that also expose LLM models. Credentials +// for these providers live in the Integration table (Configurations → +// Integrations → Voice Platform), not AIProvider. +const INTEGRATION_LLM_PLATFORMS = new Set(['sarvam']) + function isAudioCapableModel(provider: string, model: string): boolean { const matchers = AUDIO_CAPABLE_MODEL_MATCHERS[provider] if (!matchers) return false @@ -161,9 +166,32 @@ export default function ProviderModelPicker({ const { data: integrations = [] } = useQuery({ queryKey: ['integrations'], queryFn: () => apiClient.listIntegrations(), - enabled: kind === 'stt', + enabled: kind === 'stt' || kind === 'llm', }) + const integrationCredentialRows: CredentialRow[] = + kind === 'stt' + ? integrations.map((i) => ({ + id: i.id, + provider: (i.platform || '').toLowerCase(), + is_active: i.is_active, + is_default: i.is_default, + name: i.name ?? null, + source: 'integration', + })) + : integrations + .filter((i) => + INTEGRATION_LLM_PLATFORMS.has((i.platform || '').toLowerCase()), + ) + .map((i) => ({ + id: i.id, + provider: (i.platform || '').toLowerCase(), + is_active: i.is_active, + is_default: i.is_default, + name: i.name ?? null, + source: 'integration', + })) + const allCredentials: CredentialRow[] = [ ...aiProviders.map((p) => ({ id: p.id, @@ -173,18 +201,7 @@ export default function ProviderModelPicker({ name: p.name ?? null, source: 'aiprovider', })), - // Only merge integrations for STT — LLM credentials never live in - // that table, so skipping it keeps the LLM picker unchanged. - ...(kind === 'stt' - ? integrations.map((i) => ({ - id: i.id, - provider: (i.platform || '').toLowerCase(), - is_active: i.is_active, - is_default: i.is_default, - name: i.name ?? null, - source: 'integration', - })) - : []), + ...integrationCredentialRows, ] const allowSet = providerAllowList diff --git a/frontend/src/config/providers.ts b/frontend/src/config/providers.ts index 6f948b80..50e4d04e 100644 --- a/frontend/src/config/providers.ts +++ b/frontend/src/config/providers.ts @@ -100,7 +100,7 @@ export const MODEL_PROVIDER_CONFIG: Record = { [ModelProvider.SARVAM]: { label: 'Sarvam', logo: '/sarvam.png', - description: 'Sarvam STT & TTS', + description: 'Sarvam STT, TTS & LLM', }, [ModelProvider.VOICEMAKER]: { label: 'VoiceMaker', @@ -162,7 +162,7 @@ export const INTEGRATION_PLATFORM_CONFIG: Record=4.47.0", "openpyxl>=3.1.0", "weasyprint>=62.0", + "flexprice>=2.1.0", ] [project.optional-dependencies] diff --git a/scripts/create_api_key.py b/scripts/create_api_key.py index 631015ee..a2b0e33b 100644 --- a/scripts/create_api_key.py +++ b/scripts/create_api_key.py @@ -23,7 +23,10 @@ from app.database import SessionLocal, init_db from app.models.database import APIKey, Organization -from app.services.organization_provisioning import provision_default_workspace +from app.services.organization_provisioning import ( + provision_billing_customer, + provision_default_workspace, +) def main() -> int: @@ -61,6 +64,10 @@ def main() -> int: db.add(org) db.flush() provision_default_workspace(db, organization_id=org.id) + provision_billing_customer( + organization_id=org.id, + name=args.new_org, + ) api_key = secrets.token_urlsafe(32) db_key = APIKey( diff --git a/scripts/remove_legacy_flexprice_meter.py b/scripts/remove_legacy_flexprice_meter.py new file mode 100644 index 00000000..a11656f6 --- /dev/null +++ b/scripts/remove_legacy_flexprice_meter.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +"""Remove legacy Flexprice meters/features superseded by the metering catalog.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import httpx +from flexprice import Flexprice + +from app.config import load_config_from_file, settings + +CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.yml" + +LEGACY_EVENT_NAMES = { + "BLIND_TEST_SHARE_CREATED_EVENT", +} + + +def _headers() -> dict[str, str]: + return {"x-api-key": settings.FLEXPRICE_API_KEY or ""} + + +def _base_url() -> str: + return (settings.FLEXPRICE_API_HOST or "").rstrip("/") + + +def _delete_meter(client: httpx.Client, meter_id: str) -> dict: + for method, suffix in (("DELETE", ""), ("POST", "/archive")): + url = f"{_base_url()}/meters/{meter_id}{suffix}" + if method == "DELETE": + resp = client.delete(url, headers=_headers()) + else: + resp = client.post(url, headers=_headers()) + if resp.status_code in (200, 204, 404): + return {"meter_id": meter_id, "method": method, "status": resp.status_code} + resp.raise_for_status() + return {"meter_id": meter_id, "status": resp.status_code} + + +def main() -> int: + if CONFIG_PATH.exists(): + load_config_from_file(str(CONFIG_PATH)) + + if not settings.FLEXPRICE_API_KEY: + print("FLEXPRICE_API_KEY is missing.", file=sys.stderr) + return 1 + + results: dict[str, list] = {"meters_removed": [], "features_removed": [], "errors": []} + + with httpx.Client(timeout=60.0) as client: + resp = client.get(f"{_base_url()}/meters", headers=_headers(), params={"limit": 200}) + resp.raise_for_status() + for meter in resp.json().get("items") or []: + event_name = meter.get("event_name") or "" + if event_name not in LEGACY_EVENT_NAMES: + continue + try: + results["meters_removed"].append( + {"event_name": event_name, **_delete_meter(client, meter["id"])} + ) + except Exception as exc: + results["errors"].append(f"meter {event_name}: {exc}") + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as sdk: + resp = sdk.features.query_feature(limit=200) + for item in resp.items or []: + lookup = (item.lookup_key or "").lower() + name = (item.name or "").upper() + meter = item.meter + meter_event = (meter.event_name if meter else "") or "" + if ( + meter_event in LEGACY_EVENT_NAMES + or lookup == "feat-blind_test_share_created_event" + or name in LEGACY_EVENT_NAMES + ): + try: + sdk.features.delete_feature(id=item.id) + results["features_removed"].append( + {"id": item.id, "lookup_key": item.lookup_key, "name": item.name} + ) + except Exception as exc: + results["errors"].append(f"feature {item.lookup_key or item.name}: {exc}") + + print(json.dumps(results, indent=2)) + return 1 if results["errors"] else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/setup_flexprice_meters.py b/scripts/setup_flexprice_meters.py new file mode 100644 index 00000000..533526c7 --- /dev/null +++ b/scripts/setup_flexprice_meters.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +"""Create Flexprice features and meters for EfficientAI usage metering catalog.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any + +import httpx +from flexprice import Flexprice + +from app.config import load_config_from_file, settings + +CONFIG_PATH = Path(__file__).resolve().parent.parent / "config.yml" + +# (event_name, display_name, aggregation_type, aggregation_field|None) +METERS: list[tuple[str, str, str, str | None]] = [ + # Voice playground + ("blind_test.share_created", "Blind Test Share Created", "COUNT", None), + ("blind_test.response_submitted", "Blind Test Response Submitted", "COUNT", None), + ("tts.generation_started", "TTS Generation Started", "COUNT", None), + ("tts.sample_synthesized", "TTS Sample Synthesized", "SUM", "quantity"), + ("tts.report_requested", "TTS Report Requested", "COUNT", None), + ("tts.report_completed", "TTS Report Completed", "COUNT", None), + # Call imports + ("call_import.batch_created", "Call Import Batch Created", "SUM", "quantity"), + ("call_import.evaluation_completed", "Call Import Evaluation Completed", "SUM", "quantity"), + # Agent playground + ("playground.web_call_started", "Playground Web Call Started", "COUNT", None), + ("playground.websocket_session_started", "Playground Websocket Session Started", "COUNT", None), + ("playground.call_evaluated", "Playground Call Evaluated", "COUNT", None), + ("playground.evaluation_completed", "Playground Evaluation Completed", "COUNT", None), + # Evaluators + ("evaluator.run_requested", "Evaluator Run Requested", "SUM", "quantity"), + ("evaluator.run_completed", "Evaluator Run Completed", "COUNT", None), + # Legacy evaluations + ("evaluation.created", "Evaluation Created", "COUNT", None), + ("evaluation.completed", "Evaluation Completed", "COUNT", None), + # Prompt optimization + ("prompt_optimization.run_started", "Prompt Optimization Run Started", "COUNT", None), + ("prompt_optimization.run_completed", "Prompt Optimization Run Completed", "COUNT", None), + # Judge alignment + ("judge_alignment.run_started", "Judge Alignment Run Started", "COUNT", None), + ("judge_alignment.run_completed", "Judge Alignment Run Completed", "COUNT", None), + # Observability + ("observability.call_ingested", "Observability Call Ingested", "COUNT", None), + ("observability.call_evaluated", "Observability Call Evaluated", "COUNT", None), + # Test agents + ("test_agent.conversation_started", "Test Agent Conversation Started", "COUNT", None), + ("test_agent.conversation_ended", "Test Agent Conversation Ended", "SUM", "quantity"), + # LLM assist + ("metrics.llm_assist", "Metrics LLM Assist", "COUNT", None), + ("chat.completion", "Chat Completion", "SUM", "quantity"), +] + +LICENSE_FEATURES: list[dict[str, Any]] = [ + { + "name": "Call Imports", + "lookup_key": "call_imports", + "description": "CSV/audio call import batches, row processing, and evaluations", + "unit_singular": "batch", + "unit_plural": "batches", + "event_name": "call_import.batch_created", + "aggregation": {"type": "COUNT"}, + }, + { + "name": "Voice Playground", + "lookup_key": "voice_playground", + "description": "TTS comparisons, blind tests, and voice quality reports", + "unit_singular": "share", + "unit_plural": "shares", + "event_name": "blind_test.share_created", + "aggregation": {"type": "COUNT"}, + }, + { + "name": "GEPA Optimization", + "lookup_key": "gepa_optimization", + "description": "Prompt optimization (GEPA) runs", + "unit_singular": "run", + "unit_plural": "runs", + "event_name": "prompt_optimization.run_started", + "aggregation": {"type": "COUNT"}, + }, +] + + +def _headers() -> dict[str, str]: + return {"x-api-key": settings.FLEXPRICE_API_KEY or "", "Content-Type": "application/json"} + + +def _base_url() -> str: + return (settings.FLEXPRICE_API_HOST or "").rstrip("/") + + +def _meter_payload(name: str, event_name: str, agg_type: str, field: str | None) -> dict[str, Any]: + aggregation: dict[str, str] = {"type": agg_type} + if field and agg_type in {"SUM", "MAX", "LATEST", "COUNT_UNIQUE", "AVG"}: + aggregation["field"] = field + return { + "name": name, + "event_name": event_name, + "aggregation": aggregation, + "reset_usage": "BILLING_PERIOD", + } + + +def _list_meters(client: httpx.Client) -> dict[str, dict]: + existing: dict[str, dict] = {} + offset = 0 + while True: + resp = client.get( + f"{_base_url()}/meters", + headers=_headers(), + params={"limit": 200, "offset": offset}, + ) + resp.raise_for_status() + data = resp.json() + items = data.get("items") or [] + for item in items: + event_name = item.get("event_name") + if event_name: + existing[event_name] = item + pagination = data.get("pagination") or {} + total = pagination.get("total") + offset += len(items) + if not items or (total is not None and offset >= total): + break + return existing + + +def _create_meter(client: httpx.Client, event_name: str, name: str, agg_type: str, field: str | None) -> dict: + payload = _meter_payload(name, event_name, agg_type, field) + resp = client.post(f"{_base_url()}/meters", headers=_headers(), json=payload) + if resp.status_code == 409 or ( + resp.status_code == 400 and "exist" in resp.text.lower() + ): + return {"skipped": True, "event_name": event_name, "detail": resp.text} + resp.raise_for_status() + return resp.json() + + +def _list_feature_lookup_keys(sdk: Flexprice) -> set[str]: + keys: set[str] = set() + offset = 0 + while True: + resp = sdk.features.query_feature(limit=200, offset=offset) + items = resp.items or [] + for item in items: + if item.lookup_key: + keys.add(item.lookup_key) + offset += len(items) + if not items: + break + return keys + + +def _create_license_feature(sdk: Flexprice, spec: dict[str, Any]) -> dict: + meter = { + "name": spec["name"], + "event_name": spec["event_name"], + "aggregation": spec["aggregation"], + "reset_usage": "BILLING_PERIOD", + } + try: + result = sdk.features.create_feature( + name=spec["name"], + type_="metered", + lookup_key=spec["lookup_key"], + description=spec["description"], + unit_singular=spec.get("unit_singular"), + unit_plural=spec.get("unit_plural"), + meter=meter, + ) + return {"created": True, "lookup_key": spec["lookup_key"], "id": result.id} + except Exception as exc: + message = str(exc).lower() + if "already exist" in message or "duplicate" in message: + return {"skipped": True, "lookup_key": spec["lookup_key"], "detail": str(exc)} + raise + + +def main() -> int: + if CONFIG_PATH.exists(): + load_config_from_file(str(CONFIG_PATH)) + + if not settings.FLEXPRICE_ENABLED or not settings.FLEXPRICE_API_KEY: + print("Flexprice is not enabled or FLEXPRICE_API_KEY is missing.", file=sys.stderr) + return 1 + + created_meters: list[str] = [] + skipped_meters: list[str] = [] + failed_meters: list[str] = [] + created_features: list[str] = [] + skipped_features: list[str] = [] + failed_features: list[str] = [] + + with httpx.Client(timeout=60.0) as http_client: + existing_meters = _list_meters(http_client) + for event_name, name, agg_type, field in METERS: + if event_name in existing_meters: + skipped_meters.append(event_name) + continue + try: + result = _create_meter(http_client, event_name, name, agg_type, field) + if result.get("skipped"): + skipped_meters.append(event_name) + else: + created_meters.append(event_name) + except Exception as exc: + failed_meters.append(f"{event_name}: {exc}") + + with Flexprice( + server_url=settings.FLEXPRICE_API_HOST, + api_key_auth=settings.FLEXPRICE_API_KEY, + ) as sdk: + existing_feature_keys = _list_feature_lookup_keys(sdk) + for spec in LICENSE_FEATURES: + key = spec["lookup_key"] + if key in existing_feature_keys: + skipped_features.append(key) + continue + try: + result = _create_license_feature(sdk, spec) + if result.get("skipped"): + skipped_features.append(key) + else: + created_features.append(key) + except Exception as exc: + failed_features.append(f"{key}: {exc}") + + summary = { + "meters": {"created": created_meters, "skipped": skipped_meters, "failed": failed_meters}, + "features": {"created": created_features, "skipped": skipped_features, "failed": failed_features}, + } + print(json.dumps(summary, indent=2)) + return 1 if (failed_meters or failed_features) else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_api/test_call_import_diarization_and_eval_llm.py b/tests/test_api/test_call_import_diarization_and_eval_llm.py index e85e73fa..ffd32385 100644 --- a/tests/test_api/test_call_import_diarization_and_eval_llm.py +++ b/tests/test_api/test_call_import_diarization_and_eval_llm.py @@ -1313,3 +1313,83 @@ def test_compact_diarisation_error_strips_details_block(): assert _compact_diarisation_error(raw) == ( "Diarisation LLM provider/model not configured." ) + + +def test_estimate_diarisation_max_tokens_scales_with_transcript(): + from app.workers.tasks.helpers import llm_diarisation + + short = llm_diarisation._estimate_diarisation_max_tokens(text_length=500) + long = llm_diarisation._estimate_diarisation_max_tokens(text_length=40_000) + assert short >= llm_diarisation._DIARISATION_MIN_MAX_TOKENS + assert long > short + assert long <= llm_diarisation._DIARISATION_MAX_MAX_TOKENS + + +def test_diariser_passes_scaled_max_tokens(monkeypatch): + from app.workers.tasks.helpers import llm_diarisation + + captured: dict = {} + + def _fake_generate_response(**kwargs): + captured.update(kwargs) + return { + "text": '{"turns": [{"speaker": "agent", "text": "Hi"}]}', + "truncated": False, + } + + monkeypatch.setattr( + llm_diarisation.llm_service, + "generate_response", + _fake_generate_response, + ) + + llm_diarisation.diarize_transcript_with_llm( + "Hi there.", + llm_provider="openai", + llm_model="gpt-4o-mini", + organization_id=uuid4(), + db=SimpleNamespace(), + ) + + assert captured["task_defaults"]["max_tokens"] >= ( + llm_diarisation._DIARISATION_MIN_MAX_TOKENS + ) + + +def test_diariser_retries_when_first_response_truncated(monkeypatch): + from app.workers.tasks.helpers import llm_diarisation + + budgets: list[int] = [] + + def _fake_generate_response(**kwargs): + budgets.append(kwargs["task_defaults"]["max_tokens"]) + if len(budgets) == 1: + return { + "text": '{"turns": [{"speaker": "agent", "text": "partial', + "truncated": True, + } + return { + "text": '{"turns": [{"speaker": "agent", "text": "Done"}]}', + "truncated": False, + } + + monkeypatch.setattr( + llm_diarisation.llm_service, + "generate_response", + _fake_generate_response, + ) + + turns = llm_diarisation.diarize_transcript_with_llm( + "word " * 500, + llm_provider="openai", + llm_model="gpt-4o-mini", + organization_id=uuid4(), + db=SimpleNamespace(), + ) + + assert len(budgets) == 2 + assert budgets[1] == min( + llm_diarisation._DIARISATION_MAX_MAX_TOKENS, + budgets[0] * 2, + ) + assert turns[0]["text"] == "Done" diff --git a/tests/test_api/test_voice_playground_routes.py b/tests/test_api/test_voice_playground_routes.py index 126da011..97397f3e 100644 --- a/tests/test_api/test_voice_playground_routes.py +++ b/tests/test_api/test_voice_playground_routes.py @@ -301,3 +301,99 @@ def test_blind_test_only_rejects_cross_workspace_tts_sample( assert response.status_code == 404 + +def _seed_blind_test_ready_comparison(db_session, org_id, workspace_id): + comparison = TTSComparison( + id=uuid4(), + organization_id=org_id, + workspace_id=workspace_id, + simulation_id="sim003", + name="Blind share metering test", + status=TTSComparisonStatus.COMPLETED.value, + mode="benchmark", + provider_a="openai", + provider_b="elevenlabs", + model_a="gpt-4o-mini-tts", + model_b="eleven_multilingual_v2", + voices_a=[{"id": "alloy", "name": "Alloy"}], + voices_b=[{"id": "rachel", "name": "Rachel"}], + sample_texts=["hello"], + num_runs=1, + ) + db_session.add(comparison) + db_session.flush() + + db_session.add_all( + [ + TTSSample( + id=uuid4(), + comparison_id=comparison.id, + organization_id=org_id, + workspace_id=workspace_id, + provider="openai", + model="gpt-4o-mini-tts", + voice_id="alloy", + voice_name="Alloy", + side="A", + sample_index=0, + run_index=0, + text="hello", + audio_s3_key="organizations/test/a.wav", + status=TTSSampleStatus.COMPLETED.value, + source_type="tts", + ), + TTSSample( + id=uuid4(), + comparison_id=comparison.id, + organization_id=org_id, + workspace_id=workspace_id, + provider="elevenlabs", + model="eleven_multilingual_v2", + voice_id="rachel", + voice_name="Rachel", + side="B", + sample_index=0, + run_index=0, + text="hello", + audio_s3_key="organizations/test/b.wav", + status=TTSSampleStatus.COMPLETED.value, + source_type="tts", + ), + ] + ) + db_session.commit() + return comparison + + +def test_create_blind_test_share_meters_only_on_first_create( + authenticated_client, db_session, org_id, default_workspace, monkeypatch +): + from app.api.v1.routes import voice_playground as vp_routes + + comparison = _seed_blind_test_ready_comparison(db_session, org_id, default_workspace.id) + calls = [] + + def _record_metering(*args, **kwargs): + calls.append((args, kwargs)) + + monkeypatch.setattr(vp_routes, "record_blind_test_share_created", _record_metering) + + payload = {"title": "Public blind test", "custom_metrics": []} + first = authenticated_client.post( + f"/api/v1/voice-playground/comparisons/{comparison.id}/share", + json=payload, + ) + assert first.status_code == 200 + + second = authenticated_client.post( + f"/api/v1/voice-playground/comparisons/{comparison.id}/share", + json={"title": "Updated title", "custom_metrics": []}, + ) + assert second.status_code == 200 + + assert len(calls) == 1 + args, kwargs = calls[0] + assert args[0] == org_id + assert kwargs["comparison_id"] == comparison.id + assert kwargs["workspace_id"] == default_workspace.id + diff --git a/tests/test_core/test_operational_access_middleware.py b/tests/test_core/test_operational_access_middleware.py new file mode 100644 index 00000000..9c7956e3 --- /dev/null +++ b/tests/test_core/test_operational_access_middleware.py @@ -0,0 +1,41 @@ +"""Tests for operational endpoint access control.""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app.config import settings +from app.core.operational_access_middleware import OperationalAccessMiddleware + + +@pytest.fixture +def operational_client(monkeypatch): + monkeypatch.setattr(settings, "OPERATIONAL_PUBLIC", False) + monkeypatch.setattr(settings, "OPERATIONAL_TRUSTED_IPS", []) + + app = FastAPI() + app.add_middleware(OperationalAccessMiddleware) + + @app.get("/health") + def health(): + return {"status": "healthy"} + + @app.get("/metrics") + def metrics(): + return {"ok": True} + + with TestClient(app) as client: + yield client + + +def test_health_allowed_without_trusted_ips(operational_client): + response = operational_client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_metrics_blocked_without_trusted_ips(operational_client): + response = operational_client.get("/metrics") + assert response.status_code == 404 diff --git a/tests/test_core/test_operational_endpoints.py b/tests/test_core/test_operational_endpoints.py index 4fd116f6..d15416c4 100644 --- a/tests/test_core/test_operational_endpoints.py +++ b/tests/test_core/test_operational_endpoints.py @@ -66,13 +66,14 @@ def metrics(): yield client -def test_public_health_via_alb_is_blocked(operational_client): +def test_public_health_via_alb_is_allowed_for_lb_probes(operational_client): + """/health stays open for ALB/kube probes; only /metrics is gated.""" response = operational_client.get( "/health", headers={"X-Forwarded-For": "203.0.113.1", "User-Agent": "SecurityScanner/1.0"}, ) - assert response.status_code == 404 + assert response.status_code == 200 def test_spoofed_user_agent_does_not_bypass_gate(monkeypatch): diff --git a/tests/test_services/test_ai/test_llm_service.py b/tests/test_services/test_ai/test_llm_service.py index 580bd2ad..ca612c18 100644 --- a/tests/test_services/test_ai/test_llm_service.py +++ b/tests/test_services/test_ai/test_llm_service.py @@ -21,11 +21,15 @@ def test_litellm_model_name_maps_known_provider_prefixes(): LLMService._litellm_model_name(ModelProvider.FIREWORKS, "deepseek-v4-pro") == "fireworks_ai/accounts/fireworks/models/deepseek-v4-pro" ) + assert ( + LLMService._litellm_model_name(ModelProvider.SARVAM, "sarvam-30b") + == "sarvam/sarvam-30b" + ) def test_generate_response_raises_when_provider_not_configured(monkeypatch): service = LLMService() - monkeypatch.setattr(service, "_get_ai_provider", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service, "_resolve_api_key", lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("AI provider sarvam not configured for this organization."))) with pytest.raises(RuntimeError, match="not configured"): service.generate_response( @@ -41,13 +45,10 @@ def test_generate_response_success_with_normalized_usage(monkeypatch): service = LLMService() monkeypatch.setattr( service, - "_get_ai_provider", - lambda *_args, **_kwargs: SimpleNamespace(api_key="encrypted-key"), + "_resolve_api_key", + lambda *_args, **_kwargs: "decrypted-key", ) - encryption_module = importlib.import_module("app.core.encryption") - monkeypatch.setattr(encryption_module, "decrypt_api_key", lambda value: f"decrypted::{value}") - # ``llm_service.generate_response`` reads ``finish_reason`` off the # first choice to flag truncated outputs (so JSON-parsing callers # can blame the right thing). The stub has to include it; "stop" @@ -82,12 +83,10 @@ def test_generate_response_wraps_litellm_errors(monkeypatch): service = LLMService() monkeypatch.setattr( service, - "_get_ai_provider", - lambda *_args, **_kwargs: SimpleNamespace(api_key="encrypted-key"), + "_resolve_api_key", + lambda *_args, **_kwargs: "decrypted-key", ) - encryption_module = importlib.import_module("app.core.encryption") - monkeypatch.setattr(encryption_module, "decrypt_api_key", lambda value: value) monkeypatch.setattr( llm_module.litellm, "completion", diff --git a/tests/test_services/test_evaluation/test_evaluation_service.py b/tests/test_services/test_evaluation/test_evaluation_service.py index 542e6588..8b71c747 100644 --- a/tests/test_services/test_evaluation/test_evaluation_service.py +++ b/tests/test_services/test_evaluation/test_evaluation_service.py @@ -49,9 +49,13 @@ def __init__(self, **kwargs): def test_process_evaluation_success(monkeypatch): evaluation_id = uuid4() audio_id = uuid4() + org_id = uuid4() + workspace_id = uuid4() evaluation = SimpleNamespace( id=evaluation_id, audio_id=audio_id, + organization_id=org_id, + workspace_id=workspace_id, model_name="base", metrics_requested=["wer", "latency"], reference_text="hello world", @@ -71,6 +75,10 @@ def test_process_evaluation_success(monkeypatch): lambda **_kwargs: {"wer": 0.0, "latency_s": 0.5, "latency_ms": 500.0}, ) monkeypatch.setattr(evaluation_module, "EvaluationResult", _FakeEvaluationResult) + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_evaluation_completed", + lambda *_a, **_k: None, + ) service = EvaluationService() result = service.process_evaluation(evaluation_id, db) diff --git a/tests/test_services/test_flexprice_service.py b/tests/test_services/test_flexprice_service.py new file mode 100644 index 00000000..7c6778f1 --- /dev/null +++ b/tests/test_services/test_flexprice_service.py @@ -0,0 +1,241 @@ +"""Unit tests for Flexprice billing service (optional metering).""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.config import settings +from app.services.billing import flexprice_service as svc + + +@pytest.fixture(autouse=True) +def reset_flexprice_settings(): + previous = ( + settings.FLEXPRICE_ENABLED, + settings.FLEXPRICE_API_KEY, + settings.FLEXPRICE_API_HOST, + svc._disabled_skip_logged, + ) + svc._disabled_skip_logged = False + yield + ( + settings.FLEXPRICE_ENABLED, + settings.FLEXPRICE_API_KEY, + settings.FLEXPRICE_API_HOST, + svc._disabled_skip_logged, + ) = previous + + +def test_is_enabled_false_when_disabled(): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + assert svc.is_enabled() is False + + +def test_is_enabled_false_when_api_key_missing(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = None + assert svc.is_enabled() is False + + +def test_is_enabled_true_when_configured(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + assert svc.is_enabled() is True + + +def test_disabled_reason_when_api_key_missing(): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = None + assert svc.disabled_reason() is not None + assert "api_key" in svc.disabled_reason() + + +@patch("flexprice.Flexprice") +def test_ensure_customer_no_op_when_disabled(mock_flexprice): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + + svc.ensure_customer(uuid4(), name="Acme") + + mock_flexprice.assert_not_called() + + +@patch("flexprice.Flexprice") +def test_ensure_customer_calls_create_customer(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + settings.FLEXPRICE_API_HOST = "https://us.api.flexprice.io/v1" + + org_id = uuid4() + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.ensure_customer(org_id, name="Acme Inc", email="admin@acme.com") + + mock_flexprice.assert_called_once_with( + server_url="https://us.api.flexprice.io/v1", + api_key_auth="test-key", + ) + mock_client.customers.create_customer.assert_called_once_with( + external_id=str(org_id), + name="Acme Inc", + email="admin@acme.com", + ) + + +@patch("flexprice.Flexprice") +def test_ensure_customer_swallows_already_exists(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.customers.create_customer.side_effect = Exception("Customer already exists") + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.ensure_customer(uuid4(), name="Acme Inc") + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_no_op_when_disabled(mock_flexprice): + settings.FLEXPRICE_ENABLED = False + settings.FLEXPRICE_API_KEY = "test-key" + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + mock_flexprice.assert_not_called() + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_ingests_event(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + share_id = uuid4() + workspace_id = uuid4() + comparison_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created( + org_id, + share_id, + workspace_id=workspace_id, + comparison_id=comparison_id, + ) + + mock_client.events.ingest_event.assert_called_once_with( + event_name="blind_test.share_created", + external_customer_id=str(org_id), + event_id=str(share_id), + source="efficientai", + properties={ + "share_id": str(share_id), + "workspace_id": str(workspace_id), + "comparison_id": str(comparison_id), + "feature": "voice_playground", + }, + ) + + +@patch("flexprice.Flexprice") +def test_record_blind_test_share_created_logs_and_swallows_errors(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.events.ingest_event.side_effect = RuntimeError("network down") + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + +@patch("flexprice.Flexprice") +def test_ingest_usage_event_falls_back_to_request_dict(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + mock_client = MagicMock() + mock_client.events.ingest_event.side_effect = [ + TypeError("ingest_event() got an unexpected keyword argument 'event_name'"), + None, + ] + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_blind_test_share_created(uuid4(), uuid4(), workspace_id=uuid4(), comparison_id=uuid4()) + + assert mock_client.events.ingest_event.call_count == 2 + assert "request" in mock_client.events.ingest_event.call_args_list[1].kwargs + + +@patch("flexprice.Flexprice") +def test_record_call_import_batch_created_includes_volume_properties(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + call_import_id = uuid4() + workspace_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_call_import_batch_created( + org_id, + call_import_id, + workspace_id=workspace_id, + total_rows=42, + source="csv", + provider="exotel", + ) + + mock_client.events.ingest_event.assert_called_once() + payload = mock_client.events.ingest_event.call_args.kwargs + assert payload["event_name"] == "call_import.batch_created" + assert payload["properties"]["total_rows"] == "42" + assert payload["properties"]["quantity"] == "42" + assert payload["properties"]["feature"] == "call_imports" + + +@patch("flexprice.Flexprice") +def test_record_call_import_evaluation_completed_meters_pass_delta(mock_flexprice): + settings.FLEXPRICE_ENABLED = True + settings.FLEXPRICE_API_KEY = "test-key" + + org_id = uuid4() + evaluation_id = uuid4() + workspace_id = uuid4() + call_import_id = uuid4() + + mock_client = MagicMock() + mock_flexprice.return_value.__enter__.return_value = mock_client + + svc.record_call_import_evaluation_completed( + org_id, + evaluation_id, + workspace_id=workspace_id, + call_import_id=call_import_id, + rows_billed=50, + completed_total=1950, + metric_count=5, + ) + + mock_client.events.ingest_event.assert_called_once_with( + event_name="call_import.evaluation_completed", + external_customer_id=str(org_id), + event_id=f"{evaluation_id}:1950", + source="efficientai", + properties={ + "workspace_id": str(workspace_id), + "feature": "call_imports", + "call_import_id": str(call_import_id), + "evaluation_id": str(evaluation_id), + "rows_billed": "50", + "completed_total": "1950", + "metric_count": "5", + "quantity": "50", + }, + ) diff --git a/tests/test_workers/test_evaluate_call_import_row.py b/tests/test_workers/test_evaluate_call_import_row.py index 8ea87e09..8ef04cde 100644 --- a/tests/test_workers/test_evaluate_call_import_row.py +++ b/tests/test_workers/test_evaluate_call_import_row.py @@ -840,3 +840,97 @@ def _simulate_cancel(*_args, **kwargs): refreshed_eval = db_session.get(CallImportEvaluation, evaluation.id) assert refreshed_eval.failed_rows == 1 assert refreshed_eval.completed_rows == 0 + + +def test_rollup_parent_emits_pass_delta_when_leaving_running(db_session, monkeypatch): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=2) + evaluation.status = "running" + eval_rows[0].status = "completed" + eval_rows[1].status = "failed" + db_session.commit() + + _rollup_parent(db_session, evaluation) + + assert evaluation.status == "partial" + assert evaluation.billed_completed_rows == 1 + assert len(recorded) == 1 + assert recorded[0]["rows_billed"] == 1 + assert recorded[0]["completed_total"] == 1 + + +def test_rollup_parent_bills_only_retry_delta(db_session, monkeypatch): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=10) + evaluation.status = "running" + evaluation.billed_completed_rows = 8 + for idx in range(8): + eval_rows[idx].status = "completed" + for idx in range(8, 10): + eval_rows[idx].status = "pending" + db_session.commit() + + _rollup_parent(db_session, evaluation) + assert evaluation.status == "running" + assert recorded == [] + + for idx in range(8, 10): + eval_rows[idx].status = "completed" + db_session.flush() + _rollup_parent(db_session, evaluation) + + assert evaluation.status == "completed" + assert evaluation.billed_completed_rows == 10 + assert len(recorded) == 1 + assert recorded[0]["rows_billed"] == 2 + assert recorded[0]["completed_total"] == 10 + + +def test_rollup_parent_skips_billing_when_metric_rerun_unchanged( + db_session, monkeypatch +): + from app.workers.tasks.evaluate_call_import_row import _rollup_parent + + recorded = [] + + def _capture(*_args, **kwargs): + recorded.append(kwargs) + + monkeypatch.setattr( + "app.services.billing.flexprice_service.record_call_import_evaluation_completed", + _capture, + ) + + _, _, _, _, evaluation, eval_rows = _seed(db_session, row_count=5) + evaluation.status = "running" + evaluation.billed_completed_rows = 5 + for eval_row in eval_rows: + eval_row.status = "completed" + db_session.commit() + + _rollup_parent(db_session, evaluation) + + assert evaluation.billed_completed_rows == 5 + assert recorded == []