arize-phoenix 8.32.1__py3-none-any.whl → 9.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/RECORD +76 -56
  3. phoenix/db/constants.py +1 -0
  4. phoenix/db/facilitator.py +55 -0
  5. phoenix/db/insertion/document_annotation.py +31 -13
  6. phoenix/db/insertion/evaluation.py +15 -3
  7. phoenix/db/insertion/helpers.py +2 -1
  8. phoenix/db/insertion/span_annotation.py +26 -9
  9. phoenix/db/insertion/trace_annotation.py +25 -9
  10. phoenix/db/insertion/types.py +7 -0
  11. phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
  12. phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
  13. phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
  14. phoenix/db/models.py +151 -10
  15. phoenix/db/types/annotation_configs.py +97 -0
  16. phoenix/db/types/db_models.py +41 -0
  17. phoenix/db/types/trace_retention.py +267 -0
  18. phoenix/experiments/functions.py +5 -1
  19. phoenix/server/api/auth.py +9 -0
  20. phoenix/server/api/context.py +5 -0
  21. phoenix/server/api/dataloaders/__init__.py +4 -0
  22. phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
  23. phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
  24. phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
  25. phoenix/server/api/helpers/annotations.py +9 -0
  26. phoenix/server/api/helpers/prompts/models.py +34 -67
  27. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
  28. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
  29. phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
  30. phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
  31. phoenix/server/api/mutations/__init__.py +6 -0
  32. phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
  33. phoenix/server/api/mutations/dataset_mutations.py +62 -39
  34. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
  35. phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
  36. phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
  37. phoenix/server/api/queries.py +86 -0
  38. phoenix/server/api/routers/v1/__init__.py +4 -0
  39. phoenix/server/api/routers/v1/annotation_configs.py +449 -0
  40. phoenix/server/api/routers/v1/annotations.py +161 -0
  41. phoenix/server/api/routers/v1/evaluations.py +6 -0
  42. phoenix/server/api/routers/v1/projects.py +1 -50
  43. phoenix/server/api/routers/v1/spans.py +35 -8
  44. phoenix/server/api/routers/v1/traces.py +22 -13
  45. phoenix/server/api/routers/v1/utils.py +60 -0
  46. phoenix/server/api/types/Annotation.py +7 -0
  47. phoenix/server/api/types/AnnotationConfig.py +124 -0
  48. phoenix/server/api/types/AnnotationSource.py +9 -0
  49. phoenix/server/api/types/AnnotationSummary.py +28 -14
  50. phoenix/server/api/types/AnnotatorKind.py +1 -0
  51. phoenix/server/api/types/CronExpression.py +15 -0
  52. phoenix/server/api/types/Evaluation.py +4 -30
  53. phoenix/server/api/types/Project.py +50 -2
  54. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
  55. phoenix/server/api/types/Span.py +78 -0
  56. phoenix/server/api/types/SpanAnnotation.py +24 -0
  57. phoenix/server/api/types/Trace.py +2 -2
  58. phoenix/server/api/types/TraceAnnotation.py +23 -0
  59. phoenix/server/app.py +20 -0
  60. phoenix/server/retention.py +76 -0
  61. phoenix/server/static/.vite/manifest.json +36 -36
  62. phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
  63. phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
  64. phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
  65. phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
  66. phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
  67. phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
  68. phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
  69. phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
  70. phoenix/trace/dsl/filter.py +25 -5
  71. phoenix/utilities/__init__.py +18 -0
  72. phoenix/version.py +1 -1
  73. phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
  74. phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
  75. phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
  76. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -95,6 +95,7 @@ def run_experiment(
95
95
  dry_run: Union[bool, int] = False,
96
96
  print_summary: bool = True,
97
97
  concurrency: int = 3,
98
+ timeout: Optional[int] = None,
98
99
  ) -> RanExperiment:
99
100
  """
100
101
  Runs an experiment using a given set of dataset of examples.
@@ -148,6 +149,8 @@ def run_experiment(
148
149
  concurrency (int): Specifies the concurrency for task execution. In order to enable
149
150
  concurrent task execution, the task callable must be a coroutine function.
150
151
  Defaults to 3.
152
+ timeout (Optional[int]): The timeout for the task execution in seconds. Use this to run
153
+ longer tasks to avoid re-queuing the same task multiple times. Defaults to None.
151
154
 
152
155
  Returns:
153
156
  RanExperiment: The results of the experiment and evaluation. Additional evaluations can be
@@ -380,6 +383,7 @@ def run_experiment(
380
383
  fallback_return_value=None,
381
384
  tqdm_bar_format=get_tqdm_progress_bar_formatter("running tasks"),
382
385
  concurrency=concurrency,
386
+ timeout=timeout,
383
387
  )
384
388
 
385
389
  test_cases = [
@@ -752,7 +756,7 @@ def _print_experiment_error(
752
756
  Prints an experiment error.
753
757
  """
754
758
  display_error = RuntimeError(
755
- f"{kind} failed for example id {repr(example_id)}, " f"repetition {repr(repetition_number)}"
759
+ f"{kind} failed for example id {repr(example_id)}, repetition {repr(repetition_number)}"
756
760
  )
757
761
  display_error.__cause__ = error
758
762
  formatted_exception = "".join(
@@ -42,3 +42,12 @@ class IsAdmin(Authorization):
42
42
  if not info.context.auth_enabled:
43
43
  return False
44
44
  return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
45
+
46
+
47
+ class IsAdminIfAuthEnabled(Authorization):
48
+ message = MSG_ADMIN_ONLY
49
+
50
+ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
51
+ if not info.context.auth_enabled:
52
+ return True
53
+ return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
@@ -32,6 +32,7 @@ from phoenix.server.api.dataloaders import (
32
32
  NumChildSpansDataLoader,
33
33
  NumSpansPerTraceDataLoader,
34
34
  ProjectByNameDataLoader,
35
+ ProjectIdsByTraceRetentionPolicyIdDataLoader,
35
36
  PromptVersionSequenceNumberDataLoader,
36
37
  RecordCountDataLoader,
37
38
  SessionIODataLoader,
@@ -47,6 +48,7 @@ from phoenix.server.api.dataloaders import (
47
48
  TableFieldsDataLoader,
48
49
  TokenCountDataLoader,
49
50
  TraceByTraceIdsDataLoader,
51
+ TraceRetentionPolicyIdByProjectIdDataLoader,
50
52
  TraceRootSpansDataLoader,
51
53
  UserRolesDataLoader,
52
54
  UsersDataLoader,
@@ -82,6 +84,7 @@ class DataLoaders:
82
84
  num_child_spans: NumChildSpansDataLoader
83
85
  num_spans_per_trace: NumSpansPerTraceDataLoader
84
86
  project_fields: TableFieldsDataLoader
87
+ projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
85
88
  prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
86
89
  record_counts: RecordCountDataLoader
87
90
  session_first_inputs: SessionIODataLoader
@@ -99,6 +102,8 @@ class DataLoaders:
99
102
  token_counts: TokenCountDataLoader
100
103
  trace_by_trace_ids: TraceByTraceIdsDataLoader
101
104
  trace_fields: TableFieldsDataLoader
105
+ trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
106
+ project_trace_retention_policy_fields: TableFieldsDataLoader
102
107
  trace_root_spans: TraceRootSpansDataLoader
103
108
  project_by_name: ProjectByNameDataLoader
104
109
  users: UsersDataLoader
@@ -20,6 +20,7 @@ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMax
20
20
  from .num_child_spans import NumChildSpansDataLoader
21
21
  from .num_spans_per_trace import NumSpansPerTraceDataLoader
22
22
  from .project_by_name import ProjectByNameDataLoader
23
+ from .project_ids_by_trace_retention_policy_id import ProjectIdsByTraceRetentionPolicyIdDataLoader
23
24
  from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
24
25
  from .record_counts import RecordCountCache, RecordCountDataLoader
25
26
  from .session_io import SessionIODataLoader
@@ -35,6 +36,7 @@ from .span_projects import SpanProjectsDataLoader
35
36
  from .table_fields import TableFieldsDataLoader
36
37
  from .token_counts import TokenCountCache, TokenCountDataLoader
37
38
  from .trace_by_trace_ids import TraceByTraceIdsDataLoader
39
+ from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
38
40
  from .trace_root_spans import TraceRootSpansDataLoader
39
41
  from .user_roles import UserRolesDataLoader
40
42
  from .users import UsersDataLoader
@@ -57,6 +59,7 @@ __all__ = [
57
59
  "MinStartOrMaxEndTimeDataLoader",
58
60
  "NumChildSpansDataLoader",
59
61
  "NumSpansPerTraceDataLoader",
62
+ "ProjectIdsByTraceRetentionPolicyIdDataLoader",
60
63
  "PromptVersionSequenceNumberDataLoader",
61
64
  "RecordCountDataLoader",
62
65
  "SessionIODataLoader",
@@ -71,6 +74,7 @@ __all__ = [
71
74
  "TableFieldsDataLoader",
72
75
  "TokenCountDataLoader",
73
76
  "TraceByTraceIdsDataLoader",
77
+ "TraceRetentionPolicyIdByProjectIdDataLoader",
74
78
  "TraceRootSpansDataLoader",
75
79
  "ProjectByNameDataLoader",
76
80
  "SpanAnnotationsDataLoader",
@@ -1,11 +1,11 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import Any, Literal, Optional
3
+ from typing import Any, Literal, Optional, Type, Union, cast
4
4
 
5
5
  import pandas as pd
6
6
  from aioitertools.itertools import groupby
7
7
  from cachetools import LFUCache, TTLCache
8
- from sqlalchemy import Select, func, or_, select
8
+ from sqlalchemy import Select, and_, case, distinct, func, or_, select
9
9
  from strawberry.dataloader import AbstractCache, DataLoader
10
10
  from typing_extensions import TypeAlias, assert_never
11
11
 
@@ -92,7 +92,7 @@ class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
92
92
  async with self._db() as session:
93
93
  data = await session.stream(stmt)
94
94
  async for annotation_name, group in groupby(data, lambda row: row.name):
95
- summary = AnnotationSummary(pd.DataFrame(group))
95
+ summary = AnnotationSummary(name=annotation_name, df=pd.DataFrame(group))
96
96
  for position in params[annotation_name]:
97
97
  results[position] = summary
98
98
  return results
@@ -103,23 +103,64 @@ def _get_stmt(
103
103
  *annotation_names: Param,
104
104
  ) -> Select[Any]:
105
105
  kind, project_rowid, (start_time, end_time), filter_condition = segment
106
- stmt = select()
106
+
107
+ annotation_model: Union[Type[models.SpanAnnotation], Type[models.TraceAnnotation]]
108
+ entity_model: Union[Type[models.Span], Type[models.Trace]]
109
+ entity_join_model: Optional[Type[models.Base]]
110
+ entity_id_column: Any
111
+
107
112
  if kind == "span":
108
- msa = models.SpanAnnotation
109
- name_column, label_column, score_column = msa.name, msa.label, msa.score
110
- time_column = models.Span.start_time
111
- stmt = stmt.join(models.Span).join_from(models.Span, models.Trace)
112
- if filter_condition:
113
- sf = SpanFilter(filter_condition)
114
- stmt = sf(stmt)
113
+ annotation_model = models.SpanAnnotation
114
+ entity_model = models.Span
115
+ entity_join_model = models.Trace
116
+ entity_id_column = models.Span.id.label("entity_id")
115
117
  elif kind == "trace":
116
- mta = models.TraceAnnotation
117
- name_column, label_column, score_column = mta.name, mta.label, mta.score
118
- time_column = models.Trace.start_time
119
- stmt = stmt.join(models.Trace)
118
+ annotation_model = models.TraceAnnotation
119
+ entity_model = models.Trace
120
+ entity_join_model = None
121
+ entity_id_column = models.Trace.id.label("entity_id")
120
122
  else:
121
123
  assert_never(kind)
122
- stmt = stmt.add_columns(
124
+
125
+ name_column = annotation_model.name
126
+ label_column = annotation_model.label
127
+ score_column = annotation_model.score
128
+ time_column = entity_model.start_time
129
+
130
+ # First query: count distinct entities per annotation name
131
+ # This is used later to calculate accurate fractions that account for entities without labels
132
+ entity_count_query = select(
133
+ name_column, func.count(distinct(entity_id_column)).label("entity_count")
134
+ )
135
+
136
+ if kind == "span":
137
+ entity_count_query = entity_count_query.join(cast(Type[models.Span], entity_model))
138
+ entity_count_query = entity_count_query.join_from(
139
+ cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
140
+ )
141
+ entity_count_query = entity_count_query.where(models.Trace.project_rowid == project_rowid)
142
+ elif kind == "trace":
143
+ entity_count_query = entity_count_query.join(cast(Type[models.Trace], entity_model))
144
+ entity_count_query = entity_count_query.where(
145
+ cast(Type[models.Trace], entity_model).project_rowid == project_rowid
146
+ )
147
+
148
+ entity_count_query = entity_count_query.where(
149
+ or_(score_column.is_not(None), label_column.is_not(None))
150
+ )
151
+ entity_count_query = entity_count_query.where(name_column.in_(annotation_names))
152
+
153
+ if start_time:
154
+ entity_count_query = entity_count_query.where(start_time <= time_column)
155
+ if end_time:
156
+ entity_count_query = entity_count_query.where(time_column < end_time)
157
+
158
+ entity_count_query = entity_count_query.group_by(name_column)
159
+ entity_count_subquery = entity_count_query.subquery()
160
+
161
+ # Main query: gets raw annotation data with counts per (span/trace)+name+label
162
+ base_stmt = select(
163
+ entity_id_column,
123
164
  name_column,
124
165
  label_column,
125
166
  func.count().label("record_count"),
@@ -127,13 +168,151 @@ def _get_stmt(
127
168
  func.count(score_column).label("score_count"),
128
169
  func.sum(score_column).label("score_sum"),
129
170
  )
130
- stmt = stmt.group_by(name_column, label_column)
131
- stmt = stmt.order_by(name_column, label_column)
132
- stmt = stmt.where(models.Trace.project_rowid == project_rowid)
133
- stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
134
- stmt = stmt.where(name_column.in_(annotation_names))
171
+
172
+ if kind == "span":
173
+ base_stmt = base_stmt.join(cast(Type[models.Span], entity_model))
174
+ base_stmt = base_stmt.join_from(
175
+ cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
176
+ )
177
+ base_stmt = base_stmt.where(models.Trace.project_rowid == project_rowid)
178
+ if filter_condition:
179
+ sf = SpanFilter(filter_condition)
180
+ base_stmt = sf(base_stmt)
181
+ elif kind == "trace":
182
+ base_stmt = base_stmt.join(cast(Type[models.Trace], entity_model))
183
+ base_stmt = base_stmt.where(
184
+ cast(Type[models.Trace], entity_model).project_rowid == project_rowid
185
+ )
186
+ else:
187
+ assert_never(kind)
188
+
189
+ base_stmt = base_stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
190
+ base_stmt = base_stmt.where(name_column.in_(annotation_names))
191
+
135
192
  if start_time:
136
- stmt = stmt.where(start_time <= time_column)
193
+ base_stmt = base_stmt.where(start_time <= time_column)
137
194
  if end_time:
138
- stmt = stmt.where(time_column < end_time)
139
- return stmt
195
+ base_stmt = base_stmt.where(time_column < end_time)
196
+
197
+ # Group to get one row per (span/trace)+name+label combination
198
+ base_stmt = base_stmt.group_by(entity_id_column, name_column, label_column)
199
+
200
+ base_subquery = base_stmt.subquery()
201
+
202
+ # Calculate total counts per (span/trace)+name for computing fractions
203
+ entity_totals = (
204
+ select(
205
+ base_subquery.c.entity_id,
206
+ base_subquery.c.name,
207
+ func.sum(base_subquery.c.label_count).label("total_label_count"),
208
+ func.sum(base_subquery.c.score_count).label("total_score_count"),
209
+ func.sum(base_subquery.c.score_sum).label("entity_score_sum"),
210
+ )
211
+ .group_by(base_subquery.c.entity_id, base_subquery.c.name)
212
+ .subquery()
213
+ )
214
+
215
+ per_entity_fractions = (
216
+ select(
217
+ base_subquery.c.entity_id,
218
+ base_subquery.c.name,
219
+ base_subquery.c.label,
220
+ base_subquery.c.record_count,
221
+ base_subquery.c.label_count,
222
+ base_subquery.c.score_count,
223
+ base_subquery.c.score_sum,
224
+ # Calculate label fraction, avoiding division by zero when total_label_count is 0
225
+ case(
226
+ (
227
+ entity_totals.c.total_label_count > 0,
228
+ base_subquery.c.label_count * 1.0 / entity_totals.c.total_label_count,
229
+ ),
230
+ else_=None,
231
+ ).label("label_fraction"),
232
+ # Calculate average score for the entity (if there are any scores)
233
+ case(
234
+ (
235
+ entity_totals.c.total_score_count > 0,
236
+ entity_totals.c.entity_score_sum * 1.0 / entity_totals.c.total_score_count,
237
+ ),
238
+ else_=None,
239
+ ).label("entity_avg_score"),
240
+ )
241
+ .join(
242
+ entity_totals,
243
+ and_(
244
+ base_subquery.c.entity_id == entity_totals.c.entity_id,
245
+ base_subquery.c.name == entity_totals.c.name,
246
+ ),
247
+ )
248
+ .subquery()
249
+ )
250
+
251
+ # Aggregate metrics across (spans/traces) for each name+label combination.
252
+ label_entity_metrics = (
253
+ select(
254
+ per_entity_fractions.c.name,
255
+ per_entity_fractions.c.label,
256
+ func.count(distinct(per_entity_fractions.c.entity_id)).label("entities_with_label"),
257
+ func.sum(per_entity_fractions.c.label_count).label("total_label_count"),
258
+ func.sum(per_entity_fractions.c.score_count).label("total_score_count"),
259
+ func.sum(per_entity_fractions.c.score_sum).label("total_score_sum"),
260
+ # Average of label fractions for entities that have this label
261
+ func.avg(per_entity_fractions.c.label_fraction).label("avg_label_fraction_present"),
262
+ # Average of per-entity average scores (but we handle overall aggregation separately)
263
+ )
264
+ .group_by(per_entity_fractions.c.name, per_entity_fractions.c.label)
265
+ .subquery()
266
+ )
267
+
268
+ # Compute distinct per-entity average scores to ensure each entity counts only once.
269
+ distinct_entity_scores = (
270
+ select(
271
+ per_entity_fractions.c.entity_id,
272
+ per_entity_fractions.c.name,
273
+ per_entity_fractions.c.entity_avg_score,
274
+ )
275
+ .distinct()
276
+ .subquery()
277
+ )
278
+
279
+ overall_score_aggregates = (
280
+ select(
281
+ distinct_entity_scores.c.name,
282
+ func.avg(distinct_entity_scores.c.entity_avg_score).label("overall_avg_score"),
283
+ )
284
+ .group_by(distinct_entity_scores.c.name)
285
+ .subquery()
286
+ )
287
+
288
+ # Final result: adjust label fractions by the proportion of entities reporting this label
289
+ # and include the overall average score per annotation name.
290
+ final_stmt = (
291
+ select(
292
+ label_entity_metrics.c.name,
293
+ label_entity_metrics.c.label,
294
+ # Adjust label fraction, guarding against division by zero in entity_count
295
+ case(
296
+ (
297
+ entity_count_subquery.c.entity_count > 0,
298
+ label_entity_metrics.c.avg_label_fraction_present
299
+ * label_entity_metrics.c.entities_with_label
300
+ / entity_count_subquery.c.entity_count,
301
+ ),
302
+ else_=None,
303
+ ).label("avg_label_fraction"),
304
+ overall_score_aggregates.c.overall_avg_score.label("avg_score"), # same for all labels
305
+ label_entity_metrics.c.total_label_count.label("label_count"),
306
+ label_entity_metrics.c.total_score_count.label("score_count"),
307
+ label_entity_metrics.c.total_score_sum.label("score_sum"),
308
+ label_entity_metrics.c.entities_with_label.label("record_count"),
309
+ )
310
+ .join(entity_count_subquery, label_entity_metrics.c.name == entity_count_subquery.c.name)
311
+ .join(
312
+ overall_score_aggregates,
313
+ label_entity_metrics.c.name == overall_score_aggregates.c.name,
314
+ )
315
+ .order_by(label_entity_metrics.c.name, label_entity_metrics.c.label)
316
+ )
317
+
318
+ return final_stmt
@@ -0,0 +1,42 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import or_, select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
8
+ from phoenix.db.models import Project
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ PolicyRowId: TypeAlias = int
12
+ ProjectRowId: TypeAlias = int
13
+
14
+ Key: TypeAlias = PolicyRowId
15
+ Result: TypeAlias = list[ProjectRowId]
16
+
17
+
18
+ class ProjectIdsByTraceRetentionPolicyIdDataLoader(DataLoader[Key, Result]):
19
+ def __init__(self, db: DbSessionFactory) -> None:
20
+ super().__init__(load_fn=self._load_fn)
21
+ self._db = db
22
+
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
+ ids = set(keys)
25
+ stmt = select(Project.trace_retention_policy_id, Project.id)
26
+ if DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID in ids:
27
+ stmt = stmt.where(
28
+ or_(
29
+ Project.trace_retention_policy_id.in_(ids),
30
+ Project.trace_retention_policy_id.is_(None),
31
+ )
32
+ )
33
+ else:
34
+ stmt = stmt.where(Project.trace_retention_policy_id.in_(ids))
35
+ projects: defaultdict[Key, Result] = defaultdict(list)
36
+ async with self._db() as session:
37
+ data = await session.stream(stmt)
38
+ async for policy_rowid, project_rowid in data:
39
+ projects[policy_rowid or DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID].append(
40
+ project_rowid
41
+ )
42
+ return [projects.get(project_name, []).copy() for project_name in keys]
@@ -0,0 +1,34 @@
1
+ from sqlalchemy import select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
6
+ from phoenix.db.models import Project
7
+ from phoenix.server.types import DbSessionFactory
8
+
9
+ PolicyRowId: TypeAlias = int
10
+ ProjectRowId: TypeAlias = int
11
+
12
+ Key: TypeAlias = ProjectRowId
13
+ Result: TypeAlias = PolicyRowId
14
+
15
+
16
+ class TraceRetentionPolicyIdByProjectIdDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ ids = set(keys)
23
+ stmt = (
24
+ select(Project.id, Project.trace_retention_policy_id)
25
+ .where(Project.trace_retention_policy_id.isnot(None))
26
+ .where(Project.id.in_(ids))
27
+ )
28
+ async with self._db() as session:
29
+ data = await session.execute(stmt)
30
+ result = {project_rowid: policy_id for project_rowid, policy_id in data.all()}
31
+ return [
32
+ result.get(project_rowid, DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID)
33
+ for project_rowid in keys
34
+ ]
@@ -0,0 +1,9 @@
1
+ from strawberry.relay import GlobalID
2
+
3
+
4
+ def get_user_identifier(user_id: int) -> str:
5
+ """
6
+ Generates an annotation identifier unique to the user.
7
+ """
8
+ user_gid = str(GlobalID(type_name="User", node_id=str(user_id)))
9
+ return f"px-app:{user_gid}"