arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -247
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +13 -107
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.0.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.0.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -617
  295. phoenix/core/traces.py +0 -100
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -1,13 +1,14 @@
1
1
  import json
2
+ from collections.abc import Mapping
2
3
  from dataclasses import asdict, dataclass, replace
3
- from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
4
+ from typing import Any, Optional, Union
4
5
 
5
- EmbeddingFeatures = Dict[str, "EmbeddingColumnNames"]
6
+ EmbeddingFeatures = dict[str, "EmbeddingColumnNames"]
6
7
  SchemaFieldName = str
7
- SchemaFieldValue = Union[Optional[str], Optional[List[str]], Optional[EmbeddingFeatures]]
8
+ SchemaFieldValue = Union[Optional[str], Optional[list[str]], Optional[EmbeddingFeatures]]
8
9
 
9
- MULTI_COLUMN_SCHEMA_FIELD_NAMES: Tuple[str, ...] = ("feature_column_names", "tag_column_names")
10
- SINGLE_COLUMN_SCHEMA_FIELD_NAMES: Tuple[str, ...] = (
10
+ MULTI_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = ("feature_column_names", "tag_column_names")
11
+ SINGLE_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = (
11
12
  "prediction_id_column_name",
12
13
  "timestamp_column_name",
13
14
  "prediction_label_column_name",
@@ -19,7 +20,7 @@ LLM_SCHEMA_FIELD_NAMES = ["prompt_column_names", "response_column_names"]
19
20
 
20
21
 
21
22
  @dataclass(frozen=True)
22
- class EmbeddingColumnNames(Dict[str, Any]):
23
+ class EmbeddingColumnNames(dict[str, Any]):
23
24
  """
24
25
  A dataclass to hold the column names for the embedding features.
25
26
  An embedding feature is a feature that is represented by a vector.
@@ -34,7 +35,6 @@ class EmbeddingColumnNames(Dict[str, Any]):
34
35
  @dataclass(frozen=True)
35
36
  class RetrievalEmbeddingColumnNames(EmbeddingColumnNames):
36
37
  """
37
- *** Experimental ***
38
38
  A relationship is a column that maps a prediction to another record.
39
39
 
40
40
  Example
@@ -81,8 +81,8 @@ class Schema:
81
81
  prediction_id_column_name: Optional[str] = None
82
82
  id_column_name: Optional[str] = None # Syntax sugar for prediction_id_column_name
83
83
  timestamp_column_name: Optional[str] = None
84
- feature_column_names: Optional[List[str]] = None
85
- tag_column_names: Optional[List[str]] = None
84
+ feature_column_names: Optional[list[str]] = None
85
+ tag_column_names: Optional[list[str]] = None
86
86
  prediction_label_column_name: Optional[str] = None
87
87
  prediction_score_column_name: Optional[str] = None
88
88
  actual_label_column_name: Optional[str] = None
@@ -92,7 +92,7 @@ class Schema:
92
92
  # document_column_names is used explicitly when the schema is used to capture a corpus
93
93
  document_column_names: Optional[EmbeddingColumnNames] = None
94
94
  embedding_feature_column_names: Optional[EmbeddingFeatures] = None
95
- excluded_column_names: Optional[List[str]] = None
95
+ excluded_column_names: Optional[list[str]] = None
96
96
 
97
97
  def __post_init__(self) -> None:
98
98
  # re-map document_column_names to be in the prompt_column_names position
@@ -108,7 +108,7 @@ class Schema:
108
108
  def replace(self, **changes: Any) -> "Schema":
109
109
  return replace(self, **changes)
110
110
 
111
- def asdict(self) -> Dict[str, str]:
111
+ def asdict(self) -> dict[str, str]:
112
112
  return asdict(self)
113
113
 
114
114
  def to_json(self) -> str:
@@ -1,5 +1,4 @@
1
1
  import math
2
- from typing import List
3
2
 
4
3
  import numpy as np
5
4
  from pandas import DataFrame, Series
@@ -11,8 +10,8 @@ from .schema import EmbeddingColumnNames, Schema
11
10
  RESERVED_EMBEDDING_NAMES = ("prompt", "response")
12
11
 
13
12
 
14
- def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
15
- errs: List[str] = []
13
+ def _check_valid_schema(schema: Schema) -> list[err.ValidationError]:
14
+ errs: list[str] = []
16
15
  if schema.excluded_column_names is None:
17
16
  return []
18
17
 
@@ -34,7 +33,7 @@ def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
34
33
  return []
35
34
 
36
35
 
37
- def validate_dataset_inputs(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
36
+ def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
38
37
  errors = _check_missing_columns(dataframe, schema)
39
38
  if errors:
40
39
  return errors
@@ -53,12 +52,12 @@ def validate_dataset_inputs(dataframe: DataFrame, schema: Schema) -> List[err.Va
53
52
  return []
54
53
 
55
54
 
56
- def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
55
+ def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
57
56
  embedding_col_names = schema.embedding_feature_column_names
58
57
  if embedding_col_names is None:
59
58
  return []
60
59
 
61
- embedding_errors: List[err.ValidationError] = []
60
+ embedding_errors: list[err.ValidationError] = []
62
61
  for embedding_name, column_names in embedding_col_names.items():
63
62
  if embedding_name in RESERVED_EMBEDDING_NAMES:
64
63
  embedding_errors += _validate_reserved_embedding_name(embedding_name, schema)
@@ -71,8 +70,8 @@ def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> List[er
71
70
 
72
71
  def _check_valid_prompt_response_data(
73
72
  dataframe: DataFrame, schema: Schema
74
- ) -> List[err.ValidationError]:
75
- prompt_response_errors: List[err.ValidationError] = []
73
+ ) -> list[err.ValidationError]:
74
+ prompt_response_errors: list[err.ValidationError] = []
76
75
 
77
76
  prompt_response_column_names = {
78
77
  "prompt": schema.prompt_column_names,
@@ -89,7 +88,7 @@ def _check_valid_prompt_response_data(
89
88
 
90
89
  def _validate_reserved_embedding_name(
91
90
  embedding_name: str, schema: Schema
92
- ) -> List[err.ValidationError]:
91
+ ) -> list[err.ValidationError]:
93
92
  if embedding_name == "prompt" and schema.prompt_column_names is not None:
94
93
  return [err.InvalidEmbeddingReservedName(embedding_name, "schema.prompt_column_names")]
95
94
  elif embedding_name == "response" and schema.response_column_names is not None:
@@ -99,9 +98,9 @@ def _validate_reserved_embedding_name(
99
98
 
100
99
  def _validate_embedding_vector(
101
100
  dataframe: DataFrame, name: str, vector_column_name: str
102
- ) -> List[err.ValidationError]:
101
+ ) -> list[err.ValidationError]:
103
102
  vector_column = dataframe[vector_column_name]
104
- errors: List[err.ValidationError] = []
103
+ errors: list[err.ValidationError] = []
105
104
  vector_length = None
106
105
 
107
106
  for vector in vector_column:
@@ -156,8 +155,8 @@ def _validate_embedding_vector(
156
155
  return errors
157
156
 
158
157
 
159
- def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
160
- wrong_type_cols: List[str] = []
158
+ def _check_column_types(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
159
+ wrong_type_cols: list[str] = []
161
160
  if schema.prediction_id_column_name is not None:
162
161
  if not (
163
162
  is_numeric_dtype(dataframe.dtypes[schema.prediction_id_column_name])
@@ -172,7 +171,7 @@ def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.Valida
172
171
  return []
173
172
 
174
173
 
175
- def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> List[err.ValidationError]:
174
+ def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
176
175
  # converting to a set first makes the checks run a lot faster
177
176
  existing_columns = set(dataframe.columns)
178
177
  missing_columns = []
@@ -0,0 +1,3 @@
1
+ from ._config import setup_logging
2
+
3
+ __all__ = ["setup_logging"]
@@ -0,0 +1,90 @@
1
+ import atexit
2
+ import logging
3
+ import logging.config
4
+ import logging.handlers
5
+ import queue
6
+ from sys import stderr, stdout
7
+
8
+ from typing_extensions import assert_never
9
+
10
+ from phoenix.config import LoggingMode
11
+ from phoenix.logging._filter import NonErrorFilter
12
+ from phoenix.settings import Settings
13
+
14
+ from ._formatter import PhoenixJSONFormatter
15
+
16
+
17
+ def setup_logging() -> None:
18
+ """
19
+ Configures logging for the specified logging mode.
20
+ """
21
+ logging_mode = Settings.logging_mode
22
+ if logging_mode is LoggingMode.DEFAULT:
23
+ _setup_library_logging()
24
+ elif logging_mode is LoggingMode.STRUCTURED:
25
+ _setup_application_logging()
26
+ else:
27
+ assert_never(logging_mode)
28
+
29
+
30
+ def _setup_library_logging() -> None:
31
+ """
32
+ Configures logging if Phoenix is used as a library
33
+ """
34
+ logger = logging.getLogger("phoenix")
35
+ logger.setLevel(Settings.logging_level)
36
+ db_logger = logging.getLogger("sqlalchemy")
37
+ db_logger.setLevel(Settings.db_logging_level)
38
+ logger.info("Default logging ready")
39
+
40
+
41
+ def _setup_application_logging() -> None:
42
+ """
43
+ Configures logging if Phoenix is used as an application
44
+ """
45
+ sql_engine_logger = logging.getLogger("sqlalchemy.engine.Engine")
46
+ # Remove all existing handlers
47
+ for handler in sql_engine_logger.handlers[:]:
48
+ sql_engine_logger.removeHandler(handler)
49
+ handler.close()
50
+
51
+ phoenix_logger = logging.getLogger("phoenix")
52
+ phoenix_logger.setLevel(Settings.logging_level)
53
+ phoenix_logger.propagate = False # Do not pass records to the root logger
54
+ sql_logger = logging.getLogger("sqlalchemy")
55
+ sql_logger.setLevel(Settings.db_logging_level)
56
+ sql_logger.propagate = False # Do not pass records to the root logger
57
+
58
+ log_queue = queue.Queue() # type:ignore
59
+ queue_handler = logging.handlers.QueueHandler(log_queue)
60
+ phoenix_logger.addHandler(queue_handler)
61
+ sql_logger.addHandler(queue_handler)
62
+
63
+ fmt_keys = {
64
+ "level": "levelname",
65
+ "message": "message",
66
+ "timestamp": "timestamp",
67
+ "logger": "name",
68
+ "module": "module",
69
+ "function": "funcName",
70
+ "line": "lineno",
71
+ "thread_name": "threadName",
72
+ }
73
+ formatter = PhoenixJSONFormatter(fmt_keys=fmt_keys)
74
+
75
+ # stdout handler
76
+ stdout_handler = logging.StreamHandler(stdout)
77
+ stdout_handler.setFormatter(formatter)
78
+ stdout_handler.setLevel(Settings.logging_level)
79
+ stdout_handler.addFilter(NonErrorFilter())
80
+
81
+ # stderr handler
82
+ stderr_handler = logging.StreamHandler(stderr)
83
+ stderr_handler.setFormatter(formatter)
84
+ stderr_handler.setLevel(logging.WARNING)
85
+
86
+ queue_listener = logging.handlers.QueueListener(log_queue, stdout_handler, stderr_handler)
87
+ if queue_listener is not None:
88
+ queue_listener.start()
89
+ atexit.register(queue_listener.stop)
90
+ phoenix_logger.info("Structured logging ready")
@@ -0,0 +1,6 @@
1
+ import logging
2
+
3
+
4
+ class NonErrorFilter(logging.Filter):
5
+ def filter(self, record: logging.LogRecord) -> bool:
6
+ return record.levelno <= logging.INFO
@@ -0,0 +1,69 @@
1
+ import datetime as dt
2
+ import json
3
+ import logging
4
+ from typing import Optional
5
+
6
+ LOG_RECORD_BUILTIN_ATTRS = {
7
+ "args",
8
+ "asctime",
9
+ "created",
10
+ "exc_info",
11
+ "exc_text",
12
+ "filename",
13
+ "funcName",
14
+ "levelname",
15
+ "levelno",
16
+ "lineno",
17
+ "module",
18
+ "msecs",
19
+ "message",
20
+ "msg",
21
+ "name",
22
+ "pathname",
23
+ "process",
24
+ "processName",
25
+ "relativeCreated",
26
+ "stack_info",
27
+ "thread",
28
+ "threadName",
29
+ "taskName",
30
+ }
31
+
32
+
33
+ class PhoenixJSONFormatter(logging.Formatter):
34
+ def __init__(
35
+ self,
36
+ *,
37
+ fmt_keys: Optional[dict[str, str]] = None,
38
+ ):
39
+ super().__init__()
40
+ self.fmt_keys = fmt_keys if fmt_keys is not None else {}
41
+
42
+ def format(self, record: logging.LogRecord) -> str:
43
+ message = self._prepare_log_dict(record)
44
+ return json.dumps(message, default=str)
45
+
46
+ def _prepare_log_dict(self, record: logging.LogRecord) -> dict[str, str]:
47
+ always_fields = {
48
+ "message": record.getMessage(),
49
+ "timestamp": dt.datetime.fromtimestamp(record.created, tz=dt.timezone.utc).isoformat(),
50
+ }
51
+ if record.exc_info is not None:
52
+ always_fields["exc_info"] = self.formatException(record.exc_info)
53
+
54
+ if record.stack_info is not None:
55
+ always_fields["stack_info"] = self.formatStack(record.stack_info)
56
+
57
+ message = {
58
+ key: msg_val
59
+ if (msg_val := always_fields.pop(val, None)) is not None
60
+ else getattr(record, val)
61
+ for key, val in self.fmt_keys.items()
62
+ }
63
+ message.update(always_fields)
64
+
65
+ for key, val in record.__dict__.items():
66
+ if key not in LOG_RECORD_BUILTIN_ATTRS:
67
+ message[key] = val
68
+
69
+ return message
@@ -1,8 +1,9 @@
1
1
  import logging
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
+ from collections.abc import Iterable, Mapping
4
5
  from dataclasses import dataclass
5
- from typing import Any, Iterable, List, Mapping, Optional, Union
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  import numpy as np
8
9
  import pandas as pd
@@ -36,13 +37,13 @@ class Metric(ABC):
36
37
  def calc(self, dataframe: pd.DataFrame) -> Any: ...
37
38
 
38
39
  @abstractmethod
39
- def operands(self) -> List[Column]: ...
40
+ def operands(self) -> list[Column]: ...
40
41
 
41
42
  def __call__(
42
43
  self,
43
44
  df: pd.DataFrame,
44
45
  /,
45
- subset_rows: Optional[Union[slice, List[int]]] = None,
46
+ subset_rows: Optional[Union[slice, list[int]]] = None,
46
47
  ) -> Any:
47
48
  """
48
49
  Computes the metric on a dataframe.
@@ -51,7 +52,7 @@ class Metric(ABC):
51
52
  ----------
52
53
  df: pandas DataFrame
53
54
  The dataframe input to the metric.
54
- subset_rows: Optional[Union[slice, List[int]]] = None
55
+ subset_rows: Optional[Union[slice, list[int]]] = None
55
56
  Optionally specifying a subset of rows for the computation.
56
57
  Can be a list or slice (e.g. `slice(100, 200)`) of integers.
57
58
  """
@@ -1,8 +1,9 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable, Sequence
3
4
  from dataclasses import dataclass
4
5
  from functools import partial
5
- from typing import Any, Iterable, Optional, Sequence, cast
6
+ from typing import Any, Optional, cast
6
7
 
7
8
  import numpy as np
8
9
  import pandas as pd
@@ -78,7 +79,7 @@ class IntervalBinning(BinningMethod):
78
79
  else pd.IntervalIndex(
79
80
  (
80
81
  pd.Interval(
81
- np.NINF,
82
+ -np.inf,
82
83
  np.inf,
83
84
  closed="neither",
84
85
  ),
@@ -208,7 +209,7 @@ class QuantileBinning(IntervalBinning):
208
209
  # Extend min and max to infinities, unless len(breaks) < 3,
209
210
  # in which case the min is kept and two bins are created.
210
211
  breaks = breaks[1:-1] if len(breaks) > 2 else breaks[:1]
211
- breaks = [np.NINF] + breaks + [np.inf]
212
+ breaks = [-np.inf] + breaks + [np.inf]
212
213
  return pd.IntervalIndex.from_breaks(
213
214
  breaks,
214
215
  closed="left",
@@ -1,8 +1,9 @@
1
1
  import math
2
2
  import warnings
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass, field
4
5
  from functools import cached_property
5
- from typing import Callable, Union, cast
6
+ from typing import Union, cast
6
7
 
7
8
  import numpy as np
8
9
  import numpy.typing as npt
phoenix/metrics/mixins.py CHANGED
@@ -7,10 +7,11 @@ on cooperative multiple inheritance and method resolution order in Python.
7
7
  import collections
8
8
  import inspect
9
9
  from abc import ABC, abstractmethod
10
+ from collections.abc import Callable
10
11
  from dataclasses import dataclass, field, fields, replace
11
12
  from functools import cached_property
12
13
  from itertools import repeat
13
- from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Mapping, Optional
14
+ from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional
14
15
 
15
16
  import numpy as np
16
17
  import pandas as pd
@@ -42,7 +43,7 @@ class VectorOperator(ABC):
42
43
 
43
44
  @dataclass(frozen=True)
44
45
  class NullaryOperator(Metric, ABC):
45
- def operands(self) -> List[Column]:
46
+ def operands(self) -> list[Column]:
46
47
  return []
47
48
 
48
49
 
@@ -55,7 +56,7 @@ class UnaryOperator(Metric, ABC):
55
56
 
56
57
  operand: Column = Column()
57
58
 
58
- def operands(self) -> List[Column]:
59
+ def operands(self) -> list[Column]:
59
60
  return [self.operand]
60
61
 
61
62
 
@@ -98,10 +99,10 @@ class EvaluationMetricKeywordParameters(_BaseMapping):
98
99
  return sum(1 for _ in self)
99
100
 
100
101
  @property
101
- def columns(self) -> List[Column]:
102
+ def columns(self) -> list[Column]:
102
103
  return [v for v in self.values() if isinstance(v, Column)]
103
104
 
104
- def __call__(self, df: pd.DataFrame) -> Dict[str, Any]:
105
+ def __call__(self, df: pd.DataFrame) -> dict[str, Any]:
105
106
  return {k: v(df) if isinstance(v, Column) else v for k, v in self.items()}
106
107
 
107
108
 
@@ -142,7 +143,7 @@ class EvaluationMetric(Metric, ABC):
142
143
  ),
143
144
  )
144
145
 
145
- def operands(self) -> List[Column]:
146
+ def operands(self) -> list[Column]:
146
147
  return [self.actual, self.predicted] + self.parameters.columns
147
148
 
148
149
  def calc(self, df: pd.DataFrame) -> float:
@@ -1,5 +1,6 @@
1
+ from collections.abc import Iterable
1
2
  from dataclasses import dataclass, field
2
- from typing import Iterable, Optional, cast
3
+ from typing import Optional, cast
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
@@ -1,7 +1,8 @@
1
+ from collections.abc import Callable, Iterable, Iterator
1
2
  from datetime import datetime, timedelta, timezone
2
3
  from functools import partial
3
4
  from itertools import accumulate, repeat
4
- from typing import Callable, Iterable, Iterator, Tuple, cast
5
+ from typing import cast
5
6
 
6
7
  import pandas as pd
7
8
  from typing_extensions import TypeAlias
@@ -41,12 +42,12 @@ def row_interval_from_sorted_time_index(
41
42
  time_index: pd.DatetimeIndex,
42
43
  time_start: datetime,
43
44
  time_stop: datetime,
44
- ) -> Tuple[StartIndex, StopIndex]:
45
+ ) -> tuple[StartIndex, StopIndex]:
45
46
  """
46
47
  Returns end exclusive time slice from sorted index.
47
48
  """
48
49
  return cast(
49
- Tuple[StartIndex, StopIndex],
50
+ tuple[StartIndex, StopIndex],
50
51
  time_index.searchsorted((time_start, time_stop)),
51
52
  )
52
53
 
@@ -86,7 +87,7 @@ def _groupers(
86
87
  end_time: datetime,
87
88
  evaluation_window: timedelta,
88
89
  sampling_interval: timedelta,
89
- ) -> Iterator[Tuple[StartTime, EndTime, pd.Grouper]]:
90
+ ) -> Iterator[tuple[StartTime, EndTime, pd.Grouper]]:
90
91
  """
91
92
  Yields pandas.Groupers from time series parameters.
92
93
  """
@@ -18,7 +18,7 @@ from abc import ABC
18
18
  from enum import Enum
19
19
  from inspect import Signature
20
20
  from itertools import chain, islice
21
- from typing import Any, Dict, List, Tuple, cast
21
+ from typing import Any, cast
22
22
 
23
23
  import numpy as np
24
24
  import pandas as pd
@@ -27,6 +27,8 @@ from sklearn import metrics as sk
27
27
  from sklearn.utils.multiclass import check_classification_targets
28
28
  from wrapt import PartialCallableObjectProxy
29
29
 
30
+ from phoenix.config import SKLEARN_VERSION
31
+
30
32
 
31
33
  class Eval(PartialCallableObjectProxy, ABC): # type: ignore
32
34
  def __call__(
@@ -157,7 +159,7 @@ def _coerce_dtype_if_necessary(
157
159
  def _eliminate_missing_values_from_all_series(
158
160
  *args: Any,
159
161
  **kwargs: Any,
160
- ) -> Tuple[List[Any], Dict[str, Any]]:
162
+ ) -> tuple[list[Any], dict[str, Any]]:
161
163
  positional_arguments = list(args)
162
164
  keyword_arguments = dict(kwargs)
163
165
  all_series = [
@@ -232,5 +234,9 @@ class SkEval(Enum):
232
234
  r2_score = RegressionEval(sk.r2_score)
233
235
  recall_score = ClassificationEval(sk.recall_score)
234
236
  roc_auc_score = ScoredClassificationEval(sk.roc_auc_score)
235
- root_mean_squared_error = RegressionEval(sk.mean_squared_error, squared=False)
237
+ root_mean_squared_error = (
238
+ RegressionEval(sk.mean_squared_error, squared=False)
239
+ if SKLEARN_VERSION < (1, 6)
240
+ else RegressionEval(sk.root_mean_squared_error)
241
+ )
236
242
  zero_one_loss = ClassificationEval(sk.zero_one_loss)
@@ -1,13 +1,11 @@
1
1
  from dataclasses import asdict, dataclass
2
- from typing import List, Set
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
6
- from hdbscan import HDBSCAN
7
5
  from typing_extensions import TypeAlias
8
6
 
9
7
  RowIndex: TypeAlias = int
10
- RawCluster: TypeAlias = Set[RowIndex]
8
+ RawCluster: TypeAlias = set[RowIndex]
11
9
  Matrix: TypeAlias = npt.NDArray[np.float64]
12
10
 
13
11
 
@@ -17,9 +15,11 @@ class Hdbscan:
17
15
  min_samples: float = 1
18
16
  cluster_selection_epsilon: float = 0.0
19
17
 
20
- def find_clusters(self, mat: Matrix) -> List[RawCluster]:
18
+ def find_clusters(self, mat: Matrix) -> list[RawCluster]:
19
+ from fast_hdbscan import HDBSCAN
20
+
21
21
  cluster_ids: npt.NDArray[np.int_] = HDBSCAN(**asdict(self)).fit_predict(mat)
22
- ans: List[RawCluster] = [set() for _ in range(np.max(cluster_ids) + 1)]
22
+ ans: list[RawCluster] = [set() for _ in range(np.max(cluster_ids) + 1)]
23
23
  for row_idx, cluster_id in enumerate(cluster_ids):
24
24
  if cluster_id > -1:
25
25
  ans[cluster_id].add(row_idx)
@@ -1,9 +1,9 @@
1
+ from collections.abc import Hashable, Mapping
1
2
  from dataclasses import dataclass
2
- from typing import Dict, List, Mapping, Protocol, Set, Tuple
3
+ from typing import Protocol, TypeVar
3
4
 
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
- from strawberry import ID
7
7
  from typing_extensions import TypeAlias
8
8
 
9
9
  from phoenix.pointcloud.clustering import RawCluster
@@ -12,13 +12,15 @@ Vector: TypeAlias = npt.NDArray[np.float64]
12
12
  Matrix: TypeAlias = npt.NDArray[np.float64]
13
13
  RowIndex: TypeAlias = int
14
14
 
15
+ _IdType = TypeVar("_IdType", bound=Hashable)
16
+
15
17
 
16
18
  class DimensionalityReducer(Protocol):
17
19
  def project(self, mat: Matrix, n_components: int) -> Matrix: ...
18
20
 
19
21
 
20
22
  class ClustersFinder(Protocol):
21
- def find_clusters(self, mat: Matrix) -> List[RawCluster]: ...
23
+ def find_clusters(self, mat: Matrix) -> list[RawCluster]: ...
22
24
 
23
25
 
24
26
  @dataclass(frozen=True)
@@ -28,9 +30,9 @@ class PointCloud:
28
30
 
29
31
  def generate(
30
32
  self,
31
- data: Mapping[ID, Vector],
33
+ data: Mapping[_IdType, Vector],
32
34
  n_components: int = 3,
33
- ) -> Tuple[Dict[ID, Vector], Dict[str, Set[ID]]]:
35
+ ) -> tuple[dict[_IdType, Vector], dict[str, set[_IdType]]]:
34
36
  """
35
37
  Given a set of vectors, projects them onto lower dimensions, and
36
38
  finds clusters among the projections.
@@ -6,12 +6,6 @@ import numpy as np
6
6
  import numpy.typing as npt
7
7
  from typing_extensions import TypeAlias
8
8
 
9
- with warnings.catch_warnings():
10
- from numba.core.errors import NumbaWarning
11
-
12
- warnings.simplefilter("ignore", category=NumbaWarning)
13
- from umap import UMAP
14
-
15
9
  Matrix: TypeAlias = npt.NDArray[np.float64]
16
10
 
17
11
 
@@ -25,6 +19,11 @@ class Umap:
25
19
  min_dist: float = 0.1
26
20
 
27
21
  def project(self, mat: Matrix, n_components: int) -> Matrix:
22
+ with warnings.catch_warnings():
23
+ from numba.core.errors import NumbaWarning
24
+
25
+ warnings.simplefilter("ignore", category=NumbaWarning)
26
+ from umap import UMAP
28
27
  config = asdict(self)
29
28
  config["n_components"] = n_components
30
29
  if len(mat) <= n_components: