arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +208 -77
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +269 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +390 -286
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a21.dist-info/RECORD +0 -146
  166. arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
@@ -1,4 +1,4 @@
1
- from typing import List
1
+ """Dataset validation logic for structure and content checks."""
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -7,7 +7,17 @@ from arize.datasets import errors as err
7
7
 
8
8
  def validate_dataset_df(
9
9
  df: pd.DataFrame,
10
- ) -> List[err.DatasetError]:
10
+ ) -> list[err.DatasetError]:
11
+ """Validate a dataset DataFrame for structural and content errors.
12
+
13
+ Checks for required columns, unique ID values, and non-empty data.
14
+
15
+ Args:
16
+ df: The pandas DataFrame to validate.
17
+
18
+ Returns:
19
+ A list of DatasetError objects found during validation. Empty list if valid.
20
+ """
11
21
  ## check all require columns are present
12
22
  required_columns_errors = _check_required_columns(df)
13
23
  if required_columns_errors:
@@ -19,14 +29,14 @@ def validate_dataset_df(
19
29
  return id_column_unique_constraint_error
20
30
 
21
31
  # check DataFrame has at least one row in it
22
- emtpy_dataframe_error = _check_empty_dataframe(df)
23
- if emtpy_dataframe_error:
24
- return emtpy_dataframe_error
32
+ empty_dataframe_error = _check_empty_dataframe(df)
33
+ if empty_dataframe_error:
34
+ return empty_dataframe_error
25
35
 
26
36
  return []
27
37
 
28
38
 
29
- def _check_required_columns(df: pd.DataFrame) -> List[err.DatasetError]:
39
+ def _check_required_columns(df: pd.DataFrame) -> list[err.DatasetError]:
30
40
  required_columns = ["id", "created_at", "updated_at"]
31
41
  missing_columns = set(required_columns) - set(df.columns)
32
42
  if missing_columns:
@@ -34,13 +44,13 @@ def _check_required_columns(df: pd.DataFrame) -> List[err.DatasetError]:
34
44
  return []
35
45
 
36
46
 
37
- def _check_id_column_is_unique(df: pd.DataFrame) -> List[err.DatasetError]:
47
+ def _check_id_column_is_unique(df: pd.DataFrame) -> list[err.DatasetError]:
38
48
  if not df["id"].is_unique:
39
49
  return [err.IDColumnUniqueConstraintError()]
40
50
  return []
41
51
 
42
52
 
43
- def _check_empty_dataframe(df: pd.DataFrame) -> List[err.DatasetError]:
53
+ def _check_empty_dataframe(df: pd.DataFrame) -> list[err.DatasetError]:
44
54
  if df.empty:
45
55
  return [err.EmptyDatasetError()]
46
56
  return []
@@ -1,3 +1,5 @@
1
+ """Embedding generation and use case utilities for the Arize SDK."""
2
+
1
3
  from arize.embeddings.auto_generator import EmbeddingGenerator
2
4
  from arize.embeddings.usecases import UseCases
3
5
 
@@ -1,4 +1,4 @@
1
- from typing import Any
1
+ """Automatic embedding generation factory for various ML use cases."""
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -30,7 +30,14 @@ UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
30
30
 
31
31
 
32
32
  class EmbeddingGenerator:
33
- def __init__(self, **kwargs: str):
33
+ """Factory class for creating embedding generators based on use case."""
34
+
35
+ def __init__(self, **kwargs: str) -> None:
36
+ """Raise error directing users to use from_use_case factory method.
37
+
38
+ Raises:
39
+ OSError: Always raised to prevent direct instantiation.
40
+ """
34
41
  raise OSError(
35
42
  f"{self.__class__.__name__} is designed to be instantiated using the "
36
43
  f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
@@ -38,23 +45,24 @@ class EmbeddingGenerator:
38
45
 
39
46
  @staticmethod
40
47
  def from_use_case(
41
- use_case: UseCaseLike, **kwargs: Any
48
+ use_case: UseCaseLike, **kwargs: object
42
49
  ) -> BaseEmbeddingGenerator:
50
+ """Create an embedding generator for the specified use case."""
43
51
  if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
44
52
  return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
45
- elif use_case == UseCases.NLP.SUMMARIZATION:
53
+ if use_case == UseCases.NLP.SUMMARIZATION:
46
54
  return EmbeddingGeneratorForNLPSummarization(**kwargs)
47
- elif use_case == UseCases.CV.IMAGE_CLASSIFICATION:
55
+ if use_case == UseCases.CV.IMAGE_CLASSIFICATION:
48
56
  return EmbeddingGeneratorForCVImageClassification(**kwargs)
49
- elif use_case == UseCases.CV.OBJECT_DETECTION:
57
+ if use_case == UseCases.CV.OBJECT_DETECTION:
50
58
  return EmbeddingGeneratorForCVObjectDetection(**kwargs)
51
- elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
59
+ if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
52
60
  return EmbeddingGeneratorForTabularFeatures(**kwargs)
53
- else:
54
- raise ValueError(f"Invalid use case {use_case}")
61
+ raise ValueError(f"Invalid use case {use_case}")
55
62
 
56
63
  @classmethod
57
64
  def list_default_models(cls) -> pd.DataFrame:
65
+ """Return a DataFrame of default models for each use case."""
58
66
  df = pd.DataFrame(
59
67
  {
60
68
  "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
@@ -74,13 +82,12 @@ class EmbeddingGenerator:
74
82
  ],
75
83
  }
76
84
  )
77
- df.sort_values(
78
- by=[col for col in df.columns], ascending=True, inplace=True
79
- )
85
+ df.sort_values(by=list(df.columns), ascending=True, inplace=True)
80
86
  return df.reset_index(drop=True)
81
87
 
82
88
  @classmethod
83
89
  def list_pretrained_models(cls) -> pd.DataFrame:
90
+ """Return a DataFrame of all available pretrained models."""
84
91
  data = {
85
92
  "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
86
93
  + ["CV" for _ in CV_PRETRAINED_MODELS],
@@ -91,18 +98,15 @@ class EmbeddingGenerator:
91
98
  "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
92
99
  }
93
100
  df = pd.DataFrame(data)
94
- df.sort_values(
95
- by=[col for col in df.columns], ascending=True, inplace=True
96
- )
101
+ df.sort_values(by=list(df.columns), ascending=True, inplace=True)
97
102
  return df.reset_index(drop=True)
98
103
 
99
104
  @staticmethod
100
105
  def __parse_model_arch(model_name: str) -> str:
101
106
  if constants.GPT.lower() in model_name.lower():
102
107
  return constants.GPT
103
- elif constants.BERT.lower() in model_name.lower():
108
+ if constants.BERT.lower() in model_name.lower():
104
109
  return constants.BERT
105
- elif constants.VIT.lower() in model_name.lower():
110
+ if constants.VIT.lower() in model_name.lower():
106
111
  return constants.VIT
107
- else:
108
- raise ValueError("Invalid model_name, unknown architecture.")
112
+ raise ValueError("Invalid model_name, unknown architecture.")
@@ -1,8 +1,9 @@
1
+ """Base embedding generator classes for NLP, CV, and tabular data."""
2
+
1
3
  import os
2
4
  from abc import ABC, abstractmethod
3
5
  from enum import Enum
4
6
  from functools import partial
5
- from typing import Dict, List, Union, cast
6
7
 
7
8
  import pandas as pd
8
9
 
@@ -31,9 +32,26 @@ transformer_logging.enable_progress_bar()
31
32
 
32
33
 
33
34
  class BaseEmbeddingGenerator(ABC):
35
+ """Abstract base class for all embedding generators."""
36
+
34
37
  def __init__(
35
- self, use_case: Enum, model_name: str, batch_size: int = 100, **kwargs
36
- ):
38
+ self,
39
+ use_case: Enum,
40
+ model_name: str,
41
+ batch_size: int = 100,
42
+ **kwargs: object,
43
+ ) -> None:
44
+ """Initialize the embedding generator with model and configuration.
45
+
46
+ Args:
47
+ use_case: Enum specifying the use case for embedding generation.
48
+ model_name: Name of the pre-trained model to use.
49
+ batch_size: Number of samples to process per batch.
50
+ **kwargs: Additional arguments for model initialization.
51
+
52
+ Raises:
53
+ HuggingFaceRepositoryNotFound: If the model name is not found on HuggingFace.
54
+ """
37
55
  self.__use_case = self._parse_use_case(use_case=use_case)
38
56
  self.__model_name = model_name
39
57
  self.__device = self.select_device()
@@ -45,43 +63,50 @@ class BaseEmbeddingGenerator(ABC):
45
63
  ).to(self.device)
46
64
  except OSError as e:
47
65
  raise err.HuggingFaceRepositoryNotFound(model_name) from e
48
- except Exception as e:
49
- raise e
66
+ except Exception:
67
+ raise
50
68
 
51
69
  @abstractmethod
52
- def generate_embeddings(self, **kwargs) -> pd.Series: ...
70
+ def generate_embeddings(self, **kwargs: object) -> pd.Series:
71
+ """Generate embeddings for the input data."""
72
+ ...
53
73
 
54
74
  def select_device(self) -> torch.device:
75
+ """Select the best available device (CUDA, MPS, or CPU) for model execution."""
55
76
  if torch.cuda.is_available():
56
77
  return torch.device("cuda")
57
- elif torch.backends.mps.is_available():
78
+ if torch.backends.mps.is_available():
58
79
  return torch.device("mps")
59
- else:
60
- logger.warning(
61
- "No available GPU has been detected. The use of GPU acceleration is "
62
- "strongly recommended. You can check for GPU availability by running "
63
- "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
64
- )
65
- return torch.device("cpu")
80
+ logger.warning(
81
+ "No available GPU has been detected. The use of GPU acceleration is "
82
+ "strongly recommended. You can check for GPU availability by running "
83
+ "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
84
+ )
85
+ return torch.device("cpu")
66
86
 
67
87
  @property
68
88
  def use_case(self) -> str:
89
+ """Return the use case for this embedding generator."""
69
90
  return self.__use_case
70
91
 
71
92
  @property
72
93
  def model_name(self) -> str:
94
+ """Return the name of the model being used."""
73
95
  return self.__model_name
74
96
 
75
97
  @property
76
- def model(self):
98
+ def model(self) -> object:
99
+ """Return the underlying model instance."""
77
100
  return self.__model
78
101
 
79
102
  @property
80
103
  def device(self) -> torch.device:
104
+ """Return the device (CPU/GPU) being used for computation."""
81
105
  return self.__device
82
106
 
83
107
  @property
84
108
  def batch_size(self) -> int:
109
+ """Return the batch size for processing."""
85
110
  return self.__batch_size
86
111
 
87
112
  @batch_size.setter
@@ -89,11 +114,10 @@ class BaseEmbeddingGenerator(ABC):
89
114
  err_message = "New batch size should be an integer greater than 0."
90
115
  if not isinstance(new_batch_size, int):
91
116
  raise TypeError(err_message)
92
- elif new_batch_size <= 0:
117
+ if new_batch_size <= 0:
93
118
  raise ValueError(err_message)
94
- else:
95
- self.__batch_size = new_batch_size
96
- logger.info(f"Batch size has been set to {new_batch_size}.")
119
+ self.__batch_size = new_batch_size
120
+ logger.info(f"Batch size has been set to {new_batch_size}.")
97
121
 
98
122
  @staticmethod
99
123
  def _parse_use_case(use_case: Enum) -> str:
@@ -102,8 +126,8 @@ class BaseEmbeddingGenerator(ABC):
102
126
  return f"{uc_area}.{uc_task}"
103
127
 
104
128
  def _get_embedding_vector(
105
- self, batch: Dict[str, torch.Tensor], method
106
- ) -> Dict[str, torch.Tensor]:
129
+ self, batch: dict[str, torch.Tensor], method: str
130
+ ) -> dict[str, torch.Tensor]:
107
131
  with torch.no_grad():
108
132
  outputs = self.model(**batch)
109
133
  # (batch_size, seq_length/or/num_tokens, hidden_size)
@@ -116,20 +140,23 @@ class BaseEmbeddingGenerator(ABC):
116
140
  return {"embedding_vector": embeddings.cpu().numpy().astype(float)}
117
141
 
118
142
  @staticmethod
119
- def check_invalid_index(field: Union[pd.Series, pd.DataFrame]) -> None:
143
+ def check_invalid_index(field: pd.Series | pd.DataFrame) -> None:
144
+ """Check if the field has a valid index and raise error if invalid."""
120
145
  if (field.index != field.reset_index(drop=True).index).any():
121
146
  if isinstance(field, pd.DataFrame):
122
147
  raise err.InvalidIndexError("DataFrame")
123
- else:
124
- raise err.InvalidIndexError(str(field.name))
148
+ raise err.InvalidIndexError(str(field.name))
125
149
 
126
150
  @abstractmethod
127
151
  def __repr__(self) -> str:
128
- pass
152
+ """Return a string representation of the embedding generator."""
129
153
 
130
154
 
131
155
  class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
156
+ """Base class for NLP embedding generators with text tokenization support."""
157
+
132
158
  def __repr__(self) -> str:
159
+ """Return a string representation of the NLP embedding generator."""
133
160
  return (
134
161
  f"{self.__class__.__name__}(\n"
135
162
  f" use_case={self.use_case},\n"
@@ -146,8 +173,16 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
146
173
  use_case: Enum,
147
174
  model_name: str,
148
175
  tokenizer_max_length: int = 512,
149
- **kwargs,
150
- ):
176
+ **kwargs: object,
177
+ ) -> None:
178
+ """Initialize the NLP embedding generator with tokenizer configuration.
179
+
180
+ Args:
181
+ use_case: Enum specifying the NLP use case.
182
+ model_name: Name of the pre-trained NLP model.
183
+ tokenizer_max_length: Maximum sequence length for the tokenizer.
184
+ **kwargs: Additional arguments for model initialization.
185
+ """
151
186
  super().__init__(use_case=use_case, model_name=model_name, **kwargs)
152
187
  self.__tokenizer_max_length = tokenizer_max_length
153
188
  # We don't check for the tokenizer's existence since it is coupled with the corresponding model
@@ -158,16 +193,19 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
158
193
  )
159
194
 
160
195
  @property
161
- def tokenizer(self):
196
+ def tokenizer(self) -> object:
197
+ """Return the tokenizer instance for text processing."""
162
198
  return self.__tokenizer
163
199
 
164
200
  @property
165
201
  def tokenizer_max_length(self) -> int:
202
+ """Return the maximum sequence length for the tokenizer."""
166
203
  return self.__tokenizer_max_length
167
204
 
168
205
  def tokenize(
169
- self, batch: Dict[str, List[str]], text_feat_name: str
206
+ self, batch: dict[str, list[str]], text_feat_name: str
170
207
  ) -> BatchEncoding:
208
+ """Tokenize a batch of text inputs."""
171
209
  return self.tokenizer(
172
210
  batch[text_feat_name],
173
211
  padding=True,
@@ -178,7 +216,10 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
178
216
 
179
217
 
180
218
  class CVEmbeddingGenerator(BaseEmbeddingGenerator):
219
+ """Base class for computer vision embedding generators with image preprocessing support."""
220
+
181
221
  def __repr__(self) -> str:
222
+ """Return a string representation of the computer vision embedding generator."""
182
223
  return (
183
224
  f"{self.__class__.__name__}(\n"
184
225
  f" use_case={self.use_case},\n"
@@ -189,7 +230,16 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
189
230
  f")"
190
231
  )
191
232
 
192
- def __init__(self, use_case: Enum, model_name: str, **kwargs):
233
+ def __init__(
234
+ self, use_case: Enum, model_name: str, **kwargs: object
235
+ ) -> None:
236
+ """Initialize the computer vision embedding generator with image processor.
237
+
238
+ Args:
239
+ use_case: Enum specifying the computer vision use case.
240
+ model_name: Name of the pre-trained vision model.
241
+ **kwargs: Additional arguments for model initialization.
242
+ """
193
243
  super().__init__(use_case=use_case, model_name=model_name, **kwargs)
194
244
  logger.info("Downloading image processor")
195
245
  # We don't check for the image processor's existence since it is coupled with the corresponding model
@@ -199,18 +249,21 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
199
249
  )
200
250
 
201
251
  @property
202
- def image_processor(self):
252
+ def image_processor(self) -> object:
253
+ """Return the image processor instance for image preprocessing."""
203
254
  return self.__image_processor
204
255
 
205
256
  @staticmethod
206
257
  def open_image(image_path: str) -> Image.Image:
258
+ """Open and convert an image to RGB format."""
207
259
  if not os.path.exists(image_path):
208
260
  raise ValueError(f"Cannot find image {image_path}")
209
261
  return Image.open(image_path).convert("RGB")
210
262
 
211
263
  def preprocess_image(
212
- self, batch: Dict[str, List[str]], local_image_feat_name: str
213
- ):
264
+ self, batch: dict[str, list[str]], local_image_feat_name: str
265
+ ) -> object:
266
+ """Preprocess a batch of images for model input."""
214
267
  return self.image_processor(
215
268
  [
216
269
  self.open_image(image_path)
@@ -220,8 +273,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
220
273
  ).to(self.device)
221
274
 
222
275
  def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
223
- """
224
- Obtain embedding vectors from your image data using pre-trained image models.
276
+ """Obtain embedding vectors from your image data using pre-trained image models.
225
277
 
226
278
  :param local_image_path_col: a pandas Series containing the local path to the images to
227
279
  be used to generate the embedding vectors.
@@ -252,4 +304,5 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
252
304
  batched=True,
253
305
  batch_size=self.batch_size,
254
306
  )
255
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
307
+ df: pd.DataFrame = ds.to_pandas()
308
+ return df["embedding_vector"]
@@ -1,3 +1,5 @@
1
+ """Embedding generation constants and pre-trained model definitions."""
2
+
1
3
  DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased"
2
4
  DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased"
3
5
  DEFAULT_TABULAR_MODEL = "distilbert-base-uncased"
@@ -1,3 +1,5 @@
1
+ """Computer vision embedding generators for image classification and object detection."""
2
+
1
3
  from arize.embeddings.base_generators import CVEmbeddingGenerator
2
4
  from arize.embeddings.constants import (
3
5
  DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
@@ -7,9 +9,19 @@ from arize.embeddings.usecases import UseCases
7
9
 
8
10
 
9
11
  class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
12
+ """Embedding generator for computer vision image classification tasks."""
13
+
10
14
  def __init__(
11
- self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs
12
- ):
15
+ self,
16
+ model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
17
+ **kwargs: object,
18
+ ) -> None:
19
+ """Initialize the image classification embedding generator.
20
+
21
+ Args:
22
+ model_name: Name of the pre-trained vision model.
23
+ **kwargs: Additional arguments for model initialization.
24
+ """
13
25
  super().__init__(
14
26
  use_case=UseCases.CV.IMAGE_CLASSIFICATION,
15
27
  model_name=model_name,
@@ -18,9 +30,19 @@ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
18
30
 
19
31
 
20
32
  class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
33
+ """Embedding generator for computer vision object detection tasks."""
34
+
21
35
  def __init__(
22
- self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs
23
- ):
36
+ self,
37
+ model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL,
38
+ **kwargs: object,
39
+ ) -> None:
40
+ """Initialize the object detection embedding generator.
41
+
42
+ Args:
43
+ model_name: Name of the pre-trained vision model.
44
+ **kwargs: Additional arguments for model initialization.
45
+ """
24
46
  super().__init__(
25
47
  use_case=UseCases.CV.OBJECT_DETECTION,
26
48
  model_name=model_name,
@@ -1,37 +1,59 @@
1
+ """Embedding generation exception classes."""
2
+
3
+
1
4
  class InvalidIndexError(Exception):
5
+ """Raised when DataFrame or Series has an invalid index."""
6
+
2
7
  def __repr__(self) -> str:
8
+ """Return a string representation for debugging and logging."""
3
9
  return "Invalid_Index_Error"
4
10
 
5
11
  def __str__(self) -> str:
12
+ """Return a human-readable error message."""
6
13
  return self.error_message()
7
14
 
8
15
  def __init__(self, field_name: str) -> None:
16
+ """Initialize the exception with field name context.
17
+
18
+ Args:
19
+ field_name: Name of the DataFrame or Series field with invalid index.
20
+ """
9
21
  self.field_name = field_name
10
22
 
11
23
  def error_message(self) -> str:
24
+ """Return the error message for this exception."""
12
25
  if self.field_name == "DataFrame":
13
26
  return (
14
27
  f"The index of the {self.field_name} is invalid; "
15
28
  f"reset the index by using df.reset_index(drop=True, inplace=True)"
16
29
  )
17
- else:
18
- return (
19
- f"The index of the Series given by the column '{self.field_name}' is invalid; "
20
- f"reset the index by using df.reset_index(drop=True, inplace=True)"
21
- )
30
+ return (
31
+ f"The index of the Series given by the column '{self.field_name}' is invalid; "
32
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
33
+ )
22
34
 
23
35
 
24
36
  class HuggingFaceRepositoryNotFound(Exception):
37
+ """Raised when HuggingFace model repository is not found."""
38
+
25
39
  def __repr__(self) -> str:
40
+ """Return a string representation for debugging and logging."""
26
41
  return "HuggingFace_Repository_Not_Found_Error"
27
42
 
28
43
  def __str__(self) -> str:
44
+ """Return a human-readable error message."""
29
45
  return self.error_message()
30
46
 
31
47
  def __init__(self, model_name: str) -> None:
48
+ """Initialize the exception with model name context.
49
+
50
+ Args:
51
+ model_name: Name of the HuggingFace model that was not found.
52
+ """
32
53
  self.model_name = model_name
33
54
 
34
55
  def error_message(self) -> str:
56
+ """Return the error message for this exception."""
35
57
  return (
36
58
  f"The given model name '{self.model_name}' is not a valid model identifier listed on "
37
59
  "'https://huggingface.co/models'. "
@@ -1,6 +1,7 @@
1
+ """NLP embedding generators for text classification and summarization tasks."""
2
+
1
3
  import logging
2
4
  from functools import partial
3
- from typing import Optional, cast
4
5
 
5
6
  import pandas as pd
6
7
 
@@ -22,11 +23,19 @@ logger = logging.getLogger(__name__)
22
23
 
23
24
 
24
25
  class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
26
+ """Embedding generator for NLP sequence classification tasks."""
27
+
25
28
  def __init__(
26
29
  self,
27
30
  model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
28
- **kwargs,
29
- ):
31
+ **kwargs: object,
32
+ ) -> None:
33
+ """Initialize the sequence classification embedding generator.
34
+
35
+ Args:
36
+ model_name: Name of the pre-trained NLP model.
37
+ **kwargs: Additional arguments for model initialization.
38
+ """
30
39
  super().__init__(
31
40
  use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
32
41
  model_name=model_name,
@@ -36,10 +45,9 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
36
45
  def generate_embeddings(
37
46
  self,
38
47
  text_col: pd.Series,
39
- class_label_col: Optional[pd.Series] = None,
48
+ class_label_col: pd.Series | None = None,
40
49
  ) -> pd.Series:
41
- """
42
- Obtain embedding vectors from your text data using pre-trained large language models.
50
+ """Obtain embedding vectors from your text data using pre-trained large language models.
43
51
 
44
52
  :param text_col: a pandas Series containing the different pieces of text.
45
53
  :param class_label_col: if this column is passed, the sentence "The classification label
@@ -72,13 +80,24 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
72
80
  batched=True,
73
81
  batch_size=self.batch_size,
74
82
  )
75
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
83
+ df: pd.DataFrame = ds.to_pandas()
84
+ return df["embedding_vector"]
76
85
 
77
86
 
78
87
  class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
88
+ """Embedding generator for NLP text summarization tasks."""
89
+
79
90
  def __init__(
80
- self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs
81
- ):
91
+ self,
92
+ model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL,
93
+ **kwargs: object,
94
+ ) -> None:
95
+ """Initialize the text summarization embedding generator.
96
+
97
+ Args:
98
+ model_name: Name of the pre-trained NLP model.
99
+ **kwargs: Additional arguments for model initialization.
100
+ """
82
101
  super().__init__(
83
102
  use_case=UseCases.NLP.SUMMARIZATION,
84
103
  model_name=model_name,
@@ -89,8 +108,7 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
89
108
  self,
90
109
  text_col: pd.Series,
91
110
  ) -> pd.Series:
92
- """
93
- Obtain embedding vectors from your text data using pre-trained large language models.
111
+ """Obtain embedding vectors from your text data using pre-trained large language models.
94
112
 
95
113
  :param text_col: a pandas Series containing the different pieces of text.
96
114
  :return: a pandas Series containing the embedding vectors.
@@ -108,4 +126,5 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
108
126
  batched=True,
109
127
  batch_size=self.batch_size,
110
128
  )
111
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
129
+ df: pd.DataFrame = ds.to_pandas()
130
+ return df["embedding_vector"]