arize-phoenix 3.25.0__py3-none-any.whl → 4.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +51 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +44 -200
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -1,8 +1,7 @@
1
1
  import json
2
- from collections import defaultdict
3
2
  from datetime import datetime
4
3
  from enum import Enum
5
- from typing import Any, DefaultDict, Dict, Iterable, List, Mapping, Optional, Sized, cast
4
+ from typing import Any, List, Mapping, Optional, Sized, cast
6
5
 
7
6
  import numpy as np
8
7
  import strawberry
@@ -11,13 +10,13 @@ from strawberry import ID, UNSET
11
10
  from strawberry.types import Info
12
11
 
13
12
  import phoenix.trace.schemas as trace_schema
14
- from phoenix.core.project import Project, WrappedSpan
15
- from phoenix.metrics.retrieval_metrics import RetrievalMetrics
13
+ from phoenix.db import models
16
14
  from phoenix.server.api.context import Context
17
15
  from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
18
16
  from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
19
17
  from phoenix.server.api.types.MimeType import MimeType
20
- from phoenix.trace.schemas import ComputedAttributes, SpanID
18
+ from phoenix.server.api.types.node import Node
19
+ from phoenix.trace.attributes import get_attribute_value
21
20
 
22
21
  EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
23
22
  EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
@@ -40,18 +39,20 @@ class SpanKind(Enum):
40
39
  NB: this is actively under construction
41
40
  """
42
41
 
43
- chain = trace_schema.SpanKind.CHAIN
44
- tool = trace_schema.SpanKind.TOOL
45
- llm = trace_schema.SpanKind.LLM
46
- retriever = trace_schema.SpanKind.RETRIEVER
47
- embedding = trace_schema.SpanKind.EMBEDDING
48
- agent = trace_schema.SpanKind.AGENT
49
- reranker = trace_schema.SpanKind.RERANKER
50
- unknown = trace_schema.SpanKind.UNKNOWN
42
+ chain = "CHAIN"
43
+ tool = "TOOL"
44
+ llm = "LLM"
45
+ retriever = "RETRIEVER"
46
+ embedding = "EMBEDDING"
47
+ agent = "AGENT"
48
+ reranker = "RERANKER"
49
+ unknown = "UNKNOWN"
51
50
 
52
51
  @classmethod
53
52
  def _missing_(cls, v: Any) -> Optional["SpanKind"]:
54
- return None if v else cls.unknown
53
+ if v and isinstance(v, str) and v.isascii() and not v.isupper():
54
+ return cls(v.upper())
55
+ return cls.unknown
55
56
 
56
57
 
57
58
  @strawberry.type
@@ -65,12 +66,18 @@ class SpanIOValue:
65
66
  mime_type: MimeType
66
67
  value: str
67
68
 
69
+ @strawberry.field(
70
+ description="Truncate value up to `chars` characters, appending '...' if truncated.",
71
+ ) # type: ignore
72
+ def truncated_value(self, chars: int = 100) -> str:
73
+ return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value
74
+
68
75
 
69
76
  @strawberry.enum
70
77
  class SpanStatusCode(Enum):
71
- OK = trace_schema.SpanStatusCode.OK
72
- ERROR = trace_schema.SpanStatusCode.ERROR
73
- UNSET = trace_schema.SpanStatusCode.UNSET
78
+ OK = "OK"
79
+ ERROR = "ERROR"
80
+ UNSET = "UNSET"
74
81
 
75
82
  @classmethod
76
83
  def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]:
@@ -84,19 +91,18 @@ class SpanEvent:
84
91
  timestamp: datetime
85
92
 
86
93
  @staticmethod
87
- def from_event(
88
- event: trace_schema.SpanEvent,
94
+ def from_dict(
95
+ event: Mapping[str, Any],
89
96
  ) -> "SpanEvent":
90
97
  return SpanEvent(
91
- name=event.name,
92
- message=cast(str, event.attributes.get(trace_schema.EXCEPTION_MESSAGE) or ""),
93
- timestamp=event.timestamp,
98
+ name=event["name"],
99
+ message=cast(str, event["attributes"].get(trace_schema.EXCEPTION_MESSAGE) or ""),
100
+ timestamp=datetime.fromisoformat(event["timestamp"]),
94
101
  )
95
102
 
96
103
 
97
104
  @strawberry.type
98
- class Span:
99
- project: strawberry.Private[Project]
105
+ class Span(Node):
100
106
  name: str
101
107
  status_code: SpanStatusCode
102
108
  status_message: str
@@ -143,12 +149,8 @@ class Span:
143
149
  "an LLM, an evaluation may assess the helpfulness of its response with "
144
150
  "respect to its input."
145
151
  ) # type: ignore
146
- def span_evaluations(self) -> List[SpanEvaluation]:
147
- span_id = SpanID(str(self.context.span_id))
148
- return [
149
- SpanEvaluation.from_pb_evaluation(evaluation)
150
- for evaluation in self.project.get_evaluations_by_span_id(span_id)
151
- ]
152
+ async def span_evaluations(self, info: Info[Context, None]) -> List[SpanEvaluation]:
153
+ return await info.context.data_loaders.span_evaluations.load(self.id_attr)
152
154
 
153
155
  @strawberry.field(
154
156
  description="Evaluations of the documents associated with the span, e.g. "
@@ -158,68 +160,43 @@ class Span:
158
160
  "a list, and each evaluation is identified by its document's (zero-based) "
159
161
  "index in that list."
160
162
  ) # type: ignore
161
- def document_evaluations(self) -> List[DocumentEvaluation]:
162
- span_id = SpanID(str(self.context.span_id))
163
- return [
164
- DocumentEvaluation.from_pb_evaluation(evaluation)
165
- for evaluation in self.project.get_document_evaluations_by_span_id(span_id)
166
- ]
163
+ async def document_evaluations(self, info: Info[Context, None]) -> List[DocumentEvaluation]:
164
+ return await info.context.data_loaders.document_evaluations.load(self.id_attr)
167
165
 
168
166
  @strawberry.field(
169
167
  description="Retrieval metrics: NDCG@K, Precision@K, Reciprocal Rank, etc.",
170
168
  ) # type: ignore
171
- def document_retrieval_metrics(
169
+ async def document_retrieval_metrics(
172
170
  self,
171
+ info: Info[Context, None],
173
172
  evaluation_name: Optional[str] = UNSET,
174
173
  ) -> List[DocumentRetrievalMetrics]:
175
174
  if not self.num_documents:
176
175
  return []
177
- span_id = SpanID(str(self.context.span_id))
178
- all_document_evaluation_names = self.project.get_document_evaluation_names(span_id)
179
- if not all_document_evaluation_names:
180
- return []
181
- if evaluation_name is UNSET:
182
- evaluation_names = all_document_evaluation_names
183
- elif evaluation_name not in all_document_evaluation_names:
184
- return []
185
- else:
186
- evaluation_names = [evaluation_name]
187
- retrieval_metrics = []
188
- for name in evaluation_names:
189
- evaluation_scores = self.project.get_document_evaluation_scores(
190
- span_id=span_id,
191
- evaluation_name=name,
192
- num_documents=self.num_documents,
193
- )
194
- retrieval_metrics.append(
195
- DocumentRetrievalMetrics(
196
- evaluation_name=name,
197
- metrics=RetrievalMetrics(evaluation_scores),
198
- )
199
- )
200
- return retrieval_metrics
176
+ return await info.context.data_loaders.document_retrieval_metrics.load(
177
+ (self.id_attr, evaluation_name or None, self.num_documents),
178
+ )
201
179
 
202
180
  @strawberry.field(
203
181
  description="All descendant spans (children, grandchildren, etc.)",
204
182
  ) # type: ignore
205
- def descendants(
183
+ async def descendants(
206
184
  self,
207
185
  info: Info[Context, None],
208
186
  ) -> List["Span"]:
209
- return [
210
- to_gql_span(span, self.project)
211
- for span in self.project.get_descendant_spans(SpanID(self.context.span_id))
212
- ]
187
+ span_id = str(self.context.span_id)
188
+ spans = await info.context.data_loaders.span_descendants.load(span_id)
189
+ return [to_gql_span(span) for span in spans]
213
190
 
214
191
 
215
- def to_gql_span(span: WrappedSpan, project: Project) -> "Span":
216
- events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events))
217
- input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE))
218
- output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE))
219
- retrieval_documents = span.attributes.get(RETRIEVAL_DOCUMENTS)
192
+ def to_gql_span(span: models.Span) -> Span:
193
+ events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
194
+ input_value = cast(Optional[str], get_attribute_value(span.attributes, INPUT_VALUE))
195
+ output_value = cast(Optional[str], get_attribute_value(span.attributes, OUTPUT_VALUE))
196
+ retrieval_documents = get_attribute_value(span.attributes, RETRIEVAL_DOCUMENTS)
220
197
  num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
221
198
  return Span(
222
- project=project,
199
+ id_attr=span.id,
223
200
  name=span.name,
224
201
  status_code=SpanStatusCode(span.status_code),
225
202
  status_message=span.status_message,
@@ -227,50 +204,39 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span":
227
204
  span_kind=SpanKind(span.span_kind),
228
205
  start_time=span.start_time,
229
206
  end_time=span.end_time,
230
- latency_ms=cast(Optional[float], span[ComputedAttributes.LATENCY_MS]),
207
+ latency_ms=span.latency_ms,
231
208
  context=SpanContext(
232
- trace_id=cast(ID, span.context.trace_id),
233
- span_id=cast(ID, span.context.span_id),
234
- ),
235
- attributes=json.dumps(
236
- _nested_attributes(_hide_embedding_vectors(span.attributes)),
237
- cls=_JSONEncoder,
209
+ trace_id=cast(ID, span.trace.trace_id),
210
+ span_id=cast(ID, span.span_id),
238
211
  ),
239
- metadata=_convert_metadata_to_string(span.attributes.get(METADATA)),
212
+ attributes=json.dumps(_hide_embedding_vectors(span.attributes), cls=_JSONEncoder),
213
+ metadata=_convert_metadata_to_string(get_attribute_value(span.attributes, METADATA)),
240
214
  num_documents=num_documents,
241
215
  token_count_total=cast(
242
216
  Optional[int],
243
- span.attributes.get(LLM_TOKEN_COUNT_TOTAL),
217
+ get_attribute_value(span.attributes, LLM_TOKEN_COUNT_TOTAL),
244
218
  ),
245
219
  token_count_prompt=cast(
246
220
  Optional[int],
247
- span.attributes.get(LLM_TOKEN_COUNT_PROMPT),
221
+ get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
248
222
  ),
249
223
  token_count_completion=cast(
250
224
  Optional[int],
251
- span.attributes.get(LLM_TOKEN_COUNT_COMPLETION),
252
- ),
253
- cumulative_token_count_total=cast(
254
- Optional[int],
255
- span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL],
256
- ),
257
- cumulative_token_count_prompt=cast(
258
- Optional[int],
259
- span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT],
260
- ),
261
- cumulative_token_count_completion=cast(
262
- Optional[int],
263
- span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION],
225
+ get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION),
264
226
  ),
227
+ cumulative_token_count_total=span.cumulative_llm_token_count_prompt
228
+ + span.cumulative_llm_token_count_completion,
229
+ cumulative_token_count_prompt=span.cumulative_llm_token_count_prompt,
230
+ cumulative_token_count_completion=span.cumulative_llm_token_count_completion,
265
231
  propagated_status_code=(
266
232
  SpanStatusCode.ERROR
267
- if span[ComputedAttributes.CUMULATIVE_ERROR_COUNT]
233
+ if span.cumulative_error_count
268
234
  else SpanStatusCode(span.status_code)
269
235
  ),
270
236
  events=events,
271
237
  input=(
272
238
  SpanIOValue(
273
- mime_type=MimeType(span.attributes.get(INPUT_MIME_TYPE)),
239
+ mime_type=MimeType(get_attribute_value(span.attributes, INPUT_MIME_TYPE)),
274
240
  value=input_value,
275
241
  )
276
242
  if input_value is not None
@@ -278,7 +244,7 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span":
278
244
  ),
279
245
  output=(
280
246
  SpanIOValue(
281
- mime_type=MimeType(span.attributes.get(OUTPUT_MIME_TYPE)),
247
+ mime_type=MimeType(get_attribute_value(span.attributes, OUTPUT_MIME_TYPE)),
282
248
  value=output_value,
283
249
  )
284
250
  if output_value is not None
@@ -287,6 +253,29 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span":
287
253
  )
288
254
 
289
255
 
256
+ def _hide_embedding_vectors(attributes: Mapping[str, Any]) -> Mapping[str, Any]:
257
+ if not (
258
+ isinstance(em := attributes.get("embedding"), dict)
259
+ and isinstance(embeddings := em.get("embeddings"), list)
260
+ and embeddings
261
+ ):
262
+ return attributes
263
+ embeddings = embeddings.copy()
264
+ for i, embedding in enumerate(embeddings):
265
+ if not (
266
+ isinstance(embedding, dict)
267
+ and isinstance(emb := embedding.get("embedding"), dict)
268
+ and isinstance(vector := emb.get("vector"), list)
269
+ and vector
270
+ ):
271
+ continue
272
+ embeddings[i] = {
273
+ **embedding,
274
+ "embedding": {**emb, "vector": f"<{len(vector)} dimensional vector>"},
275
+ }
276
+ return {**attributes, "embedding": {**em, "embeddings": embeddings}}
277
+
278
+
290
279
  class _JSONEncoder(json.JSONEncoder):
291
280
  def default(self, obj: Any) -> Any:
292
281
  if isinstance(obj, datetime):
@@ -302,39 +291,6 @@ class _JSONEncoder(json.JSONEncoder):
302
291
  return super().default(obj)
303
292
 
304
293
 
305
- def _trie() -> DefaultDict[str, Any]:
306
- return defaultdict(_trie)
307
-
308
-
309
- def _nested_attributes(
310
- attributes: Mapping[str, Any],
311
- ) -> DefaultDict[str, Any]:
312
- nested_attributes = _trie()
313
- for attribute_name, attribute_value in attributes.items():
314
- trie = nested_attributes
315
- keys = attribute_name.split(".")
316
- for key in keys[:-1]:
317
- trie = trie[key]
318
- trie[keys[-1]] = attribute_value
319
- return nested_attributes
320
-
321
-
322
- def _hide_embedding_vectors(
323
- attributes: Mapping[str, Any],
324
- ) -> Dict[str, Any]:
325
- _attributes = dict(attributes)
326
- if not isinstance((embeddings := _attributes.get(EMBEDDING_EMBEDDINGS)), Iterable):
327
- return _attributes
328
- _embeddings = []
329
- for embedding in embeddings:
330
- _embedding = dict(embedding)
331
- if isinstance((vector := _embedding.get(EMBEDDING_VECTOR)), Sized):
332
- _embedding[EMBEDDING_VECTOR] = f"<{len(vector)} dimensional vector>"
333
- _embeddings.append(_embedding)
334
- _attributes[EMBEDDING_EMBEDDINGS] = _embeddings
335
- return _attributes
336
-
337
-
338
294
  def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
339
295
  """
340
296
  Converts metadata to a string representation.
@@ -1,47 +1,56 @@
1
1
  from typing import List, Optional
2
2
 
3
3
  import strawberry
4
- from strawberry import ID, UNSET, Private
4
+ from sqlalchemy import desc, select
5
+ from sqlalchemy.orm import contains_eager
6
+ from strawberry import UNSET
7
+ from strawberry.types import Info
5
8
 
6
- from phoenix.core.project import Project
9
+ from phoenix.db import models
10
+ from phoenix.server.api.context import Context
7
11
  from phoenix.server.api.types.Evaluation import TraceEvaluation
12
+ from phoenix.server.api.types.node import Node
8
13
  from phoenix.server.api.types.pagination import (
9
14
  Connection,
10
15
  ConnectionArgs,
11
- Cursor,
16
+ CursorString,
12
17
  connection_from_list,
13
18
  )
14
19
  from phoenix.server.api.types.Span import Span, to_gql_span
15
- from phoenix.trace.schemas import TraceID
16
20
 
17
21
 
18
22
  @strawberry.type
19
- class Trace:
20
- trace_id: ID
21
- project: Private[Project]
22
-
23
+ class Trace(Node):
23
24
  @strawberry.field
24
- def spans(
25
+ async def spans(
25
26
  self,
27
+ info: Info[Context, None],
26
28
  first: Optional[int] = 50,
27
29
  last: Optional[int] = UNSET,
28
- after: Optional[Cursor] = UNSET,
29
- before: Optional[Cursor] = UNSET,
30
+ after: Optional[CursorString] = UNSET,
31
+ before: Optional[CursorString] = UNSET,
30
32
  ) -> Connection[Span]:
31
33
  args = ConnectionArgs(
32
34
  first=first,
33
- after=after if isinstance(after, Cursor) else None,
35
+ after=after if isinstance(after, CursorString) else None,
34
36
  last=last,
35
- before=before if isinstance(before, Cursor) else None,
37
+ before=before if isinstance(before, CursorString) else None,
36
38
  )
37
- spans = sorted(
38
- self.project.get_trace(TraceID(self.trace_id)),
39
- key=lambda span: span.start_time,
39
+ stmt = (
40
+ select(models.Span)
41
+ .join(models.Trace)
42
+ .where(models.Trace.id == self.id_attr)
43
+ .options(contains_eager(models.Span.trace))
44
+ # Sort descending because the root span tends to show up later
45
+ # in the ingestion process.
46
+ .order_by(desc(models.Span.id))
47
+ .limit(first)
40
48
  )
41
- data = [to_gql_span(span, self.project) for span in spans]
49
+ async with info.context.db() as session:
50
+ spans = await session.stream_scalars(stmt)
51
+ data = [to_gql_span(span) async for span in spans]
42
52
  return connection_from_list(data=data, args=args)
43
53
 
44
54
  @strawberry.field(description="Evaluations associated with the trace") # type: ignore
45
- def trace_evaluations(self) -> List[TraceEvaluation]:
46
- evaluations = self.project.get_evaluations_by_trace_id(TraceID(self.trace_id))
47
- return [TraceEvaluation.from_pb_evaluation(evaluation) for evaluation in evaluations]
55
+ async def trace_evaluations(self, info: Info[Context, None]) -> List[TraceEvaluation]:
56
+ return await info.context.data_loaders.trace_evaluations.load(self.id_attr)
@@ -1,11 +1,16 @@
1
1
  import base64
2
2
  from dataclasses import dataclass
3
- from typing import Generic, List, Optional, TypeVar
3
+ from datetime import datetime
4
+ from enum import Enum, auto
5
+ from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union
4
6
 
5
7
  import strawberry
6
8
  from strawberry import UNSET
9
+ from typing_extensions import TypeAlias, assert_never
7
10
 
11
+ ID: TypeAlias = int
8
12
  GenericType = TypeVar("GenericType")
13
+ CursorSortColumnValue: TypeAlias = Union[str, int, float, datetime]
9
14
 
10
15
 
11
16
  @strawberry.type
@@ -35,11 +40,10 @@ class PageInfo:
35
40
  has_previous_page: bool
36
41
  start_cursor: Optional[str]
37
42
  end_cursor: Optional[str]
38
- total_count: int
39
43
 
40
44
 
41
45
  # A type alias for the connection cursor implementation
42
- Cursor = str
46
+ CursorString = str
43
47
 
44
48
 
45
49
  @strawberry.type
@@ -56,14 +60,132 @@ class Edge(Generic[GenericType]):
56
60
  CURSOR_PREFIX = "connection:"
57
61
 
58
62
 
59
- def offset_to_cursor(offset: int) -> Cursor:
63
+ class CursorSortColumnDataType(Enum):
64
+ STRING = auto()
65
+ INT = auto()
66
+ FLOAT = auto()
67
+ DATETIME = auto()
68
+
69
+
70
+ @dataclass
71
+ class CursorSortColumn:
72
+ type: CursorSortColumnDataType
73
+ value: CursorSortColumnValue
74
+
75
+ def __str__(self) -> str:
76
+ if isinstance(self.value, str):
77
+ return self.value
78
+ if isinstance(self.value, (int, float)):
79
+ return str(self.value)
80
+ if isinstance(self.value, datetime):
81
+ return self.value.isoformat()
82
+ assert_never(self.type)
83
+
84
+ @classmethod
85
+ def from_string(cls, type: CursorSortColumnDataType, cursor_string: str) -> "CursorSortColumn":
86
+ value: CursorSortColumnValue
87
+ if type is CursorSortColumnDataType.STRING:
88
+ value = cursor_string
89
+ elif type is CursorSortColumnDataType.INT:
90
+ value = int(cursor_string)
91
+ elif type is CursorSortColumnDataType.FLOAT:
92
+ value = float(cursor_string)
93
+ elif type is CursorSortColumnDataType.DATETIME:
94
+ value = datetime.fromisoformat(cursor_string)
95
+ else:
96
+ assert_never(type)
97
+ return cls(type=type, value=value)
98
+
99
+
100
+ @dataclass
101
+ class Cursor:
102
+ """
103
+ Serializes and deserializes cursor strings for ID-based pagination.
104
+
105
+ In the simplest case, a cursor encodes the rowid of a record. In the case
106
+ that a sort has been applied, the cursor additionally encodes the data type
107
+ and value of the column indexed for sorting so that the sort position can be
108
+ efficiently found. The encoding ensures that the cursor string is opaque to
109
+ the client and discourages the client from making use of the encoded
110
+ content.
111
+
112
+ Examples:
113
+ # encodes "10"
114
+ Cursor(rowid=10)
115
+
116
+ # encodes "11:STRING:abc"
117
+ Cursor(
118
+ rowid=11,
119
+ sort_column=CursorSortColumn(
120
+ type=CursorSortColumnDataType.STRING,
121
+ value="abc"
122
+ )
123
+ )
124
+
125
+ # encodes "10:INT:5"
126
+ Cursor(
127
+ rowid=10,
128
+ sort_column=CursorSortColumn(
129
+ type=CursorSortColumnDataType.INT,
130
+ value=5
131
+ )
132
+ )
133
+
134
+ # encodes "17:FLOAT:5.7"
135
+ Cursor(
136
+ rowid=17,
137
+ sort_column=CursorSortColumn(
138
+ type=CursorSortColumnDataType.FLOAT,
139
+ value=5.7
140
+ )
141
+ )
142
+
143
+ # encodes "20:DATETIME:2024-05-05T04:25:29.911245+00:00"
144
+ Cursor(
145
+ rowid=20,
146
+ sort_column=CursorSortColumn(
147
+ type=CursorSortColumnDataType.DATETIME,
148
+ value=datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00")
149
+ )
150
+ )
151
+ """
152
+
153
+ rowid: int
154
+ sort_column: Optional[CursorSortColumn] = None
155
+
156
+ _DELIMITER: ClassVar[str] = ":"
157
+
158
+ def __str__(self) -> str:
159
+ cursor_parts = [str(self.rowid)]
160
+ if (sort_column := self.sort_column) is not None:
161
+ cursor_parts.extend([sort_column.type.name, str(sort_column)])
162
+ return base64.b64encode(self._DELIMITER.join(cursor_parts).encode()).decode()
163
+
164
+ @classmethod
165
+ def from_string(cls, cursor: str) -> "Cursor":
166
+ decoded = base64.b64decode(cursor).decode()
167
+ rowid_string = decoded
168
+ sort_column = None
169
+ if (first_delimiter_index := decoded.find(cls._DELIMITER)) > -1:
170
+ rowid_string = decoded[:first_delimiter_index]
171
+ second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1)
172
+ sort_column = CursorSortColumn.from_string(
173
+ type=CursorSortColumnDataType[
174
+ decoded[first_delimiter_index + 1 : second_delimiter_index]
175
+ ],
176
+ cursor_string=decoded[second_delimiter_index + 1 :],
177
+ )
178
+ return cls(rowid=int(rowid_string), sort_column=sort_column)
179
+
180
+
181
+ def offset_to_cursor(offset: int) -> CursorString:
60
182
  """
61
183
  Creates the cursor string from an offset.
62
184
  """
63
185
  return base64.b64encode(f"{CURSOR_PREFIX}{offset}".encode("utf-8")).decode()
64
186
 
65
187
 
66
- def cursor_to_offset(cursor: Cursor) -> int:
188
+ def cursor_to_offset(cursor: CursorString) -> int:
67
189
  """
68
190
  Extracts the offset from the cursor string.
69
191
  """
@@ -71,13 +193,13 @@ def cursor_to_offset(cursor: Cursor) -> int:
71
193
  return int(offset)
72
194
 
73
195
 
74
- def get_offset_with_default(cursor: Optional[Cursor], default_offset: int) -> int:
196
+ def get_offset_with_default(cursor: Optional[CursorString], default_offset: int) -> int:
75
197
  """
76
198
  Given an optional cursor and a default offset, returns the offset
77
199
  to use; if the cursor contains a valid offset, that will be used,
78
200
  otherwise it will be the default.
79
201
  """
80
- if not isinstance(cursor, Cursor):
202
+ if not isinstance(cursor, CursorString):
81
203
  return default_offset
82
204
  offset = cursor_to_offset(cursor)
83
205
  return offset if isinstance(offset, int) else default_offset
@@ -90,9 +212,9 @@ class ConnectionArgs:
90
212
  """
91
213
 
92
214
  first: Optional[int] = UNSET
93
- after: Optional[Cursor] = UNSET
215
+ after: Optional[CursorString] = UNSET
94
216
  last: Optional[int] = UNSET
95
- before: Optional[Cursor] = UNSET
217
+ before: Optional[CursorString] = UNSET
96
218
 
97
219
 
98
220
  def connection_from_list(
@@ -169,6 +291,25 @@ def connection_from_list_slice(
169
291
  end_cursor=last_edge.cursor if last_edge else None,
170
292
  has_previous_page=start_offset > lower_bound if isinstance(args.last, int) else False,
171
293
  has_next_page=end_offset < upper_bound if isinstance(args.first, int) else False,
172
- total_count=list_length,
294
+ ),
295
+ )
296
+
297
+
298
+ def connections(
299
+ data: List[Tuple[Cursor, GenericType]],
300
+ has_previous_page: bool,
301
+ has_next_page: bool,
302
+ ) -> Connection[GenericType]:
303
+ edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in data]
304
+ has_edges = len(edges) > 0
305
+ first_edge = edges[0] if has_edges else None
306
+ last_edge = edges[-1] if has_edges else None
307
+ return Connection(
308
+ edges=edges,
309
+ page_info=PageInfo(
310
+ start_cursor=first_edge.cursor if first_edge else None,
311
+ end_cursor=last_edge.cursor if last_edge else None,
312
+ has_previous_page=has_previous_page,
313
+ has_next_page=has_next_page,
173
314
  ),
174
315
  )