arize-phoenix 10.15.0__py3-none-any.whl → 11.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/METADATA +2 -2
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/RECORD +77 -46
- phoenix/config.py +5 -2
- phoenix/datetime_utils.py +8 -1
- phoenix/db/bulk_inserter.py +40 -1
- phoenix/db/facilitator.py +263 -4
- phoenix/db/insertion/helpers.py +15 -0
- phoenix/db/insertion/span.py +3 -1
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/models.py +267 -9
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/server/api/context.py +38 -4
- phoenix/server/api/dataloaders/__init__.py +41 -5
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +35 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/helpers/playground_clients.py +103 -12
- phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
- phoenix/server/api/input_types/SpanSort.py +17 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/chat_mutations.py +17 -0
- phoenix/server/api/mutations/model_mutations.py +208 -0
- phoenix/server/api/queries.py +82 -41
- phoenix/server/api/routers/v1/traces.py +11 -4
- phoenix/server/api/subscriptions.py +36 -2
- phoenix/server/api/types/CostBreakdown.py +15 -0
- phoenix/server/api/types/Experiment.py +59 -1
- phoenix/server/api/types/ExperimentRun.py +58 -4
- phoenix/server/api/types/GenerativeModel.py +143 -2
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +11 -0
- phoenix/server/api/types/PlaygroundModel.py +10 -0
- phoenix/server/api/types/Project.py +42 -0
- phoenix/server/api/types/ProjectSession.py +44 -0
- phoenix/server/api/types/Span.py +137 -0
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +41 -0
- phoenix/server/app.py +59 -0
- phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/generative_model_store.py +51 -0
- phoenix/server/daemons/span_cost_calculator.py +103 -0
- phoenix/server/dml_event_handler.py +1 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-BQWqzM6Z.js +5055 -0
- phoenix/server/static/assets/{index-DIlhmbjB.js → index-t6f0PRIo.js} +13 -13
- phoenix/server/static/assets/{pages-YX47cEoQ.js → pages-B8Uyb2qa.js} +818 -422
- phoenix/server/static/assets/{vendor-DCZoBorz.js → vendor-DqQvHbPa.js} +147 -147
- phoenix/server/static/assets/{vendor-arizeai-Ckci3irT.js → vendor-arizeai-CLX44PFA.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-BODM513D.js → vendor-codemirror-Du3XyJnB.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-C9O2a-N3.js → vendor-recharts-B2PJDrnX.js} +25 -25
- phoenix/server/static/assets/{vendor-shiki-Dq54rRC7.js → vendor-shiki-CNbrFjf9.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-SpUMF1qV.js +0 -4509
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from sqlalchemy.orm import joinedload, load_only
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
SpanID: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = SpanID
|
|
13
|
+
Result: TypeAlias = Optional[models.SpanCost]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SpanCostsDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
span_ids = list(set(keys))
|
|
23
|
+
async with self._db() as session:
|
|
24
|
+
costs = {
|
|
25
|
+
span.id: span.span_cost
|
|
26
|
+
async for span in await session.stream_scalars(
|
|
27
|
+
select(models.Span)
|
|
28
|
+
.where(models.Span.id.in_(span_ids))
|
|
29
|
+
.options(
|
|
30
|
+
load_only(models.Span.id),
|
|
31
|
+
joinedload(models.Span.span_cost),
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
}
|
|
35
|
+
return [costs.get(span_id) for span_id in keys]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(frozen=True)
|
|
7
|
+
class CostBreakdown:
|
|
8
|
+
tokens: Optional[float] = None
|
|
9
|
+
cost: Optional[float] = None
|
|
10
|
+
|
|
11
|
+
@cached_property
|
|
12
|
+
def cost_per_token(self) -> Optional[float]:
|
|
13
|
+
if self.tokens and self.cost:
|
|
14
|
+
return self.cost / self.tokens
|
|
15
|
+
return None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class SpanCostSummary:
|
|
20
|
+
prompt: CostBreakdown = field(default_factory=CostBreakdown)
|
|
21
|
+
completion: CostBreakdown = field(default_factory=CostBreakdown)
|
|
22
|
+
total: CostBreakdown = field(default_factory=CostBreakdown)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class SpanCostDetailSummaryEntry:
|
|
27
|
+
token_type: str
|
|
28
|
+
is_prompt: bool
|
|
29
|
+
value: CostBreakdown = field(default_factory=CostBreakdown)
|
|
@@ -463,6 +463,35 @@ class OpenAIBaseStreamingClient(PlaygroundStreamingClient):
|
|
|
463
463
|
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
464
464
|
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
465
465
|
|
|
466
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details is not None:
|
|
467
|
+
prompt_details = usage.prompt_tokens_details
|
|
468
|
+
if (
|
|
469
|
+
hasattr(prompt_details, "cached_tokens")
|
|
470
|
+
and prompt_details.cached_tokens is not None
|
|
471
|
+
):
|
|
472
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, prompt_details.cached_tokens
|
|
473
|
+
if hasattr(prompt_details, "audio_tokens") and prompt_details.audio_tokens is not None:
|
|
474
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, prompt_details.audio_tokens
|
|
475
|
+
|
|
476
|
+
if (
|
|
477
|
+
hasattr(usage, "completion_tokens_details")
|
|
478
|
+
and usage.completion_tokens_details is not None
|
|
479
|
+
):
|
|
480
|
+
completion_details = usage.completion_tokens_details
|
|
481
|
+
if (
|
|
482
|
+
hasattr(completion_details, "reasoning_tokens")
|
|
483
|
+
and completion_details.reasoning_tokens is not None
|
|
484
|
+
):
|
|
485
|
+
yield (
|
|
486
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
|
|
487
|
+
completion_details.reasoning_tokens,
|
|
488
|
+
)
|
|
489
|
+
if (
|
|
490
|
+
hasattr(completion_details, "audio_tokens")
|
|
491
|
+
and completion_details.audio_tokens is not None
|
|
492
|
+
):
|
|
493
|
+
yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, completion_details.audio_tokens
|
|
494
|
+
|
|
466
495
|
|
|
467
496
|
def _get_credential_value(
|
|
468
497
|
credentials: Optional[list[PlaygroundClientCredential]], env_var_name: str
|
|
@@ -1115,13 +1144,20 @@ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
1115
1144
|
provider_key=GenerativeProviderKey.OPENAI,
|
|
1116
1145
|
model_names=[
|
|
1117
1146
|
"o1",
|
|
1147
|
+
"o1-pro",
|
|
1118
1148
|
"o1-2024-12-17",
|
|
1149
|
+
"o1-pro-2025-03-19",
|
|
1119
1150
|
"o1-mini",
|
|
1120
1151
|
"o1-mini-2024-09-12",
|
|
1121
1152
|
"o1-preview",
|
|
1122
1153
|
"o1-preview-2024-09-12",
|
|
1154
|
+
"o3",
|
|
1155
|
+
"o3-pro",
|
|
1156
|
+
"o3-2025-04-16",
|
|
1123
1157
|
"o3-mini",
|
|
1124
1158
|
"o3-mini-2025-01-31",
|
|
1159
|
+
"o4-mini",
|
|
1160
|
+
"o4-mini-2025-04-16",
|
|
1125
1161
|
],
|
|
1126
1162
|
)
|
|
1127
1163
|
class OpenAIReasoningStreamingClient(OpenAIStreamingClient):
|
|
@@ -1258,6 +1294,35 @@ class OpenAIReasoningStreamingClient(OpenAIStreamingClient):
|
|
|
1258
1294
|
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
1259
1295
|
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
1260
1296
|
|
|
1297
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details is not None:
|
|
1298
|
+
prompt_details = usage.prompt_tokens_details
|
|
1299
|
+
if (
|
|
1300
|
+
hasattr(prompt_details, "cached_tokens")
|
|
1301
|
+
and prompt_details.cached_tokens is not None
|
|
1302
|
+
):
|
|
1303
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, prompt_details.cached_tokens
|
|
1304
|
+
if hasattr(prompt_details, "audio_tokens") and prompt_details.audio_tokens is not None:
|
|
1305
|
+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, prompt_details.audio_tokens
|
|
1306
|
+
|
|
1307
|
+
if (
|
|
1308
|
+
hasattr(usage, "completion_tokens_details")
|
|
1309
|
+
and usage.completion_tokens_details is not None
|
|
1310
|
+
):
|
|
1311
|
+
completion_details = usage.completion_tokens_details
|
|
1312
|
+
if (
|
|
1313
|
+
hasattr(completion_details, "reasoning_tokens")
|
|
1314
|
+
and completion_details.reasoning_tokens is not None
|
|
1315
|
+
):
|
|
1316
|
+
yield (
|
|
1317
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
|
|
1318
|
+
completion_details.reasoning_tokens,
|
|
1319
|
+
)
|
|
1320
|
+
if (
|
|
1321
|
+
hasattr(completion_details, "audio_tokens")
|
|
1322
|
+
and completion_details.audio_tokens is not None
|
|
1323
|
+
):
|
|
1324
|
+
yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, completion_details.audio_tokens
|
|
1325
|
+
|
|
1261
1326
|
|
|
1262
1327
|
@register_llm_client(
|
|
1263
1328
|
provider_key=GenerativeProviderKey.AZURE_OPENAI,
|
|
@@ -1315,12 +1380,6 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
1315
1380
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
1316
1381
|
model_names=[
|
|
1317
1382
|
PROVIDER_DEFAULT,
|
|
1318
|
-
"claude-sonnet-4-0",
|
|
1319
|
-
"claude-sonnet-4-20250514",
|
|
1320
|
-
"claude-opus-4-0",
|
|
1321
|
-
"claude-opus-4-20250514",
|
|
1322
|
-
"claude-3-7-sonnet-latest",
|
|
1323
|
-
"claude-3-7-sonnet-20250219",
|
|
1324
1383
|
"claude-3-5-sonnet-latest",
|
|
1325
1384
|
"claude-3-5-haiku-latest",
|
|
1326
1385
|
"claude-3-5-sonnet-20241022",
|
|
@@ -1421,15 +1480,34 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
1421
1480
|
async with await throttled_stream(**anthropic_params) as stream:
|
|
1422
1481
|
async for event in stream:
|
|
1423
1482
|
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1483
|
+
usage = event.message.usage
|
|
1484
|
+
|
|
1485
|
+
token_counts: dict[str, Any] = {}
|
|
1486
|
+
if prompt_tokens := (
|
|
1487
|
+
(usage.input_tokens or 0)
|
|
1488
|
+
+ (getattr(usage, "cache_creation_input_tokens", 0) or 0)
|
|
1489
|
+
+ (getattr(usage, "cache_read_input_tokens", 0) or 0)
|
|
1490
|
+
):
|
|
1491
|
+
token_counts[LLM_TOKEN_COUNT_PROMPT] = prompt_tokens
|
|
1492
|
+
if cache_creation_tokens := getattr(usage, "cache_creation_input_tokens", None):
|
|
1493
|
+
if cache_creation_tokens is not None:
|
|
1494
|
+
token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE] = (
|
|
1495
|
+
cache_creation_tokens
|
|
1496
|
+
)
|
|
1497
|
+
self._attributes.update(token_counts)
|
|
1427
1498
|
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
1428
1499
|
yield TextChunk(content=event.text)
|
|
1429
1500
|
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1501
|
+
usage = event.message.usage
|
|
1502
|
+
output_token_counts: dict[str, Any] = {}
|
|
1503
|
+
if usage.output_tokens:
|
|
1504
|
+
output_token_counts[LLM_TOKEN_COUNT_COMPLETION] = usage.output_tokens
|
|
1505
|
+
if cache_read_tokens := getattr(usage, "cache_read_input_tokens", None):
|
|
1506
|
+
if cache_read_tokens is not None:
|
|
1507
|
+
output_token_counts[LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ] = (
|
|
1508
|
+
cache_read_tokens
|
|
1509
|
+
)
|
|
1510
|
+
self._attributes.update(output_token_counts)
|
|
1433
1511
|
elif (
|
|
1434
1512
|
isinstance(event, anthropic_streaming.ContentBlockStopEvent)
|
|
1435
1513
|
and event.content_block.type == "tool_use"
|
|
@@ -1514,6 +1592,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
1514
1592
|
@register_llm_client(
|
|
1515
1593
|
provider_key=GenerativeProviderKey.ANTHROPIC,
|
|
1516
1594
|
model_names=[
|
|
1595
|
+
"claude-sonnet-4-0",
|
|
1596
|
+
"claude-sonnet-4-20250514",
|
|
1597
|
+
"claude-opus-4-0",
|
|
1598
|
+
"claude-opus-4-20250514",
|
|
1517
1599
|
"claude-3-7-sonnet-latest",
|
|
1518
1600
|
"claude-3-7-sonnet-20250219",
|
|
1519
1601
|
],
|
|
@@ -1698,6 +1780,15 @@ LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
|
|
|
1698
1780
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
1699
1781
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
1700
1782
|
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
1783
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
|
|
1784
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = (
|
|
1785
|
+
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE
|
|
1786
|
+
)
|
|
1787
|
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO
|
|
1788
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
|
|
1789
|
+
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
|
|
1790
|
+
)
|
|
1791
|
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO
|
|
1701
1792
|
|
|
1702
1793
|
|
|
1703
1794
|
class _HttpxClient(wrapt.ObjectProxy): # type: ignore
|
|
@@ -13,6 +13,7 @@ class ProjectSessionColumn(Enum):
|
|
|
13
13
|
endTime = auto()
|
|
14
14
|
tokenCountTotal = auto()
|
|
15
15
|
numTraces = auto()
|
|
16
|
+
costTotal = auto()
|
|
16
17
|
|
|
17
18
|
@property
|
|
18
19
|
def data_type(self) -> CursorSortColumnDataType:
|
|
@@ -20,6 +21,8 @@ class ProjectSessionColumn(Enum):
|
|
|
20
21
|
return CursorSortColumnDataType.INT
|
|
21
22
|
if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
|
|
22
23
|
return CursorSortColumnDataType.DATETIME
|
|
24
|
+
if self is ProjectSessionColumn.costTotal:
|
|
25
|
+
return CursorSortColumnDataType.FLOAT
|
|
23
26
|
assert_never(self)
|
|
24
27
|
|
|
25
28
|
|
|
@@ -27,6 +27,7 @@ class SpanColumn(Enum):
|
|
|
27
27
|
cumulativeTokenCountTotal = auto()
|
|
28
28
|
cumulativeTokenCountPrompt = auto()
|
|
29
29
|
cumulativeTokenCountCompletion = auto()
|
|
30
|
+
tokenCostTotal = auto()
|
|
30
31
|
|
|
31
32
|
@property
|
|
32
33
|
def column_name(self) -> str:
|
|
@@ -56,6 +57,8 @@ class SpanColumn(Enum):
|
|
|
56
57
|
expr = models.Span.cumulative_llm_token_count_prompt
|
|
57
58
|
elif self is SpanColumn.cumulativeTokenCountCompletion:
|
|
58
59
|
expr = models.Span.cumulative_llm_token_count_completion
|
|
60
|
+
elif self is SpanColumn.tokenCostTotal:
|
|
61
|
+
expr = models.SpanCost.total_cost
|
|
59
62
|
else:
|
|
60
63
|
assert_never(self)
|
|
61
64
|
return expr.label(self.column_name)
|
|
@@ -73,12 +76,25 @@ class SpanColumn(Enum):
|
|
|
73
76
|
or self is SpanColumn.tokenCountTotal
|
|
74
77
|
or self is SpanColumn.tokenCountPrompt
|
|
75
78
|
or self is SpanColumn.tokenCountCompletion
|
|
79
|
+
or self is SpanColumn.tokenCostTotal
|
|
76
80
|
):
|
|
77
81
|
return CursorSortColumnDataType.FLOAT
|
|
78
82
|
if self is SpanColumn.startTime or self is SpanColumn.endTime:
|
|
79
83
|
return CursorSortColumnDataType.DATETIME
|
|
80
84
|
assert_never(self)
|
|
81
85
|
|
|
86
|
+
def join_tables(self, stmt: Select[Any]) -> Select[Any]:
|
|
87
|
+
"""
|
|
88
|
+
If needed, joins tables required for the sort column.
|
|
89
|
+
"""
|
|
90
|
+
if self is SpanColumn.tokenCostTotal:
|
|
91
|
+
return stmt.join_from(
|
|
92
|
+
models.Span,
|
|
93
|
+
models.SpanCost,
|
|
94
|
+
onclause=models.SpanCost.span_rowid == models.Span.id,
|
|
95
|
+
)
|
|
96
|
+
return stmt
|
|
97
|
+
|
|
82
98
|
|
|
83
99
|
@strawberry.enum
|
|
84
100
|
class EvalAttr(Enum):
|
|
@@ -140,6 +156,7 @@ class SpanSort:
|
|
|
140
156
|
def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig:
|
|
141
157
|
if (col := self.col) and not self.eval_result_key:
|
|
142
158
|
expr = col.orm_expression
|
|
159
|
+
stmt = col.join_tables(stmt)
|
|
143
160
|
stmt = stmt.add_columns(expr)
|
|
144
161
|
if self.dir == SortDir.desc:
|
|
145
162
|
expr = desc(expr)
|
|
@@ -8,6 +8,7 @@ from phoenix.server.api.mutations.chat_mutations import (
|
|
|
8
8
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
9
9
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
10
10
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
11
|
+
from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
|
|
11
12
|
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
|
|
12
13
|
from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
|
|
13
14
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
@@ -29,6 +30,7 @@ class Mutation(
|
|
|
29
30
|
DatasetMutationMixin,
|
|
30
31
|
ExperimentMutationMixin,
|
|
31
32
|
ExportEventsMutationMixin,
|
|
33
|
+
ModelMutationMixin,
|
|
32
34
|
ProjectMutationMixin,
|
|
33
35
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
34
36
|
PromptMutationMixin,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import logging
|
|
2
3
|
from dataclasses import asdict, field
|
|
3
4
|
from datetime import datetime, timezone
|
|
4
5
|
from itertools import chain, islice
|
|
@@ -73,6 +74,8 @@ from phoenix.utilities.template_formatters import (
|
|
|
73
74
|
TemplateFormatter,
|
|
74
75
|
)
|
|
75
76
|
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
76
79
|
initialize_playground_clients()
|
|
77
80
|
|
|
78
81
|
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
|
|
@@ -450,6 +453,19 @@ class ChatCompletionMutationMixin:
|
|
|
450
453
|
session.add(trace)
|
|
451
454
|
session.add(span)
|
|
452
455
|
await session.flush()
|
|
456
|
+
try:
|
|
457
|
+
span_cost = info.context.span_cost_calculator.calculate_cost(
|
|
458
|
+
start_time=span.start_time,
|
|
459
|
+
attributes=span.attributes,
|
|
460
|
+
)
|
|
461
|
+
except Exception as e:
|
|
462
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
463
|
+
span_cost = None
|
|
464
|
+
if span_cost:
|
|
465
|
+
span_cost.span_rowid = span.id
|
|
466
|
+
span_cost.trace_rowid = trace.id
|
|
467
|
+
session.add(span_cost)
|
|
468
|
+
await session.flush()
|
|
453
469
|
|
|
454
470
|
gql_span = Span(span_rowid=span.id, db_span=span)
|
|
455
471
|
|
|
@@ -605,5 +621,6 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
|
|
|
605
621
|
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
606
622
|
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
607
623
|
|
|
624
|
+
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|
|
608
625
|
|
|
609
626
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
import strawberry
|
|
7
|
+
from sqlalchemy import delete
|
|
8
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
9
|
+
from sqlalchemy.orm import joinedload
|
|
10
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
11
|
+
from strawberry.relay import GlobalID
|
|
12
|
+
from strawberry.types import Info
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.auth import IsNotReadOnly
|
|
16
|
+
from phoenix.server.api.context import Context
|
|
17
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
18
|
+
from phoenix.server.api.queries import Query
|
|
19
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
|
|
20
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
|
+
from phoenix.server.api.types.TokenPrice import TokenKind
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.input
|
|
25
|
+
class TokenPriceInput:
|
|
26
|
+
token_type: str
|
|
27
|
+
cost_per_million_tokens: float
|
|
28
|
+
kind: TokenKind
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def token_prices(self) -> models.TokenPrice:
|
|
32
|
+
"""Generate TokenPrice instances based on the input."""
|
|
33
|
+
return models.TokenPrice(
|
|
34
|
+
token_type=self.token_type,
|
|
35
|
+
is_prompt=self.kind == TokenKind.PROMPT,
|
|
36
|
+
base_rate=self.cost_per_million_tokens / 1_000_000,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@strawberry.input
|
|
41
|
+
class CreateModelMutationInput:
|
|
42
|
+
name: str
|
|
43
|
+
provider: Optional[str] = None
|
|
44
|
+
name_pattern: str
|
|
45
|
+
costs: list[TokenPriceInput]
|
|
46
|
+
start_time: Optional[datetime] = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@strawberry.type
|
|
50
|
+
class CreateModelMutationPayload:
|
|
51
|
+
model: GenerativeModel
|
|
52
|
+
query: Query
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@strawberry.input
|
|
56
|
+
class UpdateModelMutationInput:
|
|
57
|
+
id: GlobalID
|
|
58
|
+
name: str
|
|
59
|
+
provider: Optional[str]
|
|
60
|
+
name_pattern: str
|
|
61
|
+
costs: list[TokenPriceInput]
|
|
62
|
+
start_time: Optional[datetime] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class UpdateModelMutationPayload:
|
|
67
|
+
model: GenerativeModel
|
|
68
|
+
query: Query
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@strawberry.input
|
|
72
|
+
class DeleteModelMutationInput:
|
|
73
|
+
id: GlobalID
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@strawberry.type
|
|
77
|
+
class DeleteModelMutationPayload:
|
|
78
|
+
model: GenerativeModel
|
|
79
|
+
query: Query
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@strawberry.type
|
|
83
|
+
class ModelMutationMixin:
|
|
84
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
85
|
+
async def create_model(
|
|
86
|
+
self,
|
|
87
|
+
info: Info[Context, None],
|
|
88
|
+
input: CreateModelMutationInput,
|
|
89
|
+
) -> CreateModelMutationPayload:
|
|
90
|
+
cost_types = set(cost.token_type for cost in input.costs)
|
|
91
|
+
if "input" not in cost_types:
|
|
92
|
+
raise BadRequest("input cost is required")
|
|
93
|
+
if "output" not in cost_types:
|
|
94
|
+
raise BadRequest("output cost is required")
|
|
95
|
+
name_pattern = _compile_regular_expression(input.name_pattern)
|
|
96
|
+
token_prices = [cost.token_prices for cost in input.costs]
|
|
97
|
+
model = models.GenerativeModel(
|
|
98
|
+
name=input.name,
|
|
99
|
+
provider=input.provider,
|
|
100
|
+
name_pattern=name_pattern,
|
|
101
|
+
is_built_in=False,
|
|
102
|
+
token_prices=token_prices,
|
|
103
|
+
start_time=input.start_time,
|
|
104
|
+
)
|
|
105
|
+
async with info.context.db() as session:
|
|
106
|
+
session.add(model)
|
|
107
|
+
try:
|
|
108
|
+
await session.flush()
|
|
109
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
110
|
+
raise Conflict(f"Model with name '{input.name}' already exists")
|
|
111
|
+
|
|
112
|
+
return CreateModelMutationPayload(
|
|
113
|
+
model=to_gql_generative_model(model),
|
|
114
|
+
query=Query(),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
118
|
+
async def update_model(
|
|
119
|
+
self,
|
|
120
|
+
info: Info[Context, None],
|
|
121
|
+
input: UpdateModelMutationInput,
|
|
122
|
+
) -> UpdateModelMutationPayload:
|
|
123
|
+
try:
|
|
124
|
+
model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
|
|
125
|
+
except ValueError:
|
|
126
|
+
raise BadRequest(f'Invalid model id: "{input.id}"')
|
|
127
|
+
|
|
128
|
+
cost_types = set(cost.token_type for cost in input.costs)
|
|
129
|
+
if "input" not in cost_types:
|
|
130
|
+
raise BadRequest("input cost is required")
|
|
131
|
+
if "output" not in cost_types:
|
|
132
|
+
raise BadRequest("output cost is required")
|
|
133
|
+
name_pattern = _compile_regular_expression(input.name_pattern)
|
|
134
|
+
token_prices = [cost.token_prices for cost in input.costs]
|
|
135
|
+
async with info.context.db() as session:
|
|
136
|
+
model = await session.scalar(
|
|
137
|
+
sa.select(models.GenerativeModel)
|
|
138
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
139
|
+
.where(models.GenerativeModel.id == model_id)
|
|
140
|
+
.options(joinedload(models.GenerativeModel.token_prices))
|
|
141
|
+
)
|
|
142
|
+
if model is None:
|
|
143
|
+
raise NotFound(f'Model "{input.id}" not found')
|
|
144
|
+
if model.is_built_in:
|
|
145
|
+
raise BadRequest("Cannot update built-in model")
|
|
146
|
+
|
|
147
|
+
await session.execute(
|
|
148
|
+
delete(models.TokenPrice).where(models.TokenPrice.model_id == model.id)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
await session.refresh(model)
|
|
152
|
+
|
|
153
|
+
model.name = input.name
|
|
154
|
+
model.provider = input.provider or ""
|
|
155
|
+
model.name_pattern = name_pattern
|
|
156
|
+
model.token_prices = token_prices
|
|
157
|
+
model.start_time = input.start_time
|
|
158
|
+
session.add(model)
|
|
159
|
+
try:
|
|
160
|
+
await session.flush()
|
|
161
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
162
|
+
raise Conflict(f"Model with name '{input.name}' already exists")
|
|
163
|
+
await session.refresh(model)
|
|
164
|
+
|
|
165
|
+
return UpdateModelMutationPayload(
|
|
166
|
+
model=to_gql_generative_model(model),
|
|
167
|
+
query=Query(),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
171
|
+
async def delete_model(
|
|
172
|
+
self,
|
|
173
|
+
info: Info[Context, None],
|
|
174
|
+
input: DeleteModelMutationInput,
|
|
175
|
+
) -> DeleteModelMutationPayload:
|
|
176
|
+
try:
|
|
177
|
+
model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
|
|
178
|
+
except ValueError:
|
|
179
|
+
raise BadRequest(f'Invalid model id: "{input.id}"')
|
|
180
|
+
|
|
181
|
+
async with info.context.db() as session:
|
|
182
|
+
model = await session.scalar(
|
|
183
|
+
sa.update(models.GenerativeModel)
|
|
184
|
+
.values(deleted_at=datetime.now(timezone.utc))
|
|
185
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
186
|
+
.where(models.GenerativeModel.id == model_id)
|
|
187
|
+
.returning(models.GenerativeModel)
|
|
188
|
+
)
|
|
189
|
+
if model is None:
|
|
190
|
+
raise NotFound(f'Model "{input.id}" not found')
|
|
191
|
+
if model.is_built_in:
|
|
192
|
+
await session.rollback()
|
|
193
|
+
raise BadRequest("Cannot delete built-in model")
|
|
194
|
+
return DeleteModelMutationPayload(
|
|
195
|
+
model=to_gql_generative_model(model),
|
|
196
|
+
query=Query(),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _compile_regular_expression(maybe_regex: str) -> re.Pattern[str]:
|
|
201
|
+
"""
|
|
202
|
+
Compile the given string as a regular expression.
|
|
203
|
+
Raises a BadRequest error if the given string is not a valid regex.
|
|
204
|
+
"""
|
|
205
|
+
try:
|
|
206
|
+
return re.compile(maybe_regex)
|
|
207
|
+
except re.error as error:
|
|
208
|
+
raise BadRequest(f"Invalid regex: {str(error)}")
|