arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.0.dist-info/METADATA +261 -0
- arize_phoenix-7.7.0.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -247
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +13 -107
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.0.dist-info/METADATA +0 -495
- arize_phoenix-3.16.0.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -617
- phoenix/core/traces.py +0 -100
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from collections.abc import Mapping
|
|
2
3
|
from dataclasses import asdict, dataclass, replace
|
|
3
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional, Union
|
|
4
5
|
|
|
5
|
-
EmbeddingFeatures =
|
|
6
|
+
EmbeddingFeatures = dict[str, "EmbeddingColumnNames"]
|
|
6
7
|
SchemaFieldName = str
|
|
7
|
-
SchemaFieldValue = Union[Optional[str], Optional[
|
|
8
|
+
SchemaFieldValue = Union[Optional[str], Optional[list[str]], Optional[EmbeddingFeatures]]
|
|
8
9
|
|
|
9
|
-
MULTI_COLUMN_SCHEMA_FIELD_NAMES:
|
|
10
|
-
SINGLE_COLUMN_SCHEMA_FIELD_NAMES:
|
|
10
|
+
MULTI_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = ("feature_column_names", "tag_column_names")
|
|
11
|
+
SINGLE_COLUMN_SCHEMA_FIELD_NAMES: tuple[str, ...] = (
|
|
11
12
|
"prediction_id_column_name",
|
|
12
13
|
"timestamp_column_name",
|
|
13
14
|
"prediction_label_column_name",
|
|
@@ -19,7 +20,7 @@ LLM_SCHEMA_FIELD_NAMES = ["prompt_column_names", "response_column_names"]
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
@dataclass(frozen=True)
|
|
22
|
-
class EmbeddingColumnNames(
|
|
23
|
+
class EmbeddingColumnNames(dict[str, Any]):
|
|
23
24
|
"""
|
|
24
25
|
A dataclass to hold the column names for the embedding features.
|
|
25
26
|
An embedding feature is a feature that is represented by a vector.
|
|
@@ -34,7 +35,6 @@ class EmbeddingColumnNames(Dict[str, Any]):
|
|
|
34
35
|
@dataclass(frozen=True)
|
|
35
36
|
class RetrievalEmbeddingColumnNames(EmbeddingColumnNames):
|
|
36
37
|
"""
|
|
37
|
-
*** Experimental ***
|
|
38
38
|
A relationship is a column that maps a prediction to another record.
|
|
39
39
|
|
|
40
40
|
Example
|
|
@@ -81,8 +81,8 @@ class Schema:
|
|
|
81
81
|
prediction_id_column_name: Optional[str] = None
|
|
82
82
|
id_column_name: Optional[str] = None # Syntax sugar for prediction_id_column_name
|
|
83
83
|
timestamp_column_name: Optional[str] = None
|
|
84
|
-
feature_column_names: Optional[
|
|
85
|
-
tag_column_names: Optional[
|
|
84
|
+
feature_column_names: Optional[list[str]] = None
|
|
85
|
+
tag_column_names: Optional[list[str]] = None
|
|
86
86
|
prediction_label_column_name: Optional[str] = None
|
|
87
87
|
prediction_score_column_name: Optional[str] = None
|
|
88
88
|
actual_label_column_name: Optional[str] = None
|
|
@@ -92,7 +92,7 @@ class Schema:
|
|
|
92
92
|
# document_column_names is used explicitly when the schema is used to capture a corpus
|
|
93
93
|
document_column_names: Optional[EmbeddingColumnNames] = None
|
|
94
94
|
embedding_feature_column_names: Optional[EmbeddingFeatures] = None
|
|
95
|
-
excluded_column_names: Optional[
|
|
95
|
+
excluded_column_names: Optional[list[str]] = None
|
|
96
96
|
|
|
97
97
|
def __post_init__(self) -> None:
|
|
98
98
|
# re-map document_column_names to be in the prompt_column_names position
|
|
@@ -108,7 +108,7 @@ class Schema:
|
|
|
108
108
|
def replace(self, **changes: Any) -> "Schema":
|
|
109
109
|
return replace(self, **changes)
|
|
110
110
|
|
|
111
|
-
def asdict(self) ->
|
|
111
|
+
def asdict(self) -> dict[str, str]:
|
|
112
112
|
return asdict(self)
|
|
113
113
|
|
|
114
114
|
def to_json(self) -> str:
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from typing import List
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
from pandas import DataFrame, Series
|
|
@@ -11,8 +10,8 @@ from .schema import EmbeddingColumnNames, Schema
|
|
|
11
10
|
RESERVED_EMBEDDING_NAMES = ("prompt", "response")
|
|
12
11
|
|
|
13
12
|
|
|
14
|
-
def _check_valid_schema(schema: Schema) ->
|
|
15
|
-
errs:
|
|
13
|
+
def _check_valid_schema(schema: Schema) -> list[err.ValidationError]:
|
|
14
|
+
errs: list[str] = []
|
|
16
15
|
if schema.excluded_column_names is None:
|
|
17
16
|
return []
|
|
18
17
|
|
|
@@ -34,7 +33,7 @@ def _check_valid_schema(schema: Schema) -> List[err.ValidationError]:
|
|
|
34
33
|
return []
|
|
35
34
|
|
|
36
35
|
|
|
37
|
-
def
|
|
36
|
+
def validate_inferences_inputs(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
|
|
38
37
|
errors = _check_missing_columns(dataframe, schema)
|
|
39
38
|
if errors:
|
|
40
39
|
return errors
|
|
@@ -53,12 +52,12 @@ def validate_dataset_inputs(dataframe: DataFrame, schema: Schema) -> List[err.Va
|
|
|
53
52
|
return []
|
|
54
53
|
|
|
55
54
|
|
|
56
|
-
def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) ->
|
|
55
|
+
def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
|
|
57
56
|
embedding_col_names = schema.embedding_feature_column_names
|
|
58
57
|
if embedding_col_names is None:
|
|
59
58
|
return []
|
|
60
59
|
|
|
61
|
-
embedding_errors:
|
|
60
|
+
embedding_errors: list[err.ValidationError] = []
|
|
62
61
|
for embedding_name, column_names in embedding_col_names.items():
|
|
63
62
|
if embedding_name in RESERVED_EMBEDDING_NAMES:
|
|
64
63
|
embedding_errors += _validate_reserved_embedding_name(embedding_name, schema)
|
|
@@ -71,8 +70,8 @@ def _check_valid_embedding_data(dataframe: DataFrame, schema: Schema) -> List[er
|
|
|
71
70
|
|
|
72
71
|
def _check_valid_prompt_response_data(
|
|
73
72
|
dataframe: DataFrame, schema: Schema
|
|
74
|
-
) ->
|
|
75
|
-
prompt_response_errors:
|
|
73
|
+
) -> list[err.ValidationError]:
|
|
74
|
+
prompt_response_errors: list[err.ValidationError] = []
|
|
76
75
|
|
|
77
76
|
prompt_response_column_names = {
|
|
78
77
|
"prompt": schema.prompt_column_names,
|
|
@@ -89,7 +88,7 @@ def _check_valid_prompt_response_data(
|
|
|
89
88
|
|
|
90
89
|
def _validate_reserved_embedding_name(
|
|
91
90
|
embedding_name: str, schema: Schema
|
|
92
|
-
) ->
|
|
91
|
+
) -> list[err.ValidationError]:
|
|
93
92
|
if embedding_name == "prompt" and schema.prompt_column_names is not None:
|
|
94
93
|
return [err.InvalidEmbeddingReservedName(embedding_name, "schema.prompt_column_names")]
|
|
95
94
|
elif embedding_name == "response" and schema.response_column_names is not None:
|
|
@@ -99,9 +98,9 @@ def _validate_reserved_embedding_name(
|
|
|
99
98
|
|
|
100
99
|
def _validate_embedding_vector(
|
|
101
100
|
dataframe: DataFrame, name: str, vector_column_name: str
|
|
102
|
-
) ->
|
|
101
|
+
) -> list[err.ValidationError]:
|
|
103
102
|
vector_column = dataframe[vector_column_name]
|
|
104
|
-
errors:
|
|
103
|
+
errors: list[err.ValidationError] = []
|
|
105
104
|
vector_length = None
|
|
106
105
|
|
|
107
106
|
for vector in vector_column:
|
|
@@ -156,8 +155,8 @@ def _validate_embedding_vector(
|
|
|
156
155
|
return errors
|
|
157
156
|
|
|
158
157
|
|
|
159
|
-
def _check_column_types(dataframe: DataFrame, schema: Schema) ->
|
|
160
|
-
wrong_type_cols:
|
|
158
|
+
def _check_column_types(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
|
|
159
|
+
wrong_type_cols: list[str] = []
|
|
161
160
|
if schema.prediction_id_column_name is not None:
|
|
162
161
|
if not (
|
|
163
162
|
is_numeric_dtype(dataframe.dtypes[schema.prediction_id_column_name])
|
|
@@ -172,7 +171,7 @@ def _check_column_types(dataframe: DataFrame, schema: Schema) -> List[err.Valida
|
|
|
172
171
|
return []
|
|
173
172
|
|
|
174
173
|
|
|
175
|
-
def _check_missing_columns(dataframe: DataFrame, schema: Schema) ->
|
|
174
|
+
def _check_missing_columns(dataframe: DataFrame, schema: Schema) -> list[err.ValidationError]:
|
|
176
175
|
# converting to a set first makes the checks run a lot faster
|
|
177
176
|
existing_columns = set(dataframe.columns)
|
|
178
177
|
missing_columns = []
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import logging.config
|
|
4
|
+
import logging.handlers
|
|
5
|
+
import queue
|
|
6
|
+
from sys import stderr, stdout
|
|
7
|
+
|
|
8
|
+
from typing_extensions import assert_never
|
|
9
|
+
|
|
10
|
+
from phoenix.config import LoggingMode
|
|
11
|
+
from phoenix.logging._filter import NonErrorFilter
|
|
12
|
+
from phoenix.settings import Settings
|
|
13
|
+
|
|
14
|
+
from ._formatter import PhoenixJSONFormatter
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_logging() -> None:
|
|
18
|
+
"""
|
|
19
|
+
Configures logging for the specified logging mode.
|
|
20
|
+
"""
|
|
21
|
+
logging_mode = Settings.logging_mode
|
|
22
|
+
if logging_mode is LoggingMode.DEFAULT:
|
|
23
|
+
_setup_library_logging()
|
|
24
|
+
elif logging_mode is LoggingMode.STRUCTURED:
|
|
25
|
+
_setup_application_logging()
|
|
26
|
+
else:
|
|
27
|
+
assert_never(logging_mode)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _setup_library_logging() -> None:
|
|
31
|
+
"""
|
|
32
|
+
Configures logging if Phoenix is used as a library
|
|
33
|
+
"""
|
|
34
|
+
logger = logging.getLogger("phoenix")
|
|
35
|
+
logger.setLevel(Settings.logging_level)
|
|
36
|
+
db_logger = logging.getLogger("sqlalchemy")
|
|
37
|
+
db_logger.setLevel(Settings.db_logging_level)
|
|
38
|
+
logger.info("Default logging ready")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _setup_application_logging() -> None:
|
|
42
|
+
"""
|
|
43
|
+
Configures logging if Phoenix is used as an application
|
|
44
|
+
"""
|
|
45
|
+
sql_engine_logger = logging.getLogger("sqlalchemy.engine.Engine")
|
|
46
|
+
# Remove all existing handlers
|
|
47
|
+
for handler in sql_engine_logger.handlers[:]:
|
|
48
|
+
sql_engine_logger.removeHandler(handler)
|
|
49
|
+
handler.close()
|
|
50
|
+
|
|
51
|
+
phoenix_logger = logging.getLogger("phoenix")
|
|
52
|
+
phoenix_logger.setLevel(Settings.logging_level)
|
|
53
|
+
phoenix_logger.propagate = False # Do not pass records to the root logger
|
|
54
|
+
sql_logger = logging.getLogger("sqlalchemy")
|
|
55
|
+
sql_logger.setLevel(Settings.db_logging_level)
|
|
56
|
+
sql_logger.propagate = False # Do not pass records to the root logger
|
|
57
|
+
|
|
58
|
+
log_queue = queue.Queue() # type:ignore
|
|
59
|
+
queue_handler = logging.handlers.QueueHandler(log_queue)
|
|
60
|
+
phoenix_logger.addHandler(queue_handler)
|
|
61
|
+
sql_logger.addHandler(queue_handler)
|
|
62
|
+
|
|
63
|
+
fmt_keys = {
|
|
64
|
+
"level": "levelname",
|
|
65
|
+
"message": "message",
|
|
66
|
+
"timestamp": "timestamp",
|
|
67
|
+
"logger": "name",
|
|
68
|
+
"module": "module",
|
|
69
|
+
"function": "funcName",
|
|
70
|
+
"line": "lineno",
|
|
71
|
+
"thread_name": "threadName",
|
|
72
|
+
}
|
|
73
|
+
formatter = PhoenixJSONFormatter(fmt_keys=fmt_keys)
|
|
74
|
+
|
|
75
|
+
# stdout handler
|
|
76
|
+
stdout_handler = logging.StreamHandler(stdout)
|
|
77
|
+
stdout_handler.setFormatter(formatter)
|
|
78
|
+
stdout_handler.setLevel(Settings.logging_level)
|
|
79
|
+
stdout_handler.addFilter(NonErrorFilter())
|
|
80
|
+
|
|
81
|
+
# stderr handler
|
|
82
|
+
stderr_handler = logging.StreamHandler(stderr)
|
|
83
|
+
stderr_handler.setFormatter(formatter)
|
|
84
|
+
stderr_handler.setLevel(logging.WARNING)
|
|
85
|
+
|
|
86
|
+
queue_listener = logging.handlers.QueueListener(log_queue, stdout_handler, stderr_handler)
|
|
87
|
+
if queue_listener is not None:
|
|
88
|
+
queue_listener.start()
|
|
89
|
+
atexit.register(queue_listener.stop)
|
|
90
|
+
phoenix_logger.info("Structured logging ready")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
LOG_RECORD_BUILTIN_ATTRS = {
|
|
7
|
+
"args",
|
|
8
|
+
"asctime",
|
|
9
|
+
"created",
|
|
10
|
+
"exc_info",
|
|
11
|
+
"exc_text",
|
|
12
|
+
"filename",
|
|
13
|
+
"funcName",
|
|
14
|
+
"levelname",
|
|
15
|
+
"levelno",
|
|
16
|
+
"lineno",
|
|
17
|
+
"module",
|
|
18
|
+
"msecs",
|
|
19
|
+
"message",
|
|
20
|
+
"msg",
|
|
21
|
+
"name",
|
|
22
|
+
"pathname",
|
|
23
|
+
"process",
|
|
24
|
+
"processName",
|
|
25
|
+
"relativeCreated",
|
|
26
|
+
"stack_info",
|
|
27
|
+
"thread",
|
|
28
|
+
"threadName",
|
|
29
|
+
"taskName",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PhoenixJSONFormatter(logging.Formatter):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
fmt_keys: Optional[dict[str, str]] = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.fmt_keys = fmt_keys if fmt_keys is not None else {}
|
|
41
|
+
|
|
42
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
43
|
+
message = self._prepare_log_dict(record)
|
|
44
|
+
return json.dumps(message, default=str)
|
|
45
|
+
|
|
46
|
+
def _prepare_log_dict(self, record: logging.LogRecord) -> dict[str, str]:
|
|
47
|
+
always_fields = {
|
|
48
|
+
"message": record.getMessage(),
|
|
49
|
+
"timestamp": dt.datetime.fromtimestamp(record.created, tz=dt.timezone.utc).isoformat(),
|
|
50
|
+
}
|
|
51
|
+
if record.exc_info is not None:
|
|
52
|
+
always_fields["exc_info"] = self.formatException(record.exc_info)
|
|
53
|
+
|
|
54
|
+
if record.stack_info is not None:
|
|
55
|
+
always_fields["stack_info"] = self.formatStack(record.stack_info)
|
|
56
|
+
|
|
57
|
+
message = {
|
|
58
|
+
key: msg_val
|
|
59
|
+
if (msg_val := always_fields.pop(val, None)) is not None
|
|
60
|
+
else getattr(record, val)
|
|
61
|
+
for key, val in self.fmt_keys.items()
|
|
62
|
+
}
|
|
63
|
+
message.update(always_fields)
|
|
64
|
+
|
|
65
|
+
for key, val in record.__dict__.items():
|
|
66
|
+
if key not in LOG_RECORD_BUILTIN_ATTRS:
|
|
67
|
+
message[key] = val
|
|
68
|
+
|
|
69
|
+
return message
|
phoenix/metrics/__init__.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Iterable, Mapping
|
|
4
5
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Optional, Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import pandas as pd
|
|
@@ -36,13 +37,13 @@ class Metric(ABC):
|
|
|
36
37
|
def calc(self, dataframe: pd.DataFrame) -> Any: ...
|
|
37
38
|
|
|
38
39
|
@abstractmethod
|
|
39
|
-
def operands(self) ->
|
|
40
|
+
def operands(self) -> list[Column]: ...
|
|
40
41
|
|
|
41
42
|
def __call__(
|
|
42
43
|
self,
|
|
43
44
|
df: pd.DataFrame,
|
|
44
45
|
/,
|
|
45
|
-
subset_rows: Optional[Union[slice,
|
|
46
|
+
subset_rows: Optional[Union[slice, list[int]]] = None,
|
|
46
47
|
) -> Any:
|
|
47
48
|
"""
|
|
48
49
|
Computes the metric on a dataframe.
|
|
@@ -51,7 +52,7 @@ class Metric(ABC):
|
|
|
51
52
|
----------
|
|
52
53
|
df: pandas DataFrame
|
|
53
54
|
The dataframe input to the metric.
|
|
54
|
-
subset_rows: Optional[Union[slice,
|
|
55
|
+
subset_rows: Optional[Union[slice, list[int]]] = None
|
|
55
56
|
Optionally specifying a subset of rows for the computation.
|
|
56
57
|
Can be a list or slice (e.g. `slice(100, 200)`) of integers.
|
|
57
58
|
"""
|
phoenix/metrics/binning.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from functools import partial
|
|
5
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Optional, cast
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import pandas as pd
|
|
@@ -78,7 +79,7 @@ class IntervalBinning(BinningMethod):
|
|
|
78
79
|
else pd.IntervalIndex(
|
|
79
80
|
(
|
|
80
81
|
pd.Interval(
|
|
81
|
-
np.
|
|
82
|
+
-np.inf,
|
|
82
83
|
np.inf,
|
|
83
84
|
closed="neither",
|
|
84
85
|
),
|
|
@@ -208,7 +209,7 @@ class QuantileBinning(IntervalBinning):
|
|
|
208
209
|
# Extend min and max to infinities, unless len(breaks) < 3,
|
|
209
210
|
# in which case the min is kept and two bins are created.
|
|
210
211
|
breaks = breaks[1:-1] if len(breaks) > 2 else breaks[:1]
|
|
211
|
-
breaks = [np.
|
|
212
|
+
breaks = [-np.inf] + breaks + [np.inf]
|
|
212
213
|
return pd.IntervalIndex.from_breaks(
|
|
213
214
|
breaks,
|
|
214
215
|
closed="left",
|
phoenix/metrics/metrics.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from functools import cached_property
|
|
5
|
-
from typing import
|
|
6
|
+
from typing import Union, cast
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import numpy.typing as npt
|
phoenix/metrics/mixins.py
CHANGED
|
@@ -7,10 +7,11 @@ on cooperative multiple inheritance and method resolution order in Python.
|
|
|
7
7
|
import collections
|
|
8
8
|
import inspect
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
|
+
from collections.abc import Callable
|
|
10
11
|
from dataclasses import dataclass, field, fields, replace
|
|
11
12
|
from functools import cached_property
|
|
12
13
|
from itertools import repeat
|
|
13
|
-
from typing import TYPE_CHECKING, Any,
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional
|
|
14
15
|
|
|
15
16
|
import numpy as np
|
|
16
17
|
import pandas as pd
|
|
@@ -42,7 +43,7 @@ class VectorOperator(ABC):
|
|
|
42
43
|
|
|
43
44
|
@dataclass(frozen=True)
|
|
44
45
|
class NullaryOperator(Metric, ABC):
|
|
45
|
-
def operands(self) ->
|
|
46
|
+
def operands(self) -> list[Column]:
|
|
46
47
|
return []
|
|
47
48
|
|
|
48
49
|
|
|
@@ -55,7 +56,7 @@ class UnaryOperator(Metric, ABC):
|
|
|
55
56
|
|
|
56
57
|
operand: Column = Column()
|
|
57
58
|
|
|
58
|
-
def operands(self) ->
|
|
59
|
+
def operands(self) -> list[Column]:
|
|
59
60
|
return [self.operand]
|
|
60
61
|
|
|
61
62
|
|
|
@@ -98,10 +99,10 @@ class EvaluationMetricKeywordParameters(_BaseMapping):
|
|
|
98
99
|
return sum(1 for _ in self)
|
|
99
100
|
|
|
100
101
|
@property
|
|
101
|
-
def columns(self) ->
|
|
102
|
+
def columns(self) -> list[Column]:
|
|
102
103
|
return [v for v in self.values() if isinstance(v, Column)]
|
|
103
104
|
|
|
104
|
-
def __call__(self, df: pd.DataFrame) ->
|
|
105
|
+
def __call__(self, df: pd.DataFrame) -> dict[str, Any]:
|
|
105
106
|
return {k: v(df) if isinstance(v, Column) else v for k, v in self.items()}
|
|
106
107
|
|
|
107
108
|
|
|
@@ -142,7 +143,7 @@ class EvaluationMetric(Metric, ABC):
|
|
|
142
143
|
),
|
|
143
144
|
)
|
|
144
145
|
|
|
145
|
-
def operands(self) ->
|
|
146
|
+
def operands(self) -> list[Column]:
|
|
146
147
|
return [self.actual, self.predicted] + self.parameters.columns
|
|
147
148
|
|
|
148
149
|
def calc(self, df: pd.DataFrame) -> float:
|
phoenix/metrics/timeseries.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
1
2
|
from datetime import datetime, timedelta, timezone
|
|
2
3
|
from functools import partial
|
|
3
4
|
from itertools import accumulate, repeat
|
|
4
|
-
from typing import
|
|
5
|
+
from typing import cast
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
7
8
|
from typing_extensions import TypeAlias
|
|
@@ -41,12 +42,12 @@ def row_interval_from_sorted_time_index(
|
|
|
41
42
|
time_index: pd.DatetimeIndex,
|
|
42
43
|
time_start: datetime,
|
|
43
44
|
time_stop: datetime,
|
|
44
|
-
) ->
|
|
45
|
+
) -> tuple[StartIndex, StopIndex]:
|
|
45
46
|
"""
|
|
46
47
|
Returns end exclusive time slice from sorted index.
|
|
47
48
|
"""
|
|
48
49
|
return cast(
|
|
49
|
-
|
|
50
|
+
tuple[StartIndex, StopIndex],
|
|
50
51
|
time_index.searchsorted((time_start, time_stop)),
|
|
51
52
|
)
|
|
52
53
|
|
|
@@ -86,7 +87,7 @@ def _groupers(
|
|
|
86
87
|
end_time: datetime,
|
|
87
88
|
evaluation_window: timedelta,
|
|
88
89
|
sampling_interval: timedelta,
|
|
89
|
-
) -> Iterator[
|
|
90
|
+
) -> Iterator[tuple[StartTime, EndTime, pd.Grouper]]:
|
|
90
91
|
"""
|
|
91
92
|
Yields pandas.Groupers from time series parameters.
|
|
92
93
|
"""
|
phoenix/metrics/wrappers.py
CHANGED
|
@@ -18,7 +18,7 @@ from abc import ABC
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from inspect import Signature
|
|
20
20
|
from itertools import chain, islice
|
|
21
|
-
from typing import Any,
|
|
21
|
+
from typing import Any, cast
|
|
22
22
|
|
|
23
23
|
import numpy as np
|
|
24
24
|
import pandas as pd
|
|
@@ -27,6 +27,8 @@ from sklearn import metrics as sk
|
|
|
27
27
|
from sklearn.utils.multiclass import check_classification_targets
|
|
28
28
|
from wrapt import PartialCallableObjectProxy
|
|
29
29
|
|
|
30
|
+
from phoenix.config import SKLEARN_VERSION
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
class Eval(PartialCallableObjectProxy, ABC): # type: ignore
|
|
32
34
|
def __call__(
|
|
@@ -157,7 +159,7 @@ def _coerce_dtype_if_necessary(
|
|
|
157
159
|
def _eliminate_missing_values_from_all_series(
|
|
158
160
|
*args: Any,
|
|
159
161
|
**kwargs: Any,
|
|
160
|
-
) ->
|
|
162
|
+
) -> tuple[list[Any], dict[str, Any]]:
|
|
161
163
|
positional_arguments = list(args)
|
|
162
164
|
keyword_arguments = dict(kwargs)
|
|
163
165
|
all_series = [
|
|
@@ -232,5 +234,9 @@ class SkEval(Enum):
|
|
|
232
234
|
r2_score = RegressionEval(sk.r2_score)
|
|
233
235
|
recall_score = ClassificationEval(sk.recall_score)
|
|
234
236
|
roc_auc_score = ScoredClassificationEval(sk.roc_auc_score)
|
|
235
|
-
root_mean_squared_error =
|
|
237
|
+
root_mean_squared_error = (
|
|
238
|
+
RegressionEval(sk.mean_squared_error, squared=False)
|
|
239
|
+
if SKLEARN_VERSION < (1, 6)
|
|
240
|
+
else RegressionEval(sk.root_mean_squared_error)
|
|
241
|
+
)
|
|
236
242
|
zero_one_loss = ClassificationEval(sk.zero_one_loss)
|
phoenix/pointcloud/clustering.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from dataclasses import asdict, dataclass
|
|
2
|
-
from typing import List, Set
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import numpy.typing as npt
|
|
6
|
-
from hdbscan import HDBSCAN
|
|
7
5
|
from typing_extensions import TypeAlias
|
|
8
6
|
|
|
9
7
|
RowIndex: TypeAlias = int
|
|
10
|
-
RawCluster: TypeAlias =
|
|
8
|
+
RawCluster: TypeAlias = set[RowIndex]
|
|
11
9
|
Matrix: TypeAlias = npt.NDArray[np.float64]
|
|
12
10
|
|
|
13
11
|
|
|
@@ -17,9 +15,11 @@ class Hdbscan:
|
|
|
17
15
|
min_samples: float = 1
|
|
18
16
|
cluster_selection_epsilon: float = 0.0
|
|
19
17
|
|
|
20
|
-
def find_clusters(self, mat: Matrix) ->
|
|
18
|
+
def find_clusters(self, mat: Matrix) -> list[RawCluster]:
|
|
19
|
+
from fast_hdbscan import HDBSCAN
|
|
20
|
+
|
|
21
21
|
cluster_ids: npt.NDArray[np.int_] = HDBSCAN(**asdict(self)).fit_predict(mat)
|
|
22
|
-
ans:
|
|
22
|
+
ans: list[RawCluster] = [set() for _ in range(np.max(cluster_ids) + 1)]
|
|
23
23
|
for row_idx, cluster_id in enumerate(cluster_ids):
|
|
24
24
|
if cluster_id > -1:
|
|
25
25
|
ans[cluster_id].add(row_idx)
|
phoenix/pointcloud/pointcloud.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
+
from collections.abc import Hashable, Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
|
-
from typing import
|
|
3
|
+
from typing import Protocol, TypeVar
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import numpy.typing as npt
|
|
6
|
-
from strawberry import ID
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
9
9
|
from phoenix.pointcloud.clustering import RawCluster
|
|
@@ -12,13 +12,15 @@ Vector: TypeAlias = npt.NDArray[np.float64]
|
|
|
12
12
|
Matrix: TypeAlias = npt.NDArray[np.float64]
|
|
13
13
|
RowIndex: TypeAlias = int
|
|
14
14
|
|
|
15
|
+
_IdType = TypeVar("_IdType", bound=Hashable)
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
class DimensionalityReducer(Protocol):
|
|
17
19
|
def project(self, mat: Matrix, n_components: int) -> Matrix: ...
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
class ClustersFinder(Protocol):
|
|
21
|
-
def find_clusters(self, mat: Matrix) ->
|
|
23
|
+
def find_clusters(self, mat: Matrix) -> list[RawCluster]: ...
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
@dataclass(frozen=True)
|
|
@@ -28,9 +30,9 @@ class PointCloud:
|
|
|
28
30
|
|
|
29
31
|
def generate(
|
|
30
32
|
self,
|
|
31
|
-
data: Mapping[
|
|
33
|
+
data: Mapping[_IdType, Vector],
|
|
32
34
|
n_components: int = 3,
|
|
33
|
-
) ->
|
|
35
|
+
) -> tuple[dict[_IdType, Vector], dict[str, set[_IdType]]]:
|
|
34
36
|
"""
|
|
35
37
|
Given a set of vectors, projects them onto lower dimensions, and
|
|
36
38
|
finds clusters among the projections.
|
phoenix/pointcloud/projectors.py
CHANGED
|
@@ -6,12 +6,6 @@ import numpy as np
|
|
|
6
6
|
import numpy.typing as npt
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
9
|
-
with warnings.catch_warnings():
|
|
10
|
-
from numba.core.errors import NumbaWarning
|
|
11
|
-
|
|
12
|
-
warnings.simplefilter("ignore", category=NumbaWarning)
|
|
13
|
-
from umap import UMAP
|
|
14
|
-
|
|
15
9
|
Matrix: TypeAlias = npt.NDArray[np.float64]
|
|
16
10
|
|
|
17
11
|
|
|
@@ -25,6 +19,11 @@ class Umap:
|
|
|
25
19
|
min_dist: float = 0.1
|
|
26
20
|
|
|
27
21
|
def project(self, mat: Matrix, n_components: int) -> Matrix:
|
|
22
|
+
with warnings.catch_warnings():
|
|
23
|
+
from numba.core.errors import NumbaWarning
|
|
24
|
+
|
|
25
|
+
warnings.simplefilter("ignore", category=NumbaWarning)
|
|
26
|
+
from umap import UMAP
|
|
28
27
|
config = asdict(self)
|
|
29
28
|
config["n_components"] = n_components
|
|
30
29
|
if len(mat) <= n_components:
|