arize-phoenix 8.32.1__py3-none-any.whl → 9.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (79) hide show
  1. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/METADATA +5 -5
  2. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/RECORD +76 -56
  3. phoenix/db/constants.py +1 -0
  4. phoenix/db/facilitator.py +55 -0
  5. phoenix/db/insertion/document_annotation.py +31 -13
  6. phoenix/db/insertion/evaluation.py +15 -3
  7. phoenix/db/insertion/helpers.py +2 -1
  8. phoenix/db/insertion/span_annotation.py +26 -9
  9. phoenix/db/insertion/trace_annotation.py +25 -9
  10. phoenix/db/insertion/types.py +7 -0
  11. phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
  12. phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
  13. phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
  14. phoenix/db/models.py +151 -10
  15. phoenix/db/types/annotation_configs.py +97 -0
  16. phoenix/db/types/db_models.py +41 -0
  17. phoenix/db/types/trace_retention.py +267 -0
  18. phoenix/experiments/functions.py +5 -1
  19. phoenix/server/api/auth.py +9 -0
  20. phoenix/server/api/context.py +5 -0
  21. phoenix/server/api/dataloaders/__init__.py +4 -0
  22. phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
  23. phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
  24. phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
  25. phoenix/server/api/helpers/annotations.py +9 -0
  26. phoenix/server/api/helpers/prompts/models.py +34 -67
  27. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
  28. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
  29. phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
  30. phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
  31. phoenix/server/api/mutations/__init__.py +6 -0
  32. phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
  33. phoenix/server/api/mutations/dataset_mutations.py +62 -39
  34. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
  35. phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
  36. phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
  37. phoenix/server/api/queries.py +86 -0
  38. phoenix/server/api/routers/v1/__init__.py +4 -0
  39. phoenix/server/api/routers/v1/annotation_configs.py +449 -0
  40. phoenix/server/api/routers/v1/annotations.py +161 -0
  41. phoenix/server/api/routers/v1/evaluations.py +6 -0
  42. phoenix/server/api/routers/v1/projects.py +1 -50
  43. phoenix/server/api/routers/v1/spans.py +35 -8
  44. phoenix/server/api/routers/v1/traces.py +22 -13
  45. phoenix/server/api/routers/v1/utils.py +60 -0
  46. phoenix/server/api/types/Annotation.py +7 -0
  47. phoenix/server/api/types/AnnotationConfig.py +124 -0
  48. phoenix/server/api/types/AnnotationSource.py +9 -0
  49. phoenix/server/api/types/AnnotationSummary.py +28 -14
  50. phoenix/server/api/types/AnnotatorKind.py +1 -0
  51. phoenix/server/api/types/CronExpression.py +15 -0
  52. phoenix/server/api/types/Evaluation.py +4 -30
  53. phoenix/server/api/types/Project.py +50 -2
  54. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
  55. phoenix/server/api/types/Span.py +78 -0
  56. phoenix/server/api/types/SpanAnnotation.py +24 -0
  57. phoenix/server/api/types/Trace.py +2 -2
  58. phoenix/server/api/types/TraceAnnotation.py +23 -0
  59. phoenix/server/app.py +20 -0
  60. phoenix/server/retention.py +76 -0
  61. phoenix/server/static/.vite/manifest.json +36 -36
  62. phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
  63. phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
  64. phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
  65. phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
  66. phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
  67. phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
  68. phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
  69. phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
  70. phoenix/trace/dsl/filter.py +25 -5
  71. phoenix/utilities/__init__.py +18 -0
  72. phoenix/version.py +1 -1
  73. phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
  74. phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
  75. phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
  76. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,267 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timedelta, timezone
4
+ from typing import Annotated, Iterable, Literal, Optional, Union
5
+
6
+ import sqlalchemy as sa
7
+ from pydantic import AfterValidator, BaseModel, Field, RootModel
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+
10
+ from phoenix.utilities import hour_of_week
11
+
12
+
13
+ class _MaxDays(BaseModel):
14
+ max_days: Annotated[float, Field(ge=0)]
15
+
16
+ @property
17
+ def max_days_filter(self) -> sa.ColumnElement[bool]:
18
+ if self.max_days <= 0:
19
+ return sa.literal(False)
20
+ from phoenix.db.models import Trace
21
+
22
+ return Trace.start_time < datetime.now(timezone.utc) - timedelta(days=self.max_days)
23
+
24
+
25
+ class _MaxCount(BaseModel):
26
+ max_count: Annotated[int, Field(ge=0)]
27
+
28
+ @property
29
+ def max_count_filter(self) -> sa.ColumnElement[bool]:
30
+ if self.max_count <= 0:
31
+ return sa.literal(False)
32
+ from phoenix.db.models import Trace
33
+
34
+ return Trace.start_time < (
35
+ sa.select(Trace.start_time)
36
+ .order_by(Trace.start_time.desc())
37
+ .offset(self.max_count - 1)
38
+ .limit(1)
39
+ .scalar_subquery()
40
+ )
41
+
42
+
43
+ class MaxDaysRule(_MaxDays, BaseModel):
44
+ type: Literal["max_days"] = "max_days"
45
+
46
+ def __bool__(self) -> bool:
47
+ return self.max_days > 0
48
+
49
+ async def delete_traces(
50
+ self,
51
+ session: AsyncSession,
52
+ project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
53
+ ) -> set[int]:
54
+ if self.max_days <= 0:
55
+ return set()
56
+ from phoenix.db.models import Trace
57
+
58
+ stmt = (
59
+ sa.delete(Trace)
60
+ .where(Trace.project_rowid.in_(project_rowids))
61
+ .where(self.max_days_filter)
62
+ .returning(Trace.project_rowid)
63
+ )
64
+ return set(await session.scalars(stmt))
65
+
66
+
67
+ class MaxCountRule(_MaxCount, BaseModel):
68
+ type: Literal["max_count"] = "max_count"
69
+
70
+ def __bool__(self) -> bool:
71
+ return self.max_count > 0
72
+
73
+ async def delete_traces(
74
+ self,
75
+ session: AsyncSession,
76
+ project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
77
+ ) -> set[int]:
78
+ if self.max_count <= 0:
79
+ return set()
80
+ from phoenix.db.models import Trace
81
+
82
+ stmt = (
83
+ sa.delete(Trace)
84
+ .where(Trace.project_rowid.in_(project_rowids))
85
+ .where(self.max_count_filter)
86
+ .returning(Trace.project_rowid)
87
+ )
88
+ return set(await session.scalars(stmt))
89
+
90
+
91
+ class MaxDaysOrCountRule(_MaxDays, _MaxCount, BaseModel):
92
+ type: Literal["max_days_or_count"] = "max_days_or_count"
93
+
94
+ def __bool__(self) -> bool:
95
+ return self.max_days > 0 or self.max_count > 0
96
+
97
+ async def delete_traces(
98
+ self,
99
+ session: AsyncSession,
100
+ project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
101
+ ) -> set[int]:
102
+ if self.max_days <= 0 and self.max_count <= 0:
103
+ return set()
104
+ from phoenix.db.models import Trace
105
+
106
+ stmt = (
107
+ sa.delete(Trace)
108
+ .where(Trace.project_rowid.in_(project_rowids))
109
+ .where(sa.or_(self.max_days_filter, self.max_count_filter))
110
+ .returning(Trace.project_rowid)
111
+ )
112
+ return set(await session.scalars(stmt))
113
+
114
+
115
+ class TraceRetentionRule(RootModel[Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule]]):
116
+ root: Annotated[
117
+ Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule], Field(discriminator="type")
118
+ ]
119
+
120
+ def __bool__(self) -> bool:
121
+ return bool(self.root)
122
+
123
+ async def delete_traces(
124
+ self,
125
+ session: AsyncSession,
126
+ project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
127
+ ) -> set[int]:
128
+ return await self.root.delete_traces(session, project_rowids)
129
+
130
+
131
+ def _time_of_next_run(
132
+ cron_expression: str,
133
+ after: Optional[datetime] = None,
134
+ ) -> datetime:
135
+ """
136
+ Parse a cron expression and calculate the UTC datetime of the next run.
137
+ Only processes hour, and day of week fields; day-of-month and
138
+ month fields must be '*'; minute field must be 0.
139
+
140
+ Args:
141
+ cron_expression (str): Standard cron expression with 5 fields:
142
+ minute hour day-of-month month day-of-week
143
+ (minute must be '0'; day-of-month and month must be '*')
144
+ after: Optional[datetime]: The datetime to start searching from. If None,
145
+ the current time is used. Must be timezone-aware.
146
+
147
+ Returns:
148
+ datetime: The datetime of the next run. Timezone is UTC.
149
+
150
+ Raises:
151
+ ValueError: If the expression has non-wildcard values for day-of-month or month, if the
152
+ minute field is not '0', or if no match is found within the next 7 days (168 hours).
153
+ """
154
+ fields: list[str] = cron_expression.strip().split()
155
+ if len(fields) != 5:
156
+ raise ValueError(
157
+ "Invalid cron expression. Expected 5 fields "
158
+ "(minute hour day-of-month month day-of-week)."
159
+ )
160
+ if fields[0] != "0":
161
+ raise ValueError("Invalid cron expression. Minute field must be '0'.")
162
+ if fields[2] != "*" or fields[3] != "*":
163
+ raise ValueError("Invalid cron expression. Day-of-month and month fields must be '*'.")
164
+ hours: set[int] = _parse_field(fields[1], 0, 23)
165
+ # Parse days of week (0-6, where 0 is Sunday)
166
+ days_of_week: set[int] = _parse_field(fields[4], 0, 6)
167
+ # Convert to Python's weekday format (0-6, where 0 is Monday)
168
+ # Sunday (0 in cron) becomes 6 in Python's weekday()
169
+ python_days_of_week = {(day_of_week + 6) % 7 for day_of_week in days_of_week}
170
+ t = after.replace(tzinfo=timezone.utc) if after else datetime.now(timezone.utc)
171
+ t = t.replace(minute=0, second=0, microsecond=0)
172
+ for _ in range(168): # Check up to 7 days (168 hours)
173
+ t += timedelta(hours=1)
174
+ if t.hour in hours and t.weekday() in python_days_of_week:
175
+ return t
176
+ raise ValueError("No matching execution time found within the next 7 days.")
177
+
178
+
179
+ class TraceRetentionCronExpression(RootModel[str]):
180
+ root: Annotated[str, AfterValidator(lambda x: (_time_of_next_run(x), x)[1])]
181
+
182
+ def get_hour_of_prev_run(self) -> int:
183
+ """
184
+ Calculate the hour of the previous run before now.
185
+
186
+ Returns:
187
+ int: The hour of the previous run (0-167), where 0 is midnight Sunday UTC.
188
+ """
189
+ after = datetime.now(timezone.utc) - timedelta(hours=1)
190
+ return hour_of_week(_time_of_next_run(self.root, after))
191
+
192
+
193
+ def _parse_field(field: str, min_val: int, max_val: int) -> set[int]:
194
+ """
195
+ Parse a cron field and return the set of matching values.
196
+
197
+ Args:
198
+ field (str): The cron field to parse
199
+ min_val (int): Minimum allowed value for this field
200
+ max_val (int): Maximum allowed value for this field
201
+
202
+ Returns:
203
+ set[int]: Set of all valid values represented by the field expression
204
+
205
+ Raises:
206
+ ValueError: If the field contains invalid values or formats
207
+ """
208
+ if field == "*":
209
+ return set(range(min_val, max_val + 1))
210
+ values: set[int] = set()
211
+ for part in field.split(","):
212
+ if "/" in part:
213
+ # Handle steps
214
+ range_part, step_str = part.split("/")
215
+ try:
216
+ step = int(step_str)
217
+ except ValueError:
218
+ raise ValueError(f"Invalid step value: {step_str}")
219
+ if step <= 0:
220
+ raise ValueError(f"Step value must be positive: {step}")
221
+ if range_part == "*":
222
+ start, end = min_val, max_val
223
+ elif "-" in range_part:
224
+ try:
225
+ start_str, end_str = range_part.split("-")
226
+ start, end = int(start_str), int(end_str)
227
+ except ValueError:
228
+ raise ValueError(f"Invalid range format: {range_part}")
229
+ if start < min_val or end > max_val:
230
+ raise ValueError(
231
+ f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
232
+ )
233
+ if start > end:
234
+ raise ValueError(f"Invalid range: {start}-{end} (start > end)")
235
+ else:
236
+ try:
237
+ start = int(range_part)
238
+ except ValueError:
239
+ raise ValueError(f"Invalid value: {range_part}")
240
+ if start < min_val or start > max_val:
241
+ raise ValueError(f"Value {start} out of range ({min_val}-{max_val})")
242
+ end = max_val
243
+ values.update(range(start, end + 1, step))
244
+ elif "-" in part:
245
+ # Handle ranges
246
+ try:
247
+ start_str, end_str = part.split("-")
248
+ start, end = int(start_str), int(end_str)
249
+ except ValueError:
250
+ raise ValueError(f"Invalid range format: {part}")
251
+ if start < min_val or end > max_val:
252
+ raise ValueError(
253
+ f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
254
+ )
255
+ if start > end:
256
+ raise ValueError(f"Invalid range: {start}-{end} (start > end)")
257
+ values.update(range(start, end + 1))
258
+ else:
259
+ # Handle single values
260
+ try:
261
+ value = int(part)
262
+ except ValueError:
263
+ raise ValueError(f"Invalid value: {part}")
264
+ if value < min_val or value > max_val:
265
+ raise ValueError(f"Value {value} out of range ({min_val}-{max_val})")
266
+ values.add(value)
267
+ return values
@@ -95,6 +95,7 @@ def run_experiment(
95
95
  dry_run: Union[bool, int] = False,
96
96
  print_summary: bool = True,
97
97
  concurrency: int = 3,
98
+ timeout: Optional[int] = None,
98
99
  ) -> RanExperiment:
99
100
  """
100
101
  Runs an experiment using a given set of dataset of examples.
@@ -148,6 +149,8 @@ def run_experiment(
148
149
  concurrency (int): Specifies the concurrency for task execution. In order to enable
149
150
  concurrent task execution, the task callable must be a coroutine function.
150
151
  Defaults to 3.
152
+ timeout (Optional[int]): The timeout for the task execution in seconds. Use this to run
153
+ longer tasks to avoid re-queuing the same task multiple times. Defaults to None.
151
154
 
152
155
  Returns:
153
156
  RanExperiment: The results of the experiment and evaluation. Additional evaluations can be
@@ -380,6 +383,7 @@ def run_experiment(
380
383
  fallback_return_value=None,
381
384
  tqdm_bar_format=get_tqdm_progress_bar_formatter("running tasks"),
382
385
  concurrency=concurrency,
386
+ timeout=timeout,
383
387
  )
384
388
 
385
389
  test_cases = [
@@ -752,7 +756,7 @@ def _print_experiment_error(
752
756
  Prints an experiment error.
753
757
  """
754
758
  display_error = RuntimeError(
755
- f"{kind} failed for example id {repr(example_id)}, " f"repetition {repr(repetition_number)}"
759
+ f"{kind} failed for example id {repr(example_id)}, repetition {repr(repetition_number)}"
756
760
  )
757
761
  display_error.__cause__ = error
758
762
  formatted_exception = "".join(
@@ -42,3 +42,12 @@ class IsAdmin(Authorization):
42
42
  if not info.context.auth_enabled:
43
43
  return False
44
44
  return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
45
+
46
+
47
+ class IsAdminIfAuthEnabled(Authorization):
48
+ message = MSG_ADMIN_ONLY
49
+
50
+ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
51
+ if not info.context.auth_enabled:
52
+ return True
53
+ return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
@@ -32,6 +32,7 @@ from phoenix.server.api.dataloaders import (
32
32
  NumChildSpansDataLoader,
33
33
  NumSpansPerTraceDataLoader,
34
34
  ProjectByNameDataLoader,
35
+ ProjectIdsByTraceRetentionPolicyIdDataLoader,
35
36
  PromptVersionSequenceNumberDataLoader,
36
37
  RecordCountDataLoader,
37
38
  SessionIODataLoader,
@@ -47,6 +48,7 @@ from phoenix.server.api.dataloaders import (
47
48
  TableFieldsDataLoader,
48
49
  TokenCountDataLoader,
49
50
  TraceByTraceIdsDataLoader,
51
+ TraceRetentionPolicyIdByProjectIdDataLoader,
50
52
  TraceRootSpansDataLoader,
51
53
  UserRolesDataLoader,
52
54
  UsersDataLoader,
@@ -82,6 +84,7 @@ class DataLoaders:
82
84
  num_child_spans: NumChildSpansDataLoader
83
85
  num_spans_per_trace: NumSpansPerTraceDataLoader
84
86
  project_fields: TableFieldsDataLoader
87
+ projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
85
88
  prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
86
89
  record_counts: RecordCountDataLoader
87
90
  session_first_inputs: SessionIODataLoader
@@ -99,6 +102,8 @@ class DataLoaders:
99
102
  token_counts: TokenCountDataLoader
100
103
  trace_by_trace_ids: TraceByTraceIdsDataLoader
101
104
  trace_fields: TableFieldsDataLoader
105
+ trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
106
+ project_trace_retention_policy_fields: TableFieldsDataLoader
102
107
  trace_root_spans: TraceRootSpansDataLoader
103
108
  project_by_name: ProjectByNameDataLoader
104
109
  users: UsersDataLoader
@@ -20,6 +20,7 @@ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMax
20
20
  from .num_child_spans import NumChildSpansDataLoader
21
21
  from .num_spans_per_trace import NumSpansPerTraceDataLoader
22
22
  from .project_by_name import ProjectByNameDataLoader
23
+ from .project_ids_by_trace_retention_policy_id import ProjectIdsByTraceRetentionPolicyIdDataLoader
23
24
  from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
24
25
  from .record_counts import RecordCountCache, RecordCountDataLoader
25
26
  from .session_io import SessionIODataLoader
@@ -35,6 +36,7 @@ from .span_projects import SpanProjectsDataLoader
35
36
  from .table_fields import TableFieldsDataLoader
36
37
  from .token_counts import TokenCountCache, TokenCountDataLoader
37
38
  from .trace_by_trace_ids import TraceByTraceIdsDataLoader
39
+ from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
38
40
  from .trace_root_spans import TraceRootSpansDataLoader
39
41
  from .user_roles import UserRolesDataLoader
40
42
  from .users import UsersDataLoader
@@ -57,6 +59,7 @@ __all__ = [
57
59
  "MinStartOrMaxEndTimeDataLoader",
58
60
  "NumChildSpansDataLoader",
59
61
  "NumSpansPerTraceDataLoader",
62
+ "ProjectIdsByTraceRetentionPolicyIdDataLoader",
60
63
  "PromptVersionSequenceNumberDataLoader",
61
64
  "RecordCountDataLoader",
62
65
  "SessionIODataLoader",
@@ -71,6 +74,7 @@ __all__ = [
71
74
  "TableFieldsDataLoader",
72
75
  "TokenCountDataLoader",
73
76
  "TraceByTraceIdsDataLoader",
77
+ "TraceRetentionPolicyIdByProjectIdDataLoader",
74
78
  "TraceRootSpansDataLoader",
75
79
  "ProjectByNameDataLoader",
76
80
  "SpanAnnotationsDataLoader",
@@ -1,11 +1,11 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import Any, Literal, Optional
3
+ from typing import Any, Literal, Optional, Type, Union, cast
4
4
 
5
5
  import pandas as pd
6
6
  from aioitertools.itertools import groupby
7
7
  from cachetools import LFUCache, TTLCache
8
- from sqlalchemy import Select, func, or_, select
8
+ from sqlalchemy import Select, and_, case, distinct, func, or_, select
9
9
  from strawberry.dataloader import AbstractCache, DataLoader
10
10
  from typing_extensions import TypeAlias, assert_never
11
11
 
@@ -92,7 +92,7 @@ class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
92
92
  async with self._db() as session:
93
93
  data = await session.stream(stmt)
94
94
  async for annotation_name, group in groupby(data, lambda row: row.name):
95
- summary = AnnotationSummary(pd.DataFrame(group))
95
+ summary = AnnotationSummary(name=annotation_name, df=pd.DataFrame(group))
96
96
  for position in params[annotation_name]:
97
97
  results[position] = summary
98
98
  return results
@@ -103,23 +103,64 @@ def _get_stmt(
103
103
  *annotation_names: Param,
104
104
  ) -> Select[Any]:
105
105
  kind, project_rowid, (start_time, end_time), filter_condition = segment
106
- stmt = select()
106
+
107
+ annotation_model: Union[Type[models.SpanAnnotation], Type[models.TraceAnnotation]]
108
+ entity_model: Union[Type[models.Span], Type[models.Trace]]
109
+ entity_join_model: Optional[Type[models.Base]]
110
+ entity_id_column: Any
111
+
107
112
  if kind == "span":
108
- msa = models.SpanAnnotation
109
- name_column, label_column, score_column = msa.name, msa.label, msa.score
110
- time_column = models.Span.start_time
111
- stmt = stmt.join(models.Span).join_from(models.Span, models.Trace)
112
- if filter_condition:
113
- sf = SpanFilter(filter_condition)
114
- stmt = sf(stmt)
113
+ annotation_model = models.SpanAnnotation
114
+ entity_model = models.Span
115
+ entity_join_model = models.Trace
116
+ entity_id_column = models.Span.id.label("entity_id")
115
117
  elif kind == "trace":
116
- mta = models.TraceAnnotation
117
- name_column, label_column, score_column = mta.name, mta.label, mta.score
118
- time_column = models.Trace.start_time
119
- stmt = stmt.join(models.Trace)
118
+ annotation_model = models.TraceAnnotation
119
+ entity_model = models.Trace
120
+ entity_join_model = None
121
+ entity_id_column = models.Trace.id.label("entity_id")
120
122
  else:
121
123
  assert_never(kind)
122
- stmt = stmt.add_columns(
124
+
125
+ name_column = annotation_model.name
126
+ label_column = annotation_model.label
127
+ score_column = annotation_model.score
128
+ time_column = entity_model.start_time
129
+
130
+ # First query: count distinct entities per annotation name
131
+ # This is used later to calculate accurate fractions that account for entities without labels
132
+ entity_count_query = select(
133
+ name_column, func.count(distinct(entity_id_column)).label("entity_count")
134
+ )
135
+
136
+ if kind == "span":
137
+ entity_count_query = entity_count_query.join(cast(Type[models.Span], entity_model))
138
+ entity_count_query = entity_count_query.join_from(
139
+ cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
140
+ )
141
+ entity_count_query = entity_count_query.where(models.Trace.project_rowid == project_rowid)
142
+ elif kind == "trace":
143
+ entity_count_query = entity_count_query.join(cast(Type[models.Trace], entity_model))
144
+ entity_count_query = entity_count_query.where(
145
+ cast(Type[models.Trace], entity_model).project_rowid == project_rowid
146
+ )
147
+
148
+ entity_count_query = entity_count_query.where(
149
+ or_(score_column.is_not(None), label_column.is_not(None))
150
+ )
151
+ entity_count_query = entity_count_query.where(name_column.in_(annotation_names))
152
+
153
+ if start_time:
154
+ entity_count_query = entity_count_query.where(start_time <= time_column)
155
+ if end_time:
156
+ entity_count_query = entity_count_query.where(time_column < end_time)
157
+
158
+ entity_count_query = entity_count_query.group_by(name_column)
159
+ entity_count_subquery = entity_count_query.subquery()
160
+
161
+ # Main query: gets raw annotation data with counts per (span/trace)+name+label
162
+ base_stmt = select(
163
+ entity_id_column,
123
164
  name_column,
124
165
  label_column,
125
166
  func.count().label("record_count"),
@@ -127,13 +168,151 @@ def _get_stmt(
127
168
  func.count(score_column).label("score_count"),
128
169
  func.sum(score_column).label("score_sum"),
129
170
  )
130
- stmt = stmt.group_by(name_column, label_column)
131
- stmt = stmt.order_by(name_column, label_column)
132
- stmt = stmt.where(models.Trace.project_rowid == project_rowid)
133
- stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
134
- stmt = stmt.where(name_column.in_(annotation_names))
171
+
172
+ if kind == "span":
173
+ base_stmt = base_stmt.join(cast(Type[models.Span], entity_model))
174
+ base_stmt = base_stmt.join_from(
175
+ cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
176
+ )
177
+ base_stmt = base_stmt.where(models.Trace.project_rowid == project_rowid)
178
+ if filter_condition:
179
+ sf = SpanFilter(filter_condition)
180
+ base_stmt = sf(base_stmt)
181
+ elif kind == "trace":
182
+ base_stmt = base_stmt.join(cast(Type[models.Trace], entity_model))
183
+ base_stmt = base_stmt.where(
184
+ cast(Type[models.Trace], entity_model).project_rowid == project_rowid
185
+ )
186
+ else:
187
+ assert_never(kind)
188
+
189
+ base_stmt = base_stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
190
+ base_stmt = base_stmt.where(name_column.in_(annotation_names))
191
+
135
192
  if start_time:
136
- stmt = stmt.where(start_time <= time_column)
193
+ base_stmt = base_stmt.where(start_time <= time_column)
137
194
  if end_time:
138
- stmt = stmt.where(time_column < end_time)
139
- return stmt
195
+ base_stmt = base_stmt.where(time_column < end_time)
196
+
197
+ # Group to get one row per (span/trace)+name+label combination
198
+ base_stmt = base_stmt.group_by(entity_id_column, name_column, label_column)
199
+
200
+ base_subquery = base_stmt.subquery()
201
+
202
+ # Calculate total counts per (span/trace)+name for computing fractions
203
+ entity_totals = (
204
+ select(
205
+ base_subquery.c.entity_id,
206
+ base_subquery.c.name,
207
+ func.sum(base_subquery.c.label_count).label("total_label_count"),
208
+ func.sum(base_subquery.c.score_count).label("total_score_count"),
209
+ func.sum(base_subquery.c.score_sum).label("entity_score_sum"),
210
+ )
211
+ .group_by(base_subquery.c.entity_id, base_subquery.c.name)
212
+ .subquery()
213
+ )
214
+
215
+ per_entity_fractions = (
216
+ select(
217
+ base_subquery.c.entity_id,
218
+ base_subquery.c.name,
219
+ base_subquery.c.label,
220
+ base_subquery.c.record_count,
221
+ base_subquery.c.label_count,
222
+ base_subquery.c.score_count,
223
+ base_subquery.c.score_sum,
224
+ # Calculate label fraction, avoiding division by zero when total_label_count is 0
225
+ case(
226
+ (
227
+ entity_totals.c.total_label_count > 0,
228
+ base_subquery.c.label_count * 1.0 / entity_totals.c.total_label_count,
229
+ ),
230
+ else_=None,
231
+ ).label("label_fraction"),
232
+ # Calculate average score for the entity (if there are any scores)
233
+ case(
234
+ (
235
+ entity_totals.c.total_score_count > 0,
236
+ entity_totals.c.entity_score_sum * 1.0 / entity_totals.c.total_score_count,
237
+ ),
238
+ else_=None,
239
+ ).label("entity_avg_score"),
240
+ )
241
+ .join(
242
+ entity_totals,
243
+ and_(
244
+ base_subquery.c.entity_id == entity_totals.c.entity_id,
245
+ base_subquery.c.name == entity_totals.c.name,
246
+ ),
247
+ )
248
+ .subquery()
249
+ )
250
+
251
+ # Aggregate metrics across (spans/traces) for each name+label combination.
252
+ label_entity_metrics = (
253
+ select(
254
+ per_entity_fractions.c.name,
255
+ per_entity_fractions.c.label,
256
+ func.count(distinct(per_entity_fractions.c.entity_id)).label("entities_with_label"),
257
+ func.sum(per_entity_fractions.c.label_count).label("total_label_count"),
258
+ func.sum(per_entity_fractions.c.score_count).label("total_score_count"),
259
+ func.sum(per_entity_fractions.c.score_sum).label("total_score_sum"),
260
+ # Average of label fractions for entities that have this label
261
+ func.avg(per_entity_fractions.c.label_fraction).label("avg_label_fraction_present"),
262
+ # Average of per-entity average scores (but we handle overall aggregation separately)
263
+ )
264
+ .group_by(per_entity_fractions.c.name, per_entity_fractions.c.label)
265
+ .subquery()
266
+ )
267
+
268
+ # Compute distinct per-entity average scores to ensure each entity counts only once.
269
+ distinct_entity_scores = (
270
+ select(
271
+ per_entity_fractions.c.entity_id,
272
+ per_entity_fractions.c.name,
273
+ per_entity_fractions.c.entity_avg_score,
274
+ )
275
+ .distinct()
276
+ .subquery()
277
+ )
278
+
279
+ overall_score_aggregates = (
280
+ select(
281
+ distinct_entity_scores.c.name,
282
+ func.avg(distinct_entity_scores.c.entity_avg_score).label("overall_avg_score"),
283
+ )
284
+ .group_by(distinct_entity_scores.c.name)
285
+ .subquery()
286
+ )
287
+
288
+ # Final result: adjust label fractions by the proportion of entities reporting this label
289
+ # and include the overall average score per annotation name.
290
+ final_stmt = (
291
+ select(
292
+ label_entity_metrics.c.name,
293
+ label_entity_metrics.c.label,
294
+ # Adjust label fraction, guarding against division by zero in entity_count
295
+ case(
296
+ (
297
+ entity_count_subquery.c.entity_count > 0,
298
+ label_entity_metrics.c.avg_label_fraction_present
299
+ * label_entity_metrics.c.entities_with_label
300
+ / entity_count_subquery.c.entity_count,
301
+ ),
302
+ else_=None,
303
+ ).label("avg_label_fraction"),
304
+ overall_score_aggregates.c.overall_avg_score.label("avg_score"), # same for all labels
305
+ label_entity_metrics.c.total_label_count.label("label_count"),
306
+ label_entity_metrics.c.total_score_count.label("score_count"),
307
+ label_entity_metrics.c.total_score_sum.label("score_sum"),
308
+ label_entity_metrics.c.entities_with_label.label("record_count"),
309
+ )
310
+ .join(entity_count_subquery, label_entity_metrics.c.name == entity_count_subquery.c.name)
311
+ .join(
312
+ overall_score_aggregates,
313
+ label_entity_metrics.c.name == overall_score_aggregates.c.name,
314
+ )
315
+ .order_by(label_entity_metrics.c.name, label_entity_metrics.c.label)
316
+ )
317
+
318
+ return final_stmt
@@ -0,0 +1,42 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import or_, select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
8
+ from phoenix.db.models import Project
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ PolicyRowId: TypeAlias = int
12
+ ProjectRowId: TypeAlias = int
13
+
14
+ Key: TypeAlias = PolicyRowId
15
+ Result: TypeAlias = list[ProjectRowId]
16
+
17
+
18
+ class ProjectIdsByTraceRetentionPolicyIdDataLoader(DataLoader[Key, Result]):
19
+ def __init__(self, db: DbSessionFactory) -> None:
20
+ super().__init__(load_fn=self._load_fn)
21
+ self._db = db
22
+
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
+ ids = set(keys)
25
+ stmt = select(Project.trace_retention_policy_id, Project.id)
26
+ if DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID in ids:
27
+ stmt = stmt.where(
28
+ or_(
29
+ Project.trace_retention_policy_id.in_(ids),
30
+ Project.trace_retention_policy_id.is_(None),
31
+ )
32
+ )
33
+ else:
34
+ stmt = stmt.where(Project.trace_retention_policy_id.in_(ids))
35
+ projects: defaultdict[Key, Result] = defaultdict(list)
36
+ async with self._db() as session:
37
+ data = await session.stream(stmt)
38
+ async for policy_rowid, project_rowid in data:
39
+ projects[policy_rowid or DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID].append(
40
+ project_rowid
41
+ )
42
+ return [projects.get(project_name, []).copy() for project_name in keys]