Skip to content

refactor(agents-api): centralize usage tracking and remove deprecated metrics #1513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
)
from ..common.utils.llm_providers import get_api_key_env_var_name
from ..common.utils.secrets import get_secret_by_name
from ..common.utils.usage import track_embedding_usage, track_usage
from ..common.utils.usage import track_embedding_usage
from ..common.utils.usage_tracker import track_completion_usage
from ..env import (
embedding_dimensions,
embedding_model_id,
Expand Down Expand Up @@ -104,12 +105,12 @@ async def acompletion(

response = patch_litellm_response(model_response)

# Track usage in database if we have a user ID (which should be the developer ID)
# Track usage if we have a user ID (which should be the developer ID)
user = settings.get("user")
if user and isinstance(response, ModelResponse):
try:
model = response.model
await track_usage(
await track_completion_usage(
developer_id=UUID(user),
model=model,
messages=messages,
Expand Down
122 changes: 122 additions & 0 deletions src/agents-api/agents_api/common/utils/usage_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Centralized usage tracking utilities for LLM API calls.
Handles both Prometheus metrics and database tracking.
"""

from typing import Any
from uuid import UUID

from beartype import beartype
from litellm.utils import ModelResponse, Choices, Message
from prometheus_client import Counter

from .usage import track_usage

# Prometheus metrics
total_tokens_per_user = Counter(
"total_tokens_per_user",
"Total token count per user",
labelnames=("developer_id",),
)


@beartype
async def track_completion_usage(
*,
developer_id: UUID,
model: str,
messages: list[dict],
response: ModelResponse,
custom_api_used: bool = False,
metadata: Optional[dict[str, Any]] = None,
connection_pool: Any = None,
) -> None:
"""
Tracks usage for completion responses (both streaming and non-streaming).

Args:
developer_id: The developer ID for usage tracking
model: The model name used for the response
messages: The original messages sent to the model
response: The model response
custom_api_used: Whether a custom API key was used
metadata: Additional metadata for tracking
connection_pool: Connection pool for testing purposes
"""
# Track Prometheus metrics
if response.usage and response.usage.total_tokens > 0:
total_tokens_per_user.labels(str(developer_id)).inc(
amount=response.usage.total_tokens
)

# Track usage in database
await track_usage(
developer_id=developer_id,
model=model,
messages=messages,
response=response,
custom_api_used=custom_api_used,
metadata=metadata,
connection_pool=connection_pool,
)


@beartype
async def track_streaming_usage(
*,
developer_id: UUID,
model: str,
messages: list[dict],
usage_data: dict[str, Any] | None,
collected_output: list[dict],
response_id: str,
custom_api_used: bool = False,
metadata: dict[str, Any] = None,
connection_pool: Any = None,
) -> None:
"""
Tracks usage for streaming responses.

Args:
developer_id: The developer ID for usage tracking
model: The model name used for the response
messages: The original messages sent to the model
usage_data: Usage data from the streaming response
collected_output: The complete collected output from streaming
response_id: The response ID
custom_api_used: Whether a custom API key was used
metadata: Additional metadata for tracking
connection_pool: Connection pool for testing purposes
"""
# Track Prometheus metrics if usage data is available
if usage_data and usage_data.get("total_tokens", 0) > 0:
total_tokens_per_user.labels(str(developer_id)).inc(
amount=usage_data.get("total_tokens", 0)
)

# Only track usage in database if we have collected output
if not collected_output:
return

# Track usage in database
await track_usage(
developer_id=developer_id,
model=model,
messages=messages,
response=ModelResponse(
id=response_id,
choices=[
Choices(
message=Message(
content=choice.get("content", ""),
tool_calls=choice.get("tool_calls"),
),
)
for choice in collected_output
],
usage=usage_data,
),
custom_api_used=custom_api_used,
metadata=metadata,
connection_pool=connection_pool,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
convert_chat_response_to_response,
convert_create_response,
)
from ..sessions.metrics import total_tokens_per_user

from ..sessions.render import render_chat_input
from .router import router

Expand Down Expand Up @@ -315,9 +315,7 @@ async def create_response(
choices=[choice.model_dump() for choice in model_response.choices],
)

total_tokens_per_user.labels(str(developer.id)).inc(
amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0,
)


# End chat function
return convert_chat_response_to_response(
Expand Down
35 changes: 7 additions & 28 deletions src/agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from fastapi import BackgroundTasks, Depends, Header
from fastapi.responses import StreamingResponse
from litellm.utils import Choices, Message, ModelResponse
from litellm.utils import ModelResponse
from starlette.status import HTTP_201_CREATED
from uuid_extensions import uuid7

Expand All @@ -18,10 +18,9 @@
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.utils.datetime import utcnow
from ...common.utils.usage import track_usage
from ...common.utils.usage_tracker import track_streaming_usage
from ...dependencies.developer_id import get_developer_data
from ...queries.entries.create_entries import create_entries
from .metrics import total_tokens_per_user
from .render import render_chat_input
from .router import router

Expand Down Expand Up @@ -118,30 +117,14 @@ async def stream_chat_response(
# Forward the chunk as a proper ChunkChatResponse
yield f"data: {chunk_response.model_dump_json()}\n\n"

# Track token usage with Prometheus metrics if available
if usage_data and usage_data.get("total_tokens", 0) > 0:
total_tokens_per_user.labels(str(developer_id)).inc(
amount=usage_data.get("total_tokens", 0)
)

# Track usage in database
await track_usage(
# Track usage using centralized tracker
await track_streaming_usage(
developer_id=developer_id,
model=model,
messages=messages or [],
response=ModelResponse(
id=str(response_id),
choices=[
Choices(
message=Message(
content=choice.get("content", ""),
tool_calls=choice.get("tool_calls"),
),
)
for choice in collected_output
],
usage=usage_data,
),
usage_data=usage_data,
collected_output=collected_output,
response_id=str(response_id),
custom_api_used=custom_api_key_used,
metadata={
"tags": developer_tags or [],
Expand Down Expand Up @@ -300,8 +283,4 @@ async def chat(
choices=[choice.model_dump() for choice in model_response.choices],
)

total_tokens_per_user.labels(str(developer.id)).inc(
amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0,
)

return chat_response
7 changes: 0 additions & 7 deletions src/agents-api/agents_api/routers/sessions/metrics.py

This file was deleted.