arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/model_provider.py +1 -0
  12. phoenix/db/types/token_price_customization.py +29 -0
  13. phoenix/server/api/context.py +38 -4
  14. phoenix/server/api/dataloaders/__init__.py +41 -5
  15. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  16. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  20. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  21. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  27. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  28. phoenix/server/api/dataloaders/span_costs.py +35 -0
  29. phoenix/server/api/dataloaders/types.py +29 -0
  30. phoenix/server/api/helpers/playground_clients.py +562 -12
  31. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  32. phoenix/server/api/helpers/prompts/models.py +67 -0
  33. phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
  34. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  35. phoenix/server/api/input_types/SpanSort.py +17 -0
  36. phoenix/server/api/mutations/__init__.py +2 -0
  37. phoenix/server/api/mutations/chat_mutations.py +17 -0
  38. phoenix/server/api/mutations/model_mutations.py +208 -0
  39. phoenix/server/api/queries.py +82 -41
  40. phoenix/server/api/routers/v1/traces.py +11 -4
  41. phoenix/server/api/subscriptions.py +36 -2
  42. phoenix/server/api/types/CostBreakdown.py +15 -0
  43. phoenix/server/api/types/Experiment.py +59 -1
  44. phoenix/server/api/types/ExperimentRun.py +58 -4
  45. phoenix/server/api/types/GenerativeModel.py +143 -2
  46. phoenix/server/api/types/GenerativeProvider.py +33 -20
  47. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  48. phoenix/server/api/types/ModelInterface.py +11 -0
  49. phoenix/server/api/types/PlaygroundModel.py +10 -0
  50. phoenix/server/api/types/Project.py +42 -0
  51. phoenix/server/api/types/ProjectSession.py +44 -0
  52. phoenix/server/api/types/Span.py +137 -0
  53. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  54. phoenix/server/api/types/SpanCostSummary.py +10 -0
  55. phoenix/server/api/types/TokenPrice.py +16 -0
  56. phoenix/server/api/types/TokenUsage.py +3 -3
  57. phoenix/server/api/types/Trace.py +41 -0
  58. phoenix/server/app.py +59 -0
  59. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  60. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  61. phoenix/server/cost_tracking/helpers.py +68 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  63. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  64. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  65. phoenix/server/daemons/__init__.py +0 -0
  66. phoenix/server/daemons/generative_model_store.py +51 -0
  67. phoenix/server/daemons/span_cost_calculator.py +103 -0
  68. phoenix/server/dml_event_handler.py +1 -0
  69. phoenix/server/static/.vite/manifest.json +36 -36
  70. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  71. phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
  72. phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
  73. phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
  74. phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
  75. phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  76. phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
  77. phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  80. phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
  81. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,150 @@
1
+ from datetime import datetime
2
+ from enum import Enum
3
+ from typing import Optional
4
+
1
5
  import strawberry
6
+ from openinference.semconv.trace import OpenInferenceLLMProviderValues
7
+ from sqlalchemy import inspect
8
+ from strawberry.relay import Node, NodeID
9
+ from strawberry.types import Info
10
+ from typing_extensions import assert_never
2
11
 
12
+ from phoenix.db import models
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
3
15
  from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
16
+ from phoenix.server.api.types.ModelInterface import ModelInterface
17
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
18
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
19
+ from phoenix.server.api.types.TokenPrice import TokenKind, TokenPrice
20
+
21
+
22
+ @strawberry.enum
23
+ class GenerativeModelKind(Enum):
24
+ CUSTOM = "CUSTOM"
25
+ BUILT_IN = "BUILT_IN"
4
26
 
5
27
 
6
28
  @strawberry.type
7
- class GenerativeModel:
29
+ class GenerativeModel(Node, ModelInterface):
30
+ id_attr: NodeID[int]
8
31
  name: str
9
- provider_key: GenerativeProviderKey
32
+ provider: Optional[str]
33
+ name_pattern: str
34
+ kind: GenerativeModelKind
35
+ created_at: datetime
36
+ updated_at: datetime
37
+ provider_key: Optional[GenerativeProviderKey]
38
+ costs: strawberry.Private[Optional[list[models.TokenPrice]]] = None
39
+ start_time: Optional[datetime] = None
40
+
41
+ @strawberry.field
42
+ async def token_prices(self) -> list[TokenPrice]:
43
+ if self.costs is None:
44
+ raise NotImplementedError
45
+ token_prices: list[TokenPrice] = list()
46
+ for cost in self.costs:
47
+ token_prices.append(
48
+ TokenPrice(
49
+ token_type=cost.token_type,
50
+ kind=TokenKind.PROMPT if cost.is_prompt else TokenKind.COMPLETION,
51
+ cost_per_million_tokens=cost.base_rate * 1_000_000,
52
+ cost_per_token=cost.base_rate,
53
+ )
54
+ )
55
+ return token_prices
56
+
57
+ @strawberry.field
58
+ async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
59
+ loader = info.context.data_loaders.span_cost_summary_by_generative_model
60
+ summary = await loader.load(self.id_attr)
61
+ return SpanCostSummary(
62
+ prompt=CostBreakdown(
63
+ tokens=summary.prompt.tokens,
64
+ cost=summary.prompt.cost,
65
+ ),
66
+ completion=CostBreakdown(
67
+ tokens=summary.completion.tokens,
68
+ cost=summary.completion.cost,
69
+ ),
70
+ total=CostBreakdown(
71
+ tokens=summary.total.tokens,
72
+ cost=summary.total.cost,
73
+ ),
74
+ )
75
+
76
+ @strawberry.field
77
+ async def cost_detail_summary_entries(
78
+ self,
79
+ info: Info[Context, None],
80
+ ) -> list[SpanCostDetailSummaryEntry]:
81
+ loader = info.context.data_loaders.span_cost_detail_summary_entries_by_generative_model
82
+ summary = await loader.load(self.id_attr)
83
+ return [
84
+ SpanCostDetailSummaryEntry(
85
+ token_type=entry.token_type,
86
+ is_prompt=entry.is_prompt,
87
+ value=CostBreakdown(
88
+ tokens=entry.value.tokens,
89
+ cost=entry.value.cost,
90
+ ),
91
+ )
92
+ for entry in summary
93
+ ]
94
+
95
+ @strawberry.field
96
+ async def last_used_at(self, info: Info[Context, None]) -> Optional[datetime]:
97
+ model_id = self.id_attr
98
+ return await info.context.data_loaders.last_used_times_by_generative_model_id.load(model_id)
99
+
100
+
101
+ def to_gql_generative_model(model: models.GenerativeModel) -> GenerativeModel:
102
+ costs_are_loaded = isinstance(inspect(model).attrs.token_prices.loaded_value, list)
103
+ name_pattern = model.name_pattern.pattern
104
+ assert isinstance(name_pattern, str)
105
+ return GenerativeModel(
106
+ id_attr=model.id,
107
+ name=model.name,
108
+ provider=model.provider or None,
109
+ name_pattern=name_pattern,
110
+ kind=GenerativeModelKind.BUILT_IN if model.is_built_in else GenerativeModelKind.CUSTOM,
111
+ created_at=model.created_at,
112
+ updated_at=model.updated_at,
113
+ start_time=model.start_time,
114
+ provider_key=_semconv_provider_to_gql_generative_provider_key(model.provider)
115
+ if model.provider
116
+ else None,
117
+ costs=model.token_prices if costs_are_loaded else None,
118
+ )
119
+
120
+
121
+ def _semconv_provider_to_gql_generative_provider_key(
122
+ semconv_provider_str: str,
123
+ ) -> Optional[GenerativeProviderKey]:
124
+ """
125
+ Translates a semconv provider string to a GQL GenerativeProviderKey.
126
+ """
127
+
128
+ try:
129
+ semconv_provider = OpenInferenceLLMProviderValues(semconv_provider_str)
130
+ except Exception:
131
+ return None
132
+ if semconv_provider == OpenInferenceLLMProviderValues.OPENAI:
133
+ return GenerativeProviderKey.OPENAI
134
+ if semconv_provider == OpenInferenceLLMProviderValues.ANTHROPIC:
135
+ return GenerativeProviderKey.ANTHROPIC
136
+ if semconv_provider == OpenInferenceLLMProviderValues.AZURE:
137
+ return GenerativeProviderKey.AZURE_OPENAI
138
+ if semconv_provider == OpenInferenceLLMProviderValues.GOOGLE:
139
+ return GenerativeProviderKey.GOOGLE
140
+ if semconv_provider == OpenInferenceLLMProviderValues.DEEPSEEK:
141
+ return GenerativeProviderKey.DEEPSEEK
142
+ if semconv_provider == OpenInferenceLLMProviderValues.XAI:
143
+ return GenerativeProviderKey.XAI
144
+ if semconv_provider == OpenInferenceLLMProviderValues.AWS:
145
+ raise NotImplementedError("AWS models are not yet supported")
146
+ if semconv_provider == OpenInferenceLLMProviderValues.COHERE:
147
+ raise NotImplementedError("Cohere models are not yet supported")
148
+ if semconv_provider == OpenInferenceLLMProviderValues.MISTRALAI:
149
+ raise NotImplementedError("Mistral AI models are not yet supported")
150
+ assert_never(semconv_provider)
@@ -17,6 +17,7 @@ class GenerativeProviderKey(Enum):
17
17
  DEEPSEEK = "DeepSeek"
18
18
  XAI = "xAI"
19
19
  OLLAMA = "Ollama"
20
+ AWS = "AWS Bedrock"
20
21
 
21
22
 
22
23
  @strawberry.type
@@ -38,6 +39,7 @@ class GenerativeProvider:
38
39
  GenerativeProviderKey.DEEPSEEK: ["deepseek"],
39
40
  GenerativeProviderKey.XAI: ["grok"],
40
41
  GenerativeProviderKey.OLLAMA: ["llama", "mistral", "codellama", "phi", "qwen", "gemma"],
42
+ GenerativeProviderKey.AWS: ["nova", "titan"],
41
43
  }
42
44
 
43
45
  attribute_provider_to_generative_provider_map: ClassVar[dict[str, GenerativeProviderKey]] = {
@@ -45,6 +47,7 @@ class GenerativeProvider:
45
47
  OpenInferenceLLMProviderValues.ANTHROPIC.value: GenerativeProviderKey.ANTHROPIC,
46
48
  OpenInferenceLLMProviderValues.AZURE.value: GenerativeProviderKey.AZURE_OPENAI,
47
49
  OpenInferenceLLMProviderValues.GOOGLE.value: GenerativeProviderKey.GOOGLE,
50
+ OpenInferenceLLMProviderValues.AWS.value: GenerativeProviderKey.AWS,
48
51
  # Note: DeepSeek uses OpenAI compatibility but we can't duplicate the key in the dict
49
52
  # The provider will be determined through model name prefix matching instead
50
53
  # Note: xAI uses OpenAI compatibility but we can't duplicate the key in the dict
@@ -58,26 +61,36 @@ class GenerativeProvider:
58
61
  E.x. OpenAI requires a single API key
59
62
  """
60
63
  model_provider_to_credential_requirements_map: ClassVar[
61
- dict[GenerativeProviderKey, GenerativeProviderCredentialConfig]
64
+ dict[GenerativeProviderKey, list[GenerativeProviderCredentialConfig]]
62
65
  ] = {
63
- GenerativeProviderKey.AZURE_OPENAI: GenerativeProviderCredentialConfig(
64
- env_var_name="AZURE_OPENAI_API_KEY", is_required=True
65
- ),
66
- GenerativeProviderKey.ANTHROPIC: GenerativeProviderCredentialConfig(
67
- env_var_name="ANTHROPIC_API_KEY", is_required=True
68
- ),
69
- GenerativeProviderKey.OPENAI: GenerativeProviderCredentialConfig(
70
- env_var_name="OPENAI_API_KEY", is_required=True
71
- ),
72
- GenerativeProviderKey.GOOGLE: GenerativeProviderCredentialConfig(
73
- env_var_name="GEMINI_API_KEY", is_required=True
74
- ),
75
- GenerativeProviderKey.DEEPSEEK: GenerativeProviderCredentialConfig(
76
- env_var_name="DEEPSEEK_API_KEY", is_required=True
77
- ),
78
- GenerativeProviderKey.XAI: GenerativeProviderCredentialConfig(
79
- env_var_name="XAI_API_KEY", is_required=True
80
- ),
66
+ GenerativeProviderKey.AZURE_OPENAI: [
67
+ GenerativeProviderCredentialConfig(
68
+ env_var_name="AZURE_OPENAI_API_KEY", is_required=True
69
+ )
70
+ ],
71
+ GenerativeProviderKey.ANTHROPIC: [
72
+ GenerativeProviderCredentialConfig(env_var_name="ANTHROPIC_API_KEY", is_required=True)
73
+ ],
74
+ GenerativeProviderKey.OPENAI: [
75
+ GenerativeProviderCredentialConfig(env_var_name="OPENAI_API_KEY", is_required=True)
76
+ ],
77
+ GenerativeProviderKey.GOOGLE: [
78
+ GenerativeProviderCredentialConfig(env_var_name="GEMINI_API_KEY", is_required=True)
79
+ ],
80
+ GenerativeProviderKey.DEEPSEEK: [
81
+ GenerativeProviderCredentialConfig(env_var_name="DEEPSEEK_API_KEY", is_required=True)
82
+ ],
83
+ GenerativeProviderKey.XAI: [
84
+ GenerativeProviderCredentialConfig(env_var_name="XAI_API_KEY", is_required=True)
85
+ ],
86
+ GenerativeProviderKey.OLLAMA: [],
87
+ GenerativeProviderKey.AWS: [
88
+ GenerativeProviderCredentialConfig(env_var_name="AWS_ACCESS_KEY_ID", is_required=True),
89
+ GenerativeProviderCredentialConfig(
90
+ env_var_name="AWS_SECRET_ACCESS_KEY", is_required=True
91
+ ),
92
+ GenerativeProviderCredentialConfig(env_var_name="AWS_SESSION_TOKEN", is_required=False),
93
+ ],
81
94
  }
82
95
 
83
96
  @strawberry.field
@@ -110,7 +123,7 @@ class GenerativeProvider:
110
123
  credential_requirements = self.model_provider_to_credential_requirements_map.get(self.key)
111
124
  if credential_requirements is None:
112
125
  return []
113
- return [credential_requirements]
126
+ return self.model_provider_to_credential_requirements_map[self.key]
114
127
 
115
128
  @strawberry.field(description="Whether the credentials are set on the server for the provider") # type: ignore
116
129
  async def credentials_set(self) -> bool:
@@ -28,7 +28,7 @@ from .TimeSeries import (
28
28
 
29
29
 
30
30
  @strawberry.type
31
- class Model:
31
+ class InferenceModel:
32
32
  @strawberry.field
33
33
  def dimensions(
34
34
  self,
@@ -0,0 +1,11 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
6
+
7
+
8
+ @strawberry.interface
9
+ class ModelInterface:
10
+ name: str
11
+ provider_key: Optional[GenerativeProviderKey]
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
+ from phoenix.server.api.types.ModelInterface import ModelInterface
5
+
6
+
7
+ @strawberry.type
8
+ class PlaygroundModel(ModelInterface):
9
+ name: str
10
+ provider_key: GenerativeProviderKey # PlaygroundModel always has a provider_key
@@ -28,6 +28,7 @@ from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
28
28
  from phoenix.server.api.input_types.TimeRange import TimeRange
29
29
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
30
30
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
31
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
31
32
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
32
33
  from phoenix.server.api.types.pagination import (
33
34
  ConnectionArgs,
@@ -40,6 +41,7 @@ from phoenix.server.api.types.pagination import (
40
41
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
41
42
  from phoenix.server.api.types.SortDir import SortDir
42
43
  from phoenix.server.api.types.Span import Span
44
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
43
45
  from phoenix.server.api.types.TimeSeries import TimeSeries, TimeSeriesDataPoint
44
46
  from phoenix.server.api.types.Trace import Trace
45
47
  from phoenix.server.api.types.ValidationResult import ValidationResult
@@ -175,6 +177,30 @@ class Project(Node):
175
177
  ("completion", self.project_rowid, time_range, filter_condition),
176
178
  )
177
179
 
180
+ @strawberry.field
181
+ async def cost_summary(
182
+ self,
183
+ info: Info[Context, None],
184
+ time_range: Optional[TimeRange] = UNSET,
185
+ filter_condition: Optional[str] = UNSET,
186
+ ) -> SpanCostSummary:
187
+ loader = info.context.data_loaders.span_cost_summary_by_project
188
+ summary = await loader.load((self.project_rowid, time_range, filter_condition))
189
+ return SpanCostSummary(
190
+ prompt=CostBreakdown(
191
+ tokens=summary.prompt.tokens,
192
+ cost=summary.prompt.cost,
193
+ ),
194
+ completion=CostBreakdown(
195
+ tokens=summary.completion.tokens,
196
+ cost=summary.completion.cost,
197
+ ),
198
+ total=CostBreakdown(
199
+ tokens=summary.total.tokens,
200
+ cost=summary.total.cost,
201
+ ),
202
+ )
203
+
178
204
  @strawberry.field
179
205
  async def latency_ms_quantile(
180
206
  self,
@@ -238,6 +264,7 @@ class Project(Node):
238
264
  ) -> Connection[Span]:
239
265
  stmt = (
240
266
  select(models.Span.id)
267
+ .select_from(models.Span)
241
268
  .join(models.Trace)
242
269
  .where(models.Trace.project_rowid == self.project_rowid)
243
270
  )
@@ -410,6 +437,21 @@ class Project(Node):
410
437
  assert_never(sort.col)
411
438
  key = sort_subq.c.key
412
439
  stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
440
+ elif sort.col is ProjectSessionColumn.costTotal:
441
+ sort_subq = (
442
+ select(
443
+ models.Trace.project_session_rowid.label("id"),
444
+ func.sum(models.SpanCost.total_cost).label("key"),
445
+ )
446
+ .join_from(
447
+ models.Trace,
448
+ models.SpanCost,
449
+ models.Trace.id == models.SpanCost.trace_rowid,
450
+ )
451
+ .group_by(models.Trace.project_session_rowid)
452
+ ).subquery()
453
+ key = sort_subq.c.key
454
+ stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
413
455
  else:
414
456
  assert_never(sort.col)
415
457
  stmt = stmt.add_columns(key)
@@ -9,8 +9,11 @@ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
9
 
10
10
  from phoenix.db import models
11
11
  from phoenix.server.api.context import Context
12
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
12
13
  from phoenix.server.api.types.MimeType import MimeType
13
14
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
15
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
16
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
14
17
  from phoenix.server.api.types.SpanIOValue import SpanIOValue
15
18
  from phoenix.server.api.types.TokenUsage import TokenUsage
16
19
 
@@ -122,6 +125,47 @@ class ProjectSession(Node):
122
125
  (self.id_attr, probability)
123
126
  )
124
127
 
128
+ @strawberry.field
129
+ async def cost_summary(
130
+ self,
131
+ info: Info[Context, None],
132
+ ) -> SpanCostSummary:
133
+ loader = info.context.data_loaders.span_cost_summary_by_project_session
134
+ summary = await loader.load(self.id_attr)
135
+ return SpanCostSummary(
136
+ prompt=CostBreakdown(
137
+ tokens=summary.prompt.tokens,
138
+ cost=summary.prompt.cost,
139
+ ),
140
+ completion=CostBreakdown(
141
+ tokens=summary.completion.tokens,
142
+ cost=summary.completion.cost,
143
+ ),
144
+ total=CostBreakdown(
145
+ tokens=summary.total.tokens,
146
+ cost=summary.total.cost,
147
+ ),
148
+ )
149
+
150
+ @strawberry.field
151
+ async def cost_detail_summary_entries(
152
+ self,
153
+ info: Info[Context, None],
154
+ ) -> list[SpanCostDetailSummaryEntry]:
155
+ loader = info.context.data_loaders.span_cost_detail_summary_entries_by_project_session
156
+ summary = await loader.load(self.id_attr)
157
+ return [
158
+ SpanCostDetailSummaryEntry(
159
+ token_type=entry.token_type,
160
+ is_prompt=entry.is_prompt,
161
+ value=CostBreakdown(
162
+ tokens=entry.value.tokens,
163
+ cost=entry.value.cost,
164
+ ),
165
+ )
166
+ for entry in summary
167
+ ]
168
+
125
169
 
126
170
  def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
127
171
  return ProjectSession(
@@ -19,6 +19,7 @@ from typing_extensions import Annotated, TypeAlias
19
19
  import phoenix.trace.schemas as trace_schema
20
20
  from phoenix.db import models
21
21
  from phoenix.server.api.context import Context
22
+ from phoenix.server.api.dataloaders import types as dataloader_types
22
23
  from phoenix.server.api.helpers.dataset_helpers import (
23
24
  get_dataset_example_input,
24
25
  get_dataset_example_output,
@@ -33,6 +34,7 @@ from phoenix.server.api.input_types.SpanAnnotationSort import (
33
34
  SpanAnnotationSort,
34
35
  )
35
36
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
37
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
36
38
  from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
37
39
  from phoenix.server.api.types.Evaluation import DocumentEvaluation
38
40
  from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
@@ -41,6 +43,8 @@ from phoenix.server.api.types.MimeType import MimeType
41
43
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
42
44
  from phoenix.server.api.types.SortDir import SortDir
43
45
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
46
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
47
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
44
48
  from phoenix.server.api.types.SpanIOValue import SpanIOValue, truncate_value
45
49
  from phoenix.trace.attributes import get_attribute_value
46
50
 
@@ -790,6 +794,139 @@ class Span(Node):
790
794
  )
791
795
  ]
792
796
 
797
+ @strawberry.field
798
+ async def cost_summary(self, info: Info[Context, None]) -> Optional[SpanCostSummary]:
799
+ span_cost = await info.context.data_loaders.span_cost_by_span.load(self.span_rowid)
800
+ if span_cost is None:
801
+ return None
802
+ return SpanCostSummary(
803
+ prompt=CostBreakdown(
804
+ tokens=span_cost.prompt_tokens,
805
+ cost=span_cost.prompt_cost,
806
+ ),
807
+ completion=CostBreakdown(
808
+ tokens=span_cost.completion_tokens,
809
+ cost=span_cost.completion_cost,
810
+ ),
811
+ total=CostBreakdown(
812
+ tokens=span_cost.total_tokens,
813
+ cost=span_cost.total_cost,
814
+ ),
815
+ )
816
+
817
+ @strawberry.field
818
+ async def cost_detail_summary_entries(
819
+ self, info: Info[Context, None]
820
+ ) -> list[SpanCostDetailSummaryEntry]:
821
+ loader = info.context.data_loaders.span_cost_detail_summary_entries_by_span
822
+ entries = await loader.load(self.span_rowid)
823
+ return [
824
+ SpanCostDetailSummaryEntry(
825
+ token_type=entry.token_type,
826
+ is_prompt=entry.is_prompt,
827
+ value=CostBreakdown(tokens=entry.value.tokens, cost=entry.value.cost),
828
+ )
829
+ for entry in entries
830
+ ]
831
+
832
+ @strawberry.field
833
+ async def cumulative_cost_summary(self, info: Info[Context, None]) -> Optional[SpanCostSummary]:
834
+ max_depth = 0
835
+ descendant_rowids = await info.context.data_loaders.span_descendants.load(
836
+ (self.span_rowid, max_depth)
837
+ )
838
+ span_costs = await info.context.data_loaders.span_cost_by_span.load_many(
839
+ (self.span_rowid, *descendant_rowids)
840
+ )
841
+ total_cost: Optional[float] = None
842
+ total_tokens: Optional[float] = None
843
+ prompt_cost: Optional[float] = None
844
+ prompt_tokens: Optional[float] = None
845
+ completion_cost: Optional[float] = None
846
+ completion_tokens: Optional[float] = None
847
+ for span_cost in span_costs:
848
+ if span_cost is None:
849
+ continue
850
+ if span_cost.total_cost is not None:
851
+ total_cost = (total_cost or 0) + span_cost.total_cost
852
+ if span_cost.total_tokens is not None:
853
+ total_tokens = (total_tokens or 0) + span_cost.total_tokens
854
+ if span_cost.prompt_cost is not None:
855
+ prompt_cost = (prompt_cost or 0) + span_cost.prompt_cost
856
+ if span_cost.prompt_tokens is not None:
857
+ prompt_tokens = (prompt_tokens or 0) + span_cost.prompt_tokens
858
+ if span_cost.completion_cost is not None:
859
+ completion_cost = (completion_cost or 0) + span_cost.completion_cost
860
+ if span_cost.completion_tokens is not None:
861
+ completion_tokens = (completion_tokens or 0) + span_cost.completion_tokens
862
+ return SpanCostSummary(
863
+ prompt=CostBreakdown(
864
+ tokens=prompt_tokens,
865
+ cost=prompt_cost,
866
+ ),
867
+ completion=CostBreakdown(
868
+ tokens=completion_tokens,
869
+ cost=completion_cost,
870
+ ),
871
+ total=CostBreakdown(
872
+ tokens=total_tokens,
873
+ cost=total_cost,
874
+ ),
875
+ )
876
+
877
+ @strawberry.field
878
+ async def cumulative_cost_detail_summary_entries(
879
+ self, info: Info[Context, None]
880
+ ) -> list[SpanCostDetailSummaryEntry]:
881
+ max_depth = 0
882
+ descendant_rowids = await info.context.data_loaders.span_descendants.load(
883
+ (self.span_rowid, max_depth)
884
+ )
885
+ entry_lists = (
886
+ await info.context.data_loaders.span_cost_detail_summary_entries_by_span.load_many(
887
+ (self.span_rowid, *descendant_rowids)
888
+ )
889
+ )
890
+
891
+ TokenType: TypeAlias = str
892
+ IsPrompt: TypeAlias = bool
893
+ grouped_entries: dict[
894
+ IsPrompt, dict[TokenType, list[dataloader_types.SpanCostDetailSummaryEntry]]
895
+ ] = {}
896
+
897
+ for entries in entry_lists:
898
+ for entry in entries:
899
+ is_prompt = entry.is_prompt
900
+ token_type = entry.token_type
901
+ if is_prompt not in grouped_entries:
902
+ grouped_entries[is_prompt] = {}
903
+ if token_type not in grouped_entries[is_prompt]:
904
+ grouped_entries[is_prompt][token_type] = []
905
+ grouped_entries[is_prompt][token_type].append(entry)
906
+
907
+ result: list[SpanCostDetailSummaryEntry] = []
908
+ for is_prompt in (True, False):
909
+ entries_by_token_type = grouped_entries[is_prompt]
910
+ for token_type, entries in sorted(entries_by_token_type.items()):
911
+ cost: Optional[float] = None
912
+ tokens: Optional[float] = None
913
+ for entry in entries:
914
+ if entry.value.cost is not None:
915
+ cost = (cost or 0) + entry.value.cost
916
+ if entry.value.tokens is not None:
917
+ tokens = (tokens or 0) + entry.value.tokens
918
+ result.append(
919
+ SpanCostDetailSummaryEntry(
920
+ token_type=token_type,
921
+ is_prompt=is_prompt,
922
+ value=CostBreakdown(
923
+ tokens=tokens,
924
+ cost=cost,
925
+ ),
926
+ )
927
+ )
928
+ return result
929
+
793
930
 
794
931
  def _hide_embedding_vectors(attributes: Mapping[str, Any]) -> Mapping[str, Any]:
795
932
  if not (
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
4
+
5
+
6
+ @strawberry.type
7
+ class SpanCostDetailSummaryEntry:
8
+ token_type: str
9
+ is_prompt: bool
10
+ value: CostBreakdown = strawberry.field(default_factory=CostBreakdown)
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
4
+
5
+
6
+ @strawberry.type
7
+ class SpanCostSummary:
8
+ prompt: CostBreakdown = strawberry.field(default_factory=CostBreakdown)
9
+ completion: CostBreakdown = strawberry.field(default_factory=CostBreakdown)
10
+ total: CostBreakdown = strawberry.field(default_factory=CostBreakdown)
@@ -0,0 +1,16 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ class TokenKind(Enum):
7
+ PROMPT = "prompt"
8
+ COMPLETION = "completion"
9
+
10
+
11
+ @strawberry.type
12
+ class TokenPrice:
13
+ token_type: str
14
+ kind: TokenKind
15
+ cost_per_million_tokens: float
16
+ cost_per_token: float
@@ -3,9 +3,9 @@ import strawberry
3
3
 
4
4
  @strawberry.type
5
5
  class TokenUsage:
6
- prompt: int = 0
7
- completion: int = 0
6
+ prompt: float = 0
7
+ completion: float = 0
8
8
 
9
9
  @strawberry.field
10
- async def total(self) -> int:
10
+ async def total(self) -> float:
11
11
  return self.prompt + self.completion