arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/iam_auth.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def generate_aws_rds_token(
9
+ host: str,
10
+ port: int,
11
+ user: str,
12
+ ) -> str:
13
+ """Generate an AWS RDS IAM authentication token.
14
+
15
+ This function creates a short-lived (15 minutes) authentication token for connecting
16
+ to AWS RDS/Aurora PostgreSQL instances using IAM database authentication.
17
+
18
+ The AWS region is automatically resolved using boto3.
19
+
20
+ Args:
21
+ host: The database hostname (e.g., 'mydb.abc123.us-west-2.rds.amazonaws.com')
22
+ port: The database port (typically 5432 for PostgreSQL)
23
+ user: The database username (must match an IAM-enabled database user)
24
+
25
+ Returns:
26
+ A temporary authentication token string to use as the database password
27
+
28
+ Raises:
29
+ ImportError: If boto3 is not installed
30
+ Exception: If AWS credentials/region are not configured or token generation fails
31
+
32
+ Example:
33
+ >>> token = generate_aws_rds_token(
34
+ ... host='mydb.us-west-2.rds.amazonaws.com',
35
+ ... port=5432,
36
+ ... user='myuser'
37
+ ... )
38
+ """
39
+ try:
40
+ import boto3 # type: ignore
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "boto3 is required for AWS RDS IAM authentication. "
44
+ "Install it with: pip install 'arize-phoenix[aws]'"
45
+ ) from e
46
+
47
+ try:
48
+ client = boto3.client("rds")
49
+
50
+ logger.debug(f"Generating AWS RDS IAM auth token for user '{user}' at {host}:{port}")
51
+ token = client.generate_db_auth_token( # pyright: ignore
52
+ DBHostname=host,
53
+ Port=port,
54
+ DBUsername=user,
55
+ )
56
+
57
+ return str(token) # pyright: ignore
58
+
59
+ except Exception as e:
60
+ logger.error(
61
+ f"Failed to generate AWS RDS IAM authentication token: {e}. "
62
+ "Ensure AWS credentials are configured and have 'rds-db:connect' permission."
63
+ )
64
+ raise
@@ -11,7 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
11
11
  from typing_extensions import TypeAlias
12
12
 
13
13
  from phoenix.db import models
14
- from phoenix.db.insertion.helpers import DataManipulationEvent
14
+ from phoenix.db.helpers import SupportedSQLDialect
15
+ from phoenix.db.insertion.helpers import DataManipulationEvent, OnConflict, insert_on_conflict
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
@@ -27,6 +28,7 @@ class ExampleContent:
27
28
  input: dict[str, Any] = field(default_factory=dict)
28
29
  output: dict[str, Any] = field(default_factory=dict)
29
30
  metadata: dict[str, Any] = field(default_factory=dict)
31
+ splits: frozenset[str] = field(default_factory=frozenset) # Set of split names
30
32
 
31
33
 
32
34
  Examples: TypeAlias = Iterable[ExampleContent]
@@ -44,6 +46,7 @@ async def insert_dataset(
44
46
  description: Optional[str] = None,
45
47
  metadata: Optional[Mapping[str, Any]] = None,
46
48
  created_at: Optional[datetime] = None,
49
+ user_id: Optional[int] = None,
47
50
  ) -> DatasetId:
48
51
  id_ = await session.scalar(
49
52
  insert(models.Dataset)
@@ -52,6 +55,7 @@ async def insert_dataset(
52
55
  description=description,
53
56
  metadata_=metadata,
54
57
  created_at=created_at,
58
+ user_id=user_id,
55
59
  )
56
60
  .returning(models.Dataset.id)
57
61
  )
@@ -64,6 +68,7 @@ async def insert_dataset_version(
64
68
  description: Optional[str] = None,
65
69
  metadata: Optional[Mapping[str, Any]] = None,
66
70
  created_at: Optional[datetime] = None,
71
+ user_id: Optional[int] = None,
67
72
  ) -> DatasetVersionId:
68
73
  id_ = await session.scalar(
69
74
  insert(models.DatasetVersion)
@@ -72,6 +77,7 @@ async def insert_dataset_version(
72
77
  description=description,
73
78
  metadata_=metadata,
74
79
  created_at=created_at,
80
+ user_id=user_id,
75
81
  )
76
82
  .returning(models.DatasetVersion.id)
77
83
  )
@@ -134,6 +140,92 @@ async def insert_dataset_example_revision(
134
140
  return cast(DatasetExampleRevisionId, id_)
135
141
 
136
142
 
143
+ async def bulk_create_dataset_splits(
144
+ session: AsyncSession,
145
+ split_names: set[str],
146
+ user_id: Optional[int] = None,
147
+ ) -> dict[str, int]:
148
+ """
149
+ Bulk create dataset splits using upsert pattern.
150
+ Returns a mapping of split name to split ID.
151
+ """
152
+ if not split_names:
153
+ return {}
154
+
155
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
156
+ records = [
157
+ {
158
+ "name": name,
159
+ "color": "#808080", # Default gray color
160
+ "metadata_": {},
161
+ "user_id": user_id,
162
+ }
163
+ for name in split_names
164
+ ]
165
+
166
+ # Bulk upsert all splits - uses ON CONFLICT DO NOTHING to handle race conditions
167
+ stmt = insert_on_conflict(
168
+ *records,
169
+ table=models.DatasetSplit,
170
+ dialect=dialect,
171
+ unique_by=["name"],
172
+ on_conflict=OnConflict.DO_NOTHING,
173
+ )
174
+ await session.execute(stmt)
175
+
176
+ # Fetch all split IDs by name
177
+ result = await session.execute(
178
+ select(models.DatasetSplit.name, models.DatasetSplit.id).where(
179
+ models.DatasetSplit.name.in_(split_names)
180
+ )
181
+ )
182
+ return {name: split_id for name, split_id in result.all()}
183
+
184
+
185
+ async def bulk_assign_examples_to_splits(
186
+ session: AsyncSession,
187
+ assignments: list[tuple[DatasetExampleId, int]],
188
+ ) -> None:
189
+ """
190
+ Bulk assign examples to splits.
191
+ assignments is a list of (dataset_example_id, dataset_split_id) tuples.
192
+ """
193
+ if not assignments:
194
+ return
195
+
196
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
197
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
198
+ from typing_extensions import assert_never
199
+
200
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
201
+ records = [
202
+ {
203
+ "dataset_example_id": example_id,
204
+ "dataset_split_id": split_id,
205
+ }
206
+ for example_id, split_id in assignments
207
+ ]
208
+
209
+ # Use index_elements instead of constraint name because the table uses
210
+ # a PrimaryKeyConstraint, not a unique constraint
211
+ if dialect is SupportedSQLDialect.POSTGRESQL:
212
+ pg_stmt = pg_insert(models.DatasetSplitDatasetExample).values(records)
213
+ await session.execute(
214
+ pg_stmt.on_conflict_do_nothing(
215
+ index_elements=["dataset_split_id", "dataset_example_id"]
216
+ )
217
+ )
218
+ elif dialect is SupportedSQLDialect.SQLITE:
219
+ sqlite_stmt = sqlite_insert(models.DatasetSplitDatasetExample).values(records)
220
+ await session.execute(
221
+ sqlite_stmt.on_conflict_do_nothing(
222
+ index_elements=["dataset_split_id", "dataset_example_id"]
223
+ )
224
+ )
225
+ else:
226
+ assert_never(dialect)
227
+
228
+
137
229
  class DatasetAction(Enum):
138
230
  CREATE = "create"
139
231
  APPEND = "append"
@@ -152,6 +244,7 @@ async def add_dataset_examples(
152
244
  description: Optional[str] = None,
153
245
  metadata: Optional[Mapping[str, Any]] = None,
154
246
  action: DatasetAction = DatasetAction.CREATE,
247
+ user_id: Optional[int] = None,
155
248
  ) -> Optional[DatasetExampleAdditionEvent]:
156
249
  created_at = datetime.now(timezone.utc)
157
250
  dataset_id: Optional[DatasetId] = None
@@ -167,6 +260,7 @@ async def add_dataset_examples(
167
260
  description=description,
168
261
  metadata=metadata,
169
262
  created_at=created_at,
263
+ user_id=user_id,
170
264
  )
171
265
  except Exception:
172
266
  logger.exception(f"Failed to insert dataset: {name=}")
@@ -176,10 +270,14 @@ async def add_dataset_examples(
176
270
  session=session,
177
271
  dataset_id=dataset_id,
178
272
  created_at=created_at,
273
+ user_id=user_id,
179
274
  )
180
275
  except Exception:
181
276
  logger.exception(f"Failed to insert dataset version for {dataset_id=}")
182
277
  raise
278
+
279
+ # Process examples and collect split assignments (by name, resolved to IDs after iteration)
280
+ split_assignments: list[tuple[DatasetExampleId, str]] = []
183
281
  for example in (await examples) if isinstance(examples, Awaitable) else examples:
184
282
  try:
185
283
  dataset_example_id = await insert_dataset_example(
@@ -206,6 +304,40 @@ async def add_dataset_examples(
206
304
  f"{dataset_example_id=}"
207
305
  )
208
306
  raise
307
+
308
+ # Collect split assignments by name for bulk insert later
309
+ for split_name in example.splits:
310
+ split_assignments.append((dataset_example_id, split_name))
311
+
312
+ # Bulk create splits and assign examples after iteration
313
+ if split_assignments:
314
+ # Collect all unique split names
315
+ all_split_names = {name for _, name in split_assignments}
316
+ try:
317
+ split_name_to_id = await bulk_create_dataset_splits(
318
+ session=session,
319
+ split_names=all_split_names,
320
+ user_id=user_id,
321
+ )
322
+ except Exception:
323
+ logger.exception(f"Failed to bulk create dataset splits: {all_split_names}")
324
+ raise
325
+
326
+ # Convert name-based assignments to ID-based assignments
327
+ id_assignments = [
328
+ (example_id, split_name_to_id[split_name])
329
+ for example_id, split_name in split_assignments
330
+ ]
331
+
332
+ try:
333
+ await bulk_assign_examples_to_splits(
334
+ session=session,
335
+ assignments=id_assignments,
336
+ )
337
+ except Exception:
338
+ logger.exception("Failed to bulk assign examples to splits")
339
+ raise
340
+
209
341
  return DatasetExampleAdditionEvent(dataset_id=dataset_id, dataset_version_id=dataset_version_id)
210
342
 
211
343
 
@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
7
7
  from typing_extensions import TypeAlias
8
8
 
9
9
  from phoenix.db import models
10
- from phoenix.db.helpers import dedup, num_docs_col
10
+ from phoenix.db.helpers import dedup
11
11
  from phoenix.db.insertion.helpers import as_kv
12
12
  from phoenix.db.insertion.types import (
13
13
  Insertables,
@@ -63,7 +63,7 @@ class DocumentAnnotationQueueInserter(
63
63
  session: AsyncSession,
64
64
  *insertions: Insertables.DocumentAnnotation,
65
65
  ) -> list[DocumentAnnotationDmlEvent]:
66
- records = [dict(as_kv(ins.row)) for ins in insertions]
66
+ records = [{**dict(as_kv(ins.row)), "updated_at": ins.row.updated_at} for ins in insertions]
67
67
  stmt = self._insert_on_conflict(*records).returning(self.table.id)
68
68
  ids = tuple([_ async for _ in await session.stream_scalars(stmt)])
69
69
  return [DocumentAnnotationDmlEvent(ids)]
@@ -99,7 +99,7 @@ class DocumentAnnotationQueueInserter(
99
99
 
100
100
  for p in parcels:
101
101
  if (anno := existing_annos.get(_key(p))) is not None:
102
- if p.received_at <= anno.updated_at:
102
+ if p.item.updated_at <= anno.updated_at:
103
103
  to_discard.append(p)
104
104
  else:
105
105
  to_insert.append(
@@ -107,7 +107,6 @@ class DocumentAnnotationQueueInserter(
107
107
  received_at=p.received_at,
108
108
  item=p.item.as_insertable(
109
109
  span_rowid=anno.span_rowid,
110
- id_=anno.id_,
111
110
  ),
112
111
  )
113
112
  )
@@ -140,7 +139,11 @@ class DocumentAnnotationQueueInserter(
140
139
  def _select_existing(self, *keys: _Key) -> Select[_Existing]:
141
140
  anno = self.table
142
141
  span = (
143
- select(models.Span.id, models.Span.span_id, num_docs_col(self._db.dialect))
142
+ select(
143
+ models.Span.id,
144
+ models.Span.span_id,
145
+ models.Span.num_documents.label("num_docs"),
146
+ )
144
147
  .where(models.Span.span_id.in_({k.span_id for k in keys}))
145
148
  .cte()
146
149
  )
@@ -182,7 +185,7 @@ def _key(p: Received[Precursors.DocumentAnnotation]) -> _Key:
182
185
 
183
186
 
184
187
  def _unique_by(p: Received[Insertables.DocumentAnnotation]) -> _UniqueBy:
185
- return p.item.obj.name, p.item.span_rowid, p.item.document_position, p.item.identifier
188
+ return p.item.obj.name, p.item.span_rowid, p.item.document_position, p.item.obj.identifier
186
189
 
187
190
 
188
191
  def _time(p: Received[Any]) -> datetime:
@@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
5
5
  from typing_extensions import assert_never
6
6
 
7
7
  from phoenix.db import models
8
- from phoenix.db.helpers import SupportedSQLDialect, num_docs_col
8
+ from phoenix.db.helpers import SupportedSQLDialect
9
9
  from phoenix.db.insertion.helpers import insert_on_conflict
10
10
  from phoenix.exceptions import PhoenixException
11
11
  from phoenix.trace import v1 as pb
@@ -153,12 +153,11 @@ async def _insert_document_evaluation(
153
153
  score: Optional[float],
154
154
  explanation: Optional[str],
155
155
  ) -> EvaluationInsertionEvent:
156
- dialect = SupportedSQLDialect(session.bind.dialect.name)
157
156
  stmt = (
158
157
  select(
159
158
  models.Trace.project_rowid,
160
159
  models.Span.id,
161
- num_docs_col(dialect),
160
+ models.Span.num_documents,
162
161
  )
163
162
  .join_from(models.Span, models.Trace)
164
163
  .where(models.Span.span_id == span_id)
@@ -12,7 +12,7 @@ from sqlalchemy.sql.elements import KeyedColumnElement
12
12
  from typing_extensions import TypeAlias, assert_never
13
13
 
14
14
  from phoenix.db import models
15
- from phoenix.db.helpers import SupportedSQLDialect
15
+ from phoenix.db.helpers import SupportedSQLDialect, truncate_name
16
16
  from phoenix.db.models import Base
17
17
  from phoenix.trace.attributes import get_attribute_value
18
18
 
@@ -53,7 +53,7 @@ def insert_on_conflict(
53
53
  unique_records.append(v)
54
54
  seen.add(k)
55
55
  records = tuple(reversed(unique_records))
56
- constraint = constraint_name or "_".join(("uq", table.__tablename__, *unique_by))
56
+ constraint = constraint_name or truncate_name("_".join(("uq", table.__tablename__, *unique_by)))
57
57
  if dialect is SupportedSQLDialect.POSTGRESQL:
58
58
  stmt_postgresql = insert_postgresql(table).values(records)
59
59
  if on_conflict is OnConflict.DO_NOTHING:
@@ -0,0 +1,176 @@
1
+ from collections.abc import Mapping
2
+ from datetime import datetime
3
+ from typing import Any, NamedTuple, Optional
4
+
5
+ from sqlalchemy import Row, Select, and_, select, tuple_
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from typing_extensions import TypeAlias
8
+
9
+ from phoenix.db import models
10
+ from phoenix.db.helpers import dedup
11
+ from phoenix.db.insertion.helpers import as_kv
12
+ from phoenix.db.insertion.types import (
13
+ Insertables,
14
+ Postponed,
15
+ Precursors,
16
+ QueueInserter,
17
+ Received,
18
+ )
19
+ from phoenix.server.dml_event import ProjectSessionAnnotationDmlEvent
20
+
21
+ # Type alias for consistency with other annotation patterns
22
+ SessionAnnotationDmlEvent = ProjectSessionAnnotationDmlEvent
23
+
24
+ _Name: TypeAlias = str
25
+ _SessionId: TypeAlias = str
26
+ _SessionRowId: TypeAlias = int
27
+ _AnnoRowId: TypeAlias = int
28
+ _Identifier: TypeAlias = str
29
+
30
+
31
+ class _Key(NamedTuple):
32
+ annotation_name: _Name
33
+ annotation_identifier: _Identifier
34
+ session_id: _SessionId
35
+
36
+
37
+ _UniqueBy: TypeAlias = tuple[_Name, _SessionRowId, _Identifier]
38
+ _Existing: TypeAlias = tuple[
39
+ _SessionRowId,
40
+ _SessionId,
41
+ Optional[_AnnoRowId],
42
+ Optional[_Name],
43
+ Optional[datetime],
44
+ ]
45
+
46
+
47
+ class SessionAnnotationQueueInserter(
48
+ QueueInserter[
49
+ Precursors.SessionAnnotation,
50
+ Insertables.SessionAnnotation,
51
+ models.ProjectSessionAnnotation,
52
+ SessionAnnotationDmlEvent,
53
+ ],
54
+ table=models.ProjectSessionAnnotation,
55
+ unique_by=("name", "project_session_id", "identifier"),
56
+ ):
57
+ async def _events(
58
+ self,
59
+ session: AsyncSession,
60
+ *insertions: Insertables.SessionAnnotation,
61
+ ) -> list[SessionAnnotationDmlEvent]:
62
+ records = [{**dict(as_kv(ins.row)), "updated_at": ins.row.updated_at} for ins in insertions]
63
+ stmt = self._insert_on_conflict(*records).returning(self.table.id)
64
+ ids = tuple([_ async for _ in await session.stream_scalars(stmt)])
65
+ return [SessionAnnotationDmlEvent(ids)]
66
+
67
+ async def _partition(
68
+ self,
69
+ session: AsyncSession,
70
+ *parcels: Received[Precursors.SessionAnnotation],
71
+ ) -> tuple[
72
+ list[Received[Insertables.SessionAnnotation]],
73
+ list[Postponed[Precursors.SessionAnnotation]],
74
+ list[Received[Precursors.SessionAnnotation]],
75
+ ]:
76
+ to_insert: list[Received[Insertables.SessionAnnotation]] = []
77
+ to_postpone: list[Postponed[Precursors.SessionAnnotation]] = []
78
+ to_discard: list[Received[Precursors.SessionAnnotation]] = []
79
+
80
+ stmt = self._select_existing(*map(_key, parcels))
81
+ existing: list[Row[_Existing]] = [_ async for _ in await session.stream(stmt)]
82
+ existing_sessions: Mapping[str, _SessionAttr] = {
83
+ e.session_id: _SessionAttr(e.session_rowid) for e in existing
84
+ }
85
+ existing_annos: Mapping[_Key, _AnnoAttr] = {
86
+ _Key(
87
+ annotation_name=e.name,
88
+ annotation_identifier=e.identifier,
89
+ session_id=e.session_id,
90
+ ): _AnnoAttr(e.session_rowid, e.id, e.updated_at)
91
+ for e in existing
92
+ if e.id is not None and e.name is not None and e.updated_at is not None
93
+ }
94
+
95
+ for p in parcels:
96
+ if (anno := existing_annos.get(_key(p))) is not None:
97
+ if p.item.updated_at <= anno.updated_at:
98
+ to_discard.append(p)
99
+ else:
100
+ to_insert.append(
101
+ Received(
102
+ received_at=p.received_at,
103
+ item=p.item.as_insertable(
104
+ project_session_rowid=anno.session_rowid,
105
+ ),
106
+ )
107
+ )
108
+ elif (existing_session := existing_sessions.get(p.item.session_id)) is not None:
109
+ to_insert.append(
110
+ Received(
111
+ received_at=p.received_at,
112
+ item=p.item.as_insertable(
113
+ project_session_rowid=existing_session.session_rowid,
114
+ ),
115
+ )
116
+ )
117
+ elif isinstance(p, Postponed):
118
+ if p.retries_left > 1:
119
+ to_postpone.append(p.postpone(p.retries_left - 1))
120
+ else:
121
+ to_discard.append(p)
122
+ elif isinstance(p, Received):
123
+ to_postpone.append(p.postpone(self._retry_allowance))
124
+ else:
125
+ to_discard.append(p)
126
+
127
+ assert len(to_insert) + len(to_postpone) + len(to_discard) == len(parcels)
128
+ to_insert = dedup(sorted(to_insert, key=_time, reverse=True), _unique_by)[::-1]
129
+ return to_insert, to_postpone, to_discard
130
+
131
+ def _select_existing(self, *keys: _Key) -> Select[_Existing]:
132
+ anno = self.table
133
+ session = (
134
+ select(models.ProjectSession.id, models.ProjectSession.session_id)
135
+ .where(models.ProjectSession.session_id.in_({k.session_id for k in keys}))
136
+ .cte()
137
+ )
138
+ onclause = and_(
139
+ session.c.id == anno.project_session_id,
140
+ anno.name.in_({k.annotation_name for k in keys}),
141
+ tuple_(anno.name, anno.identifier, session.c.session_id).in_(keys),
142
+ )
143
+ return select(
144
+ session.c.id.label("session_rowid"),
145
+ session.c.session_id,
146
+ anno.id,
147
+ anno.name,
148
+ anno.identifier,
149
+ anno.updated_at,
150
+ ).outerjoin_from(session, anno, onclause)
151
+
152
+
153
+ class _SessionAttr(NamedTuple):
154
+ session_rowid: _SessionRowId
155
+
156
+
157
+ class _AnnoAttr(NamedTuple):
158
+ session_rowid: _SessionRowId
159
+ id_: _AnnoRowId
160
+ updated_at: datetime
161
+
162
+
163
+ def _key(p: Received[Precursors.SessionAnnotation]) -> _Key:
164
+ return _Key(
165
+ annotation_name=p.item.obj.name,
166
+ annotation_identifier=p.item.obj.identifier,
167
+ session_id=p.item.session_id,
168
+ )
169
+
170
+
171
+ def _unique_by(p: Received[Insertables.SessionAnnotation]) -> _UniqueBy:
172
+ return p.item.obj.name, p.item.project_session_rowid, p.item.obj.identifier
173
+
174
+
175
+ def _time(p: Received[Any]) -> datetime:
176
+ return p.received_at
@@ -57,7 +57,7 @@ class SpanAnnotationQueueInserter(
57
57
  session: AsyncSession,
58
58
  *insertions: Insertables.SpanAnnotation,
59
59
  ) -> list[SpanAnnotationDmlEvent]:
60
- records = [dict(as_kv(ins.row)) for ins in insertions]
60
+ records = [{**dict(as_kv(ins.row)), "updated_at": ins.row.updated_at} for ins in insertions]
61
61
  stmt = self._insert_on_conflict(*records).returning(self.table.id)
62
62
  ids = tuple([_ async for _ in await session.stream_scalars(stmt)])
63
63
  return [SpanAnnotationDmlEvent(ids)]
@@ -92,7 +92,7 @@ class SpanAnnotationQueueInserter(
92
92
 
93
93
  for p in parcels:
94
94
  if (anno := existing_annos.get(_key(p))) is not None:
95
- if p.received_at <= anno.updated_at:
95
+ if p.item.updated_at <= anno.updated_at:
96
96
  to_discard.append(p)
97
97
  else:
98
98
  to_insert.append(
@@ -100,7 +100,6 @@ class SpanAnnotationQueueInserter(
100
100
  received_at=p.received_at,
101
101
  item=p.item.as_insertable(
102
102
  span_rowid=anno.span_rowid,
103
- id_=anno.id_,
104
103
  ),
105
104
  )
106
105
  )
@@ -168,7 +167,7 @@ def _key(p: Received[Precursors.SpanAnnotation]) -> _Key:
168
167
 
169
168
 
170
169
  def _unique_by(p: Received[Insertables.SpanAnnotation]) -> _UniqueBy:
171
- return p.item.obj.name, p.item.span_rowid, p.item.identifier
170
+ return p.item.obj.name, p.item.span_rowid, p.item.obj.identifier
172
171
 
173
172
 
174
173
  def _time(p: Received[Any]) -> datetime:
@@ -56,7 +56,7 @@ class TraceAnnotationQueueInserter(
56
56
  session: AsyncSession,
57
57
  *insertions: Insertables.TraceAnnotation,
58
58
  ) -> list[TraceAnnotationDmlEvent]:
59
- records = [dict(as_kv(ins.row)) for ins in insertions]
59
+ records = [{**dict(as_kv(ins.row)), "updated_at": ins.row.updated_at} for ins in insertions]
60
60
  stmt = self._insert_on_conflict(*records).returning(self.table.id)
61
61
  ids = tuple([_ async for _ in await session.stream_scalars(stmt)])
62
62
  return [TraceAnnotationDmlEvent(ids)]
@@ -91,7 +91,7 @@ class TraceAnnotationQueueInserter(
91
91
 
92
92
  for p in parcels:
93
93
  if (anno := existing_annos.get(_key(p))) is not None:
94
- if p.received_at <= anno.updated_at:
94
+ if p.item.updated_at <= anno.updated_at:
95
95
  to_discard.append(p)
96
96
  else:
97
97
  to_insert.append(
@@ -99,7 +99,6 @@ class TraceAnnotationQueueInserter(
99
99
  received_at=p.received_at,
100
100
  item=p.item.as_insertable(
101
101
  trace_rowid=anno.trace_rowid,
102
- id_=anno.id_,
103
102
  ),
104
103
  )
105
104
  )
@@ -167,7 +166,7 @@ def _key(p: Received[Precursors.TraceAnnotation]) -> _Key:
167
166
 
168
167
 
169
168
  def _unique_by(p: Received[Insertables.TraceAnnotation]) -> _UniqueBy:
170
- return p.item.obj.name, p.item.trace_rowid, p.item.identifier
169
+ return p.item.obj.name, p.item.trace_rowid, p.item.obj.identifier
171
170
 
172
171
 
173
172
  def _time(p: Received[Any]) -> datetime: