arize-phoenix 10.15.0__py3-none-any.whl → 11.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/RECORD +77 -46
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/token_price_customization.py +29 -0
  12. phoenix/server/api/context.py +38 -4
  13. phoenix/server/api/dataloaders/__init__.py +41 -5
  14. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  15. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  16. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  20. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  21. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  27. phoenix/server/api/dataloaders/span_costs.py +35 -0
  28. phoenix/server/api/dataloaders/types.py +29 -0
  29. phoenix/server/api/helpers/playground_clients.py +103 -12
  30. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  31. phoenix/server/api/input_types/SpanSort.py +17 -0
  32. phoenix/server/api/mutations/__init__.py +2 -0
  33. phoenix/server/api/mutations/chat_mutations.py +17 -0
  34. phoenix/server/api/mutations/model_mutations.py +208 -0
  35. phoenix/server/api/queries.py +82 -41
  36. phoenix/server/api/routers/v1/traces.py +11 -4
  37. phoenix/server/api/subscriptions.py +36 -2
  38. phoenix/server/api/types/CostBreakdown.py +15 -0
  39. phoenix/server/api/types/Experiment.py +59 -1
  40. phoenix/server/api/types/ExperimentRun.py +58 -4
  41. phoenix/server/api/types/GenerativeModel.py +143 -2
  42. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  43. phoenix/server/api/types/ModelInterface.py +11 -0
  44. phoenix/server/api/types/PlaygroundModel.py +10 -0
  45. phoenix/server/api/types/Project.py +42 -0
  46. phoenix/server/api/types/ProjectSession.py +44 -0
  47. phoenix/server/api/types/Span.py +137 -0
  48. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  49. phoenix/server/api/types/SpanCostSummary.py +10 -0
  50. phoenix/server/api/types/TokenPrice.py +16 -0
  51. phoenix/server/api/types/TokenUsage.py +3 -3
  52. phoenix/server/api/types/Trace.py +41 -0
  53. phoenix/server/app.py +59 -0
  54. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  55. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  56. phoenix/server/cost_tracking/helpers.py +68 -0
  57. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  58. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  59. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  60. phoenix/server/daemons/__init__.py +0 -0
  61. phoenix/server/daemons/generative_model_store.py +51 -0
  62. phoenix/server/daemons/span_cost_calculator.py +103 -0
  63. phoenix/server/dml_event_handler.py +1 -0
  64. phoenix/server/static/.vite/manifest.json +36 -36
  65. phoenix/server/static/assets/components-BQWqzM6Z.js +5055 -0
  66. phoenix/server/static/assets/{index-DIlhmbjB.js → index-t6f0PRIo.js} +13 -13
  67. phoenix/server/static/assets/{pages-YX47cEoQ.js → pages-B8Uyb2qa.js} +818 -422
  68. phoenix/server/static/assets/{vendor-DCZoBorz.js → vendor-DqQvHbPa.js} +147 -147
  69. phoenix/server/static/assets/{vendor-arizeai-Ckci3irT.js → vendor-arizeai-CLX44PFA.js} +1 -1
  70. phoenix/server/static/assets/{vendor-codemirror-BODM513D.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  71. phoenix/server/static/assets/{vendor-recharts-C9O2a-N3.js → vendor-recharts-B2PJDrnX.js} +25 -25
  72. phoenix/server/static/assets/{vendor-shiki-Dq54rRC7.js → vendor-shiki-CNbrFjf9.js} +1 -1
  73. phoenix/version.py +1 -1
  74. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  75. phoenix/server/static/assets/components-SpUMF1qV.js +0 -4509
  76. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,35 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.orm import joinedload, load_only
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ SpanID: TypeAlias = int
12
+ Key: TypeAlias = SpanID
13
+ Result: TypeAlias = Optional[models.SpanCost]
14
+
15
+
16
+ class SpanCostsDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ span_ids = list(set(keys))
23
+ async with self._db() as session:
24
+ costs = {
25
+ span.id: span.span_cost
26
+ async for span in await session.stream_scalars(
27
+ select(models.Span)
28
+ .where(models.Span.id.in_(span_ids))
29
+ .options(
30
+ load_only(models.Span.id),
31
+ joinedload(models.Span.span_cost),
32
+ )
33
+ )
34
+ }
35
+ return [costs.get(span_id) for span_id in keys]
@@ -0,0 +1,29 @@
1
+ from dataclasses import dataclass, field
2
+ from functools import cached_property
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class CostBreakdown:
8
+ tokens: Optional[float] = None
9
+ cost: Optional[float] = None
10
+
11
+ @cached_property
12
+ def cost_per_token(self) -> Optional[float]:
13
+ if self.tokens and self.cost:
14
+ return self.cost / self.tokens
15
+ return None
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class SpanCostSummary:
20
+ prompt: CostBreakdown = field(default_factory=CostBreakdown)
21
+ completion: CostBreakdown = field(default_factory=CostBreakdown)
22
+ total: CostBreakdown = field(default_factory=CostBreakdown)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SpanCostDetailSummaryEntry:
27
+ token_type: str
28
+ is_prompt: bool
29
+ value: CostBreakdown = field(default_factory=CostBreakdown)
@@ -463,6 +463,35 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
463
463
  yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
464
464
  yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
465
465
 
466
+ if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details is not None:
467
+ prompt_details = usage.prompt_tokens_details
468
+ if (
469
+ hasattr(prompt_details, "cached_tokens")
470
+ and prompt_details.cached_tokens is not None
471
+ ):
472
+ yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, prompt_details.cached_tokens
473
+ if hasattr(prompt_details, "audio_tokens") and prompt_details.audio_tokens is not None:
474
+ yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, prompt_details.audio_tokens
475
+
476
+ if (
477
+ hasattr(usage, "completion_tokens_details")
478
+ and usage.completion_tokens_details is not None
479
+ ):
480
+ completion_details = usage.completion_tokens_details
481
+ if (
482
+ hasattr(completion_details, "reasoning_tokens")
483
+ and completion_details.reasoning_tokens is not None
484
+ ):
485
+ yield (
486
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
487
+ completion_details.reasoning_tokens,
488
+ )
489
+ if (
490
+ hasattr(completion_details, "audio_tokens")
491
+ and completion_details.audio_tokens is not None
492
+ ):
493
+ yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, completion_details.audio_tokens
494
+
466
495
 
467
496
  def _get_credential_value(
468
497
  credentials: Optional[list[PlaygroundClientCredential]], env_var_name: str
@@ -1115,13 +1144,20 @@ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
1115
1144
  provider_key=GenerativeProviderKey.OPENAI,
1116
1145
  model_names=[
1117
1146
  "o1",
1147
+ "o1-pro",
1118
1148
  "o1-2024-12-17",
1149
+ "o1-pro-2025-03-19",
1119
1150
  "o1-mini",
1120
1151
  "o1-mini-2024-09-12",
1121
1152
  "o1-preview",
1122
1153
  "o1-preview-2024-09-12",
1154
+ "o3",
1155
+ "o3-pro",
1156
+ "o3-2025-04-16",
1123
1157
  "o3-mini",
1124
1158
  "o3-mini-2025-01-31",
1159
+ "o4-mini",
1160
+ "o4-mini-2025-04-16",
1125
1161
  ],
1126
1162
  )
1127
1163
  class OpenAIReasoningStreamingClient(OpenAIStreamingClient):
@@ -1258,6 +1294,35 @@ class OpenAIReasoningStreamingClient(OpenAIStreamingClient):
1258
1294
  yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
1259
1295
  yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
1260
1296
 
1297
+ if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details is not None:
1298
+ prompt_details = usage.prompt_tokens_details
1299
+ if (
1300
+ hasattr(prompt_details, "cached_tokens")
1301
+ and prompt_details.cached_tokens is not None
1302
+ ):
1303
+ yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, prompt_details.cached_tokens
1304
+ if hasattr(prompt_details, "audio_tokens") and prompt_details.audio_tokens is not None:
1305
+ yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, prompt_details.audio_tokens
1306
+
1307
+ if (
1308
+ hasattr(usage, "completion_tokens_details")
1309
+ and usage.completion_tokens_details is not None
1310
+ ):
1311
+ completion_details = usage.completion_tokens_details
1312
+ if (
1313
+ hasattr(completion_details, "reasoning_tokens")
1314
+ and completion_details.reasoning_tokens is not None
1315
+ ):
1316
+ yield (
1317
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
1318
+ completion_details.reasoning_tokens,
1319
+ )
1320
+ if (
1321
+ hasattr(completion_details, "audio_tokens")
1322
+ and completion_details.audio_tokens is not None
1323
+ ):
1324
+ yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, completion_details.audio_tokens
1325
+
1261
1326
 
1262
1327
  @register_llm_client(
1263
1328
  provider_key=GenerativeProviderKey.AZURE_OPENAI,
@@ -1315,12 +1380,6 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
1315
1380
  provider_key=GenerativeProviderKey.ANTHROPIC,
1316
1381
  model_names=[
1317
1382
  PROVIDER_DEFAULT,
1318
- "claude-sonnet-4-0",
1319
- "claude-sonnet-4-20250514",
1320
- "claude-opus-4-0",
1321
- "claude-opus-4-20250514",
1322
- "claude-3-7-sonnet-latest",
1323
- "claude-3-7-sonnet-20250219",
1324
1383
  "claude-3-5-sonnet-latest",
1325
1384
  "claude-3-5-haiku-latest",
1326
1385
  "claude-3-5-sonnet-20241022",
@@ -1421,15 +1480,34 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1421
1480
  async with await throttled_stream(**anthropic_params) as stream:
1422
1481
  async for event in stream:
1423
1482
  if isinstance(event, anthropic_types.RawMessageStartEvent):
1424
- self._attributes.update(
1425
- {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
1426
- )
1483
+ usage = event.message.usage
1484
+
1485
+ token_counts: dict[str, Any] = {}
1486
+ if prompt_tokens := (
1487
+ (usage.input_tokens or 0)
1488
+ + (getattr(usage, "cache_creation_input_tokens", 0) or 0)
1489
+ + (getattr(usage, "cache_read_input_tokens", 0) or 0)
1490
+ ):
1491
+ token_counts[LLM_TOKEN_COUNT_PROMPT] = prompt_tokens
1492
+ if cache_creation_tokens := getattr(usage, "cache_creation_input_tokens", None):
1493
+ if cache_creation_tokens is not None:
1494
+ token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE] = (
1495
+ cache_creation_tokens
1496
+ )
1497
+ self._attributes.update(token_counts)
1427
1498
  elif isinstance(event, anthropic_streaming.TextEvent):
1428
1499
  yield TextChunk(content=event.text)
1429
1500
  elif isinstance(event, anthropic_streaming.MessageStopEvent):
1430
- self._attributes.update(
1431
- {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
1432
- )
1501
+ usage = event.message.usage
1502
+ output_token_counts: dict[str, Any] = {}
1503
+ if usage.output_tokens:
1504
+ output_token_counts[LLM_TOKEN_COUNT_COMPLETION] = usage.output_tokens
1505
+ if cache_read_tokens := getattr(usage, "cache_read_input_tokens", None):
1506
+ if cache_read_tokens is not None:
1507
+ output_token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ] = (
1508
+ cache_read_tokens
1509
+ )
1510
+ self._attributes.update(output_token_counts)
1433
1511
  elif (
1434
1512
  isinstance(event, anthropic_streaming.ContentBlockStopEvent)
1435
1513
  and event.content_block.type == "tool_use"
@@ -1514,6 +1592,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
1514
1592
  @register_llm_client(
1515
1593
  provider_key=GenerativeProviderKey.ANTHROPIC,
1516
1594
  model_names=[
1595
+ "claude-sonnet-4-0",
1596
+ "claude-sonnet-4-20250514",
1597
+ "claude-opus-4-0",
1598
+ "claude-opus-4-20250514",
1517
1599
  "claude-3-7-sonnet-latest",
1518
1600
  "claude-3-7-sonnet-20250219",
1519
1601
  ],
@@ -1698,6 +1780,15 @@ LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
1698
1780
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
1699
1781
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
1700
1782
  LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
1783
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
1784
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
1785
+ SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE
1786
+ )
1787
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO
1788
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
1789
+ SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
1790
+ )
1791
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO
1701
1792
 
1702
1793
 
1703
1794
  class _HttpxClient(wrapt.ObjectProxy): # type: ignore
@@ -13,6 +13,7 @@ class ProjectSessionColumn(Enum):
13
13
  endTime = auto()
14
14
  tokenCountTotal = auto()
15
15
  numTraces = auto()
16
+ costTotal = auto()
16
17
 
17
18
  @property
18
19
  def data_type(self) -> CursorSortColumnDataType:
@@ -20,6 +21,8 @@ class ProjectSessionColumn(Enum):
20
21
  return CursorSortColumnDataType.INT
21
22
  if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
22
23
  return CursorSortColumnDataType.DATETIME
24
+ if self is ProjectSessionColumn.costTotal:
25
+ return CursorSortColumnDataType.FLOAT
23
26
  assert_never(self)
24
27
 
25
28
 
@@ -27,6 +27,7 @@ class SpanColumn(Enum):
27
27
  cumulativeTokenCountTotal = auto()
28
28
  cumulativeTokenCountPrompt = auto()
29
29
  cumulativeTokenCountCompletion = auto()
30
+ tokenCostTotal = auto()
30
31
 
31
32
  @property
32
33
  def column_name(self) -> str:
@@ -56,6 +57,8 @@ class SpanColumn(Enum):
56
57
  expr = models.Span.cumulative_llm_token_count_prompt
57
58
  elif self is SpanColumn.cumulativeTokenCountCompletion:
58
59
  expr = models.Span.cumulative_llm_token_count_completion
60
+ elif self is SpanColumn.tokenCostTotal:
61
+ expr = models.SpanCost.total_cost
59
62
  else:
60
63
  assert_never(self)
61
64
  return expr.label(self.column_name)
@@ -73,12 +76,25 @@ class SpanColumn(Enum):
73
76
  or self is SpanColumn.tokenCountTotal
74
77
  or self is SpanColumn.tokenCountPrompt
75
78
  or self is SpanColumn.tokenCountCompletion
79
+ or self is SpanColumn.tokenCostTotal
76
80
  ):
77
81
  return CursorSortColumnDataType.FLOAT
78
82
  if self is SpanColumn.startTime or self is SpanColumn.endTime:
79
83
  return CursorSortColumnDataType.DATETIME
80
84
  assert_never(self)
81
85
 
86
+ def join_tables(self, stmt: Select[Any]) -> Select[Any]:
87
+ """
88
+ If needed, joins tables required for the sort column.
89
+ """
90
+ if self is SpanColumn.tokenCostTotal:
91
+ return stmt.join_from(
92
+ models.Span,
93
+ models.SpanCost,
94
+ onclause=models.SpanCost.span_rowid == models.Span.id,
95
+ )
96
+ return stmt
97
+
82
98
 
83
99
  @strawberry.enum
84
100
  class EvalAttr(Enum):
@@ -140,6 +156,7 @@ class SpanSort:
140
156
  def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig:
141
157
  if (col := self.col) and not self.eval_result_key:
142
158
  expr = col.orm_expression
159
+ stmt = col.join_tables(stmt)
143
160
  stmt = stmt.add_columns(expr)
144
161
  if self.dir == SortDir.desc:
145
162
  expr = desc(expr)
@@ -8,6 +8,7 @@ from phoenix.server.api.mutations.chat_mutations import (
8
8
  from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
9
9
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
10
10
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
11
+ from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
11
12
  from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
12
13
  from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
13
14
  ProjectTraceRetentionPolicyMutationMixin,
@@ -29,6 +30,7 @@ class Mutation(
29
30
  DatasetMutationMixin,
30
31
  ExperimentMutationMixin,
31
32
  ExportEventsMutationMixin,
33
+ ModelMutationMixin,
32
34
  ProjectMutationMixin,
33
35
  ProjectTraceRetentionPolicyMutationMixin,
34
36
  PromptMutationMixin,
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import logging
2
3
  from dataclasses import asdict, field
3
4
  from datetime import datetime, timezone
4
5
  from itertools import chain, islice
@@ -73,6 +74,8 @@ from phoenix.utilities.template_formatters import (
73
74
  TemplateFormatter,
74
75
  )
75
76
 
77
+ logger = logging.getLogger(__name__)
78
+
76
79
  initialize_playground_clients()
77
80
 
78
81
  ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
@@ -450,6 +453,19 @@ class ChatCompletionMutationMixin:
450
453
  session.add(trace)
451
454
  session.add(span)
452
455
  await session.flush()
456
+ try:
457
+ span_cost = info.context.span_cost_calculator.calculate_cost(
458
+ start_time=span.start_time,
459
+ attributes=span.attributes,
460
+ )
461
+ except Exception as e:
462
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
463
+ span_cost = None
464
+ if span_cost:
465
+ span_cost.span_rowid = span.id
466
+ span_cost.trace_rowid = trace.id
467
+ session.add(span_cost)
468
+ await session.flush()
453
469
 
454
470
  gql_span = Span(span_rowid=span.id, db_span=span)
455
471
 
@@ -605,5 +621,6 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
605
621
  TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
606
622
  PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
607
623
 
624
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
608
625
 
609
626
  PLAYGROUND_PROJECT_NAME = "playground"
@@ -0,0 +1,208 @@
1
+ import re
2
+ from datetime import datetime, timezone
3
+ from typing import Optional
4
+
5
+ import sqlalchemy as sa
6
+ import strawberry
7
+ from sqlalchemy import delete
8
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
9
+ from sqlalchemy.orm import joinedload
10
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
11
+ from strawberry.relay import GlobalID
12
+ from strawberry.types import Info
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.auth import IsNotReadOnly
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
+ from phoenix.server.api.queries import Query
19
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
20
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
21
+ from phoenix.server.api.types.TokenPrice import TokenKind
22
+
23
+
24
+ @strawberry.input
25
+ class TokenPriceInput:
26
+ token_type: str
27
+ cost_per_million_tokens: float
28
+ kind: TokenKind
29
+
30
+ @property
31
+ def token_prices(self) -> models.TokenPrice:
32
+ """Generate TokenPrice instances based on the input."""
33
+ return models.TokenPrice(
34
+ token_type=self.token_type,
35
+ is_prompt=self.kind == TokenKind.PROMPT,
36
+ base_rate=self.cost_per_million_tokens / 1_000_000,
37
+ )
38
+
39
+
40
+ @strawberry.input
41
+ class CreateModelMutationInput:
42
+ name: str
43
+ provider: Optional[str] = None
44
+ name_pattern: str
45
+ costs: list[TokenPriceInput]
46
+ start_time: Optional[datetime] = None
47
+
48
+
49
+ @strawberry.type
50
+ class CreateModelMutationPayload:
51
+ model: GenerativeModel
52
+ query: Query
53
+
54
+
55
+ @strawberry.input
56
+ class UpdateModelMutationInput:
57
+ id: GlobalID
58
+ name: str
59
+ provider: Optional[str]
60
+ name_pattern: str
61
+ costs: list[TokenPriceInput]
62
+ start_time: Optional[datetime] = None
63
+
64
+
65
+ @strawberry.type
66
+ class UpdateModelMutationPayload:
67
+ model: GenerativeModel
68
+ query: Query
69
+
70
+
71
+ @strawberry.input
72
+ class DeleteModelMutationInput:
73
+ id: GlobalID
74
+
75
+
76
+ @strawberry.type
77
+ class DeleteModelMutationPayload:
78
+ model: GenerativeModel
79
+ query: Query
80
+
81
+
82
+ @strawberry.type
83
+ class ModelMutationMixin:
84
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
85
+ async def create_model(
86
+ self,
87
+ info: Info[Context, None],
88
+ input: CreateModelMutationInput,
89
+ ) -> CreateModelMutationPayload:
90
+ cost_types = set(cost.token_type for cost in input.costs)
91
+ if "input" not in cost_types:
92
+ raise BadRequest("input cost is required")
93
+ if "output" not in cost_types:
94
+ raise BadRequest("output cost is required")
95
+ name_pattern = _compile_regular_expression(input.name_pattern)
96
+ token_prices = [cost.token_prices for cost in input.costs]
97
+ model = models.GenerativeModel(
98
+ name=input.name,
99
+ provider=input.provider,
100
+ name_pattern=name_pattern,
101
+ is_built_in=False,
102
+ token_prices=token_prices,
103
+ start_time=input.start_time,
104
+ )
105
+ async with info.context.db() as session:
106
+ session.add(model)
107
+ try:
108
+ await session.flush()
109
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
110
+ raise Conflict(f"Model with name '{input.name}' already exists")
111
+
112
+ return CreateModelMutationPayload(
113
+ model=to_gql_generative_model(model),
114
+ query=Query(),
115
+ )
116
+
117
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
118
+ async def update_model(
119
+ self,
120
+ info: Info[Context, None],
121
+ input: UpdateModelMutationInput,
122
+ ) -> UpdateModelMutationPayload:
123
+ try:
124
+ model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
125
+ except ValueError:
126
+ raise BadRequest(f'Invalid model id: "{input.id}"')
127
+
128
+ cost_types = set(cost.token_type for cost in input.costs)
129
+ if "input" not in cost_types:
130
+ raise BadRequest("input cost is required")
131
+ if "output" not in cost_types:
132
+ raise BadRequest("output cost is required")
133
+ name_pattern = _compile_regular_expression(input.name_pattern)
134
+ token_prices = [cost.token_prices for cost in input.costs]
135
+ async with info.context.db() as session:
136
+ model = await session.scalar(
137
+ sa.select(models.GenerativeModel)
138
+ .where(models.GenerativeModel.deleted_at.is_(None))
139
+ .where(models.GenerativeModel.id == model_id)
140
+ .options(joinedload(models.GenerativeModel.token_prices))
141
+ )
142
+ if model is None:
143
+ raise NotFound(f'Model "{input.id}" not found')
144
+ if model.is_built_in:
145
+ raise BadRequest("Cannot update built-in model")
146
+
147
+ await session.execute(
148
+ delete(models.TokenPrice).where(models.TokenPrice.model_id == model.id)
149
+ )
150
+
151
+ await session.refresh(model)
152
+
153
+ model.name = input.name
154
+ model.provider = input.provider or ""
155
+ model.name_pattern = name_pattern
156
+ model.token_prices = token_prices
157
+ model.start_time = input.start_time
158
+ session.add(model)
159
+ try:
160
+ await session.flush()
161
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
162
+ raise Conflict(f"Model with name '{input.name}' already exists")
163
+ await session.refresh(model)
164
+
165
+ return UpdateModelMutationPayload(
166
+ model=to_gql_generative_model(model),
167
+ query=Query(),
168
+ )
169
+
170
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
171
+ async def delete_model(
172
+ self,
173
+ info: Info[Context, None],
174
+ input: DeleteModelMutationInput,
175
+ ) -> DeleteModelMutationPayload:
176
+ try:
177
+ model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
178
+ except ValueError:
179
+ raise BadRequest(f'Invalid model id: "{input.id}"')
180
+
181
+ async with info.context.db() as session:
182
+ model = await session.scalar(
183
+ sa.update(models.GenerativeModel)
184
+ .values(deleted_at=datetime.now(timezone.utc))
185
+ .where(models.GenerativeModel.deleted_at.is_(None))
186
+ .where(models.GenerativeModel.id == model_id)
187
+ .returning(models.GenerativeModel)
188
+ )
189
+ if model is None:
190
+ raise NotFound(f'Model "{input.id}" not found')
191
+ if model.is_built_in:
192
+ await session.rollback()
193
+ raise BadRequest("Cannot delete built-in model")
194
+ return DeleteModelMutationPayload(
195
+ model=to_gql_generative_model(model),
196
+ query=Query(),
197
+ )
198
+
199
+
200
+ def _compile_regular_expression(maybe_regex: str) -> re.Pattern[str]:
201
+ """
202
+ Compile the given string as a regular expression.
203
+ Raises a BadRequest error if the given string is not a valid regex.
204
+ """
205
+ try:
206
+ return re.compile(maybe_regex)
207
+ except re.error as error:
208
+ raise BadRequest(f"Invalid regex: {str(error)}")