arize 8.0.0b2__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arize/__init__.py +8 -1
  2. arize/_exporter/client.py +18 -17
  3. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  4. arize/_exporter/validation.py +1 -1
  5. arize/_flight/client.py +33 -13
  6. arize/_lazy.py +37 -2
  7. arize/client.py +61 -35
  8. arize/config.py +168 -14
  9. arize/constants/config.py +1 -0
  10. arize/datasets/client.py +32 -19
  11. arize/embeddings/auto_generator.py +14 -7
  12. arize/embeddings/base_generators.py +15 -9
  13. arize/embeddings/cv_generators.py +2 -2
  14. arize/embeddings/nlp_generators.py +8 -8
  15. arize/embeddings/tabular_generators.py +5 -5
  16. arize/exceptions/config.py +22 -0
  17. arize/exceptions/parameters.py +1 -1
  18. arize/exceptions/values.py +8 -5
  19. arize/experiments/__init__.py +4 -0
  20. arize/experiments/client.py +17 -11
  21. arize/experiments/evaluators/base.py +6 -3
  22. arize/experiments/evaluators/executors.py +6 -4
  23. arize/experiments/evaluators/rate_limiters.py +3 -1
  24. arize/experiments/evaluators/types.py +7 -5
  25. arize/experiments/evaluators/utils.py +7 -5
  26. arize/experiments/functions.py +111 -48
  27. arize/experiments/tracing.py +4 -1
  28. arize/experiments/types.py +31 -26
  29. arize/logging.py +53 -32
  30. arize/ml/batch_validation/validator.py +82 -70
  31. arize/ml/bounded_executor.py +25 -6
  32. arize/ml/casting.py +45 -27
  33. arize/ml/client.py +35 -28
  34. arize/ml/proto.py +16 -17
  35. arize/ml/stream_validation.py +63 -25
  36. arize/ml/surrogate_explainer/mimic.py +15 -7
  37. arize/ml/types.py +26 -12
  38. arize/pre_releases.py +7 -6
  39. arize/py.typed +0 -0
  40. arize/regions.py +10 -10
  41. arize/spans/client.py +113 -21
  42. arize/spans/conversion.py +7 -5
  43. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  44. arize/spans/validation/annotations/value_validation.py +11 -14
  45. arize/spans/validation/common/dataframe_form_validation.py +1 -1
  46. arize/spans/validation/common/value_validation.py +10 -13
  47. arize/spans/validation/evals/value_validation.py +1 -1
  48. arize/spans/validation/metadata/argument_validation.py +1 -1
  49. arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
  50. arize/spans/validation/metadata/value_validation.py +23 -1
  51. arize/utils/arrow.py +37 -1
  52. arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
  53. arize/utils/proto.py +0 -1
  54. arize/utils/types.py +6 -6
  55. arize/version.py +1 -1
  56. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/METADATA +18 -3
  57. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/RECORD +60 -58
  58. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/WHEEL +0 -0
  59. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/LICENSE +0 -0
  60. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/NOTICE +0 -0
arize/config.py CHANGED
@@ -25,6 +25,7 @@ from arize.constants.config import (
25
25
  ENV_API_KEY,
26
26
  ENV_API_SCHEME,
27
27
  ENV_ARIZE_DIRECTORY,
28
+ ENV_BASE_DOMAIN,
28
29
  ENV_ENABLE_CACHING,
29
30
  ENV_FLIGHT_HOST,
30
31
  ENV_FLIGHT_PORT,
@@ -42,6 +43,7 @@ from arize.constants.config import (
42
43
  )
43
44
  from arize.constants.pyarrow import MAX_CHUNKSIZE
44
45
  from arize.exceptions.auth import MissingAPIKeyError
46
+ from arize.exceptions.config import MultipleEndpointOverridesError
45
47
  from arize.regions import REGION_ENDPOINTS, Region
46
48
  from arize.version import __version__
47
49
 
@@ -53,18 +55,44 @@ ALLOWED_HTTP_SCHEMES = {"http", "https"}
53
55
 
54
56
 
55
57
  def _is_sensitive_field(name: str) -> bool:
58
+ """Check if a field name contains sensitive information markers.
59
+
60
+ Args:
61
+ name: The field name to check.
62
+
63
+ Returns:
64
+ bool: True if the field name contains 'key', 'token', or 'secret' (case-insensitive).
65
+ """
56
66
  n = name.lower()
57
67
  return bool(any(k in n for k in SENSITIVE_FIELD_MARKERS))
58
68
 
59
69
 
60
70
  def _mask_secret(secret: str, N: int = 4) -> str:
61
- """Show first N chars then '***'; empty string if empty."""
71
+ """Mask a secret string by showing only the first N characters.
72
+
73
+ Args:
74
+ secret: The secret string to mask.
75
+ N: Number of characters to show before masking. Defaults to 4.
76
+
77
+ Returns:
78
+ str: The masked string (first N chars + '***'), or empty string if input is empty.
79
+ """
62
80
  if len(secret) == 0:
63
81
  return ""
64
82
  return f"{secret[:N]}***"
65
83
 
66
84
 
67
85
  def _endpoint(scheme: str, base: str, path: str = "") -> str:
86
+ """Construct a full endpoint URL from scheme, base, and optional path.
87
+
88
+ Args:
89
+ scheme: The URL scheme (e.g., "http", "https").
90
+ base: The base URL or hostname.
91
+ path: Optional path to append to the base URL. Defaults to empty string.
92
+
93
+ Returns:
94
+ str: The fully constructed endpoint URL.
95
+ """
68
96
  endpoint = scheme + "://" + base.rstrip("/")
69
97
  if path:
70
98
  endpoint += "/" + path.lstrip("/")
@@ -72,6 +100,18 @@ def _endpoint(scheme: str, base: str, path: str = "") -> str:
72
100
 
73
101
 
74
102
  def _env_http_scheme(name: str, default: str) -> str:
103
+ """Get an HTTP scheme from environment variable with validation.
104
+
105
+ Args:
106
+ name: The environment variable name.
107
+ default: The default value if the environment variable is not set.
108
+
109
+ Returns:
110
+ str: The validated HTTP scheme ('http' or 'https').
111
+
112
+ Raises:
113
+ ValueError: If the scheme is not 'http' or 'https'.
114
+ """
75
115
  v = _env_str(name, default).lower()
76
116
  if v not in ALLOWED_HTTP_SCHEMES:
77
117
  raise ValueError(
@@ -86,6 +126,20 @@ def _env_str(
86
126
  min_len: int | None = None,
87
127
  max_len: int | None = None,
88
128
  ) -> str:
129
+ """Get a string value from environment variable with length validation.
130
+
131
+ Args:
132
+ name: The environment variable name.
133
+ default: The default value if the environment variable is not set.
134
+ min_len: Optional minimum length constraint for the string.
135
+ max_len: Optional maximum length constraint for the string.
136
+
137
+ Returns:
138
+ str: The validated string value (stripped of whitespace).
139
+
140
+ Raises:
141
+ ValueError: If the string length violates min_len or max_len constraints.
142
+ """
89
143
  val = os.getenv(name, default).strip()
90
144
 
91
145
  if min_len is not None and len(val) < min_len:
@@ -107,6 +161,20 @@ def _env_int(
107
161
  min_val: int | None = None,
108
162
  max_val: int | None = None,
109
163
  ) -> int:
164
+ """Get an integer value from environment variable with range validation.
165
+
166
+ Args:
167
+ name: The environment variable name.
168
+ default: The default value if the environment variable is not set.
169
+ min_val: Optional minimum value constraint for the integer.
170
+ max_val: Optional maximum value constraint for the integer.
171
+
172
+ Returns:
173
+ int: The validated integer value.
174
+
175
+ Raises:
176
+ ValueError: If the value cannot be parsed as an integer or violates min_val/max_val constraints.
177
+ """
110
178
  raw = os.getenv(name, default)
111
179
  try:
112
180
  val = int(raw)
@@ -132,6 +200,20 @@ def _env_float(
132
200
  min_val: float | None = None,
133
201
  max_val: float | None = None,
134
202
  ) -> float:
203
+ """Get a float value from environment variable with range validation.
204
+
205
+ Args:
206
+ name: The environment variable name.
207
+ default: The default value if the environment variable is not set.
208
+ min_val: Optional minimum value constraint for the float.
209
+ max_val: Optional maximum value constraint for the float.
210
+
211
+ Returns:
212
+ float: The validated float value.
213
+
214
+ Raises:
215
+ ValueError: If the value cannot be parsed as a float or violates min_val/max_val constraints.
216
+ """
135
217
  raw = os.getenv(name, default)
136
218
  try:
137
219
  val = float(raw)
@@ -152,10 +234,28 @@ def _env_float(
152
234
 
153
235
 
154
236
  def _env_bool(name: str, default: bool) -> bool:
237
+ """Get a boolean value from environment variable.
238
+
239
+ Args:
240
+ name: The environment variable name.
241
+ default: The default boolean value if the environment variable is not set.
242
+
243
+ Returns:
244
+ bool: The parsed boolean value.
245
+ """
155
246
  return _parse_bool(os.getenv(name, str(default)))
156
247
 
157
248
 
158
249
  def _parse_bool(val: bool | str | None) -> bool:
250
+ """Parse a boolean value from various input types.
251
+
252
+ Args:
253
+ val: The value to parse. Can be a bool, string, or None.
254
+
255
+ Returns:
256
+ bool: True if the value is already True or matches one of the truthy strings
257
+ ('1', 'true', 'yes', 'on', case-insensitive). False otherwise.
258
+ """
159
259
  if isinstance(val, bool):
160
260
  return val
161
261
  return (val or "").strip().lower() in {"1", "true", "yes", "on"}
@@ -227,15 +327,27 @@ class SDKConfiguration:
227
327
  individual host/port settings.
228
328
  Environment variable: ARIZE_REGION.
229
329
  Default: :class:`Region.UNSET`.
230
- single_host: Single host to use for all endpoints. Overrides individual host settings.
330
+ single_host: Single host to use for all endpoints. When specified, overrides
331
+ individual host settings.
231
332
  Environment variable: ARIZE_SINGLE_HOST.
232
333
  Default: "" (not set).
233
- single_port: Single port to use for all endpoints. Overrides individual port settings (0-65535).
334
+ single_port: Single port to use for all endpoints. When specified, overrides
335
+ individual port settings (0-65535).
234
336
  Environment variable: ARIZE_SINGLE_PORT.
235
337
  Default: 0 (not set).
338
+ base_domain: Base domain for generating all endpoint hosts. Intended for Private Connect
339
+ setups. When specified, generates hosts as api.<base_domain>, otlp.<base_domain>,
340
+ flight.<base_domain>. When specified, overrides individual host settings.
341
+ Environment variable: ARIZE_BASE_DOMAIN.
342
+ Default: "" (not set).
343
+
344
+ Note:
345
+ The endpoint override options (region, single_host/single_port, base_domain) are
346
+ mutually exclusive. Specifying more than one will raise MultipleEndpointOverridesError.
236
347
 
237
348
  Raises:
238
349
  MissingAPIKeyError: If api_key is not provided via argument or environment variable.
350
+ MultipleEndpointOverridesError: If multiple endpoint override options are provided.
239
351
  """
240
352
 
241
353
  api_key: str = field(
@@ -326,27 +438,73 @@ class SDKConfiguration:
326
438
  ENV_SINGLE_PORT, 0, min_val=0, max_val=65535
327
439
  )
328
440
  )
441
+ base_domain: str = field(
442
+ default_factory=lambda: _env_str(ENV_BASE_DOMAIN, "")
443
+ )
329
444
 
330
445
  def __post_init__(self) -> None:
331
446
  """Validate and configure SDK endpoints after initialization.
332
447
 
448
+ Endpoint override options are mutually exclusive. Only one of the following
449
+ can be specified:
450
+ 1. region - Overrides all via REGION_ENDPOINTS mapping
451
+ 2. single_host/single_port - Overrides individual hosts/ports
452
+ 3. base_domain - Generates hosts from base domain
453
+
454
+ If none are specified, per-endpoint host/port settings are used.
455
+
333
456
  Raises:
334
- MissingAPIKeyError: If api_key is not provided via argument or environment variable.
457
+ MissingAPIKeyError: If api_key is not provided.
458
+ MultipleEndpointOverridesError: If multiple endpoint override options are provided.
335
459
  """
336
- # Validate Configuration
460
+ # Validate configuration
337
461
  if not self.api_key:
338
462
  raise MissingAPIKeyError()
339
463
 
464
+ # Check which override options are set
465
+ has_base_domain = bool(self.base_domain)
340
466
  has_single_host = bool(self.single_host)
341
467
  has_single_port = self.single_port != 0
342
468
  has_region = self.region is not Region.UNSET
343
- if (has_single_host or has_single_port) and has_region:
469
+
470
+ # Ensure only one override method is used (mutually exclusive)
471
+ override_count = sum(
472
+ [has_base_domain, has_single_host or has_single_port, has_region]
473
+ )
474
+ if override_count > 1:
475
+ # Determine which overrides were provided
476
+ provided_overrides = []
477
+ if has_region:
478
+ provided_overrides.append(f"region={self.region.value}")
479
+ if has_single_host or has_single_port:
480
+ if has_single_host:
481
+ provided_overrides.append(
482
+ f"single_host={self.single_host!r}"
483
+ )
484
+ if has_single_port:
485
+ provided_overrides.append(f"single_port={self.single_port}")
486
+ if has_base_domain:
487
+ provided_overrides.append(f"base_domain={self.base_domain!r}")
488
+
489
+ error_message = (
490
+ f"Multiple endpoint override options provided: {', '.join(provided_overrides)}. "
491
+ "Only one of the following can be specified: 'region', "
492
+ "'single_host'/'single_port', or 'base_domain'."
493
+ )
494
+ logger.error(error_message)
495
+ raise MultipleEndpointOverridesError(error_message)
496
+
497
+ if has_base_domain:
344
498
  logger.info(
345
- "Multiple endpoint override options provided. Preference order is: "
346
- "region > single_host/single_port > per-endpoint host/port."
499
+ "Base domain %r provided; generating hosts from base domain.",
500
+ self.base_domain,
501
+ )
502
+ object.__setattr__(self, "api_host", f"api.{self.base_domain}")
503
+ object.__setattr__(self, "otlp_host", f"otlp.{self.base_domain}")
504
+ object.__setattr__(
505
+ self, "flight_host", f"flight.{self.base_domain}"
347
506
  )
348
507
 
349
- # Single host override: if single_host is set, it overrides hosts
350
508
  if has_single_host:
351
509
  logger.info(
352
510
  "Single host %r provided; overriding hosts configuration with single host.",
@@ -356,7 +514,6 @@ class SDKConfiguration:
356
514
  object.__setattr__(self, "otlp_host", self.single_host)
357
515
  object.__setattr__(self, "flight_host", self.single_host)
358
516
 
359
- # Single port override: if single_port is set, it overrides ports
360
517
  if has_single_port:
361
518
  logger.info(
362
519
  "Single port %s provided; overriding ports configuration with single port.",
@@ -364,15 +521,12 @@ class SDKConfiguration:
364
521
  )
365
522
  object.__setattr__(self, "flight_port", self.single_port)
366
523
 
367
- # Region override: if region is set, it *always* wins over host/port fields
368
524
  if has_region:
369
- endpoints = REGION_ENDPOINTS[self.region]
370
-
371
- # Override config (region trumps everything)
372
525
  logger.info(
373
526
  "Region %s provided; overriding hosts & ports configuration with region defaults.",
374
527
  self.region.value,
375
528
  )
529
+ endpoints = REGION_ENDPOINTS[self.region]
376
530
  object.__setattr__(self, "api_host", endpoints.api_host)
377
531
  object.__setattr__(self, "otlp_host", endpoints.otlp_host)
378
532
  object.__setattr__(self, "flight_host", endpoints.flight_host)
arize/constants/config.py CHANGED
@@ -14,6 +14,7 @@ ENV_FLIGHT_PORT = "ARIZE_FLIGHT_PORT"
14
14
  ENV_FLIGHT_SCHEME = "ARIZE_FLIGHT_SCHEME"
15
15
  ENV_SINGLE_HOST = "ARIZE_SINGLE_HOST"
16
16
  ENV_SINGLE_PORT = "ARIZE_SINGLE_PORT"
17
+ ENV_BASE_DOMAIN = "ARIZE_BASE_DOMAIN"
17
18
  ENV_PYARROW_MAX_CHUNKSIZE = "ARIZE_MAX_CHUNKSIZE"
18
19
  ENV_REQUEST_VERIFY = "ARIZE_REQUEST_VERIFY"
19
20
  ENV_MAX_HTTP_PAYLOAD_SIZE_MB = "ARIZE_MAX_HTTP_PAYLOAD_SIZE_MB"
arize/datasets/client.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import logging
6
6
  import time
7
7
  import uuid
8
- from typing import TYPE_CHECKING
8
+ from typing import TYPE_CHECKING, Any, cast
9
9
 
10
10
  import pandas as pd
11
11
  import pyarrow as pa
@@ -24,6 +24,10 @@ from arize.utils.openinference_conversion import (
24
24
  from arize.utils.size import get_payload_size_mb
25
25
 
26
26
  if TYPE_CHECKING:
27
+ # builtins is needed to use builtins.list in type annotations because
28
+ # the class has a list() method that shadows the built-in list type
29
+ import builtins
30
+
27
31
  from arize._generated.api_client.api_client import ApiClient
28
32
  from arize.config import SDKConfiguration
29
33
 
@@ -97,7 +101,7 @@ class DatasetsClient:
97
101
  *,
98
102
  name: str,
99
103
  space_id: str,
100
- examples: list[dict[str, object]] | pd.DataFrame,
104
+ examples: builtins.list[dict[str, object]] | pd.DataFrame,
101
105
  force_http: bool = False,
102
106
  ) -> models.Dataset:
103
107
  """Create a dataset with JSON examples.
@@ -150,7 +154,7 @@ class DatasetsClient:
150
154
  from arize._generated import api_client as gen
151
155
 
152
156
  data = (
153
- examples.to_dict(orient="records") # type: ignore
157
+ examples.to_dict(orient="records")
154
158
  if isinstance(examples, pd.DataFrame)
155
159
  else examples
156
160
  )
@@ -158,7 +162,8 @@ class DatasetsClient:
158
162
  body = gen.DatasetsCreateRequest(
159
163
  name=name,
160
164
  space_id=space_id,
161
- examples=data,
165
+ # Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
166
+ examples=cast("list[dict[str, Any]]", data),
162
167
  )
163
168
  return self._api.datasets_create(datasets_create_request=body)
164
169
 
@@ -169,15 +174,12 @@ class DatasetsClient:
169
174
  "Trying to convert to DataFrame for more efficient upload via "
170
175
  "gRPC + Flight."
171
176
  )
172
- data = (
173
- examples
174
- if isinstance(examples, pd.DataFrame)
175
- else pd.DataFrame(examples)
176
- )
177
+ if not isinstance(examples, pd.DataFrame):
178
+ examples = pd.DataFrame(examples)
177
179
  return self._create_dataset_via_flight(
178
180
  name=name,
179
181
  space_id=space_id,
180
- examples=data,
182
+ examples=examples,
181
183
  )
182
184
 
183
185
  @prerelease_endpoint(key="datasets.get", stage=ReleaseStage.BETA)
@@ -280,7 +282,11 @@ class DatasetsClient:
280
282
  )
281
283
  if dataset_df is not None:
282
284
  return models.DatasetsExamplesList200Response(
283
- examples=dataset_df.to_dict(orient="records"), # type: ignore
285
+ # Cast: Pydantic validates and converts dicts to DatasetExample at runtime
286
+ examples=cast(
287
+ "list[models.DatasetExample]",
288
+ dataset_df.to_dict(orient="records"),
289
+ ),
284
290
  pagination=models.PaginationMetadata(
285
291
  has_more=False, # Note that all=True
286
292
  ),
@@ -321,7 +327,11 @@ class DatasetsClient:
321
327
  )
322
328
 
323
329
  return models.DatasetsExamplesList200Response(
324
- examples=dataset_df.to_dict(orient="records"), # type: ignore
330
+ # Cast: Pydantic validates and converts dicts to DatasetExample at runtime
331
+ examples=cast(
332
+ "list[models.DatasetExample]",
333
+ dataset_df.to_dict(orient="records"),
334
+ ),
325
335
  pagination=models.PaginationMetadata(
326
336
  has_more=False, # Note that all=True
327
337
  ),
@@ -336,7 +346,7 @@ class DatasetsClient:
336
346
  *,
337
347
  dataset_id: str,
338
348
  dataset_version_id: str = "",
339
- examples: list[dict[str, object]] | pd.DataFrame,
349
+ examples: builtins.list[dict[str, object]] | pd.DataFrame,
340
350
  ) -> models.Dataset:
341
351
  """Append new examples to an existing dataset.
342
352
 
@@ -377,11 +387,14 @@ class DatasetsClient:
377
387
  )
378
388
 
379
389
  data = (
380
- examples.to_dict(orient="records") # type: ignore
390
+ examples.to_dict(orient="records")
381
391
  if isinstance(examples, pd.DataFrame)
382
392
  else examples
383
393
  )
384
- body = gen.DatasetsExamplesInsertRequest(examples=data)
394
+ # Cast: pandas to_dict returns dict[Hashable, Any] but API requires dict[str, Any]
395
+ body = gen.DatasetsExamplesInsertRequest(
396
+ examples=cast("list[dict[str, Any]]", data)
397
+ )
385
398
 
386
399
  return self._api.datasets_examples_insert(
387
400
  dataset_id=dataset_id,
@@ -394,7 +407,7 @@ class DatasetsClient:
394
407
  name: str,
395
408
  space_id: str,
396
409
  examples: pd.DataFrame,
397
- ) -> object:
410
+ ) -> models.Dataset:
398
411
  """Internal method to create a dataset using Flight protocol for large example sets."""
399
412
  data = examples.copy()
400
413
  # Convert datetime columns to int64 (ms since epoch)
@@ -454,19 +467,19 @@ def _set_default_columns_for_dataset(df: pd.DataFrame) -> pd.DataFrame:
454
467
  """Set default values for created_at and updated_at columns if missing or null."""
455
468
  current_time = int(time.time() * 1000)
456
469
  if "created_at" in df.columns:
457
- if df["created_at"].isnull().values.any(): # type: ignore
470
+ if df["created_at"].isnull().any():
458
471
  df["created_at"].fillna(current_time, inplace=True)
459
472
  else:
460
473
  df["created_at"] = current_time
461
474
 
462
475
  if "updated_at" in df.columns:
463
- if df["updated_at"].isnull().values.any(): # type: ignore
476
+ if df["updated_at"].isnull().any():
464
477
  df["updated_at"].fillna(current_time, inplace=True)
465
478
  else:
466
479
  df["updated_at"] = current_time
467
480
 
468
481
  if "id" in df.columns:
469
- if df["id"].isnull().values.any(): # type: ignore
482
+ if df["id"].isnull().any():
470
483
  df["id"] = df["id"].apply(
471
484
  lambda x: str(uuid.uuid4()) if pd.isnull(x) else x
472
485
  )
@@ -1,5 +1,7 @@
1
1
  """Automatic embedding generation factory for various ML use cases."""
2
2
 
3
+ from typing import TypeAlias
4
+
3
5
  import pandas as pd
4
6
 
5
7
  from arize.embeddings import constants
@@ -24,9 +26,14 @@ from arize.embeddings.nlp_generators import (
24
26
  from arize.embeddings.tabular_generators import (
25
27
  EmbeddingGeneratorForTabularFeatures,
26
28
  )
27
- from arize.embeddings.usecases import UseCases
29
+ from arize.embeddings.usecases import (
30
+ CVUseCases,
31
+ NLPUseCases,
32
+ TabularUseCases,
33
+ UseCases,
34
+ )
28
35
 
29
- UseCaseLike = str | UseCases.NLP | UseCases.CV | UseCases.STRUCTURED
36
+ UseCaseLike: TypeAlias = str | NLPUseCases | CVUseCases | TabularUseCases
30
37
 
31
38
 
32
39
  class EmbeddingGenerator:
@@ -49,15 +56,15 @@ class EmbeddingGenerator:
49
56
  ) -> BaseEmbeddingGenerator:
50
57
  """Create an embedding generator for the specified use case."""
51
58
  if use_case == UseCases.NLP.SEQUENCE_CLASSIFICATION:
52
- return EmbeddingGeneratorForNLPSequenceClassification(**kwargs)
59
+ return EmbeddingGeneratorForNLPSequenceClassification(**kwargs) # type: ignore[arg-type]
53
60
  if use_case == UseCases.NLP.SUMMARIZATION:
54
- return EmbeddingGeneratorForNLPSummarization(**kwargs)
61
+ return EmbeddingGeneratorForNLPSummarization(**kwargs) # type: ignore[arg-type]
55
62
  if use_case == UseCases.CV.IMAGE_CLASSIFICATION:
56
- return EmbeddingGeneratorForCVImageClassification(**kwargs)
63
+ return EmbeddingGeneratorForCVImageClassification(**kwargs) # type: ignore[arg-type]
57
64
  if use_case == UseCases.CV.OBJECT_DETECTION:
58
- return EmbeddingGeneratorForCVObjectDetection(**kwargs)
65
+ return EmbeddingGeneratorForCVObjectDetection(**kwargs) # type: ignore[arg-type]
59
66
  if use_case == UseCases.STRUCTURED.TABULAR_EMBEDDINGS:
60
- return EmbeddingGeneratorForTabularFeatures(**kwargs)
67
+ return EmbeddingGeneratorForTabularFeatures(**kwargs) # type: ignore[arg-type]
61
68
  raise ValueError(f"Invalid use case {use_case}")
62
69
 
63
70
  @classmethod
@@ -14,11 +14,15 @@ try:
14
14
  import torch
15
15
  from datasets import Dataset
16
16
  from PIL import Image
17
- from transformers import ( # type: ignore
17
+ from transformers import (
18
18
  AutoImageProcessor,
19
19
  AutoModel,
20
20
  AutoTokenizer,
21
+ BaseImageProcessor,
21
22
  BatchEncoding,
23
+ BatchFeature,
24
+ PreTrainedModel,
25
+ PreTrainedTokenizerBase,
22
26
  )
23
27
  from transformers.utils import logging as transformer_logging
24
28
  except ImportError as e:
@@ -67,7 +71,9 @@ class BaseEmbeddingGenerator(ABC):
67
71
  raise
68
72
 
69
73
  @abstractmethod
70
- def generate_embeddings(self, **kwargs: object) -> pd.Series:
74
+ def generate_embeddings(
75
+ self, **kwargs: object
76
+ ) -> pd.Series | tuple[pd.Series, pd.Series]:
71
77
  """Generate embeddings for the input data."""
72
78
  ...
73
79
 
@@ -95,7 +101,7 @@ class BaseEmbeddingGenerator(ABC):
95
101
  return self.__model_name
96
102
 
97
103
  @property
98
- def model(self) -> object:
104
+ def model(self) -> PreTrainedModel:
99
105
  """Return the underlying model instance."""
100
106
  return self.__model
101
107
 
@@ -183,7 +189,7 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
183
189
  tokenizer_max_length: Maximum sequence length for the tokenizer.
184
190
  **kwargs: Additional arguments for model initialization.
185
191
  """
186
- super().__init__(use_case=use_case, model_name=model_name, **kwargs)
192
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs) # type: ignore[arg-type]
187
193
  self.__tokenizer_max_length = tokenizer_max_length
188
194
  # We don't check for the tokenizer's existence since it is coupled with the corresponding model
189
195
  # We check the model's existence in `BaseEmbeddingGenerator`
@@ -193,7 +199,7 @@ class NLPEmbeddingGenerator(BaseEmbeddingGenerator):
193
199
  )
194
200
 
195
201
  @property
196
- def tokenizer(self) -> object:
202
+ def tokenizer(self) -> PreTrainedTokenizerBase:
197
203
  """Return the tokenizer instance for text processing."""
198
204
  return self.__tokenizer
199
205
 
@@ -240,7 +246,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
240
246
  model_name: Name of the pre-trained vision model.
241
247
  **kwargs: Additional arguments for model initialization.
242
248
  """
243
- super().__init__(use_case=use_case, model_name=model_name, **kwargs)
249
+ super().__init__(use_case=use_case, model_name=model_name, **kwargs) # type: ignore[arg-type]
244
250
  logger.info("Downloading image processor")
245
251
  # We don't check for the image processor's existence since it is coupled with the corresponding model
246
252
  # We check the model's existence in `BaseEmbeddingGenerator`
@@ -249,7 +255,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
249
255
  )
250
256
 
251
257
  @property
252
- def image_processor(self) -> object:
258
+ def image_processor(self) -> BaseImageProcessor:
253
259
  """Return the image processor instance for image preprocessing."""
254
260
  return self.__image_processor
255
261
 
@@ -262,7 +268,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
262
268
 
263
269
  def preprocess_image(
264
270
  self, batch: dict[str, list[str]], local_image_feat_name: str
265
- ) -> object:
271
+ ) -> BatchFeature:
266
272
  """Preprocess a batch of images for model input."""
267
273
  return self.image_processor(
268
274
  [
@@ -272,7 +278,7 @@ class CVEmbeddingGenerator(BaseEmbeddingGenerator):
272
278
  return_tensors="pt",
273
279
  ).to(self.device)
274
280
 
275
- def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series:
281
+ def generate_embeddings(self, local_image_path_col: pd.Series) -> pd.Series: # type: ignore[override]
276
282
  """Obtain embedding vectors from your image data using pre-trained image models.
277
283
 
278
284
  :param local_image_path_col: a pandas Series containing the local path to the images to
@@ -25,7 +25,7 @@ class EmbeddingGeneratorForCVImageClassification(CVEmbeddingGenerator):
25
25
  super().__init__(
26
26
  use_case=UseCases.CV.IMAGE_CLASSIFICATION,
27
27
  model_name=model_name,
28
- **kwargs,
28
+ **kwargs, # type: ignore[arg-type]
29
29
  )
30
30
 
31
31
 
@@ -46,5 +46,5 @@ class EmbeddingGeneratorForCVObjectDetection(CVEmbeddingGenerator):
46
46
  super().__init__(
47
47
  use_case=UseCases.CV.OBJECT_DETECTION,
48
48
  model_name=model_name,
49
- **kwargs,
49
+ **kwargs, # type: ignore[arg-type]
50
50
  )
@@ -39,10 +39,10 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
39
39
  super().__init__(
40
40
  use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
41
41
  model_name=model_name,
42
- **kwargs,
42
+ **kwargs, # type: ignore[arg-type]
43
43
  )
44
44
 
45
- def generate_embeddings(
45
+ def generate_embeddings( # type: ignore[override]
46
46
  self,
47
47
  text_col: pd.Series,
48
48
  class_label_col: pd.Series | None = None,
@@ -65,10 +65,10 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
65
65
  if class_label_col is not None:
66
66
  if not isinstance(class_label_col, pd.Series):
67
67
  raise TypeError("class_label_col must be a pandas Series")
68
- df = pd.concat(
68
+ temp_df = pd.concat(
69
69
  {"text": text_col, "class_label": class_label_col}, axis=1
70
70
  )
71
- prepared_text_col = df.apply(
71
+ prepared_text_col = temp_df.apply(
72
72
  lambda row: f" The classification label is {row['class_label']}. {row['text']}",
73
73
  axis=1,
74
74
  )
@@ -83,8 +83,8 @@ class EmbeddingGeneratorForNLPSequenceClassification(NLPEmbeddingGenerator):
83
83
  batched=True,
84
84
  batch_size=self.batch_size,
85
85
  )
86
- df: pd.DataFrame = ds.to_pandas()
87
- return df["embedding_vector"]
86
+ result_df: pd.DataFrame = ds.to_pandas()
87
+ return result_df["embedding_vector"]
88
88
 
89
89
 
90
90
  class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
@@ -104,10 +104,10 @@ class EmbeddingGeneratorForNLPSummarization(NLPEmbeddingGenerator):
104
104
  super().__init__(
105
105
  use_case=UseCases.NLP.SUMMARIZATION,
106
106
  model_name=model_name,
107
- **kwargs,
107
+ **kwargs, # type: ignore[arg-type]
108
108
  )
109
109
 
110
- def generate_embeddings(
110
+ def generate_embeddings( # type: ignore[override]
111
111
  self,
112
112
  text_col: pd.Series,
113
113
  ) -> pd.Series: