arize-phoenix 8.32.1__py3-none-any.whl → 9.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/METADATA +5 -5
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/RECORD +76 -56
- phoenix/db/constants.py +1 -0
- phoenix/db/facilitator.py +55 -0
- phoenix/db/insertion/document_annotation.py +31 -13
- phoenix/db/insertion/evaluation.py +15 -3
- phoenix/db/insertion/helpers.py +2 -1
- phoenix/db/insertion/span_annotation.py +26 -9
- phoenix/db/insertion/trace_annotation.py +25 -9
- phoenix/db/insertion/types.py +7 -0
- phoenix/db/migrations/versions/2f9d1a65945f_annotation_config_migration.py +322 -0
- phoenix/db/migrations/versions/8a3764fe7f1a_change_jsonb_to_json_for_prompts.py +76 -0
- phoenix/db/migrations/versions/bb8139330879_create_project_trace_retention_policies_table.py +77 -0
- phoenix/db/models.py +151 -10
- phoenix/db/types/annotation_configs.py +97 -0
- phoenix/db/types/db_models.py +41 -0
- phoenix/db/types/trace_retention.py +267 -0
- phoenix/experiments/functions.py +5 -1
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/context.py +5 -0
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +203 -24
- phoenix/server/api/dataloaders/project_ids_by_trace_retention_policy_id.py +42 -0
- phoenix/server/api/dataloaders/trace_retention_policy_id_by_project_id.py +34 -0
- phoenix/server/api/helpers/annotations.py +9 -0
- phoenix/server/api/helpers/prompts/models.py +34 -67
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +9 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +3 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +3 -0
- phoenix/server/api/input_types/SpanAnnotationFilter.py +67 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +413 -0
- phoenix/server/api/mutations/dataset_mutations.py +62 -39
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +245 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +272 -70
- phoenix/server/api/mutations/trace_annotations_mutations.py +203 -74
- phoenix/server/api/queries.py +86 -0
- phoenix/server/api/routers/v1/__init__.py +4 -0
- phoenix/server/api/routers/v1/annotation_configs.py +449 -0
- phoenix/server/api/routers/v1/annotations.py +161 -0
- phoenix/server/api/routers/v1/evaluations.py +6 -0
- phoenix/server/api/routers/v1/projects.py +1 -50
- phoenix/server/api/routers/v1/spans.py +35 -8
- phoenix/server/api/routers/v1/traces.py +22 -13
- phoenix/server/api/routers/v1/utils.py +60 -0
- phoenix/server/api/types/Annotation.py +7 -0
- phoenix/server/api/types/AnnotationConfig.py +124 -0
- phoenix/server/api/types/AnnotationSource.py +9 -0
- phoenix/server/api/types/AnnotationSummary.py +28 -14
- phoenix/server/api/types/AnnotatorKind.py +1 -0
- phoenix/server/api/types/CronExpression.py +15 -0
- phoenix/server/api/types/Evaluation.py +4 -30
- phoenix/server/api/types/Project.py +50 -2
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +110 -0
- phoenix/server/api/types/Span.py +78 -0
- phoenix/server/api/types/SpanAnnotation.py +24 -0
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/TraceAnnotation.py +23 -0
- phoenix/server/app.py +20 -0
- phoenix/server/retention.py +76 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-B2MWTXnm.js +4326 -0
- phoenix/server/static/assets/{index-B0CbpsxD.js → index-Bfvpea_-.js} +10 -10
- phoenix/server/static/assets/pages-CZ2vKu8H.js +7268 -0
- phoenix/server/static/assets/vendor-BRDkBC5J.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-CxXYQNUl.js → vendor-arizeai-BvTqp_W8.js} +3 -3
- phoenix/server/static/assets/{vendor-codemirror-B0NIFPOL.js → vendor-codemirror-COt9UfW7.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-CrrDFWK1.js → vendor-recharts-BoHX9Hvs.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-C5bJ-RPf.js → vendor-shiki-Cw1dsDAz.js} +1 -1
- phoenix/trace/dsl/filter.py +25 -5
- phoenix/utilities/__init__.py +18 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-x-gKFJ8C.js +0 -3414
- phoenix/server/static/assets/pages-BU4VdyeH.js +0 -5867
- phoenix/server/static/assets/vendor-BfhM_F1u.js +0 -902
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/WHEEL +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-8.32.1.dist-info → arize_phoenix-9.0.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
from typing import Annotated, Iterable, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import sqlalchemy as sa
|
|
7
|
+
from pydantic import AfterValidator, BaseModel, Field, RootModel
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from phoenix.utilities import hour_of_week
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _MaxDays(BaseModel):
|
|
14
|
+
max_days: Annotated[float, Field(ge=0)]
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def max_days_filter(self) -> sa.ColumnElement[bool]:
|
|
18
|
+
if self.max_days <= 0:
|
|
19
|
+
return sa.literal(False)
|
|
20
|
+
from phoenix.db.models import Trace
|
|
21
|
+
|
|
22
|
+
return Trace.start_time < datetime.now(timezone.utc) - timedelta(days=self.max_days)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class _MaxCount(BaseModel):
|
|
26
|
+
max_count: Annotated[int, Field(ge=0)]
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def max_count_filter(self) -> sa.ColumnElement[bool]:
|
|
30
|
+
if self.max_count <= 0:
|
|
31
|
+
return sa.literal(False)
|
|
32
|
+
from phoenix.db.models import Trace
|
|
33
|
+
|
|
34
|
+
return Trace.start_time < (
|
|
35
|
+
sa.select(Trace.start_time)
|
|
36
|
+
.order_by(Trace.start_time.desc())
|
|
37
|
+
.offset(self.max_count - 1)
|
|
38
|
+
.limit(1)
|
|
39
|
+
.scalar_subquery()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MaxDaysRule(_MaxDays, BaseModel):
|
|
44
|
+
type: Literal["max_days"] = "max_days"
|
|
45
|
+
|
|
46
|
+
def __bool__(self) -> bool:
|
|
47
|
+
return self.max_days > 0
|
|
48
|
+
|
|
49
|
+
async def delete_traces(
|
|
50
|
+
self,
|
|
51
|
+
session: AsyncSession,
|
|
52
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
53
|
+
) -> set[int]:
|
|
54
|
+
if self.max_days <= 0:
|
|
55
|
+
return set()
|
|
56
|
+
from phoenix.db.models import Trace
|
|
57
|
+
|
|
58
|
+
stmt = (
|
|
59
|
+
sa.delete(Trace)
|
|
60
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
61
|
+
.where(self.max_days_filter)
|
|
62
|
+
.returning(Trace.project_rowid)
|
|
63
|
+
)
|
|
64
|
+
return set(await session.scalars(stmt))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class MaxCountRule(_MaxCount, BaseModel):
|
|
68
|
+
type: Literal["max_count"] = "max_count"
|
|
69
|
+
|
|
70
|
+
def __bool__(self) -> bool:
|
|
71
|
+
return self.max_count > 0
|
|
72
|
+
|
|
73
|
+
async def delete_traces(
|
|
74
|
+
self,
|
|
75
|
+
session: AsyncSession,
|
|
76
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
77
|
+
) -> set[int]:
|
|
78
|
+
if self.max_count <= 0:
|
|
79
|
+
return set()
|
|
80
|
+
from phoenix.db.models import Trace
|
|
81
|
+
|
|
82
|
+
stmt = (
|
|
83
|
+
sa.delete(Trace)
|
|
84
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
85
|
+
.where(self.max_count_filter)
|
|
86
|
+
.returning(Trace.project_rowid)
|
|
87
|
+
)
|
|
88
|
+
return set(await session.scalars(stmt))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class MaxDaysOrCountRule(_MaxDays, _MaxCount, BaseModel):
|
|
92
|
+
type: Literal["max_days_or_count"] = "max_days_or_count"
|
|
93
|
+
|
|
94
|
+
def __bool__(self) -> bool:
|
|
95
|
+
return self.max_days > 0 or self.max_count > 0
|
|
96
|
+
|
|
97
|
+
async def delete_traces(
|
|
98
|
+
self,
|
|
99
|
+
session: AsyncSession,
|
|
100
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
101
|
+
) -> set[int]:
|
|
102
|
+
if self.max_days <= 0 and self.max_count <= 0:
|
|
103
|
+
return set()
|
|
104
|
+
from phoenix.db.models import Trace
|
|
105
|
+
|
|
106
|
+
stmt = (
|
|
107
|
+
sa.delete(Trace)
|
|
108
|
+
.where(Trace.project_rowid.in_(project_rowids))
|
|
109
|
+
.where(sa.or_(self.max_days_filter, self.max_count_filter))
|
|
110
|
+
.returning(Trace.project_rowid)
|
|
111
|
+
)
|
|
112
|
+
return set(await session.scalars(stmt))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class TraceRetentionRule(RootModel[Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule]]):
|
|
116
|
+
root: Annotated[
|
|
117
|
+
Union[MaxDaysRule, MaxCountRule, MaxDaysOrCountRule], Field(discriminator="type")
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
def __bool__(self) -> bool:
|
|
121
|
+
return bool(self.root)
|
|
122
|
+
|
|
123
|
+
async def delete_traces(
|
|
124
|
+
self,
|
|
125
|
+
session: AsyncSession,
|
|
126
|
+
project_rowids: Union[Iterable[int], sa.ScalarSelect[int]],
|
|
127
|
+
) -> set[int]:
|
|
128
|
+
return await self.root.delete_traces(session, project_rowids)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _time_of_next_run(
|
|
132
|
+
cron_expression: str,
|
|
133
|
+
after: Optional[datetime] = None,
|
|
134
|
+
) -> datetime:
|
|
135
|
+
"""
|
|
136
|
+
Parse a cron expression and calculate the UTC datetime of the next run.
|
|
137
|
+
Only processes hour, and day of week fields; day-of-month and
|
|
138
|
+
month fields must be '*'; minute field must be 0.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
cron_expression (str): Standard cron expression with 5 fields:
|
|
142
|
+
minute hour day-of-month month day-of-week
|
|
143
|
+
(minute must be '0'; day-of-month and month must be '*')
|
|
144
|
+
after: Optional[datetime]: The datetime to start searching from. If None,
|
|
145
|
+
the current time is used. Must be timezone-aware.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
datetime: The datetime of the next run. Timezone is UTC.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If the expression has non-wildcard values for day-of-month or month, if the
|
|
152
|
+
minute field is not '0', or if no match is found within the next 7 days (168 hours).
|
|
153
|
+
"""
|
|
154
|
+
fields: list[str] = cron_expression.strip().split()
|
|
155
|
+
if len(fields) != 5:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Invalid cron expression. Expected 5 fields "
|
|
158
|
+
"(minute hour day-of-month month day-of-week)."
|
|
159
|
+
)
|
|
160
|
+
if fields[0] != "0":
|
|
161
|
+
raise ValueError("Invalid cron expression. Minute field must be '0'.")
|
|
162
|
+
if fields[2] != "*" or fields[3] != "*":
|
|
163
|
+
raise ValueError("Invalid cron expression. Day-of-month and month fields must be '*'.")
|
|
164
|
+
hours: set[int] = _parse_field(fields[1], 0, 23)
|
|
165
|
+
# Parse days of week (0-6, where 0 is Sunday)
|
|
166
|
+
days_of_week: set[int] = _parse_field(fields[4], 0, 6)
|
|
167
|
+
# Convert to Python's weekday format (0-6, where 0 is Monday)
|
|
168
|
+
# Sunday (0 in cron) becomes 6 in Python's weekday()
|
|
169
|
+
python_days_of_week = {(day_of_week + 6) % 7 for day_of_week in days_of_week}
|
|
170
|
+
t = after.replace(tzinfo=timezone.utc) if after else datetime.now(timezone.utc)
|
|
171
|
+
t = t.replace(minute=0, second=0, microsecond=0)
|
|
172
|
+
for _ in range(168): # Check up to 7 days (168 hours)
|
|
173
|
+
t += timedelta(hours=1)
|
|
174
|
+
if t.hour in hours and t.weekday() in python_days_of_week:
|
|
175
|
+
return t
|
|
176
|
+
raise ValueError("No matching execution time found within the next 7 days.")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TraceRetentionCronExpression(RootModel[str]):
|
|
180
|
+
root: Annotated[str, AfterValidator(lambda x: (_time_of_next_run(x), x)[1])]
|
|
181
|
+
|
|
182
|
+
def get_hour_of_prev_run(self) -> int:
|
|
183
|
+
"""
|
|
184
|
+
Calculate the hour of the previous run before now.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
int: The hour of the previous run (0-167), where 0 is midnight Sunday UTC.
|
|
188
|
+
"""
|
|
189
|
+
after = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
190
|
+
return hour_of_week(_time_of_next_run(self.root, after))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _parse_field(field: str, min_val: int, max_val: int) -> set[int]:
|
|
194
|
+
"""
|
|
195
|
+
Parse a cron field and return the set of matching values.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
field (str): The cron field to parse
|
|
199
|
+
min_val (int): Minimum allowed value for this field
|
|
200
|
+
max_val (int): Maximum allowed value for this field
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
set[int]: Set of all valid values represented by the field expression
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: If the field contains invalid values or formats
|
|
207
|
+
"""
|
|
208
|
+
if field == "*":
|
|
209
|
+
return set(range(min_val, max_val + 1))
|
|
210
|
+
values: set[int] = set()
|
|
211
|
+
for part in field.split(","):
|
|
212
|
+
if "/" in part:
|
|
213
|
+
# Handle steps
|
|
214
|
+
range_part, step_str = part.split("/")
|
|
215
|
+
try:
|
|
216
|
+
step = int(step_str)
|
|
217
|
+
except ValueError:
|
|
218
|
+
raise ValueError(f"Invalid step value: {step_str}")
|
|
219
|
+
if step <= 0:
|
|
220
|
+
raise ValueError(f"Step value must be positive: {step}")
|
|
221
|
+
if range_part == "*":
|
|
222
|
+
start, end = min_val, max_val
|
|
223
|
+
elif "-" in range_part:
|
|
224
|
+
try:
|
|
225
|
+
start_str, end_str = range_part.split("-")
|
|
226
|
+
start, end = int(start_str), int(end_str)
|
|
227
|
+
except ValueError:
|
|
228
|
+
raise ValueError(f"Invalid range format: {range_part}")
|
|
229
|
+
if start < min_val or end > max_val:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
|
|
232
|
+
)
|
|
233
|
+
if start > end:
|
|
234
|
+
raise ValueError(f"Invalid range: {start}-{end} (start > end)")
|
|
235
|
+
else:
|
|
236
|
+
try:
|
|
237
|
+
start = int(range_part)
|
|
238
|
+
except ValueError:
|
|
239
|
+
raise ValueError(f"Invalid value: {range_part}")
|
|
240
|
+
if start < min_val or start > max_val:
|
|
241
|
+
raise ValueError(f"Value {start} out of range ({min_val}-{max_val})")
|
|
242
|
+
end = max_val
|
|
243
|
+
values.update(range(start, end + 1, step))
|
|
244
|
+
elif "-" in part:
|
|
245
|
+
# Handle ranges
|
|
246
|
+
try:
|
|
247
|
+
start_str, end_str = part.split("-")
|
|
248
|
+
start, end = int(start_str), int(end_str)
|
|
249
|
+
except ValueError:
|
|
250
|
+
raise ValueError(f"Invalid range format: {part}")
|
|
251
|
+
if start < min_val or end > max_val:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Range {start}-{end} outside allowed values ({min_val}-{max_val})"
|
|
254
|
+
)
|
|
255
|
+
if start > end:
|
|
256
|
+
raise ValueError(f"Invalid range: {start}-{end} (start > end)")
|
|
257
|
+
values.update(range(start, end + 1))
|
|
258
|
+
else:
|
|
259
|
+
# Handle single values
|
|
260
|
+
try:
|
|
261
|
+
value = int(part)
|
|
262
|
+
except ValueError:
|
|
263
|
+
raise ValueError(f"Invalid value: {part}")
|
|
264
|
+
if value < min_val or value > max_val:
|
|
265
|
+
raise ValueError(f"Value {value} out of range ({min_val}-{max_val})")
|
|
266
|
+
values.add(value)
|
|
267
|
+
return values
|
phoenix/experiments/functions.py
CHANGED
|
@@ -95,6 +95,7 @@ def run_experiment(
|
|
|
95
95
|
dry_run: Union[bool, int] = False,
|
|
96
96
|
print_summary: bool = True,
|
|
97
97
|
concurrency: int = 3,
|
|
98
|
+
timeout: Optional[int] = None,
|
|
98
99
|
) -> RanExperiment:
|
|
99
100
|
"""
|
|
100
101
|
Runs an experiment using a given set of dataset of examples.
|
|
@@ -148,6 +149,8 @@ def run_experiment(
|
|
|
148
149
|
concurrency (int): Specifies the concurrency for task execution. In order to enable
|
|
149
150
|
concurrent task execution, the task callable must be a coroutine function.
|
|
150
151
|
Defaults to 3.
|
|
152
|
+
timeout (Optional[int]): The timeout for the task execution in seconds. Use this to run
|
|
153
|
+
longer tasks to avoid re-queuing the same task multiple times. Defaults to None.
|
|
151
154
|
|
|
152
155
|
Returns:
|
|
153
156
|
RanExperiment: The results of the experiment and evaluation. Additional evaluations can be
|
|
@@ -380,6 +383,7 @@ def run_experiment(
|
|
|
380
383
|
fallback_return_value=None,
|
|
381
384
|
tqdm_bar_format=get_tqdm_progress_bar_formatter("running tasks"),
|
|
382
385
|
concurrency=concurrency,
|
|
386
|
+
timeout=timeout,
|
|
383
387
|
)
|
|
384
388
|
|
|
385
389
|
test_cases = [
|
|
@@ -752,7 +756,7 @@ def _print_experiment_error(
|
|
|
752
756
|
Prints an experiment error.
|
|
753
757
|
"""
|
|
754
758
|
display_error = RuntimeError(
|
|
755
|
-
f"{kind} failed for example id {repr(example_id)},
|
|
759
|
+
f"{kind} failed for example id {repr(example_id)}, repetition {repr(repetition_number)}"
|
|
756
760
|
)
|
|
757
761
|
display_error.__cause__ = error
|
|
758
762
|
formatted_exception = "".join(
|
phoenix/server/api/auth.py
CHANGED
|
@@ -42,3 +42,12 @@ class IsAdmin(Authorization):
|
|
|
42
42
|
if not info.context.auth_enabled:
|
|
43
43
|
return False
|
|
44
44
|
return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class IsAdminIfAuthEnabled(Authorization):
|
|
48
|
+
message = MSG_ADMIN_ONLY
|
|
49
|
+
|
|
50
|
+
def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool:
|
|
51
|
+
if not info.context.auth_enabled:
|
|
52
|
+
return True
|
|
53
|
+
return isinstance((user := info.context.user), PhoenixUser) and user.is_admin
|
phoenix/server/api/context.py
CHANGED
|
@@ -32,6 +32,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
32
32
|
NumChildSpansDataLoader,
|
|
33
33
|
NumSpansPerTraceDataLoader,
|
|
34
34
|
ProjectByNameDataLoader,
|
|
35
|
+
ProjectIdsByTraceRetentionPolicyIdDataLoader,
|
|
35
36
|
PromptVersionSequenceNumberDataLoader,
|
|
36
37
|
RecordCountDataLoader,
|
|
37
38
|
SessionIODataLoader,
|
|
@@ -47,6 +48,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
47
48
|
TableFieldsDataLoader,
|
|
48
49
|
TokenCountDataLoader,
|
|
49
50
|
TraceByTraceIdsDataLoader,
|
|
51
|
+
TraceRetentionPolicyIdByProjectIdDataLoader,
|
|
50
52
|
TraceRootSpansDataLoader,
|
|
51
53
|
UserRolesDataLoader,
|
|
52
54
|
UsersDataLoader,
|
|
@@ -82,6 +84,7 @@ class DataLoaders:
|
|
|
82
84
|
num_child_spans: NumChildSpansDataLoader
|
|
83
85
|
num_spans_per_trace: NumSpansPerTraceDataLoader
|
|
84
86
|
project_fields: TableFieldsDataLoader
|
|
87
|
+
projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
|
|
85
88
|
prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
|
|
86
89
|
record_counts: RecordCountDataLoader
|
|
87
90
|
session_first_inputs: SessionIODataLoader
|
|
@@ -99,6 +102,8 @@ class DataLoaders:
|
|
|
99
102
|
token_counts: TokenCountDataLoader
|
|
100
103
|
trace_by_trace_ids: TraceByTraceIdsDataLoader
|
|
101
104
|
trace_fields: TableFieldsDataLoader
|
|
105
|
+
trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
|
|
106
|
+
project_trace_retention_policy_fields: TableFieldsDataLoader
|
|
102
107
|
trace_root_spans: TraceRootSpansDataLoader
|
|
103
108
|
project_by_name: ProjectByNameDataLoader
|
|
104
109
|
users: UsersDataLoader
|
|
@@ -20,6 +20,7 @@ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMax
|
|
|
20
20
|
from .num_child_spans import NumChildSpansDataLoader
|
|
21
21
|
from .num_spans_per_trace import NumSpansPerTraceDataLoader
|
|
22
22
|
from .project_by_name import ProjectByNameDataLoader
|
|
23
|
+
from .project_ids_by_trace_retention_policy_id import ProjectIdsByTraceRetentionPolicyIdDataLoader
|
|
23
24
|
from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
|
|
24
25
|
from .record_counts import RecordCountCache, RecordCountDataLoader
|
|
25
26
|
from .session_io import SessionIODataLoader
|
|
@@ -35,6 +36,7 @@ from .span_projects import SpanProjectsDataLoader
|
|
|
35
36
|
from .table_fields import TableFieldsDataLoader
|
|
36
37
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
37
38
|
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
|
|
39
|
+
from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
|
|
38
40
|
from .trace_root_spans import TraceRootSpansDataLoader
|
|
39
41
|
from .user_roles import UserRolesDataLoader
|
|
40
42
|
from .users import UsersDataLoader
|
|
@@ -57,6 +59,7 @@ __all__ = [
|
|
|
57
59
|
"MinStartOrMaxEndTimeDataLoader",
|
|
58
60
|
"NumChildSpansDataLoader",
|
|
59
61
|
"NumSpansPerTraceDataLoader",
|
|
62
|
+
"ProjectIdsByTraceRetentionPolicyIdDataLoader",
|
|
60
63
|
"PromptVersionSequenceNumberDataLoader",
|
|
61
64
|
"RecordCountDataLoader",
|
|
62
65
|
"SessionIODataLoader",
|
|
@@ -71,6 +74,7 @@ __all__ = [
|
|
|
71
74
|
"TableFieldsDataLoader",
|
|
72
75
|
"TokenCountDataLoader",
|
|
73
76
|
"TraceByTraceIdsDataLoader",
|
|
77
|
+
"TraceRetentionPolicyIdByProjectIdDataLoader",
|
|
74
78
|
"TraceRootSpansDataLoader",
|
|
75
79
|
"ProjectByNameDataLoader",
|
|
76
80
|
"SpanAnnotationsDataLoader",
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import Any, Literal, Optional
|
|
3
|
+
from typing import Any, Literal, Optional, Type, Union, cast
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from aioitertools.itertools import groupby
|
|
7
7
|
from cachetools import LFUCache, TTLCache
|
|
8
|
-
from sqlalchemy import Select, func, or_, select
|
|
8
|
+
from sqlalchemy import Select, and_, case, distinct, func, or_, select
|
|
9
9
|
from strawberry.dataloader import AbstractCache, DataLoader
|
|
10
10
|
from typing_extensions import TypeAlias, assert_never
|
|
11
11
|
|
|
@@ -92,7 +92,7 @@ class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
|
|
|
92
92
|
async with self._db() as session:
|
|
93
93
|
data = await session.stream(stmt)
|
|
94
94
|
async for annotation_name, group in groupby(data, lambda row: row.name):
|
|
95
|
-
summary = AnnotationSummary(pd.DataFrame(group))
|
|
95
|
+
summary = AnnotationSummary(name=annotation_name, df=pd.DataFrame(group))
|
|
96
96
|
for position in params[annotation_name]:
|
|
97
97
|
results[position] = summary
|
|
98
98
|
return results
|
|
@@ -103,23 +103,64 @@ def _get_stmt(
|
|
|
103
103
|
*annotation_names: Param,
|
|
104
104
|
) -> Select[Any]:
|
|
105
105
|
kind, project_rowid, (start_time, end_time), filter_condition = segment
|
|
106
|
-
|
|
106
|
+
|
|
107
|
+
annotation_model: Union[Type[models.SpanAnnotation], Type[models.TraceAnnotation]]
|
|
108
|
+
entity_model: Union[Type[models.Span], Type[models.Trace]]
|
|
109
|
+
entity_join_model: Optional[Type[models.Base]]
|
|
110
|
+
entity_id_column: Any
|
|
111
|
+
|
|
107
112
|
if kind == "span":
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
if filter_condition:
|
|
113
|
-
sf = SpanFilter(filter_condition)
|
|
114
|
-
stmt = sf(stmt)
|
|
113
|
+
annotation_model = models.SpanAnnotation
|
|
114
|
+
entity_model = models.Span
|
|
115
|
+
entity_join_model = models.Trace
|
|
116
|
+
entity_id_column = models.Span.id.label("entity_id")
|
|
115
117
|
elif kind == "trace":
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
118
|
+
annotation_model = models.TraceAnnotation
|
|
119
|
+
entity_model = models.Trace
|
|
120
|
+
entity_join_model = None
|
|
121
|
+
entity_id_column = models.Trace.id.label("entity_id")
|
|
120
122
|
else:
|
|
121
123
|
assert_never(kind)
|
|
122
|
-
|
|
124
|
+
|
|
125
|
+
name_column = annotation_model.name
|
|
126
|
+
label_column = annotation_model.label
|
|
127
|
+
score_column = annotation_model.score
|
|
128
|
+
time_column = entity_model.start_time
|
|
129
|
+
|
|
130
|
+
# First query: count distinct entities per annotation name
|
|
131
|
+
# This is used later to calculate accurate fractions that account for entities without labels
|
|
132
|
+
entity_count_query = select(
|
|
133
|
+
name_column, func.count(distinct(entity_id_column)).label("entity_count")
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if kind == "span":
|
|
137
|
+
entity_count_query = entity_count_query.join(cast(Type[models.Span], entity_model))
|
|
138
|
+
entity_count_query = entity_count_query.join_from(
|
|
139
|
+
cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
|
|
140
|
+
)
|
|
141
|
+
entity_count_query = entity_count_query.where(models.Trace.project_rowid == project_rowid)
|
|
142
|
+
elif kind == "trace":
|
|
143
|
+
entity_count_query = entity_count_query.join(cast(Type[models.Trace], entity_model))
|
|
144
|
+
entity_count_query = entity_count_query.where(
|
|
145
|
+
cast(Type[models.Trace], entity_model).project_rowid == project_rowid
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
entity_count_query = entity_count_query.where(
|
|
149
|
+
or_(score_column.is_not(None), label_column.is_not(None))
|
|
150
|
+
)
|
|
151
|
+
entity_count_query = entity_count_query.where(name_column.in_(annotation_names))
|
|
152
|
+
|
|
153
|
+
if start_time:
|
|
154
|
+
entity_count_query = entity_count_query.where(start_time <= time_column)
|
|
155
|
+
if end_time:
|
|
156
|
+
entity_count_query = entity_count_query.where(time_column < end_time)
|
|
157
|
+
|
|
158
|
+
entity_count_query = entity_count_query.group_by(name_column)
|
|
159
|
+
entity_count_subquery = entity_count_query.subquery()
|
|
160
|
+
|
|
161
|
+
# Main query: gets raw annotation data with counts per (span/trace)+name+label
|
|
162
|
+
base_stmt = select(
|
|
163
|
+
entity_id_column,
|
|
123
164
|
name_column,
|
|
124
165
|
label_column,
|
|
125
166
|
func.count().label("record_count"),
|
|
@@ -127,13 +168,151 @@ def _get_stmt(
|
|
|
127
168
|
func.count(score_column).label("score_count"),
|
|
128
169
|
func.sum(score_column).label("score_sum"),
|
|
129
170
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
171
|
+
|
|
172
|
+
if kind == "span":
|
|
173
|
+
base_stmt = base_stmt.join(cast(Type[models.Span], entity_model))
|
|
174
|
+
base_stmt = base_stmt.join_from(
|
|
175
|
+
cast(Type[models.Span], entity_model), cast(Type[models.Trace], entity_join_model)
|
|
176
|
+
)
|
|
177
|
+
base_stmt = base_stmt.where(models.Trace.project_rowid == project_rowid)
|
|
178
|
+
if filter_condition:
|
|
179
|
+
sf = SpanFilter(filter_condition)
|
|
180
|
+
base_stmt = sf(base_stmt)
|
|
181
|
+
elif kind == "trace":
|
|
182
|
+
base_stmt = base_stmt.join(cast(Type[models.Trace], entity_model))
|
|
183
|
+
base_stmt = base_stmt.where(
|
|
184
|
+
cast(Type[models.Trace], entity_model).project_rowid == project_rowid
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
assert_never(kind)
|
|
188
|
+
|
|
189
|
+
base_stmt = base_stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
|
|
190
|
+
base_stmt = base_stmt.where(name_column.in_(annotation_names))
|
|
191
|
+
|
|
135
192
|
if start_time:
|
|
136
|
-
|
|
193
|
+
base_stmt = base_stmt.where(start_time <= time_column)
|
|
137
194
|
if end_time:
|
|
138
|
-
|
|
139
|
-
|
|
195
|
+
base_stmt = base_stmt.where(time_column < end_time)
|
|
196
|
+
|
|
197
|
+
# Group to get one row per (span/trace)+name+label combination
|
|
198
|
+
base_stmt = base_stmt.group_by(entity_id_column, name_column, label_column)
|
|
199
|
+
|
|
200
|
+
base_subquery = base_stmt.subquery()
|
|
201
|
+
|
|
202
|
+
# Calculate total counts per (span/trace)+name for computing fractions
|
|
203
|
+
entity_totals = (
|
|
204
|
+
select(
|
|
205
|
+
base_subquery.c.entity_id,
|
|
206
|
+
base_subquery.c.name,
|
|
207
|
+
func.sum(base_subquery.c.label_count).label("total_label_count"),
|
|
208
|
+
func.sum(base_subquery.c.score_count).label("total_score_count"),
|
|
209
|
+
func.sum(base_subquery.c.score_sum).label("entity_score_sum"),
|
|
210
|
+
)
|
|
211
|
+
.group_by(base_subquery.c.entity_id, base_subquery.c.name)
|
|
212
|
+
.subquery()
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
per_entity_fractions = (
|
|
216
|
+
select(
|
|
217
|
+
base_subquery.c.entity_id,
|
|
218
|
+
base_subquery.c.name,
|
|
219
|
+
base_subquery.c.label,
|
|
220
|
+
base_subquery.c.record_count,
|
|
221
|
+
base_subquery.c.label_count,
|
|
222
|
+
base_subquery.c.score_count,
|
|
223
|
+
base_subquery.c.score_sum,
|
|
224
|
+
# Calculate label fraction, avoiding division by zero when total_label_count is 0
|
|
225
|
+
case(
|
|
226
|
+
(
|
|
227
|
+
entity_totals.c.total_label_count > 0,
|
|
228
|
+
base_subquery.c.label_count * 1.0 / entity_totals.c.total_label_count,
|
|
229
|
+
),
|
|
230
|
+
else_=None,
|
|
231
|
+
).label("label_fraction"),
|
|
232
|
+
# Calculate average score for the entity (if there are any scores)
|
|
233
|
+
case(
|
|
234
|
+
(
|
|
235
|
+
entity_totals.c.total_score_count > 0,
|
|
236
|
+
entity_totals.c.entity_score_sum * 1.0 / entity_totals.c.total_score_count,
|
|
237
|
+
),
|
|
238
|
+
else_=None,
|
|
239
|
+
).label("entity_avg_score"),
|
|
240
|
+
)
|
|
241
|
+
.join(
|
|
242
|
+
entity_totals,
|
|
243
|
+
and_(
|
|
244
|
+
base_subquery.c.entity_id == entity_totals.c.entity_id,
|
|
245
|
+
base_subquery.c.name == entity_totals.c.name,
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
.subquery()
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Aggregate metrics across (spans/traces) for each name+label combination.
|
|
252
|
+
label_entity_metrics = (
|
|
253
|
+
select(
|
|
254
|
+
per_entity_fractions.c.name,
|
|
255
|
+
per_entity_fractions.c.label,
|
|
256
|
+
func.count(distinct(per_entity_fractions.c.entity_id)).label("entities_with_label"),
|
|
257
|
+
func.sum(per_entity_fractions.c.label_count).label("total_label_count"),
|
|
258
|
+
func.sum(per_entity_fractions.c.score_count).label("total_score_count"),
|
|
259
|
+
func.sum(per_entity_fractions.c.score_sum).label("total_score_sum"),
|
|
260
|
+
# Average of label fractions for entities that have this label
|
|
261
|
+
func.avg(per_entity_fractions.c.label_fraction).label("avg_label_fraction_present"),
|
|
262
|
+
# Average of per-entity average scores (but we handle overall aggregation separately)
|
|
263
|
+
)
|
|
264
|
+
.group_by(per_entity_fractions.c.name, per_entity_fractions.c.label)
|
|
265
|
+
.subquery()
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Compute distinct per-entity average scores to ensure each entity counts only once.
|
|
269
|
+
distinct_entity_scores = (
|
|
270
|
+
select(
|
|
271
|
+
per_entity_fractions.c.entity_id,
|
|
272
|
+
per_entity_fractions.c.name,
|
|
273
|
+
per_entity_fractions.c.entity_avg_score,
|
|
274
|
+
)
|
|
275
|
+
.distinct()
|
|
276
|
+
.subquery()
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
overall_score_aggregates = (
|
|
280
|
+
select(
|
|
281
|
+
distinct_entity_scores.c.name,
|
|
282
|
+
func.avg(distinct_entity_scores.c.entity_avg_score).label("overall_avg_score"),
|
|
283
|
+
)
|
|
284
|
+
.group_by(distinct_entity_scores.c.name)
|
|
285
|
+
.subquery()
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Final result: adjust label fractions by the proportion of entities reporting this label
|
|
289
|
+
# and include the overall average score per annotation name.
|
|
290
|
+
final_stmt = (
|
|
291
|
+
select(
|
|
292
|
+
label_entity_metrics.c.name,
|
|
293
|
+
label_entity_metrics.c.label,
|
|
294
|
+
# Adjust label fraction, guarding against division by zero in entity_count
|
|
295
|
+
case(
|
|
296
|
+
(
|
|
297
|
+
entity_count_subquery.c.entity_count > 0,
|
|
298
|
+
label_entity_metrics.c.avg_label_fraction_present
|
|
299
|
+
* label_entity_metrics.c.entities_with_label
|
|
300
|
+
/ entity_count_subquery.c.entity_count,
|
|
301
|
+
),
|
|
302
|
+
else_=None,
|
|
303
|
+
).label("avg_label_fraction"),
|
|
304
|
+
overall_score_aggregates.c.overall_avg_score.label("avg_score"), # same for all labels
|
|
305
|
+
label_entity_metrics.c.total_label_count.label("label_count"),
|
|
306
|
+
label_entity_metrics.c.total_score_count.label("score_count"),
|
|
307
|
+
label_entity_metrics.c.total_score_sum.label("score_sum"),
|
|
308
|
+
label_entity_metrics.c.entities_with_label.label("record_count"),
|
|
309
|
+
)
|
|
310
|
+
.join(entity_count_subquery, label_entity_metrics.c.name == entity_count_subquery.c.name)
|
|
311
|
+
.join(
|
|
312
|
+
overall_score_aggregates,
|
|
313
|
+
label_entity_metrics.c.name == overall_score_aggregates.c.name,
|
|
314
|
+
)
|
|
315
|
+
.order_by(label_entity_metrics.c.name, label_entity_metrics.c.label)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return final_stmt
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import or_, select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
8
|
+
from phoenix.db.models import Project
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
PolicyRowId: TypeAlias = int
|
|
12
|
+
ProjectRowId: TypeAlias = int
|
|
13
|
+
|
|
14
|
+
Key: TypeAlias = PolicyRowId
|
|
15
|
+
Result: TypeAlias = list[ProjectRowId]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProjectIdsByTraceRetentionPolicyIdDataLoader(DataLoader[Key, Result]):
|
|
19
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
super().__init__(load_fn=self._load_fn)
|
|
21
|
+
self._db = db
|
|
22
|
+
|
|
23
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
24
|
+
ids = set(keys)
|
|
25
|
+
stmt = select(Project.trace_retention_policy_id, Project.id)
|
|
26
|
+
if DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID in ids:
|
|
27
|
+
stmt = stmt.where(
|
|
28
|
+
or_(
|
|
29
|
+
Project.trace_retention_policy_id.in_(ids),
|
|
30
|
+
Project.trace_retention_policy_id.is_(None),
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
stmt = stmt.where(Project.trace_retention_policy_id.in_(ids))
|
|
35
|
+
projects: defaultdict[Key, Result] = defaultdict(list)
|
|
36
|
+
async with self._db() as session:
|
|
37
|
+
data = await session.stream(stmt)
|
|
38
|
+
async for policy_rowid, project_rowid in data:
|
|
39
|
+
projects[policy_rowid or DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID].append(
|
|
40
|
+
project_rowid
|
|
41
|
+
)
|
|
42
|
+
return [projects.get(project_name, []).copy() for project_name in keys]
|