arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
- from sqlalchemy import case, func, select
3
+ from sqlalchemy import func, select
4
4
  from strawberry.dataloader import DataLoader
5
5
  from typing_extensions import TypeAlias
6
6
 
@@ -23,36 +23,29 @@ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
23
23
 
24
24
  async def _load_fn(self, keys: list[Key]) -> list[Result]:
25
25
  experiment_ids = keys
26
- resolved_experiment_ids = (
27
- select(models.Experiment.id)
28
- .where(models.Experiment.id.in_(set(experiment_ids)))
29
- .subquery()
30
- )
31
- query = (
26
+ average_repetition_error_rates_subquery = (
32
27
  select(
33
- resolved_experiment_ids.c.id,
34
- case(
35
- (
36
- func.count(models.ExperimentRun.id) != 0,
37
- func.count(models.ExperimentRun.error)
38
- / func.count(models.ExperimentRun.id),
39
- ),
40
- else_=None,
41
- ),
28
+ models.ExperimentRun.experiment_id.label("experiment_id"),
29
+ (
30
+ func.count(models.ExperimentRun.error) / func.count(models.ExperimentRun.id)
31
+ ).label("average_repetition_error_rate"),
42
32
  )
43
- .outerjoin_from(
44
- from_=resolved_experiment_ids,
45
- target=models.ExperimentRun,
46
- onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
47
- )
48
- .group_by(resolved_experiment_ids.c.id)
33
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
34
+ .group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
35
+ .subquery()
36
+ .alias("average_repetition_error_rates")
49
37
  )
38
+ average_run_error_rates_query = select(
39
+ average_repetition_error_rates_subquery.c.experiment_id,
40
+ func.avg(average_repetition_error_rates_subquery.c.average_repetition_error_rate).label(
41
+ "average_run_error_rates"
42
+ ),
43
+ ).group_by(average_repetition_error_rates_subquery.c.experiment_id)
50
44
  async with self._db() as session:
51
- error_rates = {
45
+ average_run_error_rates = {
52
46
  experiment_id: error_rate
53
- async for experiment_id, error_rate in await session.stream(query)
47
+ async for experiment_id, error_rate in await session.stream(
48
+ average_run_error_rates_query
49
+ )
54
50
  }
55
- return [
56
- error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
57
- for experiment_id in keys
58
- ]
51
+ return [average_run_error_rates.get(experiment_id) for experiment_id in experiment_ids]
@@ -0,0 +1,77 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from sqlalchemy import func, select, tuple_
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentID: TypeAlias = int
12
+ DatasetExampleID: TypeAlias = int
13
+ AnnotationName: TypeAlias = str
14
+ MeanAnnotationScore: TypeAlias = float
15
+
16
+
17
+ @dataclass
18
+ class AnnotationSummary:
19
+ annotation_name: AnnotationName
20
+ mean_score: Optional[MeanAnnotationScore]
21
+
22
+
23
+ Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
24
+ Result: TypeAlias = list[AnnotationSummary]
25
+
26
+
27
+ class ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(DataLoader[Key, Result]):
28
+ def __init__(
29
+ self,
30
+ db: DbSessionFactory,
31
+ ) -> None:
32
+ super().__init__(load_fn=self._load_fn)
33
+ self._db = db
34
+
35
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
36
+ annotation_summaries_query = (
37
+ select(
38
+ models.ExperimentRun.experiment_id.label("experiment_id"),
39
+ models.ExperimentRun.dataset_example_id.label("dataset_example_id"),
40
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
41
+ func.avg(models.ExperimentRunAnnotation.score).label("mean_score"),
42
+ )
43
+ .select_from(models.ExperimentRunAnnotation)
44
+ .join(
45
+ models.ExperimentRun,
46
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
47
+ )
48
+ .where(
49
+ tuple_(
50
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
51
+ ).in_(set(keys))
52
+ )
53
+ .group_by(
54
+ models.ExperimentRun.experiment_id,
55
+ models.ExperimentRun.dataset_example_id,
56
+ models.ExperimentRunAnnotation.name,
57
+ )
58
+ )
59
+ async with self._db() as session:
60
+ annotation_summaries = (await session.execute(annotation_summaries_query)).all()
61
+ annotation_summaries_by_key: dict[Key, list[AnnotationSummary]] = {}
62
+ for summary in annotation_summaries:
63
+ key = (summary.experiment_id, summary.dataset_example_id)
64
+ gql_summary = AnnotationSummary(
65
+ annotation_name=summary.annotation_name,
66
+ mean_score=summary.mean_score,
67
+ )
68
+ if key not in annotation_summaries_by_key:
69
+ annotation_summaries_by_key[key] = []
70
+ annotation_summaries_by_key[key].append(gql_summary)
71
+ return [
72
+ sorted(
73
+ annotation_summaries_by_key.get(key, []),
74
+ key=lambda summary: summary.annotation_name,
75
+ )
76
+ for key in keys
77
+ ]
@@ -0,0 +1,57 @@
1
+ from dataclasses import dataclass
2
+
3
+ from sqlalchemy import select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentID: TypeAlias = int
11
+ DatasetExampleID: TypeAlias = int
12
+ Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
13
+
14
+
15
+ @dataclass
16
+ class ExperimentRepeatedRunGroup:
17
+ experiment_rowid: int
18
+ dataset_example_rowid: int
19
+ runs: list[models.ExperimentRun]
20
+
21
+
22
+ Result: TypeAlias = ExperimentRepeatedRunGroup
23
+
24
+
25
+ class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
26
+ def __init__(self, db: DbSessionFactory) -> None:
27
+ super().__init__(load_fn=self._load_fn)
28
+ self._db = db
29
+
30
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
31
+ repeated_run_groups_query = (
32
+ select(models.ExperimentRun)
33
+ .where(
34
+ tuple_(
35
+ models.ExperimentRun.experiment_id,
36
+ models.ExperimentRun.dataset_example_id,
37
+ ).in_(set(keys))
38
+ )
39
+ .order_by(models.ExperimentRun.repetition_number)
40
+ )
41
+
42
+ async with self._db() as session:
43
+ runs_by_key: dict[Key, list[models.ExperimentRun]] = {}
44
+ for run in (await session.scalars(repeated_run_groups_query)).all():
45
+ key = (run.experiment_id, run.dataset_example_id)
46
+ if key not in runs_by_key:
47
+ runs_by_key[key] = []
48
+ runs_by_key[key].append(run)
49
+
50
+ return [
51
+ ExperimentRepeatedRunGroup(
52
+ experiment_rowid=experiment_id,
53
+ dataset_example_rowid=dataset_example_id,
54
+ runs=runs_by_key.get((experiment_id, dataset_example_id), []),
55
+ )
56
+ for (experiment_id, dataset_example_id) in keys
57
+ ]
@@ -0,0 +1,44 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentId: TypeAlias = int
11
+ DatasetExampleId: TypeAlias = int
12
+ Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
13
+ Result: TypeAlias = list[models.ExperimentRun]
14
+
15
+
16
+ class ExperimentRunsByExperimentAndExampleDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ runs_by_key: defaultdict[Key, Result] = defaultdict(list)
23
+
24
+ async with self._db() as session:
25
+ stmt = (
26
+ select(models.ExperimentRun)
27
+ .where(
28
+ tuple_(
29
+ models.ExperimentRun.experiment_id,
30
+ models.ExperimentRun.dataset_example_id,
31
+ ).in_(keys)
32
+ )
33
+ .order_by(
34
+ models.ExperimentRun.experiment_id,
35
+ models.ExperimentRun.dataset_example_id,
36
+ models.ExperimentRun.repetition_number,
37
+ )
38
+ )
39
+ result = await session.stream_scalars(stmt)
40
+ async for run in result:
41
+ key = (run.experiment_id, run.dataset_example_id)
42
+ runs_by_key[key].append(run)
43
+
44
+ return [runs_by_key[key] for key in keys]
@@ -25,6 +25,7 @@ from phoenix.db import models
25
25
  from phoenix.db.helpers import SupportedSQLDialect
26
26
  from phoenix.server.api.dataloaders.cache import TwoTierCache
27
27
  from phoenix.server.api.input_types.TimeRange import TimeRange
28
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
28
29
  from phoenix.server.types import DbSessionFactory
29
30
  from phoenix.trace.dsl import SpanFilter
30
31
 
@@ -32,13 +33,16 @@ Kind: TypeAlias = Literal["span", "trace"]
32
33
  ProjectRowId: TypeAlias = int
33
34
  TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
34
35
  FilterCondition: TypeAlias = Optional[str]
36
+ SessionFilterCondition: TypeAlias = Optional[str]
35
37
  Probability: TypeAlias = float
36
38
  QuantileValue: TypeAlias = float
37
39
 
38
- Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
40
+ Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
39
41
  Param: TypeAlias = tuple[ProjectRowId, Probability]
40
42
 
41
- Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, Probability]
43
+ Key: TypeAlias = tuple[
44
+ Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition, Probability
45
+ ]
42
46
  Result: TypeAlias = Optional[QuantileValue]
43
47
  ResultPosition: TypeAlias = int
44
48
  DEFAULT_VALUE: Result = None
@@ -47,15 +51,18 @@ FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
47
51
 
48
52
 
49
53
  def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
50
- kind, project_rowid, time_range, filter_condition, probability = key
54
+ kind, project_rowid, time_range, filter_condition, session_filter_condition, probability = key
51
55
  interval = (
52
56
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
53
57
  )
54
- return (kind, interval, filter_condition), (project_rowid, probability)
58
+ return (kind, interval, filter_condition, session_filter_condition), (
59
+ project_rowid,
60
+ probability,
61
+ )
55
62
 
56
63
 
57
64
  _Section: TypeAlias = ProjectRowId
58
- _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind, Probability]
65
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind, Probability]
59
66
 
60
67
 
61
68
  class LatencyMsQuantileCache(
@@ -71,8 +78,17 @@ class LatencyMsQuantileCache(
71
78
  )
72
79
 
73
80
  def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
74
- (kind, interval, filter_condition), (project_rowid, probability) = _cache_key_fn(key)
75
- return project_rowid, (interval, filter_condition, kind, probability)
81
+ (
82
+ (kind, interval, filter_condition, session_filter_condition),
83
+ (project_rowid, probability),
84
+ ) = _cache_key_fn(key)
85
+ return project_rowid, (
86
+ interval,
87
+ filter_condition,
88
+ session_filter_condition,
89
+ kind,
90
+ probability,
91
+ )
76
92
 
77
93
 
78
94
  class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
@@ -113,11 +129,18 @@ async def _get_results(
113
129
  segment: Segment,
114
130
  params: Mapping[Param, list[ResultPosition]],
115
131
  ) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
116
- kind, (start_time, end_time), filter_condition = segment
132
+ kind, (start_time, end_time), filter_condition, session_filter_condition = segment
117
133
  stmt = select(models.Trace.project_rowid)
118
134
  if kind == "trace":
119
135
  latency_column = cast(FloatCol, models.Trace.latency_ms)
120
136
  time_column = models.Trace.start_time
137
+ if filter_condition:
138
+ sf = SpanFilter(filter_condition)
139
+ stmt = stmt.where(
140
+ models.Trace.id.in_(
141
+ sf(select(models.Span.trace_rowid).distinct()).scalar_subquery()
142
+ )
143
+ )
121
144
  elif kind == "span":
122
145
  latency_column = cast(FloatCol, models.Span.latency_ms)
123
146
  time_column = models.Span.start_time
@@ -127,6 +150,15 @@ async def _get_results(
127
150
  stmt = sf(stmt)
128
151
  else:
129
152
  assert_never(kind)
153
+ if session_filter_condition:
154
+ project_rowids = [project_rowid for project_rowid, _ in params]
155
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
156
+ session_filter_condition=session_filter_condition,
157
+ project_rowids=project_rowids,
158
+ start_time=start_time,
159
+ end_time=end_time,
160
+ )
161
+ stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
130
162
  if start_time:
131
163
  stmt = stmt.where(start_time <= time_column)
132
164
  if end_time:
@@ -3,13 +3,14 @@ from datetime import datetime
3
3
  from typing import Any, Literal, Optional
4
4
 
5
5
  from cachetools import LFUCache, TTLCache
6
- from sqlalchemy import Select, func, select
6
+ from sqlalchemy import Select, distinct, func, select
7
7
  from strawberry.dataloader import AbstractCache, DataLoader
8
8
  from typing_extensions import TypeAlias, assert_never
9
9
 
10
10
  from phoenix.db import models
11
11
  from phoenix.server.api.dataloaders.cache import TwoTierCache
12
12
  from phoenix.server.api.input_types.TimeRange import TimeRange
13
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
13
14
  from phoenix.server.types import DbSessionFactory
14
15
  from phoenix.trace.dsl import SpanFilter
15
16
 
@@ -17,27 +18,35 @@ Kind: TypeAlias = Literal["span", "trace"]
17
18
  ProjectRowId: TypeAlias = int
18
19
  TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
19
20
  FilterCondition: TypeAlias = Optional[str]
21
+ SessionFilterCondition: TypeAlias = Optional[str]
20
22
  SpanCount: TypeAlias = int
21
23
 
22
- Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
24
+ Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
23
25
  Param: TypeAlias = ProjectRowId
24
26
 
25
- Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
27
+ Key: TypeAlias = tuple[
28
+ Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition
29
+ ]
26
30
  Result: TypeAlias = SpanCount
27
31
  ResultPosition: TypeAlias = int
28
32
  DEFAULT_VALUE: Result = 0
29
33
 
30
34
 
31
35
  def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
32
- kind, project_rowid, time_range, filter_condition = key
36
+ kind, project_rowid, time_range, filter_condition, session_filter_condition = key
33
37
  interval = (
34
38
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
35
39
  )
36
- return (kind, interval, filter_condition), project_rowid
40
+ return (
41
+ kind,
42
+ interval,
43
+ filter_condition,
44
+ session_filter_condition,
45
+ ), project_rowid
37
46
 
38
47
 
39
48
  _Section: TypeAlias = ProjectRowId
40
- _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
49
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind]
41
50
 
42
51
 
43
52
  class RecordCountCache(
@@ -53,8 +62,10 @@ class RecordCountCache(
53
62
  )
54
63
 
55
64
  def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
56
- (kind, interval, filter_condition), project_rowid = _cache_key_fn(key)
57
- return project_rowid, (interval, filter_condition, kind)
65
+ (kind, interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(
66
+ key
67
+ )
68
+ return project_rowid, (interval, filter_condition, session_filter_condition, kind)
58
69
 
59
70
 
60
71
  class RecordCountDataLoader(DataLoader[Key, Result]):
@@ -93,7 +104,7 @@ def _get_stmt(
93
104
  segment: Segment,
94
105
  *project_rowids: Param,
95
106
  ) -> Select[Any]:
96
- kind, (start_time, end_time), filter_condition = segment
107
+ kind, (start_time, end_time), filter_condition, session_filter_condition = segment
97
108
  pid = models.Trace.project_rowid
98
109
  stmt = select(pid)
99
110
  if kind == "span":
@@ -102,12 +113,28 @@ def _get_stmt(
102
113
  if filter_condition:
103
114
  sf = SpanFilter(filter_condition)
104
115
  stmt = sf(stmt)
116
+ stmt = stmt.add_columns(func.count().label("count"))
105
117
  elif kind == "trace":
106
118
  time_column = models.Trace.start_time
119
+ if filter_condition:
120
+ stmt = stmt.join(models.Span, models.Trace.id == models.Span.trace_rowid)
121
+ stmt = stmt.add_columns(func.count(distinct(models.Trace.id)).label("count"))
122
+ sf = SpanFilter(filter_condition)
123
+ stmt = sf(stmt)
124
+ else:
125
+ stmt = stmt.add_columns(func.count().label("count"))
107
126
  else:
108
127
  assert_never(kind)
109
- stmt = stmt.add_columns(func.count().label("count"))
110
128
  stmt = stmt.where(pid.in_(project_rowids))
129
+
130
+ if session_filter_condition:
131
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
132
+ session_filter_condition=session_filter_condition,
133
+ project_rowids=project_rowids,
134
+ start_time=start_time,
135
+ end_time=end_time,
136
+ )
137
+ stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
111
138
  stmt = stmt.group_by(pid)
112
139
  if start_time:
113
140
  stmt = stmt.where(start_time <= time_column)
@@ -0,0 +1,29 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.models import ProjectSessionAnnotation
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ProjectSessionId: TypeAlias = int
11
+ Key: TypeAlias = ProjectSessionId
12
+ Result: TypeAlias = list[ProjectSessionAnnotation]
13
+
14
+
15
+ class SessionAnnotationsBySessionDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
+ async with self._db() as session:
23
+ async for annotation in await session.stream_scalars(
24
+ select(ProjectSessionAnnotation).where(
25
+ ProjectSessionAnnotation.project_session_id.in_(keys)
26
+ )
27
+ ):
28
+ annotations_by_id[annotation.project_session_id].append(annotation)
29
+ return [annotations_by_id[key] for key in keys]
@@ -0,0 +1,64 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentId: TypeAlias = int
12
+ DatasetExampleId: TypeAlias = int
13
+ Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.experiment_id,
26
+ models.ExperimentRun.dataset_example_id,
27
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
28
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
29
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
30
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
31
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
32
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
33
+ )
34
+ .select_from(models.ExperimentRun)
35
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
36
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
37
+ .where(
38
+ tuple_(
39
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
40
+ ).in_(set(keys))
41
+ )
42
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
43
+ )
44
+
45
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
46
+ async with self._db() as session:
47
+ data = await session.stream(stmt)
48
+ async for (
49
+ experiment_id,
50
+ dataset_example_id,
51
+ prompt_cost,
52
+ completion_cost,
53
+ total_cost,
54
+ prompt_tokens,
55
+ completion_tokens,
56
+ total_tokens,
57
+ ) in data:
58
+ summary = SpanCostSummary(
59
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
60
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
61
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
62
+ )
63
+ results[(experiment_id, dataset_example_id)] = summary
64
+ return [results.get(key, SpanCostSummary()) for key in keys]
@@ -12,32 +12,38 @@ from phoenix.db import models
12
12
  from phoenix.server.api.dataloaders.cache import TwoTierCache
13
13
  from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
14
14
  from phoenix.server.api.input_types.TimeRange import TimeRange
15
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
15
16
  from phoenix.server.types import DbSessionFactory
16
17
  from phoenix.trace.dsl import SpanFilter
17
18
 
18
19
  ProjectRowId: TypeAlias = int
19
20
  TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
20
21
  FilterCondition: TypeAlias = Optional[str]
22
+ SessionFilterCondition: TypeAlias = Optional[str]
21
23
 
22
- Segment: TypeAlias = tuple[TimeInterval, FilterCondition]
24
+ Segment: TypeAlias = tuple[
25
+ TimeInterval,
26
+ FilterCondition,
27
+ SessionFilterCondition,
28
+ ]
23
29
  Param: TypeAlias = ProjectRowId
24
30
 
25
- Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition]
31
+ Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition]
26
32
  Result: TypeAlias = SpanCostSummary
27
33
  ResultPosition: TypeAlias = int
28
34
  DEFAULT_VALUE: Result = SpanCostSummary()
29
35
 
30
36
 
31
37
  def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
32
- project_rowid, time_range, filter_condition = key
38
+ project_rowid, time_range, filter_condition, session_filter_condition = key
33
39
  interval = (
34
40
  (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
35
41
  )
36
- return (interval, filter_condition), project_rowid
42
+ return (interval, filter_condition, session_filter_condition), project_rowid
37
43
 
38
44
 
39
45
  _Section: TypeAlias = ProjectRowId
40
- _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
46
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
41
47
 
42
48
 
43
49
  class SpanCostSummaryCache(
@@ -53,8 +59,8 @@ class SpanCostSummaryCache(
53
59
  )
54
60
 
55
61
  def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
56
- (interval, filter_condition), project_rowid = _cache_key_fn(key)
57
- return project_rowid, (interval, filter_condition)
62
+ (interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(key)
63
+ return project_rowid, (interval, filter_condition, session_filter_condition)
58
64
 
59
65
 
60
66
  class SpanCostSummaryByProjectDataLoader(DataLoader[Key, Result]):
@@ -106,12 +112,12 @@ def _get_stmt(
106
112
  segment: Segment,
107
113
  *params: Param,
108
114
  ) -> Select[Any]:
109
- (start_time, end_time), filter_condition = segment
110
- pid = models.Trace.project_rowid
115
+ project_rowids = params
116
+ (start_time, end_time), filter_condition, session_filter_condition = segment
111
117
 
112
118
  stmt: Select[Any] = (
113
119
  select(
114
- pid,
120
+ models.Trace.project_rowid,
115
121
  coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
116
122
  coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
117
123
  coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
@@ -119,8 +125,10 @@ def _get_stmt(
119
125
  coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
120
126
  coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
121
127
  )
122
- .join_from(models.SpanCost, models.Trace)
123
- .group_by(pid)
128
+ .select_from(models.Trace)
129
+ .join(models.SpanCost, models.Trace.id == models.SpanCost.trace_rowid)
130
+ .where(models.Trace.project_rowid.in_(project_rowids))
131
+ .group_by(models.Trace.project_rowid)
124
132
  )
125
133
 
126
134
  if start_time:
@@ -132,7 +140,13 @@ def _get_stmt(
132
140
  sf = SpanFilter(filter_condition)
133
141
  stmt = sf(stmt.join_from(models.SpanCost, models.Span))
134
142
 
135
- project_ids = [rowid for rowid in params]
136
- stmt = stmt.where(pid.in_(project_ids))
143
+ if session_filter_condition:
144
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
145
+ session_filter_condition=session_filter_condition,
146
+ project_rowids=project_rowids,
147
+ start_time=start_time,
148
+ end_time=end_time,
149
+ )
150
+ stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
137
151
 
138
152
  return stmt
@@ -1,7 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
3
  from sqlalchemy import select
4
- from sqlalchemy.orm import joinedload, load_only
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
@@ -22,14 +21,9 @@ class SpanCostsDataLoader(DataLoader[Key, Result]):
22
21
  span_ids = list(set(keys))
23
22
  async with self._db() as session:
24
23
  costs = {
25
- span.id: span.span_cost
26
- async for span in await session.stream_scalars(
27
- select(models.Span)
28
- .where(models.Span.id.in_(span_ids))
29
- .options(
30
- load_only(models.Span.id),
31
- joinedload(models.Span.span_cost),
32
- )
24
+ span_cost.span_rowid: span_cost
25
+ async for span_cost in await session.stream_scalars(
26
+ select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
33
27
  )
34
28
  }
35
29
  return [costs.get(span_id) for span_id in keys]