arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +207 -76
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +268 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +389 -285
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a22.dist-info/RECORD +0 -146
- arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/datasets/validation.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
"""Dataset validation logic for structure and content checks."""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
@@ -7,7 +7,17 @@ from arize.datasets import errors as err
|
|
|
7
7
|
|
|
8
8
|
def validate_dataset_df(
|
|
9
9
|
df: pd.DataFrame,
|
|
10
|
-
) ->
|
|
10
|
+
) -> list[err.DatasetError]:
|
|
11
|
+
"""Validate a dataset DataFrame for structural and content errors.
|
|
12
|
+
|
|
13
|
+
Checks for required columns, unique ID values, and non-empty data.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
df: The pandas DataFrame to validate.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A list of DatasetError objects found during validation. Empty list if valid.
|
|
20
|
+
"""
|
|
11
21
|
## check all require columns are present
|
|
12
22
|
required_columns_errors = _check_required_columns(df)
|
|
13
23
|
if required_columns_errors:
|
|
@@ -19,14 +29,14 @@ def validate_dataset_df(
|
|
|
19
29
|
return id_column_unique_constraint_error
|
|
20
30
|
|
|
21
31
|
# check DataFrame has at least one row in it
|
|
22
|
-
|
|
23
|
-
if
|
|
24
|
-
return
|
|
32
|
+
empty_dataframe_error = _check_empty_dataframe(df)
|
|
33
|
+
if empty_dataframe_error:
|
|
34
|
+
return empty_dataframe_error
|
|
25
35
|
|
|
26
36
|
return []
|
|
27
37
|
|
|
28
38
|
|
|
29
|
-
def _check_required_columns(df: pd.DataFrame) ->
|
|
39
|
+
def _check_required_columns(df: pd.DataFrame) -> list[err.DatasetError]:
|
|
30
40
|
required_columns = ["id", "created_at", "updated_at"]
|
|
31
41
|
missing_columns = set(required_columns) - set(df.columns)
|
|
32
42
|
if missing_columns:
|
|
@@ -34,13 +44,13 @@ def _check_required_columns(df: pd.DataFrame) -> List[err.DatasetError]:
|
|
|
34
44
|
return []
|
|
35
45
|
|
|
36
46
|
|
|
37
|
-
def _check_id_column_is_unique(df: pd.DataFrame) ->
|
|
47
|
+
def _check_id_column_is_unique(df: pd.DataFrame) -> list[err.DatasetError]:
|
|
38
48
|
if not df["id"].is_unique:
|
|
39
49
|
return [err.IDColumnUniqueConstraintError()]
|
|
40
50
|
return []
|
|
41
51
|
|
|
42
52
|
|
|
43
|
-
def _check_empty_dataframe(df: pd.DataFrame) ->
|
|
53
|
+
def _check_empty_dataframe(df: pd.DataFrame) -> list[err.DatasetError]:
|
|
44
54
|
if df.empty:
|
|
45
55
|
return [err.EmptyDatasetError()]
|
|
46
56
|
return []
|
arize/embeddings/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
"""Automatic embedding generation factory for various ML use cases."""
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
|
|
@@ -30,7 +30,14 @@ UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class EmbeddingGenerator:
|
|
33
|
-
|
|
33
|
+
"""Factory class for creating embedding generators based on use case."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, **kwargs: str) -> None:
|
|
36
|
+
"""Raise error directing users to use from_use_case factory method.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
OSError: Always raised to prevent direct instantiation.
|
|
40
|
+
"""
|
|
34
41
|
raise OSError(
|
|
35
42
|
f"{self.__class__.__name__} is designed to be instantiated using the "
|
|
36
43
|
f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
|
|
@@ -38,23 +45,24 @@ class EmbeddingGenerator:
|
|
|
38
45
|
|
|
39
46
|
@staticmethod
|
|
40
47
|
def from_use_case(
|
|
41
|
-
use_case: UseCaseLike, **kwargs:
|
|
48
|
+
use_case: UseCaseLike, **kwargs: object
|
|
42
49
|
) -> BaseEmbeddingGenerator:
|
|
50
|
+
"""Create an embedding generator for the specified use case."""
|
|
43
51
|
if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
|
|
44
52
|
return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
|
|
45
|
-
|
|
53
|
+
if use_case == UseCases.NLP.SUMMARIZATION:
|
|
46
54
|
return EmbeddingGeneratorForNLPSummarization(**kwargs)
|
|
47
|
-
|
|
55
|
+
if use_case == UseCases.CV.IMAGE_CLASSIFICATION:
|
|
48
56
|
return EmbeddingGeneratorForCVImageClassification(**kwargs)
|
|
49
|
-
|
|
57
|
+
if use_case == UseCases.CV.OBJECT_DETECTION:
|
|
50
58
|
return EmbeddingGeneratorForCVObjectDetection(**kwargs)
|
|
51
|
-
|
|
59
|
+
if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
|
|
52
60
|
return EmbeddingGeneratorForTabularFeatures(**kwargs)
|
|
53
|
-
|
|
54
|
-
raise ValueError(f"Invalid use case {use_case}")
|
|
61
|
+
raise ValueError(f"Invalid use case {use_case}")
|
|
55
62
|
|
|
56
63
|
@classmethod
|
|
57
64
|
def list_default_models(cls) -> pd.DataFrame:
|
|
65
|
+
"""Return a DataFrame of default models for each use case."""
|
|
58
66
|
df = pd.DataFrame(
|
|
59
67
|
{
|
|
60
68
|
"Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
|
|
@@ -74,13 +82,12 @@ class EmbeddingGenerator:
|
|
|
74
82
|
],
|
|
75
83
|
}
|
|
76
84
|
)
|
|
77
|
-
df.sort_values(
|
|
78
|
-
by=[col for col in df.columns], ascending=True, inplace=True
|
|
79
|
-
)
|
|
85
|
+
df.sort_values(by=list(df.columns), ascending=True, inplace=True)
|
|
80
86
|
return df.reset_index(drop=True)
|
|
81
87
|
|
|
82
88
|
@classmethod
|
|
83
89
|
def list_pretrained_models(cls) -> pd.DataFrame:
|
|
90
|
+
"""Return a DataFrame of all available pretrained models."""
|
|
84
91
|
data = {
|
|
85
92
|
"Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
|
|
86
93
|
+ ["CV" for _ in CV_PRETRAINED_MODELS],
|
|
@@ -91,18 +98,15 @@ class EmbeddingGenerator:
|
|
|
91
98
|
"Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
|
|
92
99
|
}
|
|
93
100
|
df = pd.DataFrame(data)
|
|
94
|
-
df.sort_values(
|
|
95
|
-
by=[col for col in df.columns], ascending=True, inplace=True
|
|
96
|
-
)
|
|
101
|
+
df.sort_values(by=list(df.columns), ascending=True, inplace=True)
|
|
97
102
|
return df.reset_index(drop=True)
|
|
98
103
|
|
|
99
104
|
@staticmethod
|
|
100
105
|
def __parse_model_arch(model_name: str) -> str:
|
|
101
106
|
if constants.GPT.lower() in model_name.lower():
|
|
102
107
|
return constants.GPT
|
|
103
|
-
|
|
108
|
+
if constants.BERT.lower() in model_name.lower():
|
|
104
109
|
return constants.BERT
|
|
105
|
-
|
|
110
|
+
if constants.VIT.lower() in model_name.lower():
|
|
106
111
|
return constants.VIT
|
|
107
|
-
|
|
108
|
-
raise ValueError("Invalid model_name, unknown architecture.")
|
|
112
|
+
raise ValueError("Invalid model_name, unknown architecture.")
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
"""Base embedding generator classes for NLP, CV, and tabular data."""
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
from abc import ABC, abstractmethod
|
|
3
5
|
from enum import Enum
|
|
4
6
|
from functools import partial
|
|
5
|
-
from typing import Dict, List, Union, cast
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
|
|
@@ -31,9 +32,26 @@ transformer_logging.enable_progress_bar()
|
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class BaseEmbeddingGenerator(ABC):
|
|
35
|
+
"""Abstract base class for all embedding generators."""
|
|
36
|
+
|
|
34
37
|
def __init__(
|
|
35
|
-
self,
|
|
36
|
-
|
|
38
|
+
self,
|
|
39
|
+
use_case: Enum,
|
|
40
|
+
model_name: str,
|
|
41
|
+
batch_size: int = 100,
|
|
42
|
+
**kwargs: object,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Initialize the embedding generator with model and configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
use_case: Enum specifying the use case for embedding generation.
|
|
48
|
+
model_name: Name of the pre-trained model to use.
|
|
49
|
+
batch_size: Number of samples to process per batch.
|
|
50
|
+
**kwargs: Additional arguments for model initialization.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
HuggingFaceRepositoryNotFound: If the model name is not found on HuggingFace.
|
|
54
|
+
"""
|
|
37
55
|
self.__use_case = self._parse_use_case(use_case=use_case)
|
|
38
56
|
self.__model_name = model_name
|
|
39
57
|
self.__device = self.select_device()
|
|
@@ -45,43 +63,50 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
45
63
|
).to(self.device)
|
|
46
64
|
except OSError as e:
|
|
47
65
|
raise err.HuggingFaceRepositoryNotFound(model_name) from e
|
|
48
|
-
except Exception
|
|
49
|
-
raise
|
|
66
|
+
except Exception:
|
|
67
|
+
raise
|
|
50
68
|
|
|
51
69
|
@abstractmethod
|
|
52
|
-
def generate_embeddings(self, **kwargs) -> pd.Series:
|
|
70
|
+
def generate_embeddings(self, **kwargs: object) -> pd.Series:
|
|
71
|
+
"""Generate embeddings for the input data."""
|
|
72
|
+
...
|
|
53
73
|
|
|
54
74
|
def select_device(self) -> torch.device:
|
|
75
|
+
"""Select the best available device (CUDA, MPS, or CPU) for model execution."""
|
|
55
76
|
if torch.cuda.is_available():
|
|
56
77
|
return torch.device("cuda")
|
|
57
|
-
|
|
78
|
+
if torch.backends.mps.is_available():
|
|
58
79
|
return torch.device("mps")
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
return torch.device("cpu")
|
|
80
|
+
logger.warning(
|
|
81
|
+
"No available GPU has been detected. The use of GPU acceleration is "
|
|
82
|
+
"strongly recommended. You can check for GPU availability by running "
|
|
83
|
+
"`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
|
|
84
|
+
)
|
|
85
|
+
return torch.device("cpu")
|
|
66
86
|
|
|
67
87
|
@property
|
|
68
88
|
def use_case(self) -> str:
|
|
89
|
+
"""Return the use case for this embedding generator."""
|
|
69
90
|
return self.__use_case
|
|
70
91
|
|
|
71
92
|
@property
|
|
72
93
|
def model_name(self) -> str:
|
|
94
|
+
"""Return the name of the model being used."""
|
|
73
95
|
return self.__model_name
|
|
74
96
|
|
|
75
97
|
@property
|
|
76
|
-
def model(self):
|
|
98
|
+
def model(self) -> object:
|
|
99
|
+
"""Return the underlying model instance."""
|
|
77
100
|
return self.__model
|
|
78
101
|
|
|
79
102
|
@property
|
|
80
103
|
def device(self) -> torch.device:
|
|
104
|
+
"""Return the device (CPU/GPU) being used for computation."""
|
|
81
105
|
return self.__device
|
|
82
106
|
|
|
83
107
|
@property
|
|
84
108
|
def batch_size(self) -> int:
|
|
109
|
+
"""Return the batch size for processing."""
|
|
85
110
|
return self.__batch_size
|
|
86
111
|
|
|
87
112
|
@batch_size.setter
|
|
@@ -89,11 +114,10 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
89
114
|
err_message = "New batch size should be an integer greater than 0."
|
|
90
115
|
if not isinstance(new_batch_size, int):
|
|
91
116
|
raise TypeError(err_message)
|
|
92
|
-
|
|
117
|
+
if new_batch_size <= 0:
|
|
93
118
|
raise ValueError(err_message)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
logger.info(f"Batch size has been set to {new_batch_size}.")
|
|
119
|
+
self.__batch_size = new_batch_size
|
|
120
|
+
logger.info(f"Batch size has been set to {new_batch_size}.")
|
|
97
121
|
|
|
98
122
|
@staticmethod
|
|
99
123
|
def _parse_use_case(use_case: Enum) -> str:
|
|
@@ -102,8 +126,8 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
102
126
|
return f"{uc_area}.{uc_task}"
|
|
103
127
|
|
|
104
128
|
def _get_embedding_vector(
|
|
105
|
-
self, batch:
|
|
106
|
-
) ->
|
|
129
|
+
self, batch: dict[str, torch.Tensor], method: str
|
|
130
|
+
) -> dict[str, torch.Tensor]:
|
|
107
131
|
with torch.no_grad():
|
|
108
132
|
outputs = self.model(**batch)
|
|
109
133
|
# (batch_size, seq_length/or/num_tokens, hidden_size)
|
|
@@ -116,20 +140,23 @@ class BaseEmbeddingGenerator(ABC):
|
|
|
116
140
|
return {"embedding_vector": embeddings.cpu().numpy().astype(float)}
|
|
117
141
|
|
|
118
142
|
@staticmethod
|
|
119
|
-
def check_invalid_index(field:
|
|
143
|
+
def check_invalid_index(field: pd.Series | pd.DataFrame) -> None:
|
|
144
|
+
"""Check if the field has a valid index and raise error if invalid."""
|
|
120
145
|
if (field.index != field.reset_index(drop=True).index).any():
|
|
121
146
|
if isinstance(field, pd.DataFrame):
|
|
122
147
|
raise err.InvalidIndexError("DataFrame")
|
|
123
|
-
|
|
124
|
-
raise err.InvalidIndexError(str(field.name))
|
|
148
|
+
raise err.InvalidIndexError(str(field.name))
|
|
125
149
|
|
|
126
150
|
@abstractmethod
|
|
127
151
|
def __repr__(self) -> str:
|
|
128
|
-
|
|
152
|
+
"""Return a string representation of the embedding generator."""
|
|
129
153
|
|
|
130
154
|
|
|
131
155
|
class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
156
|
+
"""Base class for NLP embedding generators with text tokenization support."""
|
|
157
|
+
|
|
132
158
|
def __repr__(self) -> str:
|
|
159
|
+
"""Return a string representation of the NLP embedding generator."""
|
|
133
160
|
return (
|
|
134
161
|
f"{self.__class__.__name__}(\n"
|
|
135
162
|
f" use_case={self.use_case},\n"
|
|
@@ -146,8 +173,16 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
146
173
|
use_case: Enum,
|
|
147
174
|
model_name: str,
|
|
148
175
|
tokenizer_max_length: int = 512,
|
|
149
|
-
**kwargs,
|
|
150
|
-
):
|
|
176
|
+
**kwargs: object,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Initialize the NLP embedding generator with tokenizer configuration.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
use_case: Enum specifying the NLP use case.
|
|
182
|
+
model_name: Name of the pre-trained NLP model.
|
|
183
|
+
tokenizer_max_length: Maximum sequence length for the tokenizer.
|
|
184
|
+
**kwargs: Additional arguments for model initialization.
|
|
185
|
+
"""
|
|
151
186
|
super().__init__(use_case=use_case, model_name=model_name, **kwargs)
|
|
152
187
|
self.__tokenizer_max_length = tokenizer_max_length
|
|
153
188
|
# We don't check for the tokenizer's existence since it is coupled with the corresponding model
|
|
@@ -158,16 +193,19 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
158
193
|
)
|
|
159
194
|
|
|
160
195
|
@property
|
|
161
|
-
def tokenizer(self):
|
|
196
|
+
def tokenizer(self) -> object:
|
|
197
|
+
"""Return the tokenizer instance for text processing."""
|
|
162
198
|
return self.__tokenizer
|
|
163
199
|
|
|
164
200
|
@property
|
|
165
201
|
def tokenizer_max_length(self) -> int:
|
|
202
|
+
"""Return the maximum sequence length for the tokenizer."""
|
|
166
203
|
return self.__tokenizer_max_length
|
|
167
204
|
|
|
168
205
|
def tokenize(
|
|
169
|
-
self, batch:
|
|
206
|
+
self, batch: dict[str, list[str]], text_feat_name: str
|
|
170
207
|
) -> BatchEncoding:
|
|
208
|
+
"""Tokenize a batch of text inputs."""
|
|
171
209
|
return self.tokenizer(
|
|
172
210
|
batch[text_feat_name],
|
|
173
211
|
padding=True,
|
|
@@ -178,7 +216,10 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
178
216
|
|
|
179
217
|
|
|
180
218
|
class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
219
|
+
"""Base class for computer vision embedding generators with image preprocessing support."""
|
|
220
|
+
|
|
181
221
|
def __repr__(self) -> str:
|
|
222
|
+
"""Return a string representation of the computer vision embedding generator."""
|
|
182
223
|
return (
|
|
183
224
|
f"{self.__class__.__name__}(\n"
|
|
184
225
|
f" use_case={self.use_case},\n"
|
|
@@ -189,7 +230,16 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
189
230
|
f")"
|
|
190
231
|
)
|
|
191
232
|
|
|
192
|
-
def __init__(
|
|
233
|
+
def __init__(
|
|
234
|
+
self, use_case: Enum, model_name: str, **kwargs: object
|
|
235
|
+
) -> None:
|
|
236
|
+
"""Initialize the computer vision embedding generator with image processor.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
use_case: Enum specifying the computer vision use case.
|
|
240
|
+
model_name: Name of the pre-trained vision model.
|
|
241
|
+
**kwargs: Additional arguments for model initialization.
|
|
242
|
+
"""
|
|
193
243
|
super().__init__(use_case=use_case, model_name=model_name, **kwargs)
|
|
194
244
|
logger.info("Downloading image processor")
|
|
195
245
|
# We don't check for the image processor's existence since it is coupled with the corresponding model
|
|
@@ -199,18 +249,21 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
199
249
|
)
|
|
200
250
|
|
|
201
251
|
@property
|
|
202
|
-
def image_processor(self):
|
|
252
|
+
def image_processor(self) -> object:
|
|
253
|
+
"""Return the image processor instance for image preprocessing."""
|
|
203
254
|
return self.__image_processor
|
|
204
255
|
|
|
205
256
|
@staticmethod
|
|
206
257
|
def open_image(image_path: str) -> Image.Image:
|
|
258
|
+
"""Open and convert an image to RGB format."""
|
|
207
259
|
if not os.path.exists(image_path):
|
|
208
260
|
raise ValueError(f"Cannot find image {image_path}")
|
|
209
261
|
return Image.open(image_path).convert("RGB")
|
|
210
262
|
|
|
211
263
|
def preprocess_image(
|
|
212
|
-
self, batch:
|
|
213
|
-
):
|
|
264
|
+
self, batch: dict[str, list[str]], local_image_feat_name: str
|
|
265
|
+
) -> object:
|
|
266
|
+
"""Preprocess a batch of images for model input."""
|
|
214
267
|
return self.image_processor(
|
|
215
268
|
[
|
|
216
269
|
self.open_image(image_path)
|
|
@@ -220,8 +273,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
220
273
|
).to(self.device)
|
|
221
274
|
|
|
222
275
|
def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
|
|
223
|
-
"""
|
|
224
|
-
Obtain embedding vectors from your image data using pre-trained image models.
|
|
276
|
+
"""Obtain embedding vectors from your image data using pre-trained image models.
|
|
225
277
|
|
|
226
278
|
:param local_image_path_col: a pandas Series containing the local path to the images to
|
|
227
279
|
be used to generate the embedding vectors.
|
|
@@ -252,4 +304,5 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
|
|
|
252
304
|
batched=True,
|
|
253
305
|
batch_size=self.batch_size,
|
|
254
306
|
)
|
|
255
|
-
|
|
307
|
+
df: pd.DataFrame = ds.to_pandas()
|
|
308
|
+
return df["embedding_vector"]
|
arize/embeddings/constants.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Computer vision embedding generators for image classification and object detection."""
|
|
2
|
+
|
|
1
3
|
from arize.embeddings.base_generators import CVEmbeddingGenerator
|
|
2
4
|
from arize.embeddings.constants import (
|
|
3
5
|
DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
|
|
@@ -7,9 +9,19 @@ from arize.embeddings.usecases import UseCases
|
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
|
|
12
|
+
"""Embedding generator for computer vision image classification tasks."""
|
|
13
|
+
|
|
10
14
|
def __init__(
|
|
11
|
-
self,
|
|
12
|
-
|
|
15
|
+
self,
|
|
16
|
+
model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
|
|
17
|
+
**kwargs: object,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Initialize the image classification embedding generator.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_name: Name of the pre-trained vision model.
|
|
23
|
+
**kwargs: Additional arguments for model initialization.
|
|
24
|
+
"""
|
|
13
25
|
super().__init__(
|
|
14
26
|
use_case=UseCases.CV.IMAGE_CLASSIFICATION,
|
|
15
27
|
model_name=model_name,
|
|
@@ -18,9 +30,19 @@ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
|
|
|
18
30
|
|
|
19
31
|
|
|
20
32
|
class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
|
|
33
|
+
"""Embedding generator for computer vision object detection tasks."""
|
|
34
|
+
|
|
21
35
|
def __init__(
|
|
22
|
-
self,
|
|
23
|
-
|
|
36
|
+
self,
|
|
37
|
+
model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL,
|
|
38
|
+
**kwargs: object,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize the object detection embedding generator.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model_name: Name of the pre-trained vision model.
|
|
44
|
+
**kwargs: Additional arguments for model initialization.
|
|
45
|
+
"""
|
|
24
46
|
super().__init__(
|
|
25
47
|
use_case=UseCases.CV.OBJECT_DETECTION,
|
|
26
48
|
model_name=model_name,
|
arize/embeddings/errors.py
CHANGED
|
@@ -1,37 +1,59 @@
|
|
|
1
|
+
"""Embedding generation exception classes."""
|
|
2
|
+
|
|
3
|
+
|
|
1
4
|
class InvalidIndexError(Exception):
|
|
5
|
+
"""Raised when DataFrame or Series has an invalid index."""
|
|
6
|
+
|
|
2
7
|
def __repr__(self) -> str:
|
|
8
|
+
"""Return a string representation for debugging and logging."""
|
|
3
9
|
return "Invalid_Index_Error"
|
|
4
10
|
|
|
5
11
|
def __str__(self) -> str:
|
|
12
|
+
"""Return a human-readable error message."""
|
|
6
13
|
return self.error_message()
|
|
7
14
|
|
|
8
15
|
def __init__(self, field_name: str) -> None:
|
|
16
|
+
"""Initialize the exception with field name context.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
field_name: Name of the DataFrame or Series field with invalid index.
|
|
20
|
+
"""
|
|
9
21
|
self.field_name = field_name
|
|
10
22
|
|
|
11
23
|
def error_message(self) -> str:
|
|
24
|
+
"""Return the error message for this exception."""
|
|
12
25
|
if self.field_name == "DataFrame":
|
|
13
26
|
return (
|
|
14
27
|
f"The index of the {self.field_name} is invalid; "
|
|
15
28
|
f"reset the index by using df.reset_index(drop=True, inplace=True)"
|
|
16
29
|
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
)
|
|
30
|
+
return (
|
|
31
|
+
f"The index of the Series given by the column '{self.field_name}' is invalid; "
|
|
32
|
+
f"reset the index by using df.reset_index(drop=True, inplace=True)"
|
|
33
|
+
)
|
|
22
34
|
|
|
23
35
|
|
|
24
36
|
class HuggingFaceRepositoryNotFound(Exception):
|
|
37
|
+
"""Raised when HuggingFace model repository is not found."""
|
|
38
|
+
|
|
25
39
|
def __repr__(self) -> str:
|
|
40
|
+
"""Return a string representation for debugging and logging."""
|
|
26
41
|
return "HuggingFace_Repository_Not_Found_Error"
|
|
27
42
|
|
|
28
43
|
def __str__(self) -> str:
|
|
44
|
+
"""Return a human-readable error message."""
|
|
29
45
|
return self.error_message()
|
|
30
46
|
|
|
31
47
|
def __init__(self, model_name: str) -> None:
|
|
48
|
+
"""Initialize the exception with model name context.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model_name: Name of the HuggingFace model that was not found.
|
|
52
|
+
"""
|
|
32
53
|
self.model_name = model_name
|
|
33
54
|
|
|
34
55
|
def error_message(self) -> str:
|
|
56
|
+
"""Return the error message for this exception."""
|
|
35
57
|
return (
|
|
36
58
|
f"The given model name '{self.model_name}' is not a valid model identifier listed on "
|
|
37
59
|
"'https://huggingface.co/models'. "
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
"""NLP embedding generators for text classification and summarization tasks."""
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from functools import partial
|
|
3
|
-
from typing import Optional, cast
|
|
4
5
|
|
|
5
6
|
import pandas as pd
|
|
6
7
|
|
|
@@ -22,11 +23,19 @@ logger = logging.getLogger(__name__)
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
26
|
+
"""Embedding generator for NLP sequence classification tasks."""
|
|
27
|
+
|
|
25
28
|
def __init__(
|
|
26
29
|
self,
|
|
27
30
|
model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
|
|
28
|
-
**kwargs,
|
|
29
|
-
):
|
|
31
|
+
**kwargs: object,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize the sequence classification embedding generator.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_name: Name of the pre-trained NLP model.
|
|
37
|
+
**kwargs: Additional arguments for model initialization.
|
|
38
|
+
"""
|
|
30
39
|
super().__init__(
|
|
31
40
|
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
|
32
41
|
model_name=model_name,
|
|
@@ -36,10 +45,9 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
36
45
|
def generate_embeddings(
|
|
37
46
|
self,
|
|
38
47
|
text_col: pd.Series,
|
|
39
|
-
class_label_col:
|
|
48
|
+
class_label_col: pd.Series | None = None,
|
|
40
49
|
) -> pd.Series:
|
|
41
|
-
"""
|
|
42
|
-
Obtain embedding vectors from your text data using pre-trained large language models.
|
|
50
|
+
"""Obtain embedding vectors from your text data using pre-trained large language models.
|
|
43
51
|
|
|
44
52
|
:param text_col: a pandas Series containing the different pieces of text.
|
|
45
53
|
:param class_label_col: if this column is passed, the sentence "The classification label
|
|
@@ -72,13 +80,24 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
|
|
|
72
80
|
batched=True,
|
|
73
81
|
batch_size=self.batch_size,
|
|
74
82
|
)
|
|
75
|
-
|
|
83
|
+
df: pd.DataFrame = ds.to_pandas()
|
|
84
|
+
return df["embedding_vector"]
|
|
76
85
|
|
|
77
86
|
|
|
78
87
|
class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
88
|
+
"""Embedding generator for NLP text summarization tasks."""
|
|
89
|
+
|
|
79
90
|
def __init__(
|
|
80
|
-
self,
|
|
81
|
-
|
|
91
|
+
self,
|
|
92
|
+
model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL,
|
|
93
|
+
**kwargs: object,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Initialize the text summarization embedding generator.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
model_name: Name of the pre-trained NLP model.
|
|
99
|
+
**kwargs: Additional arguments for model initialization.
|
|
100
|
+
"""
|
|
82
101
|
super().__init__(
|
|
83
102
|
use_case=UseCases.NLP.SUMMARIZATION,
|
|
84
103
|
model_name=model_name,
|
|
@@ -89,8 +108,7 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
|
89
108
|
self,
|
|
90
109
|
text_col: pd.Series,
|
|
91
110
|
) -> pd.Series:
|
|
92
|
-
"""
|
|
93
|
-
Obtain embedding vectors from your text data using pre-trained large language models.
|
|
111
|
+
"""Obtain embedding vectors from your text data using pre-trained large language models.
|
|
94
112
|
|
|
95
113
|
:param text_col: a pandas Series containing the different pieces of text.
|
|
96
114
|
:return: a pandas Series containing the embedding vectors.
|
|
@@ -108,4 +126,5 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
|
|
|
108
126
|
batched=True,
|
|
109
127
|
batch_size=self.batch_size,
|
|
110
128
|
)
|
|
111
|
-
|
|
129
|
+
df: pd.DataFrame = ds.to_pandas()
|
|
130
|
+
return df["embedding_vector"]
|