arize-phoenix 3.16.1__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,85 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from sqlalchemy.orm import joinedload
7
+ from strawberry import UNSET
8
+ from strawberry.relay.types import Connection, GlobalID, Node, NodeID
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
15
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
17
+ from phoenix.server.api.types.pagination import (
18
+ ConnectionArgs,
19
+ CursorString,
20
+ connection_from_list,
21
+ )
22
+ from phoenix.server.api.types.Span import Span, to_gql_span
23
+
24
+
25
+ @strawberry.type
26
+ class DatasetExample(Node):
27
+ id_attr: NodeID[int]
28
+ created_at: datetime
29
+ version_id: strawberry.Private[Optional[int]] = None
30
+
31
+ @strawberry.field
32
+ async def revision(
33
+ self,
34
+ info: Info[Context, None],
35
+ dataset_version_id: Optional[GlobalID] = UNSET,
36
+ ) -> DatasetExampleRevision:
37
+ example_id = self.id_attr
38
+ version_id: Optional[int] = None
39
+ if dataset_version_id:
40
+ version_id = from_global_id_with_expected_type(
41
+ global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
42
+ )
43
+ elif self.version_id is not None:
44
+ version_id = self.version_id
45
+ return await info.context.data_loaders.dataset_example_revisions.load(
46
+ (example_id, version_id)
47
+ )
48
+
49
+ @strawberry.field
50
+ async def span(
51
+ self,
52
+ info: Info[Context, None],
53
+ ) -> Optional[Span]:
54
+ return (
55
+ to_gql_span(span)
56
+ if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
57
+ else None
58
+ )
59
+
60
+ @strawberry.field
61
+ async def experiment_runs(
62
+ self,
63
+ info: Info[Context, None],
64
+ first: Optional[int] = 50,
65
+ last: Optional[int] = UNSET,
66
+ after: Optional[CursorString] = UNSET,
67
+ before: Optional[CursorString] = UNSET,
68
+ ) -> Connection[ExperimentRun]:
69
+ args = ConnectionArgs(
70
+ first=first,
71
+ after=after if isinstance(after, CursorString) else None,
72
+ last=last,
73
+ before=before if isinstance(before, CursorString) else None,
74
+ )
75
+ example_id = self.id_attr
76
+ query = (
77
+ select(models.ExperimentRun)
78
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
79
+ .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
80
+ .where(models.ExperimentRun.dataset_example_id == example_id)
81
+ .order_by(models.Experiment.id.desc())
82
+ )
83
+ async with info.context.db() as session:
84
+ runs = (await session.scalars(query)).all()
85
+ return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
@@ -0,0 +1,34 @@
1
+ from datetime import datetime
2
+ from enum import Enum
3
+
4
+ import strawberry
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
8
+
9
+
10
+ @strawberry.enum
11
+ class RevisionKind(Enum):
12
+ CREATE = "CREATE"
13
+ PATCH = "PATCH"
14
+ DELETE = "DELETE"
15
+
16
+
17
+ @strawberry.type
18
+ class DatasetExampleRevision(ExampleRevision):
19
+ """
20
+ Represents a revision (i.e., update or alteration) of a dataset example.
21
+ """
22
+
23
+ revision_kind: RevisionKind
24
+ created_at: datetime
25
+
26
+ @classmethod
27
+ def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
28
+ return cls(
29
+ input=revision.input,
30
+ output=revision.output,
31
+ metadata=revision.metadata_,
32
+ revision_kind=RevisionKind(revision.revision_kind),
33
+ created_at=revision.created_at,
34
+ )
@@ -0,0 +1,14 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+
9
+ @strawberry.type
10
+ class DatasetVersion(Node):
11
+ id_attr: NodeID[int]
12
+ description: Optional[str]
13
+ metadata: JSON
14
+ created_at: datetime
@@ -1,9 +1,10 @@
1
1
  from collections import defaultdict
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import pandas as pd
5
5
  import strawberry
6
6
  from strawberry import UNSET
7
+ from strawberry.relay import Node, NodeID
7
8
  from strawberry.types import Info
8
9
  from typing_extensions import Annotated
9
10
 
@@ -17,12 +18,11 @@ from ..context import Context
17
18
  from ..input_types.Granularity import Granularity
18
19
  from ..input_types.TimeRange import TimeRange
19
20
  from .DataQualityMetric import DataQualityMetric
20
- from .DatasetRole import DatasetRole
21
21
  from .DatasetValues import DatasetValues
22
22
  from .DimensionDataType import DimensionDataType
23
23
  from .DimensionShape import DimensionShape
24
24
  from .DimensionType import DimensionType
25
- from .node import Node
25
+ from .InferencesRole import InferencesRole
26
26
  from .ScalarDriftMetricEnum import ScalarDriftMetric
27
27
  from .Segments import (
28
28
  GqlBinFactory,
@@ -40,6 +40,7 @@ from .TimeSeries import (
40
40
 
41
41
  @strawberry.type
42
42
  class Dimension(Node):
43
+ id_attr: NodeID[int]
43
44
  name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
44
45
  type: DimensionType = strawberry.field(
45
46
  description="Whether the dimension represents a feature, tag, prediction, or actual."
@@ -62,16 +63,16 @@ class Dimension(Node):
62
63
  """
63
64
  Computes a drift metric between all reference data and the primary data
64
65
  belonging to the input time range (inclusive of the time range start and
65
- exclusive of the time range end). Returns None if no reference dataset
66
- exists, if no primary data exists in the input time range, or if the
66
+ exclusive of the time range end). Returns None if no reference inferences
67
+ exist, if no primary data exists in the input time range, or if the
67
68
  input time range is invalid.
68
69
  """
69
70
  model = info.context.model
70
71
  if model[REFERENCE].empty:
71
72
  return None
72
- dataset = model[PRIMARY]
73
+ inferences = model[PRIMARY]
73
74
  time_range, granularity = ensure_timeseries_parameters(
74
- dataset,
75
+ inferences,
75
76
  time_range,
76
77
  )
77
78
  data = get_drift_timeseries_data(
@@ -92,18 +93,18 @@ class Dimension(Node):
92
93
  info: Info[Context, None],
93
94
  metric: DataQualityMetric,
94
95
  time_range: Optional[TimeRange] = UNSET,
95
- dataset_role: Annotated[
96
- Optional[DatasetRole],
96
+ inferences_role: Annotated[
97
+ Optional[InferencesRole],
97
98
  strawberry.argument(
98
- description="The dataset (primary or reference) to query",
99
+ description="The inferences (primary or reference) to query",
99
100
  ),
100
- ] = DatasetRole.primary,
101
+ ] = InferencesRole.primary,
101
102
  ) -> Optional[float]:
102
- if not isinstance(dataset_role, DatasetRole):
103
- dataset_role = DatasetRole.primary
104
- dataset = info.context.model[dataset_role.value]
103
+ if not isinstance(inferences_role, InferencesRole):
104
+ inferences_role = InferencesRole.primary
105
+ inferences = info.context.model[inferences_role.value]
105
106
  time_range, granularity = ensure_timeseries_parameters(
106
- dataset,
107
+ inferences,
107
108
  time_range,
108
109
  )
109
110
  data = get_data_quality_timeseries_data(
@@ -111,7 +112,7 @@ class Dimension(Node):
111
112
  metric,
112
113
  time_range,
113
114
  granularity,
114
- dataset_role,
115
+ inferences_role,
115
116
  )
116
117
  return data[0].value if len(data) else None
117
118
 
@@ -122,7 +123,7 @@ class Dimension(Node):
122
123
  " Missing values are excluded. Non-categorical dimensions return an empty list."
123
124
  )
124
125
  ) # type: ignore # https://github.com/strawberry-graphql/strawberry/issues/1929
125
- def categories(self) -> List[str]:
126
+ def categories(self) -> list[str]:
126
127
  return list(self.dimension.categories)
127
128
 
128
129
  @strawberry.field(
@@ -139,18 +140,18 @@ class Dimension(Node):
139
140
  metric: DataQualityMetric,
140
141
  time_range: TimeRange,
141
142
  granularity: Granularity,
142
- dataset_role: Annotated[
143
- Optional[DatasetRole],
143
+ inferences_role: Annotated[
144
+ Optional[InferencesRole],
144
145
  strawberry.argument(
145
- description="The dataset (primary or reference) to query",
146
+ description="The inferences (primary or reference) to query",
146
147
  ),
147
- ] = DatasetRole.primary,
148
+ ] = InferencesRole.primary,
148
149
  ) -> DataQualityTimeSeries:
149
- if not isinstance(dataset_role, DatasetRole):
150
- dataset_role = DatasetRole.primary
151
- dataset = info.context.model[dataset_role.value]
150
+ if not isinstance(inferences_role, InferencesRole):
151
+ inferences_role = InferencesRole.primary
152
+ inferences = info.context.model[inferences_role.value]
152
153
  time_range, granularity = ensure_timeseries_parameters(
153
- dataset,
154
+ inferences,
154
155
  time_range,
155
156
  granularity,
156
157
  )
@@ -160,7 +161,7 @@ class Dimension(Node):
160
161
  metric,
161
162
  time_range,
162
163
  granularity,
163
- dataset_role,
164
+ inferences_role,
164
165
  )
165
166
  )
166
167
 
@@ -182,9 +183,9 @@ class Dimension(Node):
182
183
  model = info.context.model
183
184
  if model[REFERENCE].empty:
184
185
  return DriftTimeSeries(data=[])
185
- dataset = model[PRIMARY]
186
+ inferences = model[PRIMARY]
186
187
  time_range, granularity = ensure_timeseries_parameters(
187
- dataset,
188
+ inferences,
188
189
  time_range,
189
190
  granularity,
190
191
  )
@@ -202,7 +203,7 @@ class Dimension(Node):
202
203
  )
203
204
 
204
205
  @strawberry.field(
205
- description="Returns the segments across both datasets and returns the counts per segment",
206
+ description="The segments across both inference sets and returns the counts per segment",
206
207
  ) # type: ignore
207
208
  def segments_comparison(
208
209
  self,
@@ -249,8 +250,8 @@ class Dimension(Node):
249
250
  if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
250
251
  all_bins = all_bins.union(binning_method.bins)
251
252
  for bin in all_bins:
252
- values: Dict[ms.DatasetRole, Any] = defaultdict(lambda: None)
253
- for role in ms.DatasetRole:
253
+ values: dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
254
+ for role in ms.InferencesRole:
254
255
  if model[role].empty:
255
256
  continue
256
257
  try:
@@ -1,6 +1,7 @@
1
1
  import math
2
+ from collections.abc import Iterable
2
3
  from functools import cached_property
3
- from typing import Any, Dict, Iterable, Optional, Tuple
4
+ from typing import Any, Optional
4
5
 
5
6
  import pandas as pd
6
7
  import strawberry
@@ -24,8 +25,8 @@ class DocumentEvaluationSummary:
24
25
  ) -> None:
25
26
  self.evaluation_name = evaluation_name
26
27
  self.metrics_collection = pd.Series(metrics_collection, dtype=object)
27
- self._cached_average_ndcg_results: Dict[Optional[int], Tuple[float, int]] = {}
28
- self._cached_average_precision_results: Dict[Optional[int], Tuple[float, int]] = {}
28
+ self._cached_average_ndcg_results: dict[Optional[int], tuple[float, int]] = {}
29
+ self._cached_average_precision_results: dict[Optional[int], tuple[float, int]] = {}
29
30
 
30
31
  @strawberry.field
31
32
  def average_ndcg(self, k: Optional[int] = UNSET) -> Optional[float]:
@@ -67,7 +68,7 @@ class DocumentEvaluationSummary:
67
68
  _, count = self._average_hit
68
69
  return count
69
70
 
70
- def _average_ndcg(self, k: Optional[int] = None) -> Tuple[float, int]:
71
+ def _average_ndcg(self, k: Optional[int] = None) -> tuple[float, int]:
71
72
  if (result := self._cached_average_ndcg_results.get(k)) is not None:
72
73
  return result
73
74
  values = self.metrics_collection.apply(lambda m: m.ndcg(k))
@@ -75,20 +76,20 @@ class DocumentEvaluationSummary:
75
76
  self._cached_average_ndcg_results[k] = result
76
77
  return result
77
78
 
78
- def _average_precision(self, k: Optional[int] = None) -> Tuple[float, int]:
79
+ def _average_precision(self, k: Optional[int] = None) -> tuple[float, int]:
79
80
  if (result := self._cached_average_precision_results.get(k)) is not None:
80
81
  return result
81
82
  values = self.metrics_collection.apply(lambda m: m.precision(k))
82
83
  result = (values.mean(), values.count())
83
- self._cached_average_ndcg_results[k] = result
84
+ self._cached_average_precision_results[k] = result
84
85
  return result
85
86
 
86
87
  @cached_property
87
- def _average_reciprocal_rank(self) -> Tuple[float, int]:
88
+ def _average_reciprocal_rank(self) -> tuple[float, int]:
88
89
  values = self.metrics_collection.apply(lambda m: m.reciprocal_rank())
89
90
  return values.mean(), values.count()
90
91
 
91
92
  @cached_property
92
- def _average_hit(self) -> Tuple[float, int]:
93
+ def _average_hit(self) -> tuple[float, int]:
93
94
  values = self.metrics_collection.apply(lambda m: m.hit())
94
95
  return values.mean(), values.count()
@@ -1,13 +1,15 @@
1
1
  from collections import defaultdict
2
+ from collections.abc import Iterable, Iterator
2
3
  from datetime import timedelta
3
4
  from itertools import chain, repeat
4
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast
5
+ from typing import Any, Optional, Union, cast
5
6
 
6
7
  import numpy as np
7
8
  import numpy.typing as npt
8
9
  import pandas as pd
9
10
  import strawberry
10
11
  from strawberry import UNSET
12
+ from strawberry.relay import GlobalID, Node, NodeID
11
13
  from strawberry.scalars import ID
12
14
  from strawberry.types import Info
13
15
  from typing_extensions import Annotated
@@ -22,7 +24,7 @@ from phoenix.core.model_schema import (
22
24
  PRIMARY,
23
25
  PROMPT,
24
26
  REFERENCE,
25
- Dataset,
27
+ Inferences,
26
28
  )
27
29
  from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
28
30
  from phoenix.pointcloud.clustering import Hdbscan
@@ -31,7 +33,7 @@ from phoenix.pointcloud.projectors import Umap
31
33
  from phoenix.server.api.context import Context
32
34
  from phoenix.server.api.input_types.TimeRange import TimeRange
33
35
  from phoenix.server.api.types.Cluster import to_gql_clusters
34
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
36
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
35
37
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
36
38
 
37
39
  from ..input_types.Granularity import Granularity
@@ -39,7 +41,6 @@ from .DataQualityMetric import DataQualityMetric
39
41
  from .EmbeddingMetadata import EmbeddingMetadata
40
42
  from .Event import create_event_id, unpack_event_id
41
43
  from .EventMetadata import EventMetadata
42
- from .node import GlobalID, Node
43
44
  from .Retrieval import Retrieval
44
45
  from .TimeSeries import (
45
46
  DataQualityTimeSeries,
@@ -70,6 +71,7 @@ CORPUS = "CORPUS"
70
71
  class EmbeddingDimension(Node):
71
72
  """A embedding dimension of a model. Represents unstructured data"""
72
73
 
74
+ id_attr: NodeID[int]
73
75
  name: str
74
76
  dimension: strawberry.Private[ms.EmbeddingDimension]
75
77
 
@@ -155,16 +157,16 @@ class EmbeddingDimension(Node):
155
157
  metric: DataQualityMetric,
156
158
  time_range: TimeRange,
157
159
  granularity: Granularity,
158
- dataset_role: Annotated[
159
- Optional[DatasetRole],
160
+ inferences_role: Annotated[
161
+ Optional[InferencesRole],
160
162
  strawberry.argument(
161
163
  description="The dataset (primary or reference) to query",
162
164
  ),
163
- ] = DatasetRole.primary,
165
+ ] = InferencesRole.primary,
164
166
  ) -> DataQualityTimeSeries:
165
- if not isinstance(dataset_role, DatasetRole):
166
- dataset_role = DatasetRole.primary
167
- dataset = info.context.model[dataset_role.value]
167
+ if not isinstance(inferences_role, InferencesRole):
168
+ inferences_role = InferencesRole.primary
169
+ dataset = info.context.model[inferences_role.value]
168
170
  time_range, granularity = ensure_timeseries_parameters(
169
171
  dataset,
170
172
  time_range,
@@ -176,7 +178,7 @@ class EmbeddingDimension(Node):
176
178
  metric,
177
179
  time_range,
178
180
  granularity,
179
- dataset_role,
181
+ inferences_role,
180
182
  )
181
183
  )
182
184
 
@@ -312,18 +314,18 @@ class EmbeddingDimension(Node):
312
314
  ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
313
315
  ) -> UMAPPoints:
314
316
  model = info.context.model
315
- data: Dict[ID, npt.NDArray[np.float64]] = {}
316
- retrievals: List[Tuple[ID, Any, Any]] = []
317
- for dataset in model[Dataset]:
318
- dataset_id = dataset.role
319
- row_id_start, row_id_stop = 0, len(dataset)
320
- if dataset_id is PRIMARY:
317
+ data: dict[ID, npt.NDArray[np.float64]] = {}
318
+ retrievals: list[tuple[ID, Any, Any]] = []
319
+ for inferences in model[Inferences]:
320
+ inferences_id = inferences.role
321
+ row_id_start, row_id_stop = 0, len(inferences)
322
+ if inferences_id is PRIMARY:
321
323
  row_id_start, row_id_stop = row_interval_from_sorted_time_index(
322
- time_index=cast(pd.DatetimeIndex, dataset.index),
324
+ time_index=cast(pd.DatetimeIndex, inferences.index),
323
325
  time_start=time_range.start,
324
326
  time_stop=time_range.end,
325
327
  )
326
- vector_column = self.dimension[dataset_id]
328
+ vector_column = self.dimension[inferences_id]
327
329
  samples_collected = 0
328
330
  for row_id in _row_indices(
329
331
  row_id_start,
@@ -337,7 +339,7 @@ class EmbeddingDimension(Node):
337
339
  # of dunder method __len__.
338
340
  if not hasattr(embedding_vector, "__len__"):
339
341
  continue
340
- event_id = create_event_id(row_id, dataset_id)
342
+ event_id = create_event_id(row_id, inferences_id)
341
343
  data[event_id] = embedding_vector
342
344
  samples_collected += 1
343
345
  if isinstance(
@@ -347,23 +349,23 @@ class EmbeddingDimension(Node):
347
349
  retrievals.append(
348
350
  (
349
351
  event_id,
350
- self.dimension.context_retrieval_ids(dataset).iloc[row_id],
351
- self.dimension.context_retrieval_scores(dataset).iloc[row_id],
352
+ self.dimension.context_retrieval_ids(inferences).iloc[row_id],
353
+ self.dimension.context_retrieval_scores(inferences).iloc[row_id],
352
354
  )
353
355
  )
354
356
 
355
- context_retrievals: List[Retrieval] = []
357
+ context_retrievals: list[Retrieval] = []
356
358
  if isinstance(
357
359
  self.dimension,
358
360
  ms.RetrievalEmbeddingDimension,
359
361
  ) and (corpus := info.context.corpus):
360
- corpus_dataset = corpus[PRIMARY]
361
- for row_id, document_embedding_vector in enumerate(corpus_dataset[PROMPT]):
362
+ corpus_inferences = corpus[PRIMARY]
363
+ for row_id, document_embedding_vector in enumerate(corpus_inferences[PROMPT]):
362
364
  if not hasattr(document_embedding_vector, "__len__"):
363
365
  continue
364
- event_id = create_event_id(row_id, AncillaryDatasetRole.corpus)
366
+ event_id = create_event_id(row_id, AncillaryInferencesRole.corpus)
365
367
  data[event_id] = document_embedding_vector
366
- corpus_primary_key = corpus_dataset.primary_key
368
+ corpus_primary_key = corpus_inferences.primary_key
367
369
  for event_id, retrieval_ids, retrieval_scores in retrievals:
368
370
  if not isinstance(retrieval_ids, Iterable):
369
371
  continue
@@ -385,7 +387,7 @@ class EmbeddingDimension(Node):
385
387
  )
386
388
  except KeyError:
387
389
  continue
388
- document_embedding_vector = corpus_dataset[PROMPT].iloc[document_row_id]
390
+ document_embedding_vector = corpus_inferences[PROMPT].iloc[document_row_id]
389
391
  if not hasattr(document_embedding_vector, "__len__"):
390
392
  continue
391
393
  context_retrievals.append(
@@ -393,7 +395,7 @@ class EmbeddingDimension(Node):
393
395
  query_id=event_id,
394
396
  document_id=create_event_id(
395
397
  document_row_id,
396
- AncillaryDatasetRole.corpus,
398
+ AncillaryInferencesRole.corpus,
397
399
  ),
398
400
  relevance=document_score,
399
401
  )
@@ -413,48 +415,53 @@ class EmbeddingDimension(Node):
413
415
  ),
414
416
  ).generate(data, n_components=n_components)
415
417
 
416
- points: Dict[Union[DatasetRole, AncillaryDatasetRole], List[UMAPPoint]] = defaultdict(list)
418
+ points: dict[Union[InferencesRole, AncillaryInferencesRole], list[UMAPPoint]] = defaultdict(
419
+ list
420
+ )
417
421
  for event_id, vector in vectors.items():
418
- row_id, dataset_role = unpack_event_id(event_id)
419
- if isinstance(dataset_role, DatasetRole):
420
- dataset = model[dataset_role.value]
422
+ row_id, inferences_role = unpack_event_id(event_id)
423
+ if isinstance(inferences_role, InferencesRole):
424
+ dataset = model[inferences_role.value]
421
425
  embedding_metadata = EmbeddingMetadata(
422
- prediction_id=dataset[PREDICTION_ID][row_id],
423
- link_to_data=dataset[self.dimension.link_to_data][row_id],
424
- raw_data=dataset[self.dimension.raw_data][row_id],
426
+ prediction_id=dataset[PREDICTION_ID].iloc[row_id],
427
+ link_to_data=dataset[self.dimension.link_to_data].iloc[row_id],
428
+ raw_data=dataset[self.dimension.raw_data].iloc[row_id],
425
429
  )
426
430
  elif (corpus := info.context.corpus) is not None:
427
431
  dataset = corpus[PRIMARY]
428
432
  dimension = cast(ms.EmbeddingDimension, corpus[PROMPT])
429
433
  embedding_metadata = EmbeddingMetadata(
430
- prediction_id=dataset[PREDICTION_ID][row_id],
431
- link_to_data=dataset[dimension.link_to_data][row_id],
432
- raw_data=dataset[dimension.raw_data][row_id],
434
+ prediction_id=dataset[PREDICTION_ID].iloc[row_id],
435
+ link_to_data=dataset[dimension.link_to_data].iloc[row_id],
436
+ raw_data=dataset[dimension.raw_data].iloc[row_id],
433
437
  )
434
438
  else:
435
439
  continue
436
- points[dataset_role].append(
440
+ points[inferences_role].append(
437
441
  UMAPPoint(
438
- id=GlobalID(f"{type(self).__name__}:{str(dataset_role)}", row_id),
442
+ id=GlobalID(
443
+ type_name=f"{type(self).__name__}:{str(inferences_role)}",
444
+ node_id=str(row_id),
445
+ ),
439
446
  event_id=event_id,
440
447
  coordinates=to_gql_coordinates(vector),
441
448
  event_metadata=EventMetadata(
442
- prediction_label=dataset[PREDICTION_LABEL][row_id],
443
- prediction_score=dataset[PREDICTION_SCORE][row_id],
444
- actual_label=dataset[ACTUAL_LABEL][row_id],
445
- actual_score=dataset[ACTUAL_SCORE][row_id],
449
+ prediction_label=dataset[PREDICTION_LABEL].iloc[row_id],
450
+ prediction_score=dataset[PREDICTION_SCORE].iloc[row_id],
451
+ actual_label=dataset[ACTUAL_LABEL].iloc[row_id],
452
+ actual_score=dataset[ACTUAL_SCORE].iloc[row_id],
446
453
  ),
447
454
  embedding_metadata=embedding_metadata,
448
455
  )
449
456
  )
450
457
 
451
458
  return UMAPPoints(
452
- data=points[DatasetRole.primary],
453
- reference_data=points[DatasetRole.reference],
459
+ data=points[InferencesRole.primary],
460
+ reference_data=points[InferencesRole.reference],
454
461
  clusters=to_gql_clusters(
455
462
  clustered_events=clustered_events,
456
463
  ),
457
- corpus_data=points[AncillaryDatasetRole.corpus],
464
+ corpus_data=points[AncillaryInferencesRole.corpus],
458
465
  context_retrievals=context_retrievals,
459
466
  )
460
467
 
@@ -470,7 +477,7 @@ def _row_indices(
470
477
  return
471
478
  shuffled_indices = np.arange(start, stop)
472
479
  np.random.shuffle(shuffled_indices)
473
- yield from shuffled_indices
480
+ yield from shuffled_indices # type: ignore[misc,unused-ignore]
474
481
 
475
482
 
476
483
  def to_gql_embedding_dimension(