arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -13,6 +13,7 @@ from phoenix.db import models
13
13
  from phoenix.server.api.dataloaders.cache import TwoTierCache
14
14
  from phoenix.server.api.input_types.TimeRange import TimeRange
15
15
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
16
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
16
17
  from phoenix.server.types import DbSessionFactory
17
18
  from phoenix.trace.dsl import SpanFilter
18
19
 
@@ -20,27 +21,41 @@ Kind: TypeAlias = Literal["span", "trace"]
20
21
  ProjectRowId: TypeAlias = int
21
22
  TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
22
23
  FilterCondition: TypeAlias = Optional[str]
24
+ SessionFilterCondition: TypeAlias = Optional[str]
23
25
  AnnotationName: TypeAlias = str
24
26
 
25
- Segment: TypeAlias = tuple[Kind, ProjectRowId, TimeInterval, FilterCondition]
27
+ Segment: TypeAlias = tuple[
28
+ Kind,
29
+ ProjectRowId,
30
+ TimeInterval,
31
+ FilterCondition,
32
+ SessionFilterCondition,
33
+ ]
26
34
  Param: TypeAlias = AnnotationName
27
35
 
28
- Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, AnnotationName]
36
+ Key: TypeAlias = tuple[
37
+ Kind,
38
+ ProjectRowId,
39
+ Optional[TimeRange],
40
+ FilterCondition,
41
+ SessionFilterCondition,
42
+ AnnotationName,
43
+ ]
29
44
  Result: TypeAlias = Optional[AnnotationSummary]
30
45
  ResultPosition: TypeAlias = int
31
46
  DEFAULT_VALUE: Result = None
32
47
 
33
48
 
34
49
  def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
35
- kind, project_rowid, time_range, filter_condition, eval_name = key
50
+ kind, project_rowid, time_range, filter_condition, session_filter_condition, eval_name = key
36
51
  interval = (
37
52
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
38
53
  )
39
- return (kind, project_rowid, interval, filter_condition), eval_name
54
+ return (kind, project_rowid, interval, filter_condition, session_filter_condition), eval_name
40
55
 
41
56
 
42
57
  _Section: TypeAlias = tuple[ProjectRowId, AnnotationName, Kind]
43
- _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
58
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
44
59
 
45
60
 
46
61
  class AnnotationSummaryCache(
@@ -61,8 +76,21 @@ class AnnotationSummaryCache(
61
76
  del self._cache[section]
62
77
 
63
78
  def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
64
- (kind, project_rowid, interval, filter_condition), annotation_name = _cache_key_fn(key)
65
- return (project_rowid, annotation_name, kind), (interval, filter_condition)
79
+ (
80
+ (
81
+ kind,
82
+ project_rowid,
83
+ interval,
84
+ filter_condition,
85
+ session_filter_condition,
86
+ ),
87
+ annotation_name,
88
+ ) = _cache_key_fn(key)
89
+ return (project_rowid, annotation_name, kind), (
90
+ interval,
91
+ filter_condition,
92
+ session_filter_condition,
93
+ )
66
94
 
67
95
 
68
96
  class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
@@ -102,7 +130,9 @@ def _get_stmt(
102
130
  segment: Segment,
103
131
  *annotation_names: Param,
104
132
  ) -> Select[Any]:
105
- kind, project_rowid, (start_time, end_time), filter_condition = segment
133
+ kind, project_rowid, (start_time, end_time), filter_condition, session_filter_condition = (
134
+ segment
135
+ )
106
136
 
107
137
  annotation_model: Union[Type[models.SpanAnnotation], Type[models.TraceAnnotation]]
108
138
  entity_model: Union[Type[models.Span], Type[models.Trace]]
@@ -144,6 +174,19 @@ def _get_stmt(
144
174
  entity_count_query = entity_count_query.where(
145
175
  cast(Type[models.Trace], entity_model).project_rowid == project_rowid
146
176
  )
177
+ else:
178
+ assert_never(kind)
179
+
180
+ if session_filter_condition:
181
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
182
+ session_filter_condition=session_filter_condition,
183
+ project_rowids=[project_rowid],
184
+ start_time=start_time,
185
+ end_time=end_time,
186
+ )
187
+ entity_count_query = entity_count_query.where(
188
+ models.Trace.project_session_rowid.in_(filtered_session_rowids)
189
+ )
147
190
 
148
191
  entity_count_query = entity_count_query.where(
149
192
  or_(score_column.is_not(None), label_column.is_not(None))
@@ -186,6 +229,15 @@ def _get_stmt(
186
229
  else:
187
230
  assert_never(kind)
188
231
 
232
+ if session_filter_condition:
233
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
234
+ session_filter_condition=session_filter_condition,
235
+ project_rowids=[project_rowid],
236
+ start_time=start_time,
237
+ end_time=end_time,
238
+ )
239
+ base_stmt = base_stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
240
+
189
241
  base_stmt = base_stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
190
242
  base_stmt = base_stmt.where(name_column.in_(annotation_names))
191
243
 
@@ -0,0 +1,50 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import func, select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentID: TypeAlias = int
11
+ DatasetExampleID: TypeAlias = int
12
+ RunLatency: TypeAlias = float
13
+ Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
14
+ Result: TypeAlias = Optional[RunLatency]
15
+
16
+
17
+ class AverageExperimentRepeatedRunGroupLatencyDataLoader(DataLoader[Key, Result]):
18
+ def __init__(
19
+ self,
20
+ db: DbSessionFactory,
21
+ ) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ average_latency_query = (
27
+ select(
28
+ models.ExperimentRun.experiment_id.label("experiment_id"),
29
+ models.ExperimentRun.dataset_example_id.label("example_id"),
30
+ func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
31
+ )
32
+ .select_from(models.ExperimentRun)
33
+ .where(
34
+ tuple_(
35
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
36
+ ).in_(set(keys))
37
+ )
38
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
39
+ )
40
+ async with self._db() as session:
41
+ average_run_latencies_ms = {
42
+ (experiment_id, example_id): average_run_latency_ms
43
+ async for experiment_id, example_id, average_run_latency_ms in await session.stream(
44
+ average_latency_query
45
+ )
46
+ }
47
+ return [
48
+ average_run_latencies_ms.get((experiment_id, example_id))
49
+ for experiment_id, example_id in keys
50
+ ]
@@ -23,32 +23,25 @@ class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
23
23
 
24
24
  async def _load_fn(self, keys: list[Key]) -> list[Result]:
25
25
  experiment_ids = keys
26
- resolved_experiment_ids = (
27
- select(models.Experiment.id)
28
- .where(models.Experiment.id.in_(set(experiment_ids)))
29
- .subquery()
30
- )
31
- query = (
26
+ average_repetition_latency_ms = (
32
27
  select(
33
- resolved_experiment_ids.c.id,
34
- func.avg(
35
- func.extract("epoch", models.ExperimentRun.end_time)
36
- - func.extract("epoch", models.ExperimentRun.start_time)
37
- ),
28
+ models.ExperimentRun.experiment_id.label("experiment_id"),
29
+ func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
38
30
  )
39
- .outerjoin_from(
40
- from_=resolved_experiment_ids,
41
- target=models.ExperimentRun,
42
- onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
43
- )
44
- .group_by(resolved_experiment_ids.c.id)
31
+ .select_from(models.ExperimentRun)
32
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
33
+ .group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
34
+ .subquery()
45
35
  )
36
+ query = select(
37
+ average_repetition_latency_ms.c.experiment_id,
38
+ func.avg(average_repetition_latency_ms.c.average_repetition_latency_ms).label(
39
+ "average_run_latency_ms"
40
+ ),
41
+ ).group_by(average_repetition_latency_ms.c.experiment_id)
46
42
  async with self._db() as session:
47
- avg_latencies = {
48
- experiment_id: avg_latency
49
- async for experiment_id, avg_latency in await session.stream(query)
43
+ average_run_latencies_ms = {
44
+ experiment_id: average_run_latency_ms
45
+ async for experiment_id, average_run_latency_ms in await session.stream(query)
50
46
  }
51
- return [
52
- avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
53
- for experiment_id in keys
54
- ]
47
+ return [average_run_latencies_ms.get(experiment_id) for experiment_id in keys]
@@ -7,7 +7,7 @@ single-tier system we would need to check all the keys to see if they are in the
7
7
  subset that we want to invalidate.
8
8
  """
9
9
 
10
- from abc import ABC, abstractmethod
10
+ from abc import abstractmethod
11
11
  from asyncio import Future
12
12
  from collections.abc import Callable
13
13
  from typing import Any, Generic, Optional, TypeVar
@@ -25,7 +25,6 @@ _SubKey = TypeVar("_SubKey")
25
25
  class TwoTierCache(
26
26
  AbstractCache[_Key, _Result],
27
27
  Generic[_Key, _Result, _Section, _SubKey],
28
- ABC,
29
28
  ):
30
29
  def __init__(
31
30
  self,
@@ -0,0 +1,52 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ DatasetID: TypeAlias = int
9
+ Key: TypeAlias = DatasetID
10
+ Result: TypeAlias = list[models.DatasetSplit]
11
+
12
+
13
+ class DatasetDatasetSplitsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(
16
+ load_fn=self._load_fn,
17
+ )
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ dataset_ids = keys
22
+ async with self._db() as session:
23
+ splits: dict[DatasetID, dict[int, models.DatasetSplit]] = {
24
+ dataset_id: {} for dataset_id in dataset_ids
25
+ }
26
+
27
+ async for dataset_id, split in await session.stream(
28
+ select(models.DatasetExample.dataset_id, models.DatasetSplit)
29
+ .select_from(models.DatasetSplit)
30
+ .join(
31
+ models.DatasetSplitDatasetExample,
32
+ onclause=(
33
+ models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
34
+ ),
35
+ )
36
+ .join(
37
+ models.DatasetExample,
38
+ onclause=(
39
+ models.DatasetSplitDatasetExample.dataset_example_id
40
+ == models.DatasetExample.id
41
+ ),
42
+ )
43
+ .where(models.DatasetExample.dataset_id.in_(dataset_ids))
44
+ ):
45
+ # Use dict to deduplicate splits by split.id
46
+ if dataset_id in splits:
47
+ splits[dataset_id][split.id] = split
48
+
49
+ return [
50
+ sorted(splits.get(dataset_id, {}).values(), key=lambda x: x.name)
51
+ for dataset_id in keys
52
+ ]
@@ -91,7 +91,6 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
91
91
  onclause=revision_ids.c.version_id == models.DatasetVersion.id,
92
92
  isouter=True, # keep rows where the version id is null
93
93
  )
94
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
95
94
  )
96
95
  async with self._db() as session:
97
96
  results = {
@@ -0,0 +1,40 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExampleID: TypeAlias = int
9
+ Key: TypeAlias = ExampleID
10
+ Result: TypeAlias = list[models.DatasetSplit]
11
+
12
+
13
+ class DatasetExampleSplitsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(
16
+ load_fn=self._load_fn,
17
+ )
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ example_ids = keys
22
+ async with self._db() as session:
23
+ splits: dict[ExampleID, list[models.DatasetSplit]] = {}
24
+
25
+ async for example_id, split in await session.stream(
26
+ select(models.DatasetSplitDatasetExample.dataset_example_id, models.DatasetSplit)
27
+ .select_from(models.DatasetSplit)
28
+ .join(
29
+ models.DatasetSplitDatasetExample,
30
+ onclause=(
31
+ models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
32
+ ),
33
+ )
34
+ .where(models.DatasetSplitDatasetExample.dataset_example_id.in_(example_ids))
35
+ ):
36
+ if example_id not in splits:
37
+ splits[example_id] = []
38
+ splits[example_id].append(split)
39
+
40
+ return [sorted(splits.get(example_id, []), key=lambda x: x.name) for example_id in keys]
@@ -0,0 +1,47 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExperimentRunID: TypeAlias = int
9
+ DatasetExampleID: TypeAlias = int
10
+ DatasetVersionID: TypeAlias = int
11
+ Key: TypeAlias = ExperimentRunID
12
+ Result: TypeAlias = tuple[models.DatasetExample, DatasetVersionID]
13
+
14
+
15
+ class DatasetExamplesAndVersionsByExperimentRunDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ experiment_run_ids = set(keys)
22
+ examples_and_versions_query = (
23
+ select(
24
+ models.ExperimentRun.id.label("experiment_run_id"),
25
+ models.DatasetExample,
26
+ models.Experiment.dataset_version_id.label("dataset_version_id"),
27
+ )
28
+ .select_from(models.ExperimentRun)
29
+ .join(
30
+ models.DatasetExample,
31
+ models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
32
+ )
33
+ .join(
34
+ models.Experiment,
35
+ models.Experiment.id == models.ExperimentRun.experiment_id,
36
+ )
37
+ .where(models.ExperimentRun.id.in_(experiment_run_ids))
38
+ )
39
+ async with self._db() as session:
40
+ examples_and_versions = {
41
+ experiment_run_id: (example, version_id)
42
+ for experiment_run_id, example, version_id in (
43
+ await session.execute(examples_and_versions_query)
44
+ ).all()
45
+ }
46
+
47
+ return [examples_and_versions[key] for key in keys]
@@ -0,0 +1,36 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ DatasetID: TypeAlias = int
9
+ Key: TypeAlias = DatasetID
10
+ Result: TypeAlias = list[models.DatasetLabel]
11
+
12
+
13
+ class DatasetLabelsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(load_fn=self._load_fn)
16
+ self._db = db
17
+
18
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
19
+ dataset_ids = keys
20
+ async with self._db() as session:
21
+ labels: dict[Key, Result] = {}
22
+ for dataset_id, label in await session.execute(
23
+ select(models.DatasetsDatasetLabel.dataset_id, models.DatasetLabel)
24
+ .select_from(models.DatasetLabel)
25
+ .join(
26
+ models.DatasetsDatasetLabel,
27
+ models.DatasetLabel.id == models.DatasetsDatasetLabel.dataset_label_id,
28
+ )
29
+ .where(models.DatasetsDatasetLabel.dataset_id.in_(dataset_ids))
30
+ ):
31
+ if dataset_id not in labels:
32
+ labels[dataset_id] = []
33
+ labels[dataset_id].append(label)
34
+ return [
35
+ sorted(labels.get(dataset_id, []), key=lambda label: label.name) for dataset_id in keys
36
+ ]
@@ -10,7 +10,7 @@ from strawberry.dataloader import AbstractCache, DataLoader
10
10
  from typing_extensions import TypeAlias
11
11
 
12
12
  from phoenix.db import models
13
- from phoenix.db.helpers import SupportedSQLDialect, num_docs_col
13
+ from phoenix.db.helpers import SupportedSQLDialect
14
14
  from phoenix.metrics.retrieval_metrics import RetrievalMetrics
15
15
  from phoenix.server.api.dataloaders.cache import TwoTierCache
16
16
  from phoenix.server.api.input_types.TimeRange import TimeRange
@@ -122,7 +122,7 @@ def _get_stmt(
122
122
  select(
123
123
  mda.name,
124
124
  models.Span.id,
125
- num_docs_col(dialect),
125
+ models.Span.num_documents.label("num_docs"),
126
126
  mda.score,
127
127
  mda.document_position,
128
128
  )
@@ -5,11 +5,10 @@ from strawberry.dataloader import DataLoader
5
5
  from typing_extensions import TypeAlias
6
6
 
7
7
  from phoenix.db import models
8
- from phoenix.server.api.types.Evaluation import DocumentEvaluation
9
8
  from phoenix.server.types import DbSessionFactory
10
9
 
11
10
  Key: TypeAlias = int
12
- Result: TypeAlias = list[DocumentEvaluation]
11
+ Result: TypeAlias = list[models.DocumentAnnotation]
13
12
 
14
13
 
15
14
  class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
@@ -18,14 +17,12 @@ class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
18
17
  self._db = db
19
18
 
20
19
  async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
- document_evaluations_by_id: defaultdict[Key, Result] = defaultdict(list)
20
+ document_annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
21
  mda = models.DocumentAnnotation
23
22
  async with self._db() as session:
24
- data = await session.stream_scalars(
25
- select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM")
26
- )
23
+ data = await session.stream_scalars(select(mda).where(mda.span_rowid.in_(keys)))
27
24
  async for document_evaluation in data:
28
- document_evaluations_by_id[document_evaluation.span_rowid].append(
29
- DocumentEvaluation.from_sql_document_annotation(document_evaluation)
25
+ document_annotations_by_id[document_evaluation.span_rowid].append(
26
+ document_evaluation
30
27
  )
31
- return [document_evaluations_by_id[key] for key in keys]
28
+ return [document_annotations_by_id[key] for key in keys]
@@ -2,7 +2,7 @@ from collections import defaultdict
2
2
  from dataclasses import dataclass
3
3
  from typing import Optional
4
4
 
5
- from sqlalchemy import func, select
5
+ from sqlalchemy import and_, func, select
6
6
  from strawberry.dataloader import AbstractCache, DataLoader
7
7
  from typing_extensions import TypeAlias
8
8
 
@@ -37,43 +37,97 @@ class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
37
37
  async def _load_fn(self, keys: list[Key]) -> list[Result]:
38
38
  experiment_ids = keys
39
39
  summaries: defaultdict[ExperimentID, Result] = defaultdict(list)
40
+ repetition_mean_scores_by_example_subquery = (
41
+ select(
42
+ models.ExperimentRun.experiment_id.label("experiment_id"),
43
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
44
+ func.avg(models.ExperimentRunAnnotation.score).label("mean_repetition_score"),
45
+ )
46
+ .select_from(models.ExperimentRunAnnotation)
47
+ .join(
48
+ models.ExperimentRun,
49
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
50
+ )
51
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
52
+ .group_by(
53
+ models.ExperimentRun.experiment_id,
54
+ models.ExperimentRun.dataset_example_id,
55
+ models.ExperimentRunAnnotation.name,
56
+ )
57
+ .subquery()
58
+ .alias("repetition_mean_scores_by_example")
59
+ )
60
+ repetition_mean_scores_subquery = (
61
+ select(
62
+ repetition_mean_scores_by_example_subquery.c.experiment_id.label("experiment_id"),
63
+ repetition_mean_scores_by_example_subquery.c.annotation_name.label(
64
+ "annotation_name"
65
+ ),
66
+ func.avg(repetition_mean_scores_by_example_subquery.c.mean_repetition_score).label(
67
+ "mean_score"
68
+ ),
69
+ )
70
+ .select_from(repetition_mean_scores_by_example_subquery)
71
+ .group_by(
72
+ repetition_mean_scores_by_example_subquery.c.experiment_id,
73
+ repetition_mean_scores_by_example_subquery.c.annotation_name,
74
+ )
75
+ .subquery()
76
+ .alias("repetition_mean_scores")
77
+ )
78
+ repetitions_subquery = (
79
+ select(
80
+ models.ExperimentRun.experiment_id.label("experiment_id"),
81
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
82
+ func.min(models.ExperimentRunAnnotation.score).label("min_score"),
83
+ func.max(models.ExperimentRunAnnotation.score).label("max_score"),
84
+ func.count().label("count"),
85
+ func.count(models.ExperimentRunAnnotation.error).label("error_count"),
86
+ )
87
+ .select_from(models.ExperimentRunAnnotation)
88
+ .join(
89
+ models.ExperimentRun,
90
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
91
+ )
92
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
93
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
94
+ .subquery()
95
+ )
96
+ run_scores_query = (
97
+ select(
98
+ repetition_mean_scores_subquery.c.experiment_id.label("experiment_id"),
99
+ repetition_mean_scores_subquery.c.annotation_name.label("annotation_name"),
100
+ repetition_mean_scores_subquery.c.mean_score.label("mean_score"),
101
+ repetitions_subquery.c.min_score.label("min_score"),
102
+ repetitions_subquery.c.max_score.label("max_score"),
103
+ repetitions_subquery.c.count.label("count_"),
104
+ repetitions_subquery.c.error_count.label("error_count"),
105
+ )
106
+ .select_from(repetition_mean_scores_subquery)
107
+ .join(
108
+ repetitions_subquery,
109
+ and_(
110
+ repetitions_subquery.c.experiment_id
111
+ == repetition_mean_scores_subquery.c.experiment_id,
112
+ repetitions_subquery.c.annotation_name
113
+ == repetition_mean_scores_subquery.c.annotation_name,
114
+ ),
115
+ )
116
+ .order_by(repetition_mean_scores_subquery.c.annotation_name)
117
+ )
40
118
  async with self._db() as session:
41
- async for (
42
- experiment_id,
43
- annotation_name,
44
- min_score,
45
- max_score,
46
- mean_score,
47
- count,
48
- error_count,
49
- ) in await session.stream(
50
- select(
51
- models.ExperimentRun.experiment_id,
52
- models.ExperimentRunAnnotation.name,
53
- func.min(models.ExperimentRunAnnotation.score),
54
- func.max(models.ExperimentRunAnnotation.score),
55
- func.avg(models.ExperimentRunAnnotation.score),
56
- func.count(),
57
- func.count(models.ExperimentRunAnnotation.error),
58
- )
59
- .join(
60
- models.ExperimentRun,
61
- models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
62
- )
63
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
64
- .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
65
- ):
66
- summaries[experiment_id].append(
119
+ async for scores_tuple in await session.stream(run_scores_query):
120
+ summaries[scores_tuple.experiment_id].append(
67
121
  ExperimentAnnotationSummary(
68
- annotation_name=annotation_name,
69
- min_score=min_score,
70
- max_score=max_score,
71
- mean_score=mean_score,
72
- count=count,
73
- error_count=error_count,
122
+ annotation_name=scores_tuple.annotation_name,
123
+ min_score=scores_tuple.min_score,
124
+ max_score=scores_tuple.max_score,
125
+ mean_score=scores_tuple.mean_score,
126
+ count=scores_tuple.count_,
127
+ error_count=scores_tuple.error_count,
74
128
  )
75
129
  )
76
130
  return [
77
131
  sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
78
- for experiment_id in keys
132
+ for experiment_id in experiment_ids
79
133
  ]
@@ -0,0 +1,43 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExperimentID: TypeAlias = int
9
+ Key: TypeAlias = ExperimentID
10
+ Result: TypeAlias = list[models.DatasetSplit]
11
+
12
+
13
+ class ExperimentDatasetSplitsDataLoader(DataLoader[Key, Result]):
14
+ def __init__(self, db: DbSessionFactory) -> None:
15
+ super().__init__(
16
+ load_fn=self._load_fn,
17
+ )
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ experiment_ids = keys
22
+ async with self._db() as session:
23
+ splits: dict[ExperimentID, list[models.DatasetSplit]] = {}
24
+
25
+ async for experiment_id, split in await session.stream(
26
+ select(models.ExperimentDatasetSplit.experiment_id, models.DatasetSplit)
27
+ .select_from(models.DatasetSplit)
28
+ .join(
29
+ models.ExperimentDatasetSplit,
30
+ onclause=(
31
+ models.DatasetSplit.id == models.ExperimentDatasetSplit.dataset_split_id
32
+ ),
33
+ )
34
+ .where(models.ExperimentDatasetSplit.experiment_id.in_(experiment_ids))
35
+ ):
36
+ if experiment_id not in splits:
37
+ splits[experiment_id] = []
38
+ splits[experiment_id].append(split)
39
+
40
+ return [
41
+ sorted(splits.get(experiment_id, []), key=lambda x: x.name)
42
+ for experiment_id in keys
43
+ ]