arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -247
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +13 -107
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.0.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.0.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -617
  295. phoenix/core/traces.py +0 -100
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,738 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import strawberry
8
+ from sqlalchemy import and_, distinct, func, select
9
+ from sqlalchemy.orm import joinedload
10
+ from starlette.authentication import UnauthenticatedUser
11
+ from strawberry import ID, UNSET
12
+ from strawberry.relay import Connection, GlobalID, Node
13
+ from strawberry.types import Info
14
+ from typing_extensions import Annotated, TypeAlias
15
+
16
+ from phoenix.db import enums, models
17
+ from phoenix.db.models import (
18
+ DatasetExample as OrmExample,
19
+ )
20
+ from phoenix.db.models import (
21
+ DatasetExampleRevision as OrmRevision,
22
+ )
23
+ from phoenix.db.models import (
24
+ DatasetVersion as OrmVersion,
25
+ )
26
+ from phoenix.db.models import (
27
+ Experiment as OrmExperiment,
28
+ )
29
+ from phoenix.db.models import ExperimentRun as OrmExperimentRun
30
+ from phoenix.db.models import (
31
+ Trace as OrmTrace,
32
+ )
33
+ from phoenix.pointcloud.clustering import Hdbscan
34
+ from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
35
+ from phoenix.server.api.context import Context
36
+ from phoenix.server.api.exceptions import NotFound, Unauthorized
37
+ from phoenix.server.api.helpers import ensure_list
38
+ from phoenix.server.api.helpers.experiment_run_filters import (
39
+ ExperimentRunFilterConditionSyntaxError,
40
+ compile_sqlalchemy_filter_condition,
41
+ update_examples_query_with_filter_condition,
42
+ )
43
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
44
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
45
+ from phoenix.server.api.input_types.ClusterInput import ClusterInput
46
+ from phoenix.server.api.input_types.Coordinates import (
47
+ InputCoordinate2D,
48
+ InputCoordinate3D,
49
+ )
50
+ from phoenix.server.api.input_types.DatasetSort import DatasetSort
51
+ from phoenix.server.api.input_types.InvocationParameters import (
52
+ InvocationParameter,
53
+ )
54
+ from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
55
+ from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
56
+ from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
57
+ from phoenix.server.api.types.DatasetExample import DatasetExample
58
+ from phoenix.server.api.types.Dimension import to_gql_dimension
59
+ from phoenix.server.api.types.EmbeddingDimension import (
60
+ DEFAULT_CLUSTER_SELECTION_EPSILON,
61
+ DEFAULT_MIN_CLUSTER_SIZE,
62
+ DEFAULT_MIN_SAMPLES,
63
+ to_gql_embedding_dimension,
64
+ )
65
+ from phoenix.server.api.types.Event import create_event_id, unpack_event_id
66
+ from phoenix.server.api.types.Experiment import Experiment
67
+ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
68
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
69
+ from phoenix.server.api.types.Functionality import Functionality
70
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
71
+ from phoenix.server.api.types.GenerativeProvider import (
72
+ GenerativeProvider,
73
+ GenerativeProviderKey,
74
+ )
75
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
76
+ from phoenix.server.api.types.Model import Model
77
+ from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
78
+ from phoenix.server.api.types.pagination import (
79
+ ConnectionArgs,
80
+ CursorString,
81
+ connection_from_list,
82
+ )
83
+ from phoenix.server.api.types.Project import Project
84
+ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
85
+ from phoenix.server.api.types.SortDir import SortDir
86
+ from phoenix.server.api.types.Span import Span, to_gql_span
87
+ from phoenix.server.api.types.SystemApiKey import SystemApiKey
88
+ from phoenix.server.api.types.Trace import to_gql_trace
89
+ from phoenix.server.api.types.User import User, to_gql_user
90
+ from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
91
+ from phoenix.server.api.types.UserRole import UserRole
92
+ from phoenix.server.api.types.ValidationResult import ValidationResult
93
+
94
+ initialize_playground_clients()
95
+
96
+
97
+ @strawberry.input
98
+ class ModelsInput:
99
+ provider_key: Optional[GenerativeProviderKey]
100
+ model_name: Optional[str] = None
101
+
102
+
103
+ @strawberry.type
104
+ class Query:
105
+ @strawberry.field
106
+ async def model_providers(self) -> list[GenerativeProvider]:
107
+ available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
108
+ return [
109
+ GenerativeProvider(
110
+ name=provider_key.value,
111
+ key=provider_key,
112
+ )
113
+ for provider_key in available_providers
114
+ ]
115
+
116
+ @strawberry.field
117
+ async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
118
+ if input is not None and input.provider_key is not None:
119
+ supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
120
+ supported_models = [
121
+ GenerativeModel(name=model_name, provider_key=input.provider_key)
122
+ for model_name in supported_model_names
123
+ ]
124
+ return supported_models
125
+
126
+ registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
127
+ all_models: list[GenerativeModel] = []
128
+ for provider_key, model_name in registered_models:
129
+ if model_name is not None and provider_key is not None:
130
+ all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
131
+ return all_models
132
+
133
+ @strawberry.field
134
+ async def model_invocation_parameters(
135
+ self, input: Optional[ModelsInput] = None
136
+ ) -> list[InvocationParameter]:
137
+ if input is None:
138
+ return []
139
+ provider_key = input.provider_key
140
+ model_name = input.model_name
141
+ if provider_key is not None:
142
+ client = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, model_name)
143
+ if client is None:
144
+ return []
145
+ invocation_parameters = client.supported_invocation_parameters()
146
+ return invocation_parameters
147
+ else:
148
+ return []
149
+
150
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
151
+ async def users(
152
+ self,
153
+ info: Info[Context, None],
154
+ first: Optional[int] = 50,
155
+ last: Optional[int] = UNSET,
156
+ after: Optional[CursorString] = UNSET,
157
+ before: Optional[CursorString] = UNSET,
158
+ ) -> Connection[User]:
159
+ args = ConnectionArgs(
160
+ first=first,
161
+ after=after if isinstance(after, CursorString) else None,
162
+ last=last,
163
+ before=before if isinstance(before, CursorString) else None,
164
+ )
165
+ stmt = (
166
+ select(models.User)
167
+ .join(models.UserRole)
168
+ .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
169
+ .order_by(models.User.email)
170
+ .options(joinedload(models.User.role))
171
+ )
172
+ async with info.context.db() as session:
173
+ users = await session.stream_scalars(stmt)
174
+ data = [to_gql_user(user) async for user in users]
175
+ return connection_from_list(data=data, args=args)
176
+
177
+ @strawberry.field
178
+ async def user_roles(
179
+ self,
180
+ info: Info[Context, None],
181
+ ) -> list[UserRole]:
182
+ async with info.context.db() as session:
183
+ roles = await session.scalars(
184
+ select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
185
+ )
186
+ return [
187
+ UserRole(
188
+ id_attr=role.id,
189
+ name=role.name,
190
+ )
191
+ for role in roles
192
+ ]
193
+
194
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
195
+ async def user_api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
196
+ stmt = (
197
+ select(models.ApiKey)
198
+ .join(models.User)
199
+ .join(models.UserRole)
200
+ .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
201
+ )
202
+ async with info.context.db() as session:
203
+ api_keys = await session.scalars(stmt)
204
+ return [to_gql_api_key(api_key) for api_key in api_keys]
205
+
206
+ @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
207
+ async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
208
+ stmt = (
209
+ select(models.ApiKey)
210
+ .join(models.User)
211
+ .join(models.UserRole)
212
+ .where(models.UserRole.name == enums.UserRole.SYSTEM.value)
213
+ )
214
+ async with info.context.db() as session:
215
+ api_keys = await session.scalars(stmt)
216
+ return [
217
+ SystemApiKey(
218
+ id_attr=api_key.id,
219
+ name=api_key.name,
220
+ description=api_key.description,
221
+ created_at=api_key.created_at,
222
+ expires_at=api_key.expires_at,
223
+ )
224
+ for api_key in api_keys
225
+ ]
226
+
227
+ @strawberry.field
228
+ async def projects(
229
+ self,
230
+ info: Info[Context, None],
231
+ first: Optional[int] = 50,
232
+ last: Optional[int] = UNSET,
233
+ after: Optional[CursorString] = UNSET,
234
+ before: Optional[CursorString] = UNSET,
235
+ ) -> Connection[Project]:
236
+ args = ConnectionArgs(
237
+ first=first,
238
+ after=after if isinstance(after, CursorString) else None,
239
+ last=last,
240
+ before=before if isinstance(before, CursorString) else None,
241
+ )
242
+ stmt = (
243
+ select(models.Project)
244
+ .outerjoin(
245
+ models.Experiment,
246
+ and_(
247
+ models.Project.name == models.Experiment.project_name,
248
+ models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
249
+ ),
250
+ )
251
+ .where(models.Experiment.project_name.is_(None))
252
+ .order_by(models.Project.id)
253
+ )
254
+ async with info.context.db() as session:
255
+ projects = await session.stream_scalars(stmt)
256
+ data = [
257
+ Project(
258
+ id_attr=project.id,
259
+ name=project.name,
260
+ gradient_start_color=project.gradient_start_color,
261
+ gradient_end_color=project.gradient_end_color,
262
+ )
263
+ async for project in projects
264
+ ]
265
+ return connection_from_list(data=data, args=args)
266
+
267
+ @strawberry.field
268
+ def projects_last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
269
+ return info.context.last_updated_at.get(models.Project)
270
+
271
+ @strawberry.field
272
+ async def datasets(
273
+ self,
274
+ info: Info[Context, None],
275
+ first: Optional[int] = 50,
276
+ last: Optional[int] = UNSET,
277
+ after: Optional[CursorString] = UNSET,
278
+ before: Optional[CursorString] = UNSET,
279
+ sort: Optional[DatasetSort] = UNSET,
280
+ ) -> Connection[Dataset]:
281
+ args = ConnectionArgs(
282
+ first=first,
283
+ after=after if isinstance(after, CursorString) else None,
284
+ last=last,
285
+ before=before if isinstance(before, CursorString) else None,
286
+ )
287
+ stmt = select(models.Dataset)
288
+ if sort:
289
+ sort_col = getattr(models.Dataset, sort.col.value)
290
+ stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
291
+ async with info.context.db() as session:
292
+ datasets = await session.scalars(stmt)
293
+ return connection_from_list(
294
+ data=[to_gql_dataset(dataset) for dataset in datasets], args=args
295
+ )
296
+
297
+ @strawberry.field
298
+ def datasets_last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
299
+ return info.context.last_updated_at.get(models.Dataset)
300
+
301
+ @strawberry.field
302
+ async def compare_experiments(
303
+ self,
304
+ info: Info[Context, None],
305
+ experiment_ids: list[GlobalID],
306
+ filter_condition: Optional[str] = UNSET,
307
+ ) -> list[ExperimentComparison]:
308
+ experiment_ids_ = [
309
+ from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
310
+ for experiment_id in experiment_ids
311
+ ]
312
+ if len(set(experiment_ids_)) != len(experiment_ids_):
313
+ raise ValueError("Experiment IDs must be unique.")
314
+
315
+ async with info.context.db() as session:
316
+ validation_result = (
317
+ await session.execute(
318
+ select(
319
+ func.count(distinct(OrmVersion.dataset_id)),
320
+ func.max(OrmVersion.dataset_id),
321
+ func.max(OrmVersion.id),
322
+ func.count(OrmExperiment.id),
323
+ )
324
+ .select_from(OrmVersion)
325
+ .join(
326
+ OrmExperiment,
327
+ OrmExperiment.dataset_version_id == OrmVersion.id,
328
+ )
329
+ .where(
330
+ OrmExperiment.id.in_(experiment_ids_),
331
+ )
332
+ )
333
+ ).first()
334
+ if validation_result is None:
335
+ raise ValueError("No experiments could be found for input IDs.")
336
+
337
+ num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
338
+ if num_datasets != 1:
339
+ raise ValueError("Experiments must belong to the same dataset.")
340
+ if num_resolved_experiment_ids != len(experiment_ids_):
341
+ raise ValueError("Unable to resolve one or more experiment IDs.")
342
+
343
+ revision_ids = (
344
+ select(func.max(OrmRevision.id))
345
+ .join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
346
+ .where(
347
+ and_(
348
+ OrmRevision.dataset_version_id <= version_id,
349
+ OrmExample.dataset_id == dataset_id,
350
+ )
351
+ )
352
+ .group_by(OrmRevision.dataset_example_id)
353
+ .scalar_subquery()
354
+ )
355
+ examples_query = (
356
+ select(OrmExample)
357
+ .distinct(OrmExample.id)
358
+ .join(
359
+ OrmRevision,
360
+ onclause=and_(
361
+ OrmExample.id == OrmRevision.dataset_example_id,
362
+ OrmRevision.id.in_(revision_ids),
363
+ OrmRevision.revision_kind != "DELETE",
364
+ ),
365
+ )
366
+ .order_by(OrmExample.id.desc())
367
+ )
368
+
369
+ if filter_condition:
370
+ examples_query = update_examples_query_with_filter_condition(
371
+ query=examples_query,
372
+ filter_condition=filter_condition,
373
+ experiment_ids=experiment_ids_,
374
+ )
375
+
376
+ examples = (await session.scalars(examples_query)).all()
377
+
378
+ ExampleID: TypeAlias = int
379
+ ExperimentID: TypeAlias = int
380
+ runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[OrmExperimentRun]]] = (
381
+ defaultdict(lambda: defaultdict(list))
382
+ )
383
+ async for run in await session.stream_scalars(
384
+ select(OrmExperimentRun)
385
+ .where(
386
+ and_(
387
+ OrmExperimentRun.dataset_example_id.in_(example.id for example in examples),
388
+ OrmExperimentRun.experiment_id.in_(experiment_ids_),
389
+ )
390
+ )
391
+ .options(joinedload(OrmExperimentRun.trace).load_only(OrmTrace.trace_id))
392
+ ):
393
+ runs[run.dataset_example_id][run.experiment_id].append(run)
394
+
395
+ experiment_comparisons = []
396
+ for example in examples:
397
+ run_comparison_items = []
398
+ for experiment_id in experiment_ids_:
399
+ run_comparison_items.append(
400
+ RunComparisonItem(
401
+ experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
402
+ runs=[
403
+ to_gql_experiment_run(run)
404
+ for run in sorted(
405
+ runs[example.id][experiment_id], key=lambda run: run.id
406
+ )
407
+ ],
408
+ )
409
+ )
410
+ experiment_comparisons.append(
411
+ ExperimentComparison(
412
+ example=DatasetExample(
413
+ id_attr=example.id,
414
+ created_at=example.created_at,
415
+ version_id=version_id,
416
+ ),
417
+ run_comparison_items=run_comparison_items,
418
+ )
419
+ )
420
+ return experiment_comparisons
421
+
422
+ @strawberry.field
423
+ async def validate_experiment_run_filter_condition(
424
+ self,
425
+ condition: str,
426
+ experiment_ids: list[GlobalID],
427
+ ) -> ValidationResult:
428
+ try:
429
+ compile_sqlalchemy_filter_condition(
430
+ filter_condition=condition,
431
+ experiment_ids=[
432
+ from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
433
+ for experiment_id in experiment_ids
434
+ ],
435
+ )
436
+ return ValidationResult(
437
+ is_valid=True,
438
+ error_message=None,
439
+ )
440
+ except ExperimentRunFilterConditionSyntaxError as error:
441
+ return ValidationResult(
442
+ is_valid=False,
443
+ error_message=str(error),
444
+ )
445
+
446
+ @strawberry.field
447
+ async def functionality(self, info: Info[Context, None]) -> "Functionality":
448
+ has_model_inferences = not info.context.model.is_empty
449
+ async with info.context.db() as session:
450
+ has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
451
+ return Functionality(
452
+ model_inferences=has_model_inferences,
453
+ tracing=has_traces,
454
+ )
455
+
456
+ @strawberry.field
457
+ def model(self) -> Model:
458
+ return Model()
459
+
460
+ @strawberry.field
461
+ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
462
+ type_name, node_id = from_global_id(id)
463
+ if type_name == "Dimension":
464
+ dimension = info.context.model.scalar_dimensions[node_id]
465
+ return to_gql_dimension(node_id, dimension)
466
+ elif type_name == "EmbeddingDimension":
467
+ embedding_dimension = info.context.model.embedding_dimensions[node_id]
468
+ return to_gql_embedding_dimension(node_id, embedding_dimension)
469
+ elif type_name == "Project":
470
+ project_stmt = select(
471
+ models.Project.id,
472
+ models.Project.name,
473
+ models.Project.gradient_start_color,
474
+ models.Project.gradient_end_color,
475
+ ).where(models.Project.id == node_id)
476
+ async with info.context.db() as session:
477
+ project = (await session.execute(project_stmt)).first()
478
+ if project is None:
479
+ raise NotFound(f"Unknown project: {id}")
480
+ return Project(
481
+ id_attr=project.id,
482
+ name=project.name,
483
+ gradient_start_color=project.gradient_start_color,
484
+ gradient_end_color=project.gradient_end_color,
485
+ )
486
+ elif type_name == "Trace":
487
+ trace_stmt = select(models.Trace).filter_by(id=node_id)
488
+ async with info.context.db() as session:
489
+ trace = await session.scalar(trace_stmt)
490
+ if trace is None:
491
+ raise NotFound(f"Unknown trace: {id}")
492
+ return to_gql_trace(trace)
493
+ elif type_name == Span.__name__:
494
+ span_stmt = (
495
+ select(models.Span)
496
+ .options(
497
+ joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
498
+ )
499
+ .where(models.Span.id == node_id)
500
+ )
501
+ async with info.context.db() as session:
502
+ span = await session.scalar(span_stmt)
503
+ if span is None:
504
+ raise NotFound(f"Unknown span: {id}")
505
+ return to_gql_span(span)
506
+ elif type_name == Dataset.__name__:
507
+ dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
508
+ async with info.context.db() as session:
509
+ if (dataset := await session.scalar(dataset_stmt)) is None:
510
+ raise NotFound(f"Unknown dataset: {id}")
511
+ return to_gql_dataset(dataset)
512
+ elif type_name == DatasetExample.__name__:
513
+ example_id = node_id
514
+ latest_revision_id = (
515
+ select(func.max(models.DatasetExampleRevision.id))
516
+ .where(models.DatasetExampleRevision.dataset_example_id == example_id)
517
+ .scalar_subquery()
518
+ )
519
+ async with info.context.db() as session:
520
+ example = await session.scalar(
521
+ select(models.DatasetExample)
522
+ .join(
523
+ models.DatasetExampleRevision,
524
+ onclause=models.DatasetExampleRevision.dataset_example_id
525
+ == models.DatasetExample.id,
526
+ )
527
+ .where(
528
+ and_(
529
+ models.DatasetExample.id == example_id,
530
+ models.DatasetExampleRevision.id == latest_revision_id,
531
+ models.DatasetExampleRevision.revision_kind != "DELETE",
532
+ )
533
+ )
534
+ )
535
+ if not example:
536
+ raise NotFound(f"Unknown dataset example: {id}")
537
+ return DatasetExample(
538
+ id_attr=example.id,
539
+ created_at=example.created_at,
540
+ )
541
+ elif type_name == Experiment.__name__:
542
+ async with info.context.db() as session:
543
+ experiment = await session.scalar(
544
+ select(models.Experiment).where(models.Experiment.id == node_id)
545
+ )
546
+ if not experiment:
547
+ raise NotFound(f"Unknown experiment: {id}")
548
+ return Experiment(
549
+ id_attr=experiment.id,
550
+ name=experiment.name,
551
+ project_name=experiment.project_name,
552
+ description=experiment.description,
553
+ created_at=experiment.created_at,
554
+ updated_at=experiment.updated_at,
555
+ metadata=experiment.metadata_,
556
+ )
557
+ elif type_name == ExperimentRun.__name__:
558
+ async with info.context.db() as session:
559
+ if not (
560
+ run := await session.scalar(
561
+ select(models.ExperimentRun)
562
+ .where(models.ExperimentRun.id == node_id)
563
+ .options(
564
+ joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
565
+ )
566
+ )
567
+ ):
568
+ raise NotFound(f"Unknown experiment run: {id}")
569
+ return to_gql_experiment_run(run)
570
+ elif type_name == User.__name__:
571
+ if int((user := info.context.user).identity) != node_id and not user.is_admin:
572
+ raise Unauthorized(MSG_ADMIN_ONLY)
573
+ async with info.context.db() as session:
574
+ if not (
575
+ user := await session.scalar(
576
+ select(models.User).where(models.User.id == node_id)
577
+ )
578
+ ):
579
+ raise NotFound(f"Unknown user: {id}")
580
+ return to_gql_user(user)
581
+ elif type_name == ProjectSession.__name__:
582
+ async with info.context.db() as session:
583
+ if not (
584
+ project_session := await session.scalar(
585
+ select(models.ProjectSession).filter_by(id=node_id)
586
+ )
587
+ ):
588
+ raise NotFound(f"Unknown user: {id}")
589
+ return to_gql_project_session(project_session)
590
+ raise NotFound(f"Unknown node type: {type_name}")
591
+
592
+ @strawberry.field
593
+ async def viewer(self, info: Info[Context, None]) -> Optional[User]:
594
+ request = info.context.get_request()
595
+ try:
596
+ user = request.user
597
+ except AssertionError:
598
+ return None
599
+ if isinstance(user, UnauthenticatedUser):
600
+ return None
601
+ async with info.context.db() as session:
602
+ if (
603
+ user := await session.scalar(
604
+ select(models.User)
605
+ .where(models.User.id == int(user.identity))
606
+ .options(joinedload(models.User.role))
607
+ )
608
+ ) is None:
609
+ return None
610
+ return to_gql_user(user)
611
+
612
+ @strawberry.field
613
+ def clusters(
614
+ self,
615
+ clusters: list[ClusterInput],
616
+ ) -> list[Cluster]:
617
+ clustered_events: dict[str, set[ID]] = defaultdict(set)
618
+ for i, cluster in enumerate(clusters):
619
+ clustered_events[cluster.id or str(i)].update(cluster.event_ids)
620
+ return to_gql_clusters(
621
+ clustered_events=clustered_events,
622
+ )
623
+
624
+ @strawberry.field
625
+ def hdbscan_clustering(
626
+ self,
627
+ info: Info[Context, None],
628
+ event_ids: Annotated[
629
+ list[ID],
630
+ strawberry.argument(
631
+ description="Event ID of the coordinates",
632
+ ),
633
+ ],
634
+ coordinates_2d: Annotated[
635
+ Optional[list[InputCoordinate2D]],
636
+ strawberry.argument(
637
+ description="Point coordinates. Must be either 2D or 3D.",
638
+ ),
639
+ ] = UNSET,
640
+ coordinates_3d: Annotated[
641
+ Optional[list[InputCoordinate3D]],
642
+ strawberry.argument(
643
+ description="Point coordinates. Must be either 2D or 3D.",
644
+ ),
645
+ ] = UNSET,
646
+ min_cluster_size: Annotated[
647
+ int,
648
+ strawberry.argument(
649
+ description="HDBSCAN minimum cluster size",
650
+ ),
651
+ ] = DEFAULT_MIN_CLUSTER_SIZE,
652
+ cluster_min_samples: Annotated[
653
+ int,
654
+ strawberry.argument(
655
+ description="HDBSCAN minimum samples",
656
+ ),
657
+ ] = DEFAULT_MIN_SAMPLES,
658
+ cluster_selection_epsilon: Annotated[
659
+ float,
660
+ strawberry.argument(
661
+ description="HDBSCAN cluster selection epsilon",
662
+ ),
663
+ ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
664
+ ) -> list[Cluster]:
665
+ coordinates_3d = ensure_list(coordinates_3d)
666
+ coordinates_2d = ensure_list(coordinates_2d)
667
+
668
+ if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
669
+ raise ValueError("must specify only one of 2D or 3D coordinates")
670
+
671
+ if len(coordinates_3d) > 0:
672
+ coordinates = list(
673
+ map(
674
+ lambda coord: np.array(
675
+ [coord.x, coord.y, coord.z],
676
+ ),
677
+ coordinates_3d,
678
+ )
679
+ )
680
+ else:
681
+ coordinates = list(
682
+ map(
683
+ lambda coord: np.array(
684
+ [coord.x, coord.y],
685
+ ),
686
+ coordinates_2d,
687
+ )
688
+ )
689
+
690
+ if len(event_ids) != len(coordinates):
691
+ raise ValueError(
692
+ f"length mismatch between "
693
+ f"event_ids ({len(event_ids)}) "
694
+ f"and coordinates ({len(coordinates)})"
695
+ )
696
+
697
+ if len(event_ids) == 0:
698
+ return []
699
+
700
+ grouped_event_ids: dict[
701
+ Union[InferencesRole, AncillaryInferencesRole],
702
+ list[ID],
703
+ ] = defaultdict(list)
704
+ grouped_coordinates: dict[
705
+ Union[InferencesRole, AncillaryInferencesRole],
706
+ list[npt.NDArray[np.float64]],
707
+ ] = defaultdict(list)
708
+
709
+ for event_id, coordinate in zip(event_ids, coordinates):
710
+ row_id, inferences_role = unpack_event_id(event_id)
711
+ grouped_coordinates[inferences_role].append(coordinate)
712
+ grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
713
+
714
+ stacked_event_ids = (
715
+ grouped_event_ids[InferencesRole.primary]
716
+ + grouped_event_ids[InferencesRole.reference]
717
+ + grouped_event_ids[AncillaryInferencesRole.corpus]
718
+ )
719
+ stacked_coordinates = np.stack(
720
+ grouped_coordinates[InferencesRole.primary]
721
+ + grouped_coordinates[InferencesRole.reference]
722
+ + grouped_coordinates[AncillaryInferencesRole.corpus]
723
+ )
724
+
725
+ clusters = Hdbscan(
726
+ min_cluster_size=min_cluster_size,
727
+ min_samples=cluster_min_samples,
728
+ cluster_selection_epsilon=cluster_selection_epsilon,
729
+ ).find_clusters(stacked_coordinates)
730
+
731
+ clustered_events = {
732
+ str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
733
+ for i, cluster in enumerate(clusters)
734
+ }
735
+
736
+ return to_gql_clusters(
737
+ clustered_events=clustered_events,
738
+ )