Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion app/api/v1/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@
RoleEnum,
User,
)
from app.services.organization_provisioning import provision_default_workspace
from app.services.organization_provisioning import (
provision_billing_customer,
provision_default_workspace,
)

router = APIRouter(prefix="/auth", tags=["Authentication"])

Expand Down Expand Up @@ -324,6 +327,11 @@ def signup(payload: SignupRequest, db: Session = Depends(get_db)) -> TokenRespon
organization_id=organization.id,
created_by_user_id=user.id,
)
provision_billing_customer(
organization_id=organization.id,
name=org_name,
email=payload.email,
)
user.last_login_at = datetime.now(timezone.utc)
db.commit()
db.refresh(user)
Expand Down
3 changes: 2 additions & 1 deletion app/api/v1/routes/call_import_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from datetime import date, datetime, timedelta, timezone

from fastapi import APIRouter, Body, Depends, HTTPException, Query, Response, status
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Query, Response, status
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -557,6 +557,7 @@ def _rollup_evaluation_status(evaluation: CallImportEvaluation, db: Session) ->
async def create_call_import_evaluation(
call_import_id: UUID,
payload: CallImportEvaluationCreate,
background_tasks: BackgroundTasks,
api_key: str = Depends(get_api_key),
organization_id: UUID = Depends(get_organization_id),
db: Session = Depends(get_db),
Expand Down
35 changes: 34 additions & 1 deletion app/api/v1/routes/call_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple
from uuid import UUID

from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query, Response, UploadFile, status
from fastapi import APIRouter, Body, BackgroundTasks, Depends, File, Form, HTTPException, Query, Response, UploadFile, status
from loguru import logger
from sqlalchemy import desc, func, or_
from sqlalchemy.orm import Session
Expand All @@ -32,6 +32,7 @@
get_workspace_id,
require_enterprise_feature,
)
from app.services.billing.flexprice_service import record_call_import_batch_created
from app.models.database import (
CallImport,
CallImportRow,
Expand Down Expand Up @@ -1556,6 +1557,7 @@ async def update_call_import_mapping(
async def start_call_import(
call_import_id: UUID,
payload: CallImportStartRequest,
background_tasks: BackgroundTasks,
api_key: str = Depends(get_api_key),
organization_id: UUID = Depends(get_organization_id),
workspace_id: UUID = Depends(get_workspace_id),
Expand Down Expand Up @@ -1682,6 +1684,16 @@ async def start_call_import(
db.commit()
db.refresh(call_import)

background_tasks.add_task(
record_call_import_batch_created,
organization_id,
call_import.id,
workspace_id=workspace_id,
total_rows=call_import.total_rows,
source="csv",
provider=call_import.provider,
)

_enqueue_row_tasks(db, call_import, row_models)

return CallImportUploadResponse(
Expand All @@ -1705,6 +1717,7 @@ async def start_call_import(
deprecated=True,
)
async def upload_call_import_csv(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
provider: Optional[str] = Form(
None,
Expand Down Expand Up @@ -1886,6 +1899,16 @@ async def upload_call_import_csv(
db.commit()
db.refresh(call_import)

background_tasks.add_task(
record_call_import_batch_created,
organization_id,
call_import.id,
workspace_id=workspace_id,
total_rows=call_import.total_rows,
source="csv",
provider=call_import.provider,
)

_enqueue_row_tasks(db, call_import, row_models)

return CallImportUploadResponse(
Expand All @@ -1908,6 +1931,7 @@ async def upload_call_import_csv(
operation_id="uploadCallImportAudio",
)
async def upload_call_import_audio(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(
...,
description="One or more manual call recording audio files.",
Expand Down Expand Up @@ -2076,6 +2100,15 @@ async def upload_call_import_audio(
) from exc

db.refresh(call_import)
background_tasks.add_task(
record_call_import_batch_created,
organization_id,
call_import.id,
workspace_id=workspace_id,
total_rows=call_import.total_rows,
source="audio",
provider=None,
)
return CallImportUploadResponse(
id=call_import.id,
total_rows=call_import.total_rows,
Expand Down
17 changes: 14 additions & 3 deletions app/api/v1/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
Chat/Inference API Routes
For generating responses from AI models
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List, Dict, Optional, Any
from uuid import UUID
from uuid import UUID, uuid4
from pydantic import BaseModel

from app.dependencies import get_db, get_organization_id
from app.dependencies import get_db, get_organization_id, get_workspace_id
from app.services.ai.llm_service import llm_service
from app.services.billing.flexprice_service import record_chat_completion
from app.models.schemas import ModelProvider

router = APIRouter(prefix="/chat", tags=["chat"])
Expand Down Expand Up @@ -39,7 +40,9 @@ class ChatResponse(BaseModel):
@router.post("/completion", response_model=ChatResponse)
async def chat_completion(
request: ChatRequest,
background_tasks: BackgroundTasks,
organization_id: UUID = Depends(get_organization_id),
workspace_id: UUID = Depends(get_workspace_id),
db: Session = Depends(get_db)
):
"""Generate a chat completion using the specified AI provider and model."""
Expand All @@ -59,6 +62,14 @@ async def chat_completion(
task_defaults={"temperature": 0.7},
)

background_tasks.add_task(
record_chat_completion,
organization_id,
uuid4(),
workspace_id=workspace_id,
model=result.get("model", request.model),
)

return ChatResponse(
text=result.get("text", ""),
model=result.get("model", request.model),
Expand Down
13 changes: 12 additions & 1 deletion app/api/v1/routes/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Evaluation routes."""

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from uuid import UUID
Expand All @@ -13,6 +13,7 @@
EvaluationStatusResponse,
MessageResponse,
)
from app.services.billing.flexprice_service import record_evaluation_created
from app.workers.celery_app import process_evaluation_task
from app.core.exceptions import EvaluationNotFoundError

Expand All @@ -22,6 +23,7 @@
@router.post("/create", response_model=EvaluationResponse, status_code=201)
def create_evaluation(
evaluation_data: EvaluationCreate,
background_tasks: BackgroundTasks,
api_key: str = Depends(get_api_key),
organization_id: UUID = Depends(get_organization_id),
workspace_id: UUID = Depends(get_workspace_id),
Expand Down Expand Up @@ -55,6 +57,15 @@ def create_evaluation(
db.commit()
db.refresh(evaluation)

background_tasks.add_task(
record_evaluation_created,
organization_id,
evaluation.id,
workspace_id=workspace_id,
audio_id=evaluation.audio_id,
metrics_requested=len(evaluation.metrics_requested or []),
)

# Queue async task
process_evaluation_task.delay(str(evaluation.id))

Expand Down
16 changes: 13 additions & 3 deletions app/api/v1/routes/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Evaluator routes."""

from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status, Query
from fastapi.responses import JSONResponse, Response
from sqlalchemy.orm import Session
from sqlalchemy import and_
from uuid import UUID
from uuid import UUID, uuid4
import random
from typing import List
from pydantic import BaseModel
from loguru import logger

from app.database import get_db
from app.dependencies import get_organization_id, get_workspace_id, get_api_key
from app.services.billing.flexprice_service import record_evaluator_run_requested
from app.models.database import Evaluator, Agent, Persona, Scenario, EvaluatorResult, EvaluatorResultStatus, VoiceBundle, Metric
from app.models.schemas import (
EvaluatorCreate,
Expand Down Expand Up @@ -548,6 +549,7 @@ def delete_evaluator(
@router.post("/run", response_model=RunEvaluatorsResponse, status_code=200)
def run_evaluators(
request: RunEvaluatorsRequest,
background_tasks: BackgroundTasks,
organization_id: UUID = Depends(get_organization_id),
workspace_id: UUID = Depends(get_workspace_id),
db: Session = Depends(get_db),
Expand Down Expand Up @@ -649,7 +651,15 @@ def run_evaluators(

if not task_ids:
raise HTTPException(status_code=500, detail="Failed to create any tasks")


background_tasks.add_task(
record_evaluator_run_requested,
organization_id,
uuid4(),
workspace_id=workspace_id,
quantity=len(request.evaluator_ids),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Partial Dispatch Overbills When one evaluator queues successfully and a later evaluator fails inside the dispatch loop, the endpoint still returns success because task_ids is non-empty. This line then records the full requested count, so Flexprice bills evaluators that were never queued and can never produce results. Use the count of successfully queued tasks for this event.

Suggested change
quantity=len(request.evaluator_ids),
quantity=len(task_ids),

)
Comment on lines +655 to +661

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Failed Dispatches Are Billed

The dispatch loop catches per-evaluator queue failures and continues, but this event bills len(request.evaluator_ids) instead of the number of tasks that were actually queued. A request where one evaluator queues and another raises in run_evaluator_task.delay(...) returns success with one task while recording usage for both requested evaluators.

Suggested change
background_tasks.add_task(
record_evaluator_run_requested,
organization_id,
uuid4(),
workspace_id=workspace_id,
quantity=len(request.evaluator_ids),
)
background_tasks.add_task(
record_evaluator_run_requested,
organization_id,
uuid4(),
workspace_id=workspace_id,
quantity=len(task_ids),
)

Comment on lines +655 to +661

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Partial dispatch overbills The dispatch loop only appends to task_ids after a Celery task is queued, but it catches per-evaluator failures and continues. If one evaluator queues and a later one fails, this billing event still records the full requested count even though only the queued tasks can run. Billing should use the successful dispatch count.

Suggested change
background_tasks.add_task(
record_evaluator_run_requested,
organization_id,
uuid4(),
workspace_id=workspace_id,
quantity=len(request.evaluator_ids),
)
background_tasks.add_task(
record_evaluator_run_requested,
organization_id,
uuid4(),
workspace_id=workspace_id,
quantity=len(task_ids),
)


return RunEvaluatorsResponse(
task_ids=task_ids,
evaluator_results=evaluator_results
Expand Down
29 changes: 24 additions & 5 deletions app/api/v1/routes/judge_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Dict, List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile, status
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.orm import Session
Expand All @@ -29,6 +29,7 @@
get_workspace_id,
Principal,
)
from app.services.billing.flexprice_service import record_judge_alignment_run_started
from app.models.database import (
Agent,
Evaluator,
Expand Down Expand Up @@ -617,6 +618,7 @@ def list_runs(
def trigger_judge_run(
dataset_id: UUID,
body: JudgeRunCreate,
background_tasks: BackgroundTasks,
principal: Principal = Depends(get_principal),
workspace_id: UUID = Depends(get_workspace_id),
db: Session = Depends(get_db),
Expand Down Expand Up @@ -662,6 +664,19 @@ def trigger_judge_run(
),
)

from app.services.judge_alignment.judge_runner import select_samples_for_split

sample_id_strs = (
[str(s) for s in (body.sample_ids or [])] if body.split != "all" else None
)
selected_samples = select_samples_for_split(
dataset_id=dataset.id,
split=body.split,
db=db,
sample_ids=sample_id_strs,
)
sample_count = len(selected_samples)

judge_run = JudgeRun(
dataset_id=dataset.id,
organization_id=organization_id,
Expand All @@ -679,10 +694,6 @@ def trigger_judge_run(
db.commit()
db.refresh(judge_run)

sample_id_strs = (
[str(s) for s in (body.sample_ids or [])] if body.split != "all" else None
)

try:
from app.workers.tasks.run_judge_alignment import run_judge_alignment_task

Expand All @@ -692,6 +703,14 @@ def trigger_judge_run(
judge_run.celery_task_id = async_result.id
db.commit()
db.refresh(judge_run)
background_tasks.add_task(
record_judge_alignment_run_started,
organization_id,
judge_run.id,
workspace_id=workspace_id,
dataset_id=dataset.id,
sample_count=sample_count,
)
except Exception as exc:
logger.error(f"[JudgeAlignment] Failed to enqueue judge task: {exc}")
judge_run.status = "failed"
Expand Down
Loading
Loading