arize-phoenix 3.25.0__py3-none-any.whl → 4.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +51 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +44 -200
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -0,0 +1,82 @@
1
+ import gzip
2
+ import zlib
3
+
4
+ from google.protobuf.message import DecodeError
5
+ from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
6
+ ExportTraceServiceRequest,
7
+ )
8
+ from starlette.background import BackgroundTask
9
+ from starlette.concurrency import run_in_threadpool
10
+ from starlette.datastructures import State
11
+ from starlette.requests import Request
12
+ from starlette.responses import Response
13
+ from starlette.status import (
14
+ HTTP_403_FORBIDDEN,
15
+ HTTP_415_UNSUPPORTED_MEDIA_TYPE,
16
+ HTTP_422_UNPROCESSABLE_ENTITY,
17
+ )
18
+
19
+ from phoenix.trace.otel import decode_otlp_span
20
+ from phoenix.utilities.project import get_project_name
21
+
22
+
23
+ async def post_traces(request: Request) -> Response:
24
+ """
25
+ summary: Send traces to Phoenix
26
+ operationId: addTraces
27
+ tags:
28
+ - traces
29
+ requestBody:
30
+ required: true
31
+ content:
32
+ application/x-protobuf:
33
+ schema:
34
+ type: string
35
+ format: binary
36
+ responses:
37
+ 200:
38
+ description: Success
39
+ 403:
40
+ description: Forbidden
41
+ 415:
42
+ description: Unsupported content type, only gzipped protobuf
43
+ 422:
44
+ description: Request body is invalid
45
+ """
46
+ if request.app.state.read_only:
47
+ return Response(status_code=HTTP_403_FORBIDDEN)
48
+ content_type = request.headers.get("content-type")
49
+ if content_type != "application/x-protobuf":
50
+ return Response(
51
+ content=f"Unsupported content type: {content_type}",
52
+ status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE,
53
+ )
54
+ content_encoding = request.headers.get("content-encoding")
55
+ if content_encoding and content_encoding not in ("gzip", "deflate"):
56
+ return Response(
57
+ content=f"Unsupported content encoding: {content_encoding}",
58
+ status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE,
59
+ )
60
+ body = await request.body()
61
+ if content_encoding == "gzip":
62
+ body = await run_in_threadpool(gzip.decompress, body)
63
+ elif content_encoding == "deflate":
64
+ body = await run_in_threadpool(zlib.decompress, body)
65
+ req = ExportTraceServiceRequest()
66
+ try:
67
+ await run_in_threadpool(req.ParseFromString, body)
68
+ except DecodeError:
69
+ return Response(
70
+ content="Request body is invalid ExportTraceServiceRequest",
71
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
72
+ )
73
+ return Response(background=BackgroundTask(_add_spans, req, request.state))
74
+
75
+
76
+ async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
77
+ for resource_spans in req.resource_spans:
78
+ project_name = get_project_name(resource_spans.resource.attributes)
79
+ for scope_span in resource_spans.scope_spans:
80
+ for otlp_span in scope_span.spans:
81
+ span = await run_in_threadpool(decode_otlp_span, otlp_span)
82
+ await state.queue_span_for_bulk_insert(span, project_name)
@@ -4,11 +4,17 @@ from typing import Dict, List, Optional, Set, Union
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
  import strawberry
7
+ from sqlalchemy import delete, select
8
+ from sqlalchemy.orm import contains_eager, load_only
7
9
  from strawberry import ID, UNSET
8
10
  from strawberry.types import Info
9
11
  from typing_extensions import Annotated
10
12
 
13
+ from phoenix.config import DEFAULT_PROJECT_NAME
14
+ from phoenix.db import models
15
+ from phoenix.db.insertion.span import ClearProjectSpansEvent
11
16
  from phoenix.pointcloud.clustering import Hdbscan
17
+ from phoenix.server.api.context import Context
12
18
  from phoenix.server.api.helpers import ensure_list
13
19
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
14
20
  from phoenix.server.api.input_types.Coordinates import (
@@ -16,56 +22,70 @@ from phoenix.server.api.input_types.Coordinates import (
16
22
  InputCoordinate3D,
17
23
  )
18
24
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
19
- from phoenix.server.api.types.Project import Project
20
-
21
- from .context import Context
22
- from .types.DatasetRole import AncillaryDatasetRole, DatasetRole
23
- from .types.Dimension import to_gql_dimension
24
- from .types.EmbeddingDimension import (
25
+ from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
26
+ from phoenix.server.api.types.Dimension import to_gql_dimension
27
+ from phoenix.server.api.types.EmbeddingDimension import (
25
28
  DEFAULT_CLUSTER_SELECTION_EPSILON,
26
29
  DEFAULT_MIN_CLUSTER_SIZE,
27
30
  DEFAULT_MIN_SAMPLES,
28
31
  to_gql_embedding_dimension,
29
32
  )
30
- from .types.Event import create_event_id, unpack_event_id
31
- from .types.ExportEventsMutation import ExportEventsMutation
32
- from .types.Functionality import Functionality
33
- from .types.Model import Model
34
- from .types.node import GlobalID, Node, from_global_id, from_global_id_with_expected_type
35
- from .types.pagination import Connection, ConnectionArgs, Cursor, connection_from_list
33
+ from phoenix.server.api.types.Event import create_event_id, unpack_event_id
34
+ from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
35
+ from phoenix.server.api.types.Functionality import Functionality
36
+ from phoenix.server.api.types.Model import Model
37
+ from phoenix.server.api.types.node import (
38
+ GlobalID,
39
+ Node,
40
+ from_global_id,
41
+ from_global_id_with_expected_type,
42
+ )
43
+ from phoenix.server.api.types.pagination import (
44
+ Connection,
45
+ ConnectionArgs,
46
+ CursorString,
47
+ connection_from_list,
48
+ )
49
+ from phoenix.server.api.types.Project import Project
50
+ from phoenix.server.api.types.Span import to_gql_span
51
+ from phoenix.server.api.types.Trace import Trace
36
52
 
37
53
 
38
54
  @strawberry.type
39
55
  class Query:
40
56
  @strawberry.field
41
- def projects(
57
+ async def projects(
42
58
  self,
43
59
  info: Info[Context, None],
44
60
  first: Optional[int] = 50,
45
61
  last: Optional[int] = UNSET,
46
- after: Optional[Cursor] = UNSET,
47
- before: Optional[Cursor] = UNSET,
62
+ after: Optional[CursorString] = UNSET,
63
+ before: Optional[CursorString] = UNSET,
48
64
  ) -> Connection[Project]:
49
65
  args = ConnectionArgs(
50
66
  first=first,
51
- after=after if isinstance(after, Cursor) else None,
67
+ after=after if isinstance(after, CursorString) else None,
52
68
  last=last,
53
- before=before if isinstance(before, Cursor) else None,
54
- )
55
- data = (
56
- []
57
- if (traces := info.context.traces) is None
58
- else [
59
- Project(id_attr=project_id, name=project_name, project=project)
60
- for project_id, project_name, project in traces.get_projects()
61
- ]
69
+ before=before if isinstance(before, CursorString) else None,
62
70
  )
71
+ async with info.context.db() as session:
72
+ projects = await session.scalars(select(models.Project))
73
+ data = [
74
+ Project(
75
+ id_attr=project.id,
76
+ name=project.name,
77
+ gradient_start_color=project.gradient_start_color,
78
+ gradient_end_color=project.gradient_end_color,
79
+ )
80
+ for project in projects
81
+ ]
63
82
  return connection_from_list(data=data, args=args)
64
83
 
65
84
  @strawberry.field
66
- def functionality(self, info: Info[Context, None]) -> "Functionality":
85
+ async def functionality(self, info: Info[Context, None]) -> "Functionality":
67
86
  has_model_inferences = not info.context.model.is_empty
68
- has_traces = info.context.traces is not None
87
+ async with info.context.db() as session:
88
+ has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
69
89
  return Functionality(
70
90
  model_inferences=has_model_inferences,
71
91
  tracing=has_traces,
@@ -76,7 +96,7 @@ class Query:
76
96
  return Model()
77
97
 
78
98
  @strawberry.field
79
- def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
99
+ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
80
100
  type_name, node_id = from_global_id(str(id))
81
101
  if type_name == "Dimension":
82
102
  dimension = info.context.model.scalar_dimensions[node_id]
@@ -85,17 +105,42 @@ class Query:
85
105
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
86
106
  return to_gql_embedding_dimension(node_id, embedding_dimension)
87
107
  elif type_name == "Project":
88
- if (traces := info.context.traces) is not None:
89
- projects = {
90
- project_id: (project_name, project)
91
- for project_id, project_name, project in traces.get_projects()
92
- }
93
- if node_id in projects:
94
- name, project = projects[node_id]
95
- return Project(id_attr=node_id, name=name, project=project)
96
- raise Exception(f"Unknown project: {id}")
97
-
98
- raise Exception(f"Unknown node type: {type}")
108
+ project_stmt = select(
109
+ models.Project.id,
110
+ models.Project.name,
111
+ models.Project.gradient_start_color,
112
+ models.Project.gradient_end_color,
113
+ ).where(models.Project.id == node_id)
114
+ async with info.context.db() as session:
115
+ project = (await session.execute(project_stmt)).first()
116
+ if project is None:
117
+ raise ValueError(f"Unknown project: {id}")
118
+ return Project(
119
+ id_attr=project.id,
120
+ name=project.name,
121
+ gradient_start_color=project.gradient_start_color,
122
+ gradient_end_color=project.gradient_end_color,
123
+ )
124
+ elif type_name == "Trace":
125
+ trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
126
+ async with info.context.db() as session:
127
+ id_attr = await session.scalar(trace_stmt)
128
+ if id_attr is None:
129
+ raise ValueError(f"Unknown trace: {id}")
130
+ return Trace(id_attr=id_attr)
131
+ elif type_name == "Span":
132
+ span_stmt = (
133
+ select(models.Span)
134
+ .join(models.Trace)
135
+ .options(contains_eager(models.Span.trace))
136
+ .where(models.Span.id == node_id)
137
+ )
138
+ async with info.context.db() as session:
139
+ span = await session.scalar(span_stmt)
140
+ if span is None:
141
+ raise ValueError(f"Unknown span: {id}")
142
+ return to_gql_span(span)
143
+ raise Exception(f"Unknown node type: {type_name}")
99
144
 
100
145
  @strawberry.field
101
146
  def clusters(
@@ -229,18 +274,37 @@ class Query:
229
274
  @strawberry.type
230
275
  class Mutation(ExportEventsMutation):
231
276
  @strawberry.mutation
232
- def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
233
- if (traces := info.context.traces) is not None:
234
- node_id = from_global_id_with_expected_type(str(id), "Project")
235
- traces.archive_project(node_id)
277
+ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
278
+ node_id = from_global_id_with_expected_type(str(id), "Project")
279
+ async with info.context.db() as session:
280
+ project = await session.scalar(
281
+ select(models.Project)
282
+ .where(models.Project.id == node_id)
283
+ .options(load_only(models.Project.name))
284
+ )
285
+ if project is None:
286
+ raise ValueError(f"Unknown project: {id}")
287
+ if project.name == DEFAULT_PROJECT_NAME:
288
+ raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
289
+ await session.delete(project)
236
290
  return Query()
237
291
 
238
292
  @strawberry.mutation
239
- def archive_project(self, info: Info[Context, None], id: GlobalID) -> Query:
240
- if (traces := info.context.traces) is not None:
241
- node_id = from_global_id_with_expected_type(str(id), "Project")
242
- traces.archive_project(node_id)
293
+ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
294
+ project_id = from_global_id_with_expected_type(str(id), "Project")
295
+ delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
296
+ async with info.context.db() as session:
297
+ await session.execute(delete_statement)
298
+ if cache := info.context.cache_for_dataloaders:
299
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
243
300
  return Query()
244
301
 
245
302
 
246
- schema = strawberry.Schema(query=Query, mutation=Mutation)
303
+ # This is the schema for generating `schema.graphql`.
304
+ # See https://strawberry.rocks/docs/guides/schema-export
305
+ # It should be kept in sync with the server's runtime-initialized
306
+ # instance. To do so, search for the usage of `strawberry.Schema(...)`.
307
+ schema = strawberry.Schema(
308
+ query=Query,
309
+ mutation=Mutation,
310
+ )
@@ -80,7 +80,7 @@ class DocumentEvaluationSummary:
80
80
  return result
81
81
  values = self.metrics_collection.apply(lambda m: m.precision(k))
82
82
  result = (values.mean(), values.count())
83
- self._cached_average_ndcg_results[k] = result
83
+ self._cached_average_precision_results[k] = result
84
84
  return result
85
85
 
86
86
  @cached_property
@@ -3,7 +3,7 @@ from typing import Optional
3
3
  import strawberry
4
4
 
5
5
  import phoenix.trace.v1 as pb
6
- from phoenix.trace.schemas import SpanID, TraceID
6
+ from phoenix.db.models import DocumentAnnotation, SpanAnnotation, TraceAnnotation
7
7
 
8
8
 
9
9
  @strawberry.interface
@@ -26,47 +26,56 @@ class Evaluation:
26
26
 
27
27
  @strawberry.type
28
28
  class TraceEvaluation(Evaluation):
29
- trace_id: strawberry.Private[TraceID]
30
-
31
29
  @staticmethod
32
30
  def from_pb_evaluation(evaluation: pb.Evaluation) -> "TraceEvaluation":
33
31
  result = evaluation.result
34
32
  score = result.score.value if result.HasField("score") else None
35
33
  label = result.label.value if result.HasField("label") else None
36
34
  explanation = result.explanation.value if result.HasField("explanation") else None
37
- trace_id = TraceID(evaluation.subject_id.trace_id)
38
35
  return TraceEvaluation(
39
36
  name=evaluation.name,
40
37
  score=score,
41
38
  label=label,
42
39
  explanation=explanation,
43
- trace_id=trace_id,
40
+ )
41
+
42
+ @staticmethod
43
+ def from_sql_trace_annotation(annotation: TraceAnnotation) -> "TraceEvaluation":
44
+ return TraceEvaluation(
45
+ name=annotation.name,
46
+ score=annotation.score,
47
+ label=annotation.label,
48
+ explanation=annotation.explanation,
44
49
  )
45
50
 
46
51
 
47
52
  @strawberry.type
48
53
  class SpanEvaluation(Evaluation):
49
- span_id: strawberry.Private[SpanID]
50
-
51
54
  @staticmethod
52
55
  def from_pb_evaluation(evaluation: pb.Evaluation) -> "SpanEvaluation":
53
56
  result = evaluation.result
54
57
  score = result.score.value if result.HasField("score") else None
55
58
  label = result.label.value if result.HasField("label") else None
56
59
  explanation = result.explanation.value if result.HasField("explanation") else None
57
- span_id = SpanID(evaluation.subject_id.span_id)
58
60
  return SpanEvaluation(
59
61
  name=evaluation.name,
60
62
  score=score,
61
63
  label=label,
62
64
  explanation=explanation,
63
- span_id=span_id,
65
+ )
66
+
67
+ @staticmethod
68
+ def from_sql_span_annotation(annotation: SpanAnnotation) -> "SpanEvaluation":
69
+ return SpanEvaluation(
70
+ name=annotation.name,
71
+ score=annotation.score,
72
+ label=annotation.label,
73
+ explanation=annotation.explanation,
64
74
  )
65
75
 
66
76
 
67
77
  @strawberry.type
68
78
  class DocumentEvaluation(Evaluation):
69
- span_id: strawberry.Private[SpanID]
70
79
  document_position: int = strawberry.field(
71
80
  description="The zero-based index among retrieved documents, which "
72
81
  "is collected as a list (even when ordering is not inherently meaningful)."
@@ -80,12 +89,20 @@ class DocumentEvaluation(Evaluation):
80
89
  explanation = result.explanation.value if result.HasField("explanation") else None
81
90
  document_retrieval_id = evaluation.subject_id.document_retrieval_id
82
91
  document_position = document_retrieval_id.document_position
83
- span_id = SpanID(document_retrieval_id.span_id)
84
92
  return DocumentEvaluation(
85
93
  name=evaluation.name,
86
94
  score=score,
87
95
  label=label,
88
96
  explanation=explanation,
89
97
  document_position=document_position,
90
- span_id=span_id,
98
+ )
99
+
100
+ @staticmethod
101
+ def from_sql_document_annotation(annotation: DocumentAnnotation) -> "DocumentEvaluation":
102
+ return DocumentEvaluation(
103
+ name=annotation.name,
104
+ score=annotation.score,
105
+ label=annotation.label,
106
+ explanation=annotation.explanation,
107
+ document_position=annotation.document_position,
91
108
  )
@@ -1,13 +1,12 @@
1
- import math
2
- from functools import cached_property
3
- from typing import List, Optional, Tuple, cast
1
+ from typing import List, Optional, Union, cast
4
2
 
5
3
  import pandas as pd
6
4
  import strawberry
7
- from pandas.api.types import CategoricalDtype
8
5
  from strawberry import Private
9
6
 
10
- import phoenix.trace.v1 as pb
7
+ from phoenix.db import models
8
+
9
+ AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation]
11
10
 
12
11
 
13
12
  @strawberry.type
@@ -18,58 +17,44 @@ class LabelFraction:
18
17
 
19
18
  @strawberry.type
20
19
  class EvaluationSummary:
21
- count: int
22
- labels: Tuple[str, ...]
23
- evaluations: Private[Tuple[pb.Evaluation, ...]]
20
+ df: Private[pd.DataFrame]
21
+
22
+ def __init__(self, dataframe: pd.DataFrame) -> None:
23
+ self.df = dataframe
24
24
 
25
- def __init__(
26
- self,
27
- evaluations: Tuple[pb.Evaluation, ...],
28
- labels: Tuple[str, ...],
29
- ) -> None:
30
- self.evaluations = evaluations
31
- self.labels = labels
32
- self.count = len(evaluations)
25
+ @strawberry.field
26
+ def count(self) -> int:
27
+ return cast(int, self.df.record_count.sum())
28
+
29
+ @strawberry.field
30
+ def labels(self) -> List[str]:
31
+ return self.df.label.dropna().tolist()
33
32
 
34
33
  @strawberry.field
35
34
  def label_fractions(self) -> List[LabelFraction]:
36
- if not self.labels or not (n := len(self._eval_labels)):
35
+ if not (n := self.df.label_count.sum()):
37
36
  return []
38
- counts = self._eval_labels.value_counts(dropna=True)
39
37
  return [
40
- LabelFraction(label=cast(str, label), fraction=count / n)
41
- for label, count in counts.items()
38
+ LabelFraction(
39
+ label=cast(str, row.label),
40
+ fraction=row.label_count / n,
41
+ )
42
+ for row in self.df.loc[
43
+ self.df.label.notna(),
44
+ ["label", "label_count"],
45
+ ].itertuples()
42
46
  ]
43
47
 
44
48
  @strawberry.field
45
49
  def mean_score(self) -> Optional[float]:
46
- value = self._eval_scores.mean()
47
- return None if math.isnan(value) else value
50
+ if not (n := self.df.score_count.sum()):
51
+ return None
52
+ return cast(float, self.df.score_sum.sum() / n)
48
53
 
49
54
  @strawberry.field
50
55
  def score_count(self) -> int:
51
- return self._eval_scores.count()
56
+ return cast(int, self.df.score_count.sum())
52
57
 
53
58
  @strawberry.field
54
59
  def label_count(self) -> int:
55
- return self._eval_labels.count()
56
-
57
- @cached_property
58
- def _eval_scores(self) -> "pd.Series[float]":
59
- return pd.Series(
60
- (
61
- evaluation.result.score.value if evaluation.result.HasField("score") else None
62
- for evaluation in self.evaluations
63
- ),
64
- dtype=float,
65
- )
66
-
67
- @cached_property
68
- def _eval_labels(self) -> "pd.Series[CategoricalDtype]":
69
- return pd.Series(
70
- (
71
- evaluation.result.label.value if evaluation.result.HasField("label") else None
72
- for evaluation in self.evaluations
73
- ),
74
- dtype=CategoricalDtype(categories=self.labels), # type: ignore
75
- )
60
+ return cast(int, self.df.label_count.sum())
@@ -8,8 +8,8 @@ import phoenix.trace.schemas as trace_schemas
8
8
 
9
9
  @strawberry.enum
10
10
  class MimeType(Enum):
11
- text = trace_schemas.MimeType.TEXT
12
- json = trace_schemas.MimeType.JSON
11
+ text = trace_schemas.MimeType.TEXT.value
12
+ json = trace_schemas.MimeType.JSON.value
13
13
 
14
14
  @classmethod
15
15
  def _missing_(cls, v: Any) -> Optional["MimeType"]:
@@ -19,7 +19,7 @@ from .DatasetRole import AncillaryDatasetRole, DatasetRole
19
19
  from .Dimension import Dimension, to_gql_dimension
20
20
  from .EmbeddingDimension import EmbeddingDimension, to_gql_embedding_dimension
21
21
  from .ExportedFile import ExportedFile
22
- from .pagination import Connection, ConnectionArgs, Cursor, connection_from_list
22
+ from .pagination import Connection, ConnectionArgs, CursorString, connection_from_list
23
23
  from .TimeSeries import (
24
24
  PerformanceTimeSeries,
25
25
  ensure_timeseries_parameters,
@@ -35,8 +35,8 @@ class Model:
35
35
  info: Info[Context, None],
36
36
  first: Optional[int] = 50,
37
37
  last: Optional[int] = UNSET,
38
- after: Optional[Cursor] = UNSET,
39
- before: Optional[Cursor] = UNSET,
38
+ after: Optional[CursorString] = UNSET,
39
+ before: Optional[CursorString] = UNSET,
40
40
  include: Optional[DimensionFilter] = UNSET,
41
41
  exclude: Optional[DimensionFilter] = UNSET,
42
42
  ) -> Connection[Dimension]:
@@ -50,9 +50,9 @@ class Model:
50
50
  ],
51
51
  args=ConnectionArgs(
52
52
  first=first,
53
- after=after if isinstance(after, Cursor) else None,
53
+ after=after if isinstance(after, CursorString) else None,
54
54
  last=last,
55
- before=before if isinstance(before, Cursor) else None,
55
+ before=before if isinstance(before, CursorString) else None,
56
56
  ),
57
57
  )
58
58
 
@@ -105,8 +105,8 @@ class Model:
105
105
  info: Info[Context, None],
106
106
  first: Optional[int] = 50,
107
107
  last: Optional[int] = UNSET,
108
- after: Optional[Cursor] = UNSET,
109
- before: Optional[Cursor] = UNSET,
108
+ after: Optional[CursorString] = UNSET,
109
+ before: Optional[CursorString] = UNSET,
110
110
  ) -> Connection[EmbeddingDimension]:
111
111
  """
112
112
  A non-trivial implementation should efficiently fetch only
@@ -123,9 +123,9 @@ class Model:
123
123
  ],
124
124
  args=ConnectionArgs(
125
125
  first=first,
126
- after=after if isinstance(after, Cursor) else None,
126
+ after=after if isinstance(after, CursorString) else None,
127
127
  last=last,
128
- before=before if isinstance(before, Cursor) else None,
128
+ before=before if isinstance(before, CursorString) else None,
129
129
  ),
130
130
  )
131
131