arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import func, select
6
- from sqlalchemy.orm import load_only
7
6
  from sqlalchemy.sql.functions import coalesce
8
7
  from strawberry import UNSET
9
8
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -13,10 +12,7 @@ from strawberry.types import Info
13
12
  from phoenix.db import models
14
13
  from phoenix.server.api.context import Context
15
14
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
16
- from phoenix.server.api.types.ExperimentRunAnnotation import (
17
- ExperimentRunAnnotation,
18
- to_gql_experiment_run_annotation,
19
- )
15
+ from phoenix.server.api.types.ExperimentRunAnnotation import ExperimentRunAnnotation
20
16
  from phoenix.server.api.types.pagination import (
21
17
  ConnectionArgs,
22
18
  CursorString,
@@ -27,18 +23,100 @@ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
27
23
  from phoenix.server.api.types.Trace import Trace
28
24
 
29
25
  if TYPE_CHECKING:
30
- from phoenix.server.api.types.DatasetExample import DatasetExample
26
+ from .DatasetExample import DatasetExample
27
+ from .Trace import Trace
31
28
 
32
29
 
33
30
  @strawberry.type
34
31
  class ExperimentRun(Node):
35
- id_attr: NodeID[int]
36
- experiment_id: GlobalID
37
- trace_id: Optional[str]
38
- output: Optional[JSON]
39
- start_time: datetime
40
- end_time: datetime
41
- error: Optional[str]
32
+ id: NodeID[int]
33
+ db_record: strawberry.Private[Optional[models.ExperimentRun]] = None
34
+
35
+ def __post_init__(self) -> None:
36
+ if self.db_record and self.id != self.db_record.id:
37
+ raise ValueError("ExperimentRun ID mismatch")
38
+
39
+ @strawberry.field
40
+ async def experiment_id(self, info: Info[Context, None]) -> GlobalID:
41
+ from .Experiment import Experiment
42
+
43
+ if self.db_record:
44
+ experiment_id = self.db_record.experiment_id
45
+ else:
46
+ experiment_id = await info.context.data_loaders.experiment_run_fields.load(
47
+ (self.id, models.ExperimentRun.experiment_id),
48
+ )
49
+ return GlobalID(Experiment.__name__, str(experiment_id))
50
+
51
+ @strawberry.field
52
+ async def repetition_number(self, info: Info[Context, None]) -> int:
53
+ if self.db_record:
54
+ val = self.db_record.repetition_number
55
+ else:
56
+ val = await info.context.data_loaders.experiment_run_fields.load(
57
+ (self.id, models.ExperimentRun.repetition_number),
58
+ )
59
+ return val
60
+
61
+ @strawberry.field
62
+ async def trace_id(self, info: Info[Context, None]) -> Optional[str]:
63
+ if self.db_record:
64
+ val = self.db_record.trace_id
65
+ else:
66
+ val = await info.context.data_loaders.experiment_run_fields.load(
67
+ (self.id, models.ExperimentRun.trace_id),
68
+ )
69
+ return val
70
+
71
+ @strawberry.field
72
+ async def output(self, info: Info[Context, None]) -> Optional[JSON]:
73
+ if self.db_record:
74
+ output_dict = self.db_record.output
75
+ else:
76
+ output_dict = await info.context.data_loaders.experiment_run_fields.load(
77
+ (self.id, models.ExperimentRun.output),
78
+ )
79
+ return output_dict.get("task_output") if output_dict else None
80
+
81
+ @strawberry.field
82
+ async def start_time(self, info: Info[Context, None]) -> datetime:
83
+ if self.db_record:
84
+ val = self.db_record.start_time
85
+ else:
86
+ val = await info.context.data_loaders.experiment_run_fields.load(
87
+ (self.id, models.ExperimentRun.start_time),
88
+ )
89
+ return val
90
+
91
+ @strawberry.field
92
+ async def end_time(self, info: Info[Context, None]) -> datetime:
93
+ if self.db_record:
94
+ val = self.db_record.end_time
95
+ else:
96
+ val = await info.context.data_loaders.experiment_run_fields.load(
97
+ (self.id, models.ExperimentRun.end_time),
98
+ )
99
+ return val
100
+
101
+ @strawberry.field
102
+ async def error(self, info: Info[Context, None]) -> Optional[str]:
103
+ if self.db_record:
104
+ val = self.db_record.error
105
+ else:
106
+ val = await info.context.data_loaders.experiment_run_fields.load(
107
+ (self.id, models.ExperimentRun.error),
108
+ )
109
+ return val
110
+
111
+ @strawberry.field
112
+ async def latency_ms(self, info: Info[Context, None]) -> float:
113
+ if self.db_record:
114
+ val = self.db_record.latency_ms
115
+ else:
116
+ val = await info.context.data_loaders.experiment_run_fields.load(
117
+ (self.id, models.ExperimentRun.latency_ms),
118
+ )
119
+ return val
42
120
 
43
121
  @strawberry.field
44
122
  async def annotations(
@@ -55,57 +133,49 @@ class ExperimentRun(Node):
55
133
  last=last,
56
134
  before=before if isinstance(before, CursorString) else None,
57
135
  )
58
- run_id = self.id_attr
59
- annotations = await info.context.data_loaders.experiment_run_annotations.load(run_id)
136
+ annotations = await info.context.data_loaders.experiment_run_annotations.load(self.id)
60
137
  return connection_from_list(
61
- [to_gql_experiment_run_annotation(annotation) for annotation in annotations], args
138
+ [
139
+ ExperimentRunAnnotation(id=annotation.id, db_record=annotation)
140
+ for annotation in annotations
141
+ ],
142
+ args,
62
143
  )
63
144
 
64
145
  @strawberry.field
65
- async def trace(self, info: Info) -> Optional[Trace]:
66
- if not self.trace_id:
146
+ async def trace(
147
+ self, info: Info[Context, None]
148
+ ) -> Optional[Annotated["Trace", strawberry.lazy(".Trace")]]:
149
+ if self.db_record:
150
+ trace_id = self.db_record.trace_id
151
+ else:
152
+ trace_id = await info.context.data_loaders.experiment_run_fields.load(
153
+ (self.id, models.ExperimentRun.trace_id),
154
+ )
155
+ if not trace_id:
67
156
  return None
68
- dataloader = info.context.data_loaders.trace_by_trace_ids
69
- if (trace := await dataloader.load(self.trace_id)) is None:
157
+ loader = info.context.data_loaders.trace_by_trace_ids
158
+ if (trace := await loader.load(trace_id)) is None:
70
159
  return None
71
- return Trace(trace_rowid=trace.id, db_trace=trace)
160
+ from .Trace import Trace
161
+
162
+ return Trace(id=trace.id, db_record=trace)
72
163
 
73
164
  @strawberry.field
74
165
  async def example(
75
- self, info: Info
166
+ self, info: Info[Context, None]
76
167
  ) -> Annotated[
77
- "DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
168
+ "DatasetExample", strawberry.lazy(".DatasetExample")
78
169
  ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
79
- from phoenix.server.api.types.DatasetExample import DatasetExample
170
+ from .DatasetExample import DatasetExample
80
171
 
81
- async with info.context.db() as session:
82
- assert (
83
- result := await session.execute(
84
- select(models.DatasetExample, models.Experiment.dataset_version_id)
85
- .select_from(models.ExperimentRun)
86
- .join(
87
- models.DatasetExample,
88
- models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
89
- )
90
- .join(
91
- models.Experiment,
92
- models.Experiment.id == models.ExperimentRun.experiment_id,
93
- )
94
- .where(models.ExperimentRun.id == self.id_attr)
95
- .options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
96
- )
97
- ) is not None
98
- example, version_id = result.first()
99
- return DatasetExample(
100
- id_attr=example.id,
101
- created_at=example.created_at,
102
- version_id=version_id,
103
- )
172
+ loader = info.context.data_loaders.dataset_examples_and_versions_by_experiment_run
173
+ (example, version_id) = await loader.load(self.id)
174
+ return DatasetExample(id=example.id, db_record=example, version_id=version_id)
104
175
 
105
176
  @strawberry.field
106
177
  async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
107
- run_id = self.id_attr
108
- summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(run_id)
178
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(self.id)
109
179
  return SpanCostSummary(
110
180
  prompt=CostBreakdown(
111
181
  tokens=summary.prompt.tokens,
@@ -125,8 +195,6 @@ class ExperimentRun(Node):
125
195
  async def cost_detail_summary_entries(
126
196
  self, info: Info[Context, None]
127
197
  ) -> list[SpanCostDetailSummaryEntry]:
128
- run_id = self.id_attr
129
-
130
198
  stmt = (
131
199
  select(
132
200
  models.SpanCostDetail.token_type,
@@ -139,7 +207,7 @@ class ExperimentRun(Node):
139
207
  .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
140
208
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
141
209
  .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
142
- .where(models.ExperimentRun.id == run_id)
210
+ .where(models.ExperimentRun.id == self.id)
143
211
  .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
144
212
  )
145
213
 
@@ -153,21 +221,3 @@ class ExperimentRun(Node):
153
221
  )
154
222
  async for token_type, is_prompt, cost, tokens in data
155
223
  ]
156
-
157
-
158
- def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
159
- """
160
- Converts an ORM experiment run to a GraphQL ExperimentRun.
161
- """
162
-
163
- from phoenix.server.api.types.Experiment import Experiment
164
-
165
- return ExperimentRun(
166
- id_attr=run.id,
167
- experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
168
- trace_id=run.trace.trace_id if run.trace else None,
169
- output=run.output.get("task_output"),
170
- start_time=run.start_time,
171
- end_time=run.end_time,
172
- error=run.error,
173
- )
@@ -1,56 +1,175 @@
1
1
  from datetime import datetime
2
+ from math import isfinite
2
3
  from typing import Optional
3
4
 
4
5
  import strawberry
5
6
  from strawberry import Info
6
- from strawberry.relay import Node, NodeID
7
+ from strawberry.relay import GlobalID, Node, NodeID
7
8
  from strawberry.scalars import JSON
8
9
 
9
10
  from phoenix.db import models
11
+ from phoenix.server.api.context import Context
10
12
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
13
  from phoenix.server.api.types.Trace import Trace
12
14
 
13
15
 
14
16
  @strawberry.type
15
17
  class ExperimentRunAnnotation(Node):
16
- id_attr: NodeID[int]
17
- name: str
18
- annotator_kind: ExperimentRunAnnotatorKind
19
- label: Optional[str]
20
- score: Optional[float]
21
- explanation: Optional[str]
22
- error: Optional[str]
23
- metadata: JSON
24
- start_time: datetime
25
- end_time: datetime
26
- trace_id: Optional[str]
27
-
28
- @strawberry.field
29
- async def trace(self, info: Info) -> Optional[Trace]:
30
- if not self.trace_id:
18
+ id: NodeID[int]
19
+ db_record: strawberry.Private[Optional[models.ExperimentRunAnnotation]] = None
20
+
21
+ def __post_init__(self) -> None:
22
+ if self.db_record and self.id != self.db_record.id:
23
+ raise ValueError("ExperimentRunAnnotation ID mismatch")
24
+
25
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
26
+ async def name(
27
+ self,
28
+ info: Info[Context, None],
29
+ ) -> str:
30
+ if self.db_record:
31
+ val = self.db_record.name
32
+ else:
33
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
34
+ (self.id, models.ExperimentRunAnnotation.name),
35
+ )
36
+ return val
37
+
38
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
39
+ async def annotator_kind(
40
+ self,
41
+ info: Info[Context, None],
42
+ ) -> ExperimentRunAnnotatorKind:
43
+ if self.db_record:
44
+ val = self.db_record.annotator_kind
45
+ else:
46
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
47
+ (self.id, models.ExperimentRunAnnotation.annotator_kind),
48
+ )
49
+ return ExperimentRunAnnotatorKind(val)
50
+
51
+ @strawberry.field(
52
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
53
+ ) # type: ignore
54
+ async def label(
55
+ self,
56
+ info: Info[Context, None],
57
+ ) -> Optional[str]:
58
+ if self.db_record:
59
+ val = self.db_record.label
60
+ else:
61
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
62
+ (self.id, models.ExperimentRunAnnotation.label),
63
+ )
64
+ return val
65
+
66
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
67
+ async def score(
68
+ self,
69
+ info: Info[Context, None],
70
+ ) -> Optional[float]:
71
+ if self.db_record:
72
+ val = self.db_record.score
73
+ else:
74
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
75
+ (self.id, models.ExperimentRunAnnotation.score),
76
+ )
77
+ return val if val is not None and isfinite(val) else None
78
+
79
+ @strawberry.field(
80
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
81
+ ) # type: ignore
82
+ async def explanation(
83
+ self,
84
+ info: Info[Context, None],
85
+ ) -> Optional[str]:
86
+ if self.db_record:
87
+ val = self.db_record.explanation
88
+ else:
89
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
90
+ (self.id, models.ExperimentRunAnnotation.explanation),
91
+ )
92
+ return val
93
+
94
+ @strawberry.field(description="Error message if the annotation failed to produce a result.") # type: ignore
95
+ async def error(
96
+ self,
97
+ info: Info[Context, None],
98
+ ) -> Optional[str]:
99
+ if self.db_record:
100
+ val = self.db_record.error
101
+ else:
102
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
103
+ (self.id, models.ExperimentRunAnnotation.error),
104
+ )
105
+ return val
106
+
107
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
108
+ async def metadata(
109
+ self,
110
+ info: Info[Context, None],
111
+ ) -> JSON:
112
+ if self.db_record:
113
+ val = self.db_record.metadata_
114
+ else:
115
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
116
+ (self.id, models.ExperimentRunAnnotation.metadata_),
117
+ )
118
+ return val
119
+
120
+ @strawberry.field(description="The date and time when the annotation was created.") # type: ignore
121
+ async def start_time(
122
+ self,
123
+ info: Info[Context, None],
124
+ ) -> datetime:
125
+ if self.db_record:
126
+ val = self.db_record.start_time
127
+ else:
128
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
129
+ (self.id, models.ExperimentRunAnnotation.start_time),
130
+ )
131
+ return val
132
+
133
+ @strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
134
+ async def end_time(
135
+ self,
136
+ info: Info[Context, None],
137
+ ) -> datetime:
138
+ if self.db_record:
139
+ val = self.db_record.end_time
140
+ else:
141
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
142
+ (self.id, models.ExperimentRunAnnotation.end_time),
143
+ )
144
+ return val
145
+
146
+ @strawberry.field(description="The identifier of the trace associated with the annotation.") # type: ignore
147
+ async def trace_id(
148
+ self,
149
+ info: Info[Context, None],
150
+ ) -> Optional[GlobalID]:
151
+ if self.db_record:
152
+ val = self.db_record.trace_id
153
+ else:
154
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
155
+ (self.id, models.ExperimentRunAnnotation.trace_id),
156
+ )
157
+ return None if val is None else GlobalID(type_name=Trace.__name__, node_id=val)
158
+
159
+ @strawberry.field(description="The trace associated with the annotation.") # type: ignore
160
+ async def trace(
161
+ self,
162
+ info: Info[Context, None],
163
+ ) -> Optional[Trace]:
164
+ if self.db_record:
165
+ trace_id = self.db_record.trace_id
166
+ else:
167
+ trace_id = await info.context.data_loaders.experiment_run_annotation_fields.load(
168
+ (self.id, models.ExperimentRunAnnotation.trace_id),
169
+ )
170
+ if not trace_id:
31
171
  return None
32
172
  dataloader = info.context.data_loaders.trace_by_trace_ids
33
- if (trace := await dataloader.load(self.trace_id)) is None:
173
+ if (trace := await dataloader.load(trace_id)) is None:
34
174
  return None
35
- return Trace(trace_rowid=trace.id, db_trace=trace)
36
-
37
-
38
- def to_gql_experiment_run_annotation(
39
- annotation: models.ExperimentRunAnnotation,
40
- ) -> ExperimentRunAnnotation:
41
- """
42
- Converts an ORM experiment run annotation to a GraphQL ExperimentRunAnnotation.
43
- """
44
- return ExperimentRunAnnotation(
45
- id_attr=annotation.id,
46
- name=annotation.name,
47
- annotator_kind=ExperimentRunAnnotatorKind(annotation.annotator_kind),
48
- label=annotation.label,
49
- score=annotation.score,
50
- explanation=annotation.explanation,
51
- error=annotation.error,
52
- metadata=annotation.metadata_,
53
- start_time=annotation.start_time,
54
- end_time=annotation.end_time,
55
- trace_id=annotation.trace_id,
56
- )
175
+ return Trace(id=trace.id, db_record=trace)
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import strawberry
6
6
  from openinference.semconv.trace import OpenInferenceLLMProviderValues
7
- from sqlalchemy import inspect
8
7
  from strawberry.relay import Node, NodeID
9
8
  from strawberry.relay.types import GlobalID
10
9
  from strawberry.types import Info
@@ -37,20 +36,98 @@ CachedCostSummaryKey: TypeAlias = tuple[Optional[ProjectId], TimeRangeKey]
37
36
 
38
37
  @strawberry.type
39
38
  class GenerativeModel(Node, ModelInterface):
40
- id_attr: NodeID[int]
41
- name: str
42
- provider: Optional[str]
43
- name_pattern: str
44
- kind: GenerativeModelKind
45
- created_at: datetime
46
- updated_at: datetime
47
- provider_key: Optional[GenerativeProviderKey]
48
- costs: strawberry.Private[Optional[list[models.TokenPrice]]] = None
49
- start_time: Optional[datetime] = None
39
+ id: NodeID[int]
40
+ db_record: strawberry.Private[Optional[models.GenerativeModel]] = None
50
41
  cached_cost_summary: strawberry.Private[
51
42
  Optional[dict[CachedCostSummaryKey, SpanCostSummary]]
52
43
  ] = None
53
44
 
45
+ def __post_init__(self) -> None:
46
+ if self.db_record and self.id != self.db_record.id:
47
+ raise ValueError("GenerativeModel ID mismatch")
48
+
49
+ @strawberry.field
50
+ async def name(self, info: Info[Context, None]) -> str:
51
+ if self.db_record:
52
+ val = self.db_record.name
53
+ else:
54
+ val = await info.context.data_loaders.generative_model_fields.load(
55
+ (self.id, models.GenerativeModel.name),
56
+ )
57
+ return val
58
+
59
+ @strawberry.field
60
+ async def provider(self, info: Info[Context, None]) -> Optional[str]:
61
+ if self.db_record:
62
+ provider = self.db_record.provider
63
+ else:
64
+ provider = await info.context.data_loaders.generative_model_fields.load(
65
+ (self.id, models.GenerativeModel.provider),
66
+ )
67
+ return provider or None
68
+
69
+ @strawberry.field
70
+ async def name_pattern(self, info: Info[Context, None]) -> str:
71
+ if self.db_record:
72
+ pattern = self.db_record.name_pattern.pattern
73
+ else:
74
+ name_pattern_obj = await info.context.data_loaders.generative_model_fields.load(
75
+ (self.id, models.GenerativeModel.name_pattern),
76
+ )
77
+ pattern = name_pattern_obj.pattern
78
+ assert isinstance(pattern, str)
79
+ return pattern
80
+
81
+ @strawberry.field
82
+ async def kind(self, info: Info[Context, None]) -> GenerativeModelKind:
83
+ if self.db_record:
84
+ is_built_in = self.db_record.is_built_in
85
+ else:
86
+ is_built_in = await info.context.data_loaders.generative_model_fields.load(
87
+ (self.id, models.GenerativeModel.is_built_in),
88
+ )
89
+ return GenerativeModelKind.BUILT_IN if is_built_in else GenerativeModelKind.CUSTOM
90
+
91
+ @strawberry.field
92
+ async def created_at(self, info: Info[Context, None]) -> datetime:
93
+ if self.db_record:
94
+ val = self.db_record.created_at
95
+ else:
96
+ val = await info.context.data_loaders.generative_model_fields.load(
97
+ (self.id, models.GenerativeModel.created_at),
98
+ )
99
+ return val
100
+
101
+ @strawberry.field
102
+ async def updated_at(self, info: Info[Context, None]) -> datetime:
103
+ if self.db_record:
104
+ val = self.db_record.updated_at
105
+ else:
106
+ val = await info.context.data_loaders.generative_model_fields.load(
107
+ (self.id, models.GenerativeModel.updated_at),
108
+ )
109
+ return val
110
+
111
+ @strawberry.field
112
+ async def provider_key(self, info: Info[Context, None]) -> Optional[GenerativeProviderKey]:
113
+ if self.db_record:
114
+ provider = self.db_record.provider
115
+ else:
116
+ provider = await info.context.data_loaders.generative_model_fields.load(
117
+ (self.id, models.GenerativeModel.provider),
118
+ )
119
+ return _semconv_provider_to_gql_generative_provider_key(provider) if provider else None
120
+
121
+ @strawberry.field
122
+ async def start_time(self, info: Info[Context, None]) -> Optional[datetime]:
123
+ if self.db_record:
124
+ val = self.db_record.start_time
125
+ else:
126
+ val = await info.context.data_loaders.generative_model_fields.load(
127
+ (self.id, models.GenerativeModel.start_time),
128
+ )
129
+ return val
130
+
54
131
  def add_cached_cost_summary(
55
132
  self, project_id: Optional[int], time_range: TimeRange, cost_summary: SpanCostSummary
56
133
  ) -> None:
@@ -61,11 +138,10 @@ class GenerativeModel(Node, ModelInterface):
61
138
  self.cached_cost_summary[cache_key] = cost_summary
62
139
 
63
140
  @strawberry.field
64
- async def token_prices(self) -> list[TokenPrice]:
65
- if self.costs is None:
66
- raise NotImplementedError
67
- token_prices: list[TokenPrice] = list()
68
- for cost in self.costs:
141
+ async def token_prices(self, info: Info[Context, None]) -> list[TokenPrice]:
142
+ costs = await info.context.data_loaders.token_prices_by_model.load(self.id)
143
+ token_prices: list[TokenPrice] = []
144
+ for cost in costs:
69
145
  token_prices.append(
70
146
  TokenPrice(
71
147
  token_type=cost.token_type,
@@ -100,7 +176,7 @@ class GenerativeModel(Node, ModelInterface):
100
176
  )
101
177
 
102
178
  loader = info.context.data_loaders.span_cost_summary_by_generative_model
103
- summary = await loader.load(self.id_attr)
179
+ summary = await loader.load(self.id)
104
180
  return SpanCostSummary(
105
181
  prompt=CostBreakdown(
106
182
  tokens=summary.prompt.tokens,
@@ -122,7 +198,7 @@ class GenerativeModel(Node, ModelInterface):
122
198
  info: Info[Context, None],
123
199
  ) -> list[SpanCostDetailSummaryEntry]:
124
200
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_generative_model
125
- summary = await loader.load(self.id_attr)
201
+ summary = await loader.load(self.id)
126
202
  return [
127
203
  SpanCostDetailSummaryEntry(
128
204
  token_type=entry.token_type,
@@ -137,30 +213,7 @@ class GenerativeModel(Node, ModelInterface):
137
213
 
138
214
  @strawberry.field
139
215
  async def last_used_at(self, info: Info[Context, None]) -> Optional[datetime]:
140
- model_id = self.id_attr
141
- return await info.context.data_loaders.last_used_times_by_generative_model_id.load(model_id)
142
-
143
-
144
- def to_gql_generative_model(
145
- model: models.GenerativeModel,
146
- ) -> GenerativeModel:
147
- costs_are_loaded = isinstance(inspect(model).attrs.token_prices.loaded_value, list)
148
- name_pattern = model.name_pattern.pattern
149
- assert isinstance(name_pattern, str)
150
- return GenerativeModel(
151
- id_attr=model.id,
152
- name=model.name,
153
- provider=model.provider or None,
154
- name_pattern=name_pattern,
155
- kind=GenerativeModelKind.BUILT_IN if model.is_built_in else GenerativeModelKind.CUSTOM,
156
- created_at=model.created_at,
157
- updated_at=model.updated_at,
158
- start_time=model.start_time,
159
- provider_key=_semconv_provider_to_gql_generative_provider_key(model.provider)
160
- if model.provider
161
- else None,
162
- costs=model.token_prices if costs_are_loaded else None,
163
- )
216
+ return await info.context.data_loaders.last_used_times_by_generative_model_id.load(self.id)
164
217
 
165
218
 
166
219
  def _semconv_provider_to_gql_generative_provider_key(
@@ -13,7 +13,7 @@ class GenerativeProviderKey(Enum):
13
13
  OPENAI = "OpenAI"
14
14
  ANTHROPIC = "Anthropic"
15
15
  AZURE_OPENAI = "Azure OpenAI"
16
- GOOGLE = "Google AI Studio"
16
+ GOOGLE = "Google Gemini"
17
17
  DEEPSEEK = "DeepSeek"
18
18
  XAI = "xAI"
19
19
  OLLAMA = "Ollama"
@@ -7,5 +7,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
7
7
 
8
8
  @strawberry.interface
9
9
  class ModelInterface:
10
- name: str
11
- provider_key: Optional[GenerativeProviderKey]
10
+ @strawberry.field
11
+ async def name(self) -> str:
12
+ raise NotImplementedError
13
+
14
+ @strawberry.field
15
+ async def provider_key(self) -> Optional[GenerativeProviderKey]:
16
+ raise NotImplementedError