arize-phoenix 3.24.0__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +54 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +33 -199
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.24.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -1,284 +1,353 @@
1
+ import operator
1
2
  from datetime import datetime
2
- from itertools import chain
3
- from typing import List, Optional
3
+ from typing import Any, List, Optional
4
4
 
5
5
  import strawberry
6
+ from aioitertools.itertools import islice
7
+ from sqlalchemy import and_, desc, distinct, select
8
+ from sqlalchemy.orm import contains_eager
9
+ from sqlalchemy.sql.expression import tuple_
6
10
  from strawberry import ID, UNSET
11
+ from strawberry.types import Info
7
12
 
8
- from phoenix.core.project import Project as CoreProject
9
- from phoenix.metrics.retrieval_metrics import RetrievalMetrics
10
- from phoenix.server.api.input_types.SpanSort import SpanSort
13
+ from phoenix.datetime_utils import right_open_time_range
14
+ from phoenix.db import models
15
+ from phoenix.server.api.context import Context
16
+ from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
11
17
  from phoenix.server.api.input_types.TimeRange import TimeRange
12
18
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
13
19
  from phoenix.server.api.types.EvaluationSummary import EvaluationSummary
14
20
  from phoenix.server.api.types.node import Node
15
21
  from phoenix.server.api.types.pagination import (
16
22
  Connection,
17
- ConnectionArgs,
18
23
  Cursor,
19
- connection_from_list,
24
+ CursorSortColumn,
25
+ CursorString,
26
+ connections,
20
27
  )
28
+ from phoenix.server.api.types.SortDir import SortDir
21
29
  from phoenix.server.api.types.Span import Span, to_gql_span
22
30
  from phoenix.server.api.types.Trace import Trace
23
31
  from phoenix.server.api.types.ValidationResult import ValidationResult
24
32
  from phoenix.trace.dsl import SpanFilter
25
- from phoenix.trace.schemas import SpanID, TraceID
33
+
34
+ SPANS_LIMIT = 1000
26
35
 
27
36
 
28
37
  @strawberry.type
29
38
  class Project(Node):
30
39
  name: str
31
- project: strawberry.Private[CoreProject]
40
+ gradient_start_color: str
41
+ gradient_end_color: str
32
42
 
33
43
  @strawberry.field
34
- def start_time(self) -> Optional[datetime]:
35
- start_time, _ = self.project.right_open_time_range
44
+ async def start_time(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> Optional[datetime]:
48
+ start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
49
+ (self.id_attr, "start"),
50
+ )
51
+ start_time, _ = right_open_time_range(start_time, None)
36
52
  return start_time
37
53
 
38
54
  @strawberry.field
39
- def end_time(self) -> Optional[datetime]:
40
- _, end_time = self.project.right_open_time_range
55
+ async def end_time(
56
+ self,
57
+ info: Info[Context, None],
58
+ ) -> Optional[datetime]:
59
+ end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
60
+ (self.id_attr, "end"),
61
+ )
62
+ _, end_time = right_open_time_range(None, end_time)
41
63
  return end_time
42
64
 
43
65
  @strawberry.field
44
- def record_count(
66
+ async def record_count(
45
67
  self,
68
+ info: Info[Context, None],
46
69
  time_range: Optional[TimeRange] = UNSET,
70
+ filter_condition: Optional[str] = UNSET,
47
71
  ) -> int:
48
- if not time_range:
49
- return self.project.span_count()
50
- return self.project.span_count(time_range.start, time_range.end)
72
+ return await info.context.data_loaders.record_counts.load(
73
+ ("span", self.id_attr, time_range, filter_condition),
74
+ )
51
75
 
52
76
  @strawberry.field
53
- def trace_count(
77
+ async def trace_count(
54
78
  self,
79
+ info: Info[Context, None],
55
80
  time_range: Optional[TimeRange] = UNSET,
56
81
  ) -> int:
57
- if not time_range:
58
- return self.project.trace_count()
59
- return self.project.trace_count(time_range.start, time_range.end)
82
+ return await info.context.data_loaders.record_counts.load(
83
+ ("trace", self.id_attr, time_range, None),
84
+ )
60
85
 
61
86
  @strawberry.field
62
- def token_count_total(self) -> int:
63
- return self.project.token_count_total
87
+ async def token_count_total(
88
+ self,
89
+ info: Info[Context, None],
90
+ time_range: Optional[TimeRange] = UNSET,
91
+ filter_condition: Optional[str] = UNSET,
92
+ ) -> int:
93
+ return await info.context.data_loaders.token_counts.load(
94
+ ("total", self.id_attr, time_range, filter_condition),
95
+ )
64
96
 
65
97
  @strawberry.field
66
- def latency_ms_p50(self) -> Optional[float]:
67
- return self.project.root_span_latency_ms_quantiles(0.50)
98
+ async def token_count_prompt(
99
+ self,
100
+ info: Info[Context, None],
101
+ time_range: Optional[TimeRange] = UNSET,
102
+ filter_condition: Optional[str] = UNSET,
103
+ ) -> int:
104
+ return await info.context.data_loaders.token_counts.load(
105
+ ("prompt", self.id_attr, time_range, filter_condition),
106
+ )
68
107
 
69
108
  @strawberry.field
70
- def latency_ms_p99(self) -> Optional[float]:
71
- return self.project.root_span_latency_ms_quantiles(0.99)
109
+ async def token_count_completion(
110
+ self,
111
+ info: Info[Context, None],
112
+ time_range: Optional[TimeRange] = UNSET,
113
+ filter_condition: Optional[str] = UNSET,
114
+ ) -> int:
115
+ return await info.context.data_loaders.token_counts.load(
116
+ ("completion", self.id_attr, time_range, filter_condition),
117
+ )
72
118
 
73
119
  @strawberry.field
74
- def trace(self, trace_id: ID) -> Optional[Trace]:
75
- if self.project.has_trace(TraceID(trace_id)):
76
- return Trace(trace_id=trace_id, project=self.project)
77
- return None
120
+ async def latency_ms_quantile(
121
+ self,
122
+ info: Info[Context, None],
123
+ probability: float,
124
+ time_range: Optional[TimeRange] = UNSET,
125
+ ) -> Optional[float]:
126
+ return await info.context.data_loaders.latency_ms_quantile.load(
127
+ ("trace", self.id_attr, time_range, None, probability),
128
+ )
78
129
 
79
130
  @strawberry.field
80
- def spans(
131
+ async def span_latency_ms_quantile(
81
132
  self,
133
+ info: Info[Context, None],
134
+ probability: float,
135
+ time_range: Optional[TimeRange] = UNSET,
136
+ filter_condition: Optional[str] = UNSET,
137
+ ) -> Optional[float]:
138
+ return await info.context.data_loaders.latency_ms_quantile.load(
139
+ ("span", self.id_attr, time_range, filter_condition, probability),
140
+ )
141
+
142
+ @strawberry.field
143
+ async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]:
144
+ stmt = (
145
+ select(models.Trace.id)
146
+ .where(models.Trace.trace_id == str(trace_id))
147
+ .where(models.Trace.project_rowid == self.id_attr)
148
+ )
149
+ async with info.context.db() as session:
150
+ if (id_attr := await session.scalar(stmt)) is None:
151
+ return None
152
+ return Trace(id_attr=id_attr)
153
+
154
+ @strawberry.field
155
+ async def spans(
156
+ self,
157
+ info: Info[Context, None],
82
158
  time_range: Optional[TimeRange] = UNSET,
83
- trace_ids: Optional[List[ID]] = UNSET,
84
159
  first: Optional[int] = 50,
85
160
  last: Optional[int] = UNSET,
86
- after: Optional[Cursor] = UNSET,
87
- before: Optional[Cursor] = UNSET,
161
+ after: Optional[CursorString] = UNSET,
162
+ before: Optional[CursorString] = UNSET,
88
163
  sort: Optional[SpanSort] = UNSET,
89
164
  root_spans_only: Optional[bool] = UNSET,
90
165
  filter_condition: Optional[str] = UNSET,
91
166
  ) -> Connection[Span]:
92
- args = ConnectionArgs(
93
- first=first,
94
- after=after if isinstance(after, Cursor) else None,
95
- last=last,
96
- before=before if isinstance(before, Cursor) else None,
167
+ stmt = (
168
+ select(models.Span)
169
+ .join(models.Trace)
170
+ .where(models.Trace.project_rowid == self.id_attr)
171
+ .options(contains_eager(models.Span.trace))
97
172
  )
98
- start_time = time_range.start if time_range else None
99
- stop_time = time_range.end if time_range else None
100
- if not (project := self.project).span_count(
101
- start_time=start_time,
102
- stop_time=stop_time,
103
- ):
104
- return connection_from_list(data=[], args=args)
105
- predicate = (
106
- SpanFilter(
107
- condition=filter_condition,
108
- evals=project,
109
- )
110
- if filter_condition
111
- else None
112
- )
113
- if not trace_ids:
114
- spans = project.get_spans(
115
- start_time=start_time,
116
- stop_time=stop_time,
117
- root_spans_only=root_spans_only,
118
- )
119
- else:
120
- spans = chain.from_iterable(
121
- project.get_trace(trace_id) for trace_id in map(TraceID, trace_ids)
173
+ if time_range:
174
+ stmt = stmt.where(
175
+ and_(
176
+ time_range.start <= models.Span.start_time,
177
+ models.Span.start_time < time_range.end,
178
+ )
122
179
  )
123
- if predicate:
124
- spans = filter(predicate, spans)
180
+ if root_spans_only:
181
+ # A root span is any span whose parent span is missing in the
182
+ # database, even if its `parent_span_id` may not be NULL.
183
+ parent = select(models.Span.span_id).alias()
184
+ stmt = stmt.outerjoin(
185
+ parent,
186
+ models.Span.parent_id == parent.c.span_id,
187
+ ).where(parent.c.span_id.is_(None))
188
+ if filter_condition:
189
+ span_filter = SpanFilter(condition=filter_condition)
190
+ stmt = span_filter(stmt)
191
+ sort_config: Optional[SpanSortConfig] = None
192
+ cursor_rowid_column: Any = models.Span.id
125
193
  if sort:
126
- spans = sort(spans, evals=project)
127
- data = [to_gql_span(span, project) for span in spans]
128
- return connection_from_list(data=data, args=args)
194
+ sort_config = sort.update_orm_expr(stmt)
195
+ stmt = sort_config.stmt
196
+ if sort_config.dir is SortDir.desc:
197
+ cursor_rowid_column = desc(cursor_rowid_column)
198
+ if after:
199
+ cursor = Cursor.from_string(after)
200
+ if sort_config and cursor.sort_column:
201
+ sort_column = cursor.sort_column
202
+ compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt
203
+ stmt = stmt.where(
204
+ compare(
205
+ tuple_(sort_config.orm_expression, models.Span.id),
206
+ (sort_column.value, cursor.rowid),
207
+ )
208
+ )
209
+ else:
210
+ stmt = stmt.where(models.Span.id > cursor.rowid)
211
+ if first:
212
+ stmt = stmt.limit(
213
+ first + 1 # overfetch by one to determine whether there's a next page
214
+ )
215
+ stmt = stmt.order_by(cursor_rowid_column)
216
+ data = []
217
+ async with info.context.db() as session:
218
+ span_records = await session.execute(stmt)
219
+ async for span_record in islice(span_records, first):
220
+ span = span_record[0]
221
+ sort_column_value = span_record[1] if len(span_record) > 1 else None
222
+ cursor = Cursor(
223
+ rowid=span.id,
224
+ sort_column=(
225
+ CursorSortColumn(
226
+ type=sort_config.column_data_type,
227
+ value=sort_column_value,
228
+ )
229
+ if sort_config
230
+ else None
231
+ ),
232
+ )
233
+ data.append((cursor, to_gql_span(span)))
234
+ has_next_page = True
235
+ try:
236
+ next(span_records)
237
+ except StopIteration:
238
+ has_next_page = False
239
+
240
+ return connections(
241
+ data,
242
+ has_previous_page=False,
243
+ has_next_page=has_next_page,
244
+ )
129
245
 
130
246
  @strawberry.field(
131
247
  description="Names of all available evaluations for traces. "
132
248
  "(The list contains no duplicates.)"
133
249
  ) # type: ignore
134
- def trace_evaluation_names(self) -> List[str]:
135
- return self.project.get_trace_evaluation_names()
250
+ async def trace_evaluation_names(
251
+ self,
252
+ info: Info[Context, None],
253
+ ) -> List[str]:
254
+ stmt = (
255
+ select(distinct(models.TraceAnnotation.name))
256
+ .join(models.Trace)
257
+ .where(models.Trace.project_rowid == self.id_attr)
258
+ .where(models.TraceAnnotation.annotator_kind == "LLM")
259
+ )
260
+ async with info.context.db() as session:
261
+ return list(await session.scalars(stmt))
136
262
 
137
263
  @strawberry.field(
138
264
  description="Names of all available evaluations for spans. "
139
265
  "(The list contains no duplicates.)"
140
266
  ) # type: ignore
141
- def span_evaluation_names(self) -> List[str]:
142
- return self.project.get_span_evaluation_names()
267
+ async def span_evaluation_names(
268
+ self,
269
+ info: Info[Context, None],
270
+ ) -> List[str]:
271
+ stmt = (
272
+ select(distinct(models.SpanAnnotation.name))
273
+ .join(models.Span)
274
+ .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
275
+ .where(models.Trace.project_rowid == self.id_attr)
276
+ .where(models.SpanAnnotation.annotator_kind == "LLM")
277
+ )
278
+ async with info.context.db() as session:
279
+ return list(await session.scalars(stmt))
143
280
 
144
281
  @strawberry.field(
145
282
  description="Names of available document evaluations.",
146
283
  ) # type: ignore
147
- def document_evaluation_names(
284
+ async def document_evaluation_names(
148
285
  self,
286
+ info: Info[Context, None],
149
287
  span_id: Optional[ID] = UNSET,
150
288
  ) -> List[str]:
151
- return self.project.get_document_evaluation_names(
152
- None if span_id is UNSET else SpanID(span_id),
289
+ stmt = (
290
+ select(distinct(models.DocumentAnnotation.name))
291
+ .join(models.Span)
292
+ .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
293
+ .where(models.Trace.project_rowid == self.id_attr)
294
+ .where(models.DocumentAnnotation.annotator_kind == "LLM")
153
295
  )
296
+ if span_id:
297
+ stmt = stmt.where(models.Span.span_id == str(span_id))
298
+ async with info.context.db() as session:
299
+ return list(await session.scalars(stmt))
154
300
 
155
301
  @strawberry.field
156
- def trace_evaluation_summary(
302
+ async def trace_evaluation_summary(
157
303
  self,
304
+ info: Info[Context, None],
158
305
  evaluation_name: str,
159
306
  time_range: Optional[TimeRange] = UNSET,
160
307
  ) -> Optional[EvaluationSummary]:
161
- project = self.project
162
- eval_trace_ids = project.get_trace_evaluation_trace_ids(evaluation_name)
163
- if not eval_trace_ids:
164
- return None
165
- trace_ids = project.get_trace_ids(
166
- start_time=time_range.start if time_range else None,
167
- stop_time=time_range.end if time_range else None,
168
- trace_ids=eval_trace_ids,
308
+ return await info.context.data_loaders.evaluation_summaries.load(
309
+ ("trace", self.id_attr, time_range, None, evaluation_name),
169
310
  )
170
- evaluations = tuple(
171
- evaluation
172
- for trace_id in trace_ids
173
- if (
174
- evaluation := project.get_trace_evaluation(
175
- trace_id,
176
- evaluation_name,
177
- )
178
- )
179
- is not None
180
- )
181
- if not evaluations:
182
- return None
183
- labels = project.get_trace_evaluation_labels(evaluation_name)
184
- return EvaluationSummary(evaluations, labels)
185
311
 
186
312
  @strawberry.field
187
- def span_evaluation_summary(
313
+ async def span_evaluation_summary(
188
314
  self,
315
+ info: Info[Context, None],
189
316
  evaluation_name: str,
190
317
  time_range: Optional[TimeRange] = UNSET,
191
318
  filter_condition: Optional[str] = UNSET,
192
319
  ) -> Optional[EvaluationSummary]:
193
- project = self.project
194
- predicate = (
195
- SpanFilter(
196
- condition=filter_condition,
197
- evals=project,
198
- )
199
- if filter_condition
200
- else None
201
- )
202
- span_ids = project.get_span_evaluation_span_ids(evaluation_name)
203
- if not span_ids:
204
- return None
205
- spans = project.get_spans(
206
- start_time=time_range.start if time_range else None,
207
- stop_time=time_range.end if time_range else None,
208
- span_ids=span_ids,
209
- )
210
- if predicate:
211
- spans = filter(predicate, spans)
212
- evaluations = tuple(
213
- evaluation
214
- for span in spans
215
- if (
216
- evaluation := project.get_span_evaluation(
217
- span.context.span_id,
218
- evaluation_name,
219
- )
220
- )
221
- is not None
320
+ return await info.context.data_loaders.evaluation_summaries.load(
321
+ ("span", self.id_attr, time_range, filter_condition, evaluation_name),
222
322
  )
223
- if not evaluations:
224
- return None
225
- labels = project.get_span_evaluation_labels(evaluation_name)
226
- return EvaluationSummary(evaluations, labels)
227
323
 
228
324
  @strawberry.field
229
- def document_evaluation_summary(
325
+ async def document_evaluation_summary(
230
326
  self,
327
+ info: Info[Context, None],
231
328
  evaluation_name: str,
232
329
  time_range: Optional[TimeRange] = UNSET,
233
330
  filter_condition: Optional[str] = UNSET,
234
331
  ) -> Optional[DocumentEvaluationSummary]:
235
- project = self.project
236
- predicate = (
237
- SpanFilter(condition=filter_condition, evals=project) if filter_condition else None
238
- )
239
- span_ids = project.get_document_evaluation_span_ids(evaluation_name)
240
- if not span_ids:
241
- return None
242
- spans = project.get_spans(
243
- start_time=time_range.start if time_range else None,
244
- stop_time=time_range.end if time_range else None,
245
- span_ids=span_ids,
246
- )
247
- if predicate:
248
- spans = filter(predicate, spans)
249
- metrics_collection = []
250
- for span in spans:
251
- span_id = span.context.span_id
252
- num_documents = project.get_num_documents(span_id)
253
- if not num_documents:
254
- continue
255
- evaluation_scores = project.get_document_evaluation_scores(
256
- span_id=span_id,
257
- evaluation_name=evaluation_name,
258
- num_documents=num_documents,
259
- )
260
- metrics_collection.append(RetrievalMetrics(evaluation_scores))
261
- if not metrics_collection:
262
- return None
263
- return DocumentEvaluationSummary(
264
- evaluation_name=evaluation_name,
265
- metrics_collection=metrics_collection,
332
+ return await info.context.data_loaders.document_evaluation_summaries.load(
333
+ (self.id_attr, time_range, filter_condition, evaluation_name),
266
334
  )
267
335
 
268
336
  @strawberry.field
269
337
  def streaming_last_updated_at(
270
338
  self,
339
+ info: Info[Context, None],
271
340
  ) -> Optional[datetime]:
272
- return self.project.last_updated_at
341
+ return info.context.streaming_last_updated_at(self.id_attr)
273
342
 
274
343
  @strawberry.field
275
- def validate_span_filter_condition(self, condition: str) -> ValidationResult:
276
- valid_eval_names = self.project.get_span_evaluation_names()
344
+ async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
345
+ # This query is too expensive to run on every validation
346
+ # valid_eval_names = await self.span_evaluation_names()
277
347
  try:
278
348
  SpanFilter(
279
349
  condition=condition,
280
- evals=self.project,
281
- valid_eval_names=valid_eval_names,
350
+ # valid_eval_names=valid_eval_names,
282
351
  )
283
352
  return ValidationResult(is_valid=True, error_message=None)
284
353
  except SyntaxError as e: