arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,4 +1,4 @@
1
- from typing import Any
1
+ """Automatic embedding generation factory for various ML use cases."""
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -30,7 +30,14 @@ UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
30
30
 
31
31
 
32
32
  class EmbeddingGenerator:
33
- def __init__(self, **kwargs: str):
33
+ """Factory class for creating embedding generators based on use case."""
34
+
35
+ def __init__(self, **kwargs: str) -> None:
36
+ """Raise error directing users to use from_use_case factory method.
37
+
38
+ Raises:
39
+ OSError: Always raised to prevent direct instantiation.
40
+ """
34
41
  raise OSError(
35
42
  f"{self.__class__.__name__} is designed to be instantiated using the "
36
43
  f"`{self.__class__.__name__}.from_use_case(use_case, **kwargs)` method."
@@ -38,23 +45,24 @@ class EmbeddingGenerator:
38
45
 
39
46
  @staticmethod
40
47
  def from_use_case(
41
- use_case: UseCaseLike, **kwargs: Any
48
+ use_case: UseCaseLike, **kwargs: object
42
49
  ) -> BaseEmbeddingGenerator:
50
+ """Create an embedding generator for the specified use case."""
43
51
  if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
44
52
  return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
45
- elif use_case == UseCases.NLP.SUMMARIZATION:
53
+ if use_case == UseCases.NLP.SUMMARIZATION:
46
54
  return EmbeddingGeneratorForNLPSummarization(**kwargs)
47
- elif use_case == UseCases.CV.IMAGE_CLASSIFICATION:
55
+ if use_case == UseCases.CV.IMAGE_CLASSIFICATION:
48
56
  return EmbeddingGeneratorForCVImageClassification(**kwargs)
49
- elif use_case == UseCases.CV.OBJECT_DETECTION:
57
+ if use_case == UseCases.CV.OBJECT_DETECTION:
50
58
  return EmbeddingGeneratorForCVObjectDetection(**kwargs)
51
- elif use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
59
+ if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
52
60
  return EmbeddingGeneratorForTabularFeatures(**kwargs)
53
- else:
54
- raise ValueError(f"Invalid use case {use_case}")
61
+ raise ValueError(f"Invalid use case {use_case}")
55
62
 
56
63
  @classmethod
57
64
  def list_default_models(cls) -> pd.DataFrame:
65
+ """Return a DataFrame of default models for each use case."""
58
66
  df = pd.DataFrame(
59
67
  {
60
68
  "Area": ["NLP", "NLP", "CV", "CV", "STRUCTURED"],
@@ -74,13 +82,12 @@ class EmbeddingGenerator:
74
82
  ],
75
83
  }
76
84
  )
77
- df.sort_values(
78
- by=[col for col in df.columns], ascending=True, inplace=True
79
- )
85
+ df.sort_values(by=list(df.columns), ascending=True, inplace=True)
80
86
  return df.reset_index(drop=True)
81
87
 
82
88
  @classmethod
83
89
  def list_pretrained_models(cls) -> pd.DataFrame:
90
+ """Return a DataFrame of all available pretrained models."""
84
91
  data = {
85
92
  "Task": ["NLP" for _ in NLP_PRETRAINED_MODELS]
86
93
  + ["CV" for _ in CV_PRETRAINED_MODELS],
@@ -91,18 +98,15 @@ class EmbeddingGenerator:
91
98
  "Model Name": NLP_PRETRAINED_MODELS + CV_PRETRAINED_MODELS,
92
99
  }
93
100
  df = pd.DataFrame(data)
94
- df.sort_values(
95
- by=[col for col in df.columns], ascending=True, inplace=True
96
- )
101
+ df.sort_values(by=list(df.columns), ascending=True, inplace=True)
97
102
  return df.reset_index(drop=True)
98
103
 
99
104
  @staticmethod
100
105
  def __parse_model_arch(model_name: str) -> str:
101
106
  if constants.GPT.lower() in model_name.lower():
102
107
  return constants.GPT
103
- elif constants.BERT.lower() in model_name.lower():
108
+ if constants.BERT.lower() in model_name.lower():
104
109
  return constants.BERT
105
- elif constants.VIT.lower() in model_name.lower():
110
+ if constants.VIT.lower() in model_name.lower():
106
111
  return constants.VIT
107
- else:
108
- raise ValueError("Invalid model_name, unknown architecture.")
112
+ raise ValueError("Invalid model_name, unknown architecture.")
@@ -1,8 +1,9 @@
1
+ """Base embedding generator classes for NLP, CV, and tabular data."""
2
+
1
3
  import os
2
4
  from abc import ABC, abstractmethod
3
5
  from enum import Enum
4
6
  from functools import partial
5
- from typing import Dict, List, Union, cast
6
7
 
7
8
  import pandas as pd
8
9
 
@@ -31,9 +32,26 @@ transformer_logging.enable_progress_bar()
31
32
 
32
33
 
33
34
  class BaseEmbeddingGenerator(ABC):
35
+ """Abstract base class for all embedding generators."""
36
+
34
37
  def __init__(
35
- self, use_case: Enum, model_name: str, batch_size: int = 100, **kwargs
36
- ):
38
+ self,
39
+ use_case: Enum,
40
+ model_name: str,
41
+ batch_size: int = 100,
42
+ **kwargs: object,
43
+ ) -> None:
44
+ """Initialize the embedding generator with model and configuration.
45
+
46
+ Args:
47
+ use_case: Enum specifying the use case for embedding generation.
48
+ model_name: Name of the pre-trained model to use.
49
+ batch_size: Number of samples to process per batch.
50
+ **kwargs: Additional arguments for model initialization.
51
+
52
+ Raises:
53
+ HuggingFaceRepositoryNotFound: If the model name is not found on HuggingFace.
54
+ """
37
55
  self.__use_case = self._parse_use_case(use_case=use_case)
38
56
  self.__model_name = model_name
39
57
  self.__device = self.select_device()
@@ -45,43 +63,50 @@ class BaseEmbeddingGenerator(ABC):
45
63
  ).to(self.device)
46
64
  except OSError as e:
47
65
  raise err.HuggingFaceRepositoryNotFound(model_name) from e
48
- except Exception as e:
49
- raise e
66
+ except Exception:
67
+ raise
50
68
 
51
69
  @abstractmethod
52
- def generate_embeddings(self, **kwargs) -> pd.Series: ...
70
+ def generate_embeddings(self, **kwargs: object) -> pd.Series:
71
+ """Generate embeddings for the input data."""
72
+ ...
53
73
 
54
74
  def select_device(self) -> torch.device:
75
+ """Select the best available device (CUDA, MPS, or CPU) for model execution."""
55
76
  if torch.cuda.is_available():
56
77
  return torch.device("cuda")
57
- elif torch.backends.mps.is_available():
78
+ if torch.backends.mps.is_available():
58
79
  return torch.device("mps")
59
- else:
60
- logger.warning(
61
- "No available GPU has been detected. The use of GPU acceleration is "
62
- "strongly recommended. You can check for GPU availability by running "
63
- "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
64
- )
65
- return torch.device("cpu")
80
+ logger.warning(
81
+ "No available GPU has been detected. The use of GPU acceleration is "
82
+ "strongly recommended. You can check for GPU availability by running "
83
+ "`torch.cuda.is_available()` or `torch.backends.mps.is_available()`."
84
+ )
85
+ return torch.device("cpu")
66
86
 
67
87
  @property
68
88
  def use_case(self) -> str:
89
+ """Return the use case for this embedding generator."""
69
90
  return self.__use_case
70
91
 
71
92
  @property
72
93
  def model_name(self) -> str:
94
+ """Return the name of the model being used."""
73
95
  return self.__model_name
74
96
 
75
97
  @property
76
- def model(self):
98
+ def model(self) -> object:
99
+ """Return the underlying model instance."""
77
100
  return self.__model
78
101
 
79
102
  @property
80
103
  def device(self) -> torch.device:
104
+ """Return the device (CPU/GPU) being used for computation."""
81
105
  return self.__device
82
106
 
83
107
  @property
84
108
  def batch_size(self) -> int:
109
+ """Return the batch size for processing."""
85
110
  return self.__batch_size
86
111
 
87
112
  @batch_size.setter
@@ -89,11 +114,10 @@ class BaseEmbeddingGenerator(ABC):
89
114
  err_message = "New batch size should be an integer greater than 0."
90
115
  if not isinstance(new_batch_size, int):
91
116
  raise TypeError(err_message)
92
- elif new_batch_size <= 0:
117
+ if new_batch_size <= 0:
93
118
  raise ValueError(err_message)
94
- else:
95
- self.__batch_size = new_batch_size
96
- logger.info(f"Batch size has been set to {new_batch_size}.")
119
+ self.__batch_size = new_batch_size
120
+ logger.info(f"Batch size has been set to {new_batch_size}.")
97
121
 
98
122
  @staticmethod
99
123
  def _parse_use_case(use_case: Enum) -> str:
@@ -102,8 +126,8 @@ class BaseEmbeddingGenerator(ABC):
102
126
  return f"{uc_area}.{uc_task}"
103
127
 
104
128
  def _get_embedding_vector(
105
- self, batch: Dict[str, torch.Tensor], method
106
- ) -> Dict[str, torch.Tensor]:
129
+ self, batch: dict[str, torch.Tensor], method: str
130
+ ) -> dict[str, torch.Tensor]:
107
131
  with torch.no_grad():
108
132
  outputs = self.model(**batch)
109
133
  # (batch_size, seq_length/or/num_tokens, hidden_size)
@@ -116,20 +140,23 @@ class BaseEmbeddingGenerator(ABC):
116
140
  return {"embedding_vector": embeddings.cpu().numpy().astype(float)}
117
141
 
118
142
  @staticmethod
119
- def check_invalid_index(field: Union[pd.Series, pd.DataFrame]) -> None:
143
+ def check_invalid_index(field: pd.Series | pd.DataFrame) -> None:
144
+ """Check if the field has a valid index and raise error if invalid."""
120
145
  if (field.index != field.reset_index(drop=True).index).any():
121
146
  if isinstance(field, pd.DataFrame):
122
147
  raise err.InvalidIndexError("DataFrame")
123
- else:
124
- raise err.InvalidIndexError(str(field.name))
148
+ raise err.InvalidIndexError(str(field.name))
125
149
 
126
150
  @abstractmethod
127
151
  def __repr__(self) -> str:
128
- pass
152
+ """Return a string representation of the embedding generator."""
129
153
 
130
154
 
131
155
  class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
156
+ """Base class for NLP embedding generators with text tokenization support."""
157
+
132
158
  def __repr__(self) -> str:
159
+ """Return a string representation of the NLP embedding generator."""
133
160
  return (
134
161
  f"{self.__class__.__name__}(\n"
135
162
  f" use_case={self.use_case},\n"
@@ -146,8 +173,16 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
146
173
  use_case: Enum,
147
174
  model_name: str,
148
175
  tokenizer_max_length: int = 512,
149
- **kwargs,
150
- ):
176
+ **kwargs: object,
177
+ ) -> None:
178
+ """Initialize the NLP embedding generator with tokenizer configuration.
179
+
180
+ Args:
181
+ use_case: Enum specifying the NLP use case.
182
+ model_name: Name of the pre-trained NLP model.
183
+ tokenizer_max_length: Maximum sequence length for the tokenizer.
184
+ **kwargs: Additional arguments for model initialization.
185
+ """
151
186
  super().__init__(use_case=use_case, model_name=model_name, **kwargs)
152
187
  self.__tokenizer_max_length = tokenizer_max_length
153
188
  # We don't check for the tokenizer's existence since it is coupled with the corresponding model
@@ -158,16 +193,19 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
158
193
  )
159
194
 
160
195
  @property
161
- def tokenizer(self):
196
+ def tokenizer(self) -> object:
197
+ """Return the tokenizer instance for text processing."""
162
198
  return self.__tokenizer
163
199
 
164
200
  @property
165
201
  def tokenizer_max_length(self) -> int:
202
+ """Return the maximum sequence length for the tokenizer."""
166
203
  return self.__tokenizer_max_length
167
204
 
168
205
  def tokenize(
169
- self, batch: Dict[str, List[str]], text_feat_name: str
206
+ self, batch: dict[str, list[str]], text_feat_name: str
170
207
  ) -> BatchEncoding:
208
+ """Tokenize a batch of text inputs."""
171
209
  return self.tokenizer(
172
210
  batch[text_feat_name],
173
211
  padding=True,
@@ -178,7 +216,10 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
178
216
 
179
217
 
180
218
  class CVEmbeddingGenerator(BaseEmbeddingGenerator):
219
+ """Base class for computer vision embedding generators with image preprocessing support."""
220
+
181
221
  def __repr__(self) -> str:
222
+ """Return a string representation of the computer vision embedding generator."""
182
223
  return (
183
224
  f"{self.__class__.__name__}(\n"
184
225
  f" use_case={self.use_case},\n"
@@ -189,7 +230,16 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
189
230
  f")"
190
231
  )
191
232
 
192
- def __init__(self, use_case: Enum, model_name: str, **kwargs):
233
+ def __init__(
234
+ self, use_case: Enum, model_name: str, **kwargs: object
235
+ ) -> None:
236
+ """Initialize the computer vision embedding generator with image processor.
237
+
238
+ Args:
239
+ use_case: Enum specifying the computer vision use case.
240
+ model_name: Name of the pre-trained vision model.
241
+ **kwargs: Additional arguments for model initialization.
242
+ """
193
243
  super().__init__(use_case=use_case, model_name=model_name, **kwargs)
194
244
  logger.info("Downloading image processor")
195
245
  # We don't check for the image processor's existence since it is coupled with the corresponding model
@@ -199,18 +249,21 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
199
249
  )
200
250
 
201
251
  @property
202
- def image_processor(self):
252
+ def image_processor(self) -> object:
253
+ """Return the image processor instance for image preprocessing."""
203
254
  return self.__image_processor
204
255
 
205
256
  @staticmethod
206
257
  def open_image(image_path: str) -> Image.Image:
258
+ """Open and convert an image to RGB format."""
207
259
  if not os.path.exists(image_path):
208
260
  raise ValueError(f"Cannot find image {image_path}")
209
261
  return Image.open(image_path).convert("RGB")
210
262
 
211
263
  def preprocess_image(
212
- self, batch: Dict[str, List[str]], local_image_feat_name: str
213
- ):
264
+ self, batch: dict[str, list[str]], local_image_feat_name: str
265
+ ) -> object:
266
+ """Preprocess a batch of images for model input."""
214
267
  return self.image_processor(
215
268
  [
216
269
  self.open_image(image_path)
@@ -220,8 +273,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
220
273
  ).to(self.device)
221
274
 
222
275
  def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
223
- """
224
- Obtain embedding vectors from your image data using pre-trained image models.
276
+ """Obtain embedding vectors from your image data using pre-trained image models.
225
277
 
226
278
  :param local_image_path_col: a pandas Series containing the local path to the images to
227
279
  be used to generate the embedding vectors.
@@ -252,4 +304,5 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
252
304
  batched=True,
253
305
  batch_size=self.batch_size,
254
306
  )
255
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
307
+ df: pd.DataFrame = ds.to_pandas()
308
+ return df["embedding_vector"]
@@ -1,3 +1,5 @@
1
+ """Embedding generation constants and pre-trained model definitions."""
2
+
1
3
  DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL = "distilbert-base-uncased"
2
4
  DEFAULT_NLP_SUMMARIZATION_MODEL = "distilbert-base-uncased"
3
5
  DEFAULT_TABULAR_MODEL = "distilbert-base-uncased"
@@ -1,3 +1,5 @@
1
+ """Computer vision embedding generators for image classification and object detection."""
2
+
1
3
  from arize.embeddings.base_generators import CVEmbeddingGenerator
2
4
  from arize.embeddings.constants import (
3
5
  DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
@@ -7,9 +9,19 @@ from arize.embeddings.usecases import UseCases
7
9
 
8
10
 
9
11
  class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
12
+ """Embedding generator for computer vision image classification tasks."""
13
+
10
14
  def __init__(
11
- self, model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL, **kwargs
12
- ):
15
+ self,
16
+ model_name: str = DEFAULT_CV_IMAGE_CLASSIFICATION_MODEL,
17
+ **kwargs: object,
18
+ ) -> None:
19
+ """Initialize the image classification embedding generator.
20
+
21
+ Args:
22
+ model_name: Name of the pre-trained vision model.
23
+ **kwargs: Additional arguments for model initialization.
24
+ """
13
25
  super().__init__(
14
26
  use_case=UseCases.CV.IMAGE_CLASSIFICATION,
15
27
  model_name=model_name,
@@ -18,9 +30,19 @@ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
18
30
 
19
31
 
20
32
  class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
33
+ """Embedding generator for computer vision object detection tasks."""
34
+
21
35
  def __init__(
22
- self, model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL, **kwargs
23
- ):
36
+ self,
37
+ model_name: str = DEFAULT_CV_OBJECT_DETECTION_MODEL,
38
+ **kwargs: object,
39
+ ) -> None:
40
+ """Initialize the object detection embedding generator.
41
+
42
+ Args:
43
+ model_name: Name of the pre-trained vision model.
44
+ **kwargs: Additional arguments for model initialization.
45
+ """
24
46
  super().__init__(
25
47
  use_case=UseCases.CV.OBJECT_DETECTION,
26
48
  model_name=model_name,
@@ -1,37 +1,59 @@
1
+ """Embedding generation exception classes."""
2
+
3
+
1
4
  class InvalidIndexError(Exception):
5
+ """Raised when DataFrame or Series has an invalid index."""
6
+
2
7
  def __repr__(self) -> str:
8
+ """Return a string representation for debugging and logging."""
3
9
  return "Invalid_Index_Error"
4
10
 
5
11
  def __str__(self) -> str:
12
+ """Return a human-readable error message."""
6
13
  return self.error_message()
7
14
 
8
15
  def __init__(self, field_name: str) -> None:
16
+ """Initialize the exception with field name context.
17
+
18
+ Args:
19
+ field_name: Name of the DataFrame or Series field with invalid index.
20
+ """
9
21
  self.field_name = field_name
10
22
 
11
23
  def error_message(self) -> str:
24
+ """Return the error message for this exception."""
12
25
  if self.field_name == "DataFrame":
13
26
  return (
14
27
  f"The index of the {self.field_name} is invalid; "
15
28
  f"reset the index by using df.reset_index(drop=True, inplace=True)"
16
29
  )
17
- else:
18
- return (
19
- f"The index of the Series given by the column '{self.field_name}' is invalid; "
20
- f"reset the index by using df.reset_index(drop=True, inplace=True)"
21
- )
30
+ return (
31
+ f"The index of the Series given by the column '{self.field_name}' is invalid; "
32
+ f"reset the index by using df.reset_index(drop=True, inplace=True)"
33
+ )
22
34
 
23
35
 
24
36
  class HuggingFaceRepositoryNotFound(Exception):
37
+ """Raised when HuggingFace model repository is not found."""
38
+
25
39
  def __repr__(self) -> str:
40
+ """Return a string representation for debugging and logging."""
26
41
  return "HuggingFace_Repository_Not_Found_Error"
27
42
 
28
43
  def __str__(self) -> str:
44
+ """Return a human-readable error message."""
29
45
  return self.error_message()
30
46
 
31
47
  def __init__(self, model_name: str) -> None:
48
+ """Initialize the exception with model name context.
49
+
50
+ Args:
51
+ model_name: Name of the HuggingFace model that was not found.
52
+ """
32
53
  self.model_name = model_name
33
54
 
34
55
  def error_message(self) -> str:
56
+ """Return the error message for this exception."""
35
57
  return (
36
58
  f"The given model name '{self.model_name}' is not a valid model identifier listed on "
37
59
  "'https://huggingface.co/models'. "
@@ -1,6 +1,7 @@
1
+ """NLP embedding generators for text classification and summarization tasks."""
2
+
1
3
  import logging
2
4
  from functools import partial
3
- from typing import Optional, cast
4
5
 
5
6
  import pandas as pd
6
7
 
@@ -22,11 +23,19 @@ logger = logging.getLogger(__name__)
22
23
 
23
24
 
24
25
  class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
26
+ """Embedding generator for NLP sequence classification tasks."""
27
+
25
28
  def __init__(
26
29
  self,
27
30
  model_name: str = DEFAULT_NLP_SEQUENCE_CLASSIFICATION_MODEL,
28
- **kwargs,
29
- ):
31
+ **kwargs: object,
32
+ ) -> None:
33
+ """Initialize the sequence classification embedding generator.
34
+
35
+ Args:
36
+ model_name: Name of the pre-trained NLP model.
37
+ **kwargs: Additional arguments for model initialization.
38
+ """
30
39
  super().__init__(
31
40
  use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
32
41
  model_name=model_name,
@@ -36,15 +45,17 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
36
45
  def generate_embeddings(
37
46
  self,
38
47
  text_col: pd.Series,
39
- class_label_col: Optional[pd.Series] = None,
48
+ class_label_col: pd.Series | None = None,
40
49
  ) -> pd.Series:
41
- """
42
- Obtain embedding vectors from your text data using pre-trained large language models.
50
+ """Obtain embedding vectors from your text data using pre-trained large language models.
51
+
52
+ Args:
53
+ text_col: A pandas Series containing the different pieces of text.
54
+ class_label_col: If this column is passed, the sentence "The classification label
55
+ is <class_label>" will be appended to the text in the `text_col`.
43
56
 
44
- :param text_col: a pandas Series containing the different pieces of text.
45
- :param class_label_col: if this column is passed, the sentence "The classification label
46
- is <class_label>" will be appended to the text in the `text_col`.
47
- :return: a pandas Series containing the embedding vectors.
57
+ Returns:
58
+ A pandas Series containing the embedding vectors.
48
59
  """
49
60
  if not isinstance(text_col, pd.Series):
50
61
  raise TypeError("text_col must be a pandas Series")
@@ -72,13 +83,24 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
72
83
  batched=True,
73
84
  batch_size=self.batch_size,
74
85
  )
75
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
86
+ df: pd.DataFrame = ds.to_pandas()
87
+ return df["embedding_vector"]
76
88
 
77
89
 
78
90
  class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
91
+ """Embedding generator for NLP text summarization tasks."""
92
+
79
93
  def __init__(
80
- self, model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL, **kwargs
81
- ):
94
+ self,
95
+ model_name: str = DEFAULT_NLP_SUMMARIZATION_MODEL,
96
+ **kwargs: object,
97
+ ) -> None:
98
+ """Initialize the text summarization embedding generator.
99
+
100
+ Args:
101
+ model_name: Name of the pre-trained NLP model.
102
+ **kwargs: Additional arguments for model initialization.
103
+ """
82
104
  super().__init__(
83
105
  use_case=UseCases.NLP.SUMMARIZATION,
84
106
  model_name=model_name,
@@ -89,11 +111,13 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
89
111
  self,
90
112
  text_col: pd.Series,
91
113
  ) -> pd.Series:
92
- """
93
- Obtain embedding vectors from your text data using pre-trained large language models.
114
+ """Obtain embedding vectors from your text data using pre-trained large language models.
115
+
116
+ Args:
117
+ text_col: A pandas Series containing the different pieces of text.
94
118
 
95
- :param text_col: a pandas Series containing the different pieces of text.
96
- :return: a pandas Series containing the embedding vectors.
119
+ Returns:
120
+ A pandas Series containing the embedding vectors.
97
121
  """
98
122
  if not isinstance(text_col, pd.Series):
99
123
  raise TypeError("text_col must be a pandas Series")
@@ -108,4 +132,5 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
108
132
  batched=True,
109
133
  batch_size=self.batch_size,
110
134
  )
111
- return cast(pd.DataFrame, ds.to_pandas())["embedding_vector"]
135
+ df: pd.DataFrame = ds.to_pandas()
136
+ return df["embedding_vector"]