arize 8.0.0b2__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arize/__init__.py +8 -1
  2. arize/_exporter/client.py +18 -17
  3. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  4. arize/_exporter/validation.py +1 -1
  5. arize/_flight/client.py +33 -13
  6. arize/_lazy.py +37 -2
  7. arize/client.py +61 -35
  8. arize/config.py +168 -14
  9. arize/constants/config.py +1 -0
  10. arize/datasets/client.py +32 -19
  11. arize/embeddings/auto_generator.py +14 -7
  12. arize/embeddings/base_generators.py +15 -9
  13. arize/embeddings/cv_generators.py +2 -2
  14. arize/embeddings/nlp_generators.py +8 -8
  15. arize/embeddings/tabular_generators.py +5 -5
  16. arize/exceptions/config.py +22 -0
  17. arize/exceptions/parameters.py +1 -1
  18. arize/exceptions/values.py +8 -5
  19. arize/experiments/__init__.py +4 -0
  20. arize/experiments/client.py +17 -11
  21. arize/experiments/evaluators/base.py +6 -3
  22. arize/experiments/evaluators/executors.py +6 -4
  23. arize/experiments/evaluators/rate_limiters.py +3 -1
  24. arize/experiments/evaluators/types.py +7 -5
  25. arize/experiments/evaluators/utils.py +7 -5
  26. arize/experiments/functions.py +111 -48
  27. arize/experiments/tracing.py +4 -1
  28. arize/experiments/types.py +31 -26
  29. arize/logging.py +53 -32
  30. arize/ml/batch_validation/validator.py +82 -70
  31. arize/ml/bounded_executor.py +25 -6
  32. arize/ml/casting.py +45 -27
  33. arize/ml/client.py +35 -28
  34. arize/ml/proto.py +16 -17
  35. arize/ml/stream_validation.py +63 -25
  36. arize/ml/surrogate_explainer/mimic.py +15 -7
  37. arize/ml/types.py +26 -12
  38. arize/pre_releases.py +7 -6
  39. arize/py.typed +0 -0
  40. arize/regions.py +10 -10
  41. arize/spans/client.py +113 -21
  42. arize/spans/conversion.py +7 -5
  43. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  44. arize/spans/validation/annotations/value_validation.py +11 -14
  45. arize/spans/validation/common/dataframe_form_validation.py +1 -1
  46. arize/spans/validation/common/value_validation.py +10 -13
  47. arize/spans/validation/evals/value_validation.py +1 -1
  48. arize/spans/validation/metadata/argument_validation.py +1 -1
  49. arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
  50. arize/spans/validation/metadata/value_validation.py +23 -1
  51. arize/utils/arrow.py +37 -1
  52. arize/utils/online_tasks/dataframe_preprocessor.py +8 -4
  53. arize/utils/proto.py +0 -1
  54. arize/utils/types.py +6 -6
  55. arize/version.py +1 -1
  56. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/METADATA +18 -3
  57. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/RECORD +60 -58
  58. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/WHEEL +0 -0
  59. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/LICENSE +0 -0
  60. {arize-8.0.0b2.dist-info → arize-8.0.1.dist-info}/licenses/NOTICE +0 -0
@@ -6,7 +6,10 @@ import logging
6
6
  import math
7
7
  from datetime import datetime, timedelta, timezone
8
8
  from itertools import chain
9
- from typing import Any
9
+ from typing import TYPE_CHECKING, Any, cast
10
+
11
+ if TYPE_CHECKING:
12
+ from collections.abc import Sequence
10
13
 
11
14
  import numpy as np
12
15
  import pandas as pd
@@ -115,6 +118,7 @@ from arize.ml.types import (
115
118
  ModelTypes,
116
119
  PromptTemplateColumnNames,
117
120
  Schema,
121
+ _normalize_column_names,
118
122
  segments_intersect,
119
123
  )
120
124
  from arize.utils.types import (
@@ -412,25 +416,25 @@ class Validator:
412
416
  if isinstance(schema, Schema):
413
417
  general_checks = chain(
414
418
  general_checks,
415
- Validator._check_value_timestamp(dataframe, schema),
416
- Validator._check_id_field_str_length(
419
+ Validator._check_value_timestamp(dataframe, schema), # type: ignore[arg-type]
420
+ Validator._check_id_field_str_length( # type: ignore[arg-type]
417
421
  dataframe,
418
422
  "prediction_id_column_name",
419
423
  schema.prediction_id_column_name,
420
424
  ),
421
- Validator._check_embedding_vectors_dimensionality(
425
+ Validator._check_embedding_vectors_dimensionality( # type: ignore[arg-type]
422
426
  dataframe, schema
423
427
  ),
424
- Validator._check_embedding_raw_data_characters(
428
+ Validator._check_embedding_raw_data_characters( # type: ignore[arg-type]
425
429
  dataframe, schema
426
430
  ),
427
- Validator._check_invalid_record_prod(
431
+ Validator._check_invalid_record_prod( # type: ignore[arg-type]
428
432
  dataframe, environment, schema, model_type
429
433
  ),
430
- Validator._check_invalid_record_preprod(
434
+ Validator._check_invalid_record_preprod( # type: ignore[arg-type]
431
435
  dataframe, environment, schema, model_type
432
436
  ),
433
- Validator._check_value_tag(dataframe, schema),
437
+ Validator._check_value_tag(dataframe, schema), # type: ignore[arg-type]
434
438
  )
435
439
  if model_type == ModelTypes.RANKING:
436
440
  r_checks = chain(
@@ -555,7 +559,7 @@ class Validator:
555
559
  def _check_field_type_prompt_response(
556
560
  schema: Schema,
557
561
  ) -> list[InvalidFieldTypePromptResponse]:
558
- errors = []
562
+ errors: list[InvalidFieldTypePromptResponse] = []
559
563
  if schema.prompt_column_names is not None and not isinstance(
560
564
  schema.prompt_column_names, (str, EmbeddingColumnNames)
561
565
  ):
@@ -679,7 +683,7 @@ class Validator:
679
683
  schema: Schema,
680
684
  required_columns_map: list[dict[str, Any]],
681
685
  ) -> tuple[bool, list[str], list[list[str]]]:
682
- missing_columns = []
686
+ missing_columns: list[str] = []
683
687
  for item in required_columns_map:
684
688
  if model_type.name.lower() == item.get("external_model_type"):
685
689
  is_valid_combination = False
@@ -793,7 +797,9 @@ class Validator:
793
797
  missing_columns.extend(
794
798
  [
795
799
  col
796
- for col in schema.feature_column_names
800
+ for col in _normalize_column_names(
801
+ schema.feature_column_names
802
+ )
797
803
  if col not in existing_columns
798
804
  ]
799
805
  )
@@ -828,7 +834,7 @@ class Validator:
828
834
  missing_columns.extend(
829
835
  [
830
836
  col
831
- for col in schema.tag_column_names
837
+ for col in _normalize_column_names(schema.tag_column_names)
832
838
  if col not in existing_columns
833
839
  ]
834
840
  )
@@ -1051,22 +1057,19 @@ class Validator:
1051
1057
  invalid_column_names = set()
1052
1058
 
1053
1059
  if schema.feature_column_names is not None:
1054
- for col in schema.feature_column_names:
1060
+ for col in _normalize_column_names(schema.feature_column_names):
1055
1061
  if isinstance(col, str) and col.endswith("_shap"):
1056
1062
  invalid_column_names.add(col)
1057
1063
 
1058
1064
  if schema.embedding_feature_column_names is not None:
1059
1065
  for emb_col_names in schema.embedding_feature_column_names.values():
1060
- for col in emb_col_names:
1061
- if (
1062
- col is not None
1063
- and isinstance(col, str)
1064
- and col.endswith("_shap")
1065
- ):
1066
+ cols_list = [c for c in emb_col_names if c is not None]
1067
+ for col in cols_list:
1068
+ if col.endswith("_shap"):
1066
1069
  invalid_column_names.add(col)
1067
1070
 
1068
1071
  if schema.tag_column_names is not None:
1069
- for col in schema.tag_column_names:
1072
+ for col in _normalize_column_names(schema.tag_column_names):
1070
1073
  if isinstance(col, str) and col.endswith("_shap"):
1071
1074
  invalid_column_names.add(col)
1072
1075
 
@@ -1396,7 +1399,7 @@ class Validator:
1396
1399
  return [
1397
1400
  InvalidPredActColumnNamesForModelType(
1398
1401
  model_type,
1399
- None,
1402
+ None, # type: ignore[arg-type]
1400
1403
  [schema.multi_class_threshold_scores_column_name],
1401
1404
  )
1402
1405
  ]
@@ -1448,7 +1451,9 @@ class Validator:
1448
1451
  ]
1449
1452
  return [
1450
1453
  InvalidPredActColumnNamesForModelType(
1451
- model_type, allowed_cols, wrong_cols
1454
+ model_type,
1455
+ allowed_cols,
1456
+ wrong_cols, # type: ignore[arg-type]
1452
1457
  )
1453
1458
  ]
1454
1459
  return []
@@ -1589,7 +1594,7 @@ class Validator:
1589
1594
  )
1590
1595
  wrong_type_cols = [
1591
1596
  col
1592
- for col in schema.feature_column_names
1597
+ for col in _normalize_column_names(schema.feature_column_names)
1593
1598
  if col in column_types
1594
1599
  and column_types[col] not in allowed_datatypes
1595
1600
  ]
@@ -1703,7 +1708,7 @@ class Validator:
1703
1708
  )
1704
1709
  wrong_type_cols = [
1705
1710
  col
1706
- for col in schema.tag_column_names
1711
+ for col in _normalize_column_names(schema.tag_column_names)
1707
1712
  if col in column_types
1708
1713
  and column_types[col] not in allowed_datatypes
1709
1714
  ]
@@ -1750,6 +1755,7 @@ class Validator:
1750
1755
  ("Prediction labels", schema.prediction_label_column_name),
1751
1756
  ("Actual labels", schema.actual_label_column_name),
1752
1757
  )
1758
+ allowed_datatypes: tuple[Any, ...]
1753
1759
  if (
1754
1760
  model_type in CATEGORICAL_MODEL_TYPES
1755
1761
  or model_type == ModelTypes.GENERATIVE_LLM
@@ -2071,10 +2077,8 @@ class Validator:
2071
2077
  )
2072
2078
  wrong_type_cols = []
2073
2079
  if schema.tag_column_names:
2074
- if (
2075
- LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME
2076
- in schema.tag_column_names
2077
- ) and (
2080
+ tag_cols = _normalize_column_names(schema.tag_column_names)
2081
+ if (LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME in tag_cols) and (
2078
2082
  LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME in column_types
2079
2083
  and column_types[LLM_RUN_METADATA_TOTAL_TOKEN_COUNT_TAG_NAME]
2080
2084
  not in allowed_datatypes
@@ -2082,10 +2086,7 @@ class Validator:
2082
2086
  wrong_type_cols.append(
2083
2087
  schema.llm_run_metadata_column_names.total_token_count_column_name
2084
2088
  )
2085
- if (
2086
- LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME
2087
- in schema.tag_column_names
2088
- ) and (
2089
+ if (LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME in tag_cols) and (
2089
2090
  LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME in column_types
2090
2091
  and column_types[LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME]
2091
2092
  not in allowed_datatypes
@@ -2094,8 +2095,7 @@ class Validator:
2094
2095
  schema.llm_run_metadata_column_names.prompt_token_count_column_name
2095
2096
  )
2096
2097
  if (
2097
- LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME
2098
- in schema.tag_column_names
2098
+ LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME in tag_cols
2099
2099
  ) and (
2100
2100
  LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME in column_types
2101
2101
  and column_types[LLM_RUN_METADATA_RESPONSE_TOKEN_COUNT_TAG_NAME]
@@ -2104,10 +2104,7 @@ class Validator:
2104
2104
  wrong_type_cols.append(
2105
2105
  schema.llm_run_metadata_column_names.response_token_count_column_name
2106
2106
  )
2107
- if (
2108
- LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME
2109
- in schema.tag_column_names
2110
- ) and (
2107
+ if (LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME in tag_cols) and (
2111
2108
  LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME in column_types
2112
2109
  and column_types[LLM_RUN_METADATA_RESPONSE_LATENCY_MS_TAG_NAME]
2113
2110
  not in allowed_datatypes
@@ -2120,7 +2117,7 @@ class Validator:
2120
2117
  if wrong_type_cols:
2121
2118
  return [
2122
2119
  InvalidTypeColumns(
2123
- wrong_type_columns=wrong_type_cols,
2120
+ wrong_type_columns=wrong_type_cols, # type: ignore[arg-type]
2124
2121
  expected_types=["int", "float"],
2125
2122
  )
2126
2123
  ]
@@ -2525,7 +2522,7 @@ class Validator:
2525
2522
  and len(dataframe)
2526
2523
  ):
2527
2524
  return True
2528
- return (
2525
+ return bool(
2529
2526
  dataframe[col_name]
2530
2527
  .astype(str)
2531
2528
  .str.len()
@@ -2542,15 +2539,15 @@ class Validator:
2542
2539
 
2543
2540
  wrong_tag_cols = []
2544
2541
  truncated_tag_cols = []
2545
- for col in schema.tag_column_names:
2542
+ for col in _normalize_column_names(schema.tag_column_names):
2546
2543
  # This is to be defensive, validate_params should guarantee that this column is in
2547
2544
  # the dataframe, via _check_missing_columns, and return an error before reaching this
2548
2545
  # block if not
2549
2546
  # Checks max tag length when any values in a column are strings
2550
2547
  if (
2551
2548
  col in dataframe.columns
2552
- and dataframe[col].map(type).eq(str).any()
2553
- ): # type:ignore
2549
+ and dataframe[col].map(type).eq(str).any() # type: ignore[arg-type]
2550
+ ):
2554
2551
  max_tag_len = (
2555
2552
  dataframe[col]
2556
2553
  .apply(_check_value_string_length_helper)
@@ -2574,6 +2571,7 @@ class Validator:
2574
2571
  def _check_value_ranking_category(
2575
2572
  dataframe: pd.DataFrame, schema: Schema
2576
2573
  ) -> list[InvalidValueMissingValue | InvalidRankingCategoryValue]:
2574
+ col: str | None
2577
2575
  if schema.relevance_labels_column_name is not None:
2578
2576
  col = schema.relevance_labels_column_name
2579
2577
  elif schema.attributions_column_name is not None:
@@ -2581,7 +2579,7 @@ class Validator:
2581
2579
  else:
2582
2580
  col = schema.actual_label_column_name
2583
2581
  if col is not None and col in dataframe.columns:
2584
- if dataframe[col].isnull().values.any(): # type: ignore
2582
+ if dataframe[col].isnull().any():
2585
2583
  # do not attach duplicated missing value error
2586
2584
  # which would be caught by _check_value_missing
2587
2585
  return []
@@ -2661,7 +2659,11 @@ class Validator:
2661
2659
  invalid_pred_scores = {}
2662
2660
  lbound, ubound = (0, 1)
2663
2661
  invalid_actual_scores = False
2664
- errors = []
2662
+ errors: list[
2663
+ InvalidMultiClassClassNameLength
2664
+ | InvalidMultiClassActScoreValue
2665
+ | InvalidMultiClassPredScoreValue
2666
+ ] = []
2665
2667
  for col in cols:
2666
2668
  if (
2667
2669
  col is None
@@ -2711,9 +2713,9 @@ class Validator:
2711
2713
  if invalid_class_names:
2712
2714
  errors.append(InvalidMultiClassClassNameLength(invalid_class_names))
2713
2715
  if invalid_pred_scores:
2714
- errors.append(InvalidMultiClassPredScoreValue(invalid_pred_scores))
2716
+ errors.append(InvalidMultiClassPredScoreValue(invalid_pred_scores)) # type: ignore[arg-type]
2715
2717
  if invalid_actual_scores:
2716
- errors.append(InvalidMultiClassActScoreValue(col))
2718
+ errors.append(InvalidMultiClassActScoreValue(col)) # type: ignore[arg-type, arg-type]
2717
2719
  return errors
2718
2720
 
2719
2721
  @staticmethod
@@ -2762,7 +2764,7 @@ class Validator:
2762
2764
  # When a timestamp column has Date and NaN, pyarrow will be fine, but
2763
2765
  # pandas min/max will fail due to type incompatibility. So we check for
2764
2766
  # missing value first.
2765
- if dataframe[col].isnull().values.any(): # type: ignore
2767
+ if dataframe[col].isnull().any():
2766
2768
  return [
2767
2769
  InvalidValueMissingValue("Prediction timestamp", "missing")
2768
2770
  ]
@@ -2859,7 +2861,7 @@ class Validator:
2859
2861
  dataframe: pd.DataFrame, schema: BaseSchema, model_type: ModelTypes
2860
2862
  ) -> list[InvalidValueMissingValue]:
2861
2863
  errors = []
2862
- columns = ()
2864
+ columns: tuple[tuple[str, str | None], ...] = ()
2863
2865
  if isinstance(schema, CorpusSchema):
2864
2866
  columns = (("Document ID", schema.document_id_column_name),)
2865
2867
  elif isinstance(schema, Schema):
@@ -3018,7 +3020,7 @@ class Validator:
3018
3020
  null_index = null_filter[null_filter].index.values
3019
3021
  if len(null_index) == 0:
3020
3022
  return []
3021
- return [InvalidRecord(columns_subset, null_index)] # type: ignore
3023
+ return [InvalidRecord(columns_subset, null_index)] # type: ignore[arg-type]
3022
3024
 
3023
3025
  @staticmethod
3024
3026
  def _check_type_prediction_group_id(
@@ -3070,6 +3072,7 @@ class Validator:
3070
3072
  def _check_type_ranking_category(
3071
3073
  schema: Schema, column_types: dict[str, Any]
3072
3074
  ) -> list[InvalidType]:
3075
+ col: str | None
3073
3076
  if schema.relevance_labels_column_name is not None:
3074
3077
  col = schema.relevance_labels_column_name
3075
3078
  elif schema.attributions_column_name is not None:
@@ -3316,7 +3319,7 @@ class Validator:
3316
3319
  dataframe, vector_cols_to_check
3317
3320
  )
3318
3321
 
3319
- errors = []
3322
+ errors: list[ValidationError] = []
3320
3323
  if invalid_long_string_data_cols:
3321
3324
  errors.append(
3322
3325
  InvalidValueEmbeddingRawDataTooLong(
@@ -3325,7 +3328,7 @@ class Validator:
3325
3328
  )
3326
3329
  if invalid_low_dim_vector_cols or invalid_high_dim_vector_cols:
3327
3330
  errors.append(
3328
- InvalidValueEmbeddingVectorDimensionality(
3331
+ InvalidValueEmbeddingVectorDimensionality( # type: ignore[arg-type]
3329
3332
  invalid_low_dim_vector_cols,
3330
3333
  invalid_high_dim_vector_cols,
3331
3334
  )
@@ -3433,6 +3436,7 @@ class Validator:
3433
3436
  schema: CorpusSchema, column_types: dict[str, Any]
3434
3437
  ) -> list[InvalidTypeColumns]:
3435
3438
  invalid_types = []
3439
+ allowed_datatypes: tuple[Any, ...]
3436
3440
  # Check document id
3437
3441
  col = schema.document_id_column_name
3438
3442
  if col in column_types:
@@ -3577,7 +3581,8 @@ def _check_value_bounding_boxes_coordinates_helper(
3577
3581
  # 'NoneType is not iterable')
3578
3582
  if boxes is None:
3579
3583
  raise InvalidBoundingBoxesCoordinates(reason="none_boxes")
3580
- for box in boxes:
3584
+ # Type ignore: boxes comes from pandas Series, validated at runtime to be iterable
3585
+ for box in boxes: # type: ignore[attr-defined]
3581
3586
  if box is None or len(box) == 0:
3582
3587
  raise InvalidBoundingBoxesCoordinates(
3583
3588
  reason="none_or_empty_box"
@@ -3598,13 +3603,14 @@ def _box_coordinates_wrong_format(
3598
3603
  ) -> InvalidBoundingBoxesCoordinates | None:
3599
3604
  if (
3600
3605
  # Coordinates should be a collection of 4 floats
3601
- len(box_coords) != 4
3606
+ len(box_coords) != 4 # type: ignore[arg-type]
3602
3607
  # Coordinates should be positive
3603
- or any(k < 0 for k in box_coords)
3608
+ # Type ignore: box_coords validated at runtime to be iterable/indexable
3609
+ or any(k < 0 for k in box_coords) # type: ignore[attr-defined]
3604
3610
  # Coordinates represent the top-left & bottom-right corners of a box: x1 < x2
3605
- or box_coords[0] >= box_coords[2]
3611
+ or box_coords[0] >= box_coords[2] # type: ignore[index]
3606
3612
  # Coordinates represent the top-left & bottom-right corners of a box: y1 < y2
3607
- or box_coords[1] >= box_coords[3]
3613
+ or box_coords[1] >= box_coords[3] # type: ignore[index]
3608
3614
  ):
3609
3615
  return InvalidBoundingBoxesCoordinates(
3610
3616
  reason="boxes_coordinates_wrong_format"
@@ -3620,7 +3626,8 @@ def _check_value_bounding_boxes_categories_helper(
3620
3626
  # 'NoneType is not iterable')
3621
3627
  if categories is None:
3622
3628
  raise InvalidBoundingBoxesCategories(reason="none_category_list")
3623
- for category in categories:
3629
+ # Type ignore: categories validated at runtime to be iterable
3630
+ for category in categories: # type: ignore[attr-defined]
3624
3631
  # Allow for empty string category, no None values
3625
3632
  if category is None:
3626
3633
  raise InvalidBoundingBoxesCategories(reason="none_category")
@@ -3640,7 +3647,8 @@ def _check_value_bounding_boxes_scores_helper(
3640
3647
  # 'NoneType is not iterable')
3641
3648
  if scores is None:
3642
3649
  raise InvalidBoundingBoxesScores(reason="none_score_list")
3643
- for score in scores:
3650
+ # Type ignore: scores validated at runtime to be iterable
3651
+ for score in scores: # type: ignore[attr-defined]
3644
3652
  # Confidence scores are between 0 and 1
3645
3653
  if score < 0 or score > 1:
3646
3654
  raise InvalidBoundingBoxesScores(reason="scores_out_of_bounds")
@@ -3673,21 +3681,22 @@ def _polygon_coordinates_wrong_format(
3673
3681
  # Basic validations
3674
3682
  if (
3675
3683
  # Coordinates should be a collection of more than 6 floats (3 pairs of x,y coordinates)
3676
- len(polygon_coords) < 6
3684
+ len(polygon_coords) < 6 # type: ignore[arg-type]
3677
3685
  # Coordinates should be positive
3678
- or any(k < 0 for k in polygon_coords)
3686
+ # Type ignore: polygon_coords validated at runtime to be iterable
3687
+ or any(k < 0 for k in polygon_coords) # type: ignore[arg-type, attr-defined]
3679
3688
  # Coordinates should be a collection of pairs of floats
3680
- or len(polygon_coords) % 2 != 0
3689
+ or len(polygon_coords) % 2 != 0 # type: ignore[arg-type]
3681
3690
  ):
3682
3691
  return InvalidPolygonCoordinates(
3683
3692
  reason="polygon_coordinates_wrong_format",
3684
- coordinates=polygon_coords,
3693
+ coordinates=polygon_coords, # type: ignore[arg-type]
3685
3694
  )
3686
3695
 
3687
3696
  # Convert flat list to list of points [(x1,y1), (x2,y2), ...]
3697
+ coords_seq = cast("Sequence[float]", polygon_coords)
3688
3698
  points = [
3689
- (polygon_coords[i], polygon_coords[i + 1])
3690
- for i in range(0, len(polygon_coords), 2)
3699
+ (coords_seq[i], coords_seq[i + 1]) for i in range(0, len(coords_seq), 2)
3691
3700
  ]
3692
3701
 
3693
3702
  # Check for repeated vertices
@@ -3696,7 +3705,7 @@ def _polygon_coordinates_wrong_format(
3696
3705
  if points[i] == points[j]:
3697
3706
  return InvalidPolygonCoordinates(
3698
3707
  reason="polygon_coordinates_repeated_vertices",
3699
- coordinates=polygon_coords,
3708
+ coordinates=polygon_coords, # type: ignore[arg-type]
3700
3709
  )
3701
3710
 
3702
3711
  # Check for self-intersections
@@ -3717,7 +3726,7 @@ def _polygon_coordinates_wrong_format(
3717
3726
  ):
3718
3727
  return InvalidPolygonCoordinates(
3719
3728
  reason="polygon_coordinates_self_intersecting_vertices",
3720
- coordinates=polygon_coords,
3729
+ coordinates=polygon_coords, # type: ignore[arg-type]
3721
3730
  )
3722
3731
 
3723
3732
  return None
@@ -3731,7 +3740,8 @@ def _check_value_polygon_coordinates_helper(
3731
3740
  # 'NoneType is not iterable')
3732
3741
  if polygons is None:
3733
3742
  raise InvalidPolygonCoordinates(reason="none_polygons")
3734
- for polygon in polygons:
3743
+ # Type ignore: polygons validated at runtime to be iterable
3744
+ for polygon in polygons: # type: ignore[attr-defined]
3735
3745
  if polygon is None or len(polygon) == 0:
3736
3746
  raise InvalidPolygonCoordinates(reason="none_or_empty_polygon")
3737
3747
  error = _polygon_coordinates_wrong_format(polygon)
@@ -3753,7 +3763,8 @@ def _check_value_polygon_categories_helper(
3753
3763
  # 'NoneType is not iterable')
3754
3764
  if categories is None:
3755
3765
  raise InvalidPolygonCategories(reason="none_category_list")
3756
- for category in categories:
3766
+ # Type ignore: categories validated at runtime to be iterable
3767
+ for category in categories: # type: ignore[attr-defined]
3757
3768
  # Allow for empty string category, no None values
3758
3769
  if category is None:
3759
3770
  raise InvalidPolygonCategories(reason="none_category")
@@ -3773,7 +3784,8 @@ def _check_value_polygon_scores_helper(
3773
3784
  # 'NoneType is not iterable')
3774
3785
  if scores is None:
3775
3786
  raise InvalidPolygonScores(reason="none_score_list")
3776
- for score in scores:
3787
+ # Type ignore: scores validated at runtime to be iterable
3788
+ for score in scores: # type: ignore[attr-defined]
3777
3789
  # Confidence scores are between 0 and 1
3778
3790
  if score < 0 or score > 1:
3779
3791
  raise InvalidPolygonScores(reason="scores_out_of_bounds")
@@ -24,12 +24,26 @@ class BoundedExecutor:
24
24
  self.executor = ThreadPoolExecutor(max_workers=max_workers)
25
25
  self.semaphore = BoundedSemaphore(bound + max_workers)
26
26
 
27
- """See concurrent.futures.Executor#submit"""
28
-
29
27
  def submit(
30
28
  self, fn: Callable[..., object], *args: object, **kwargs: object
31
29
  ) -> object:
32
- """Submit a callable to be executed with bounded concurrency."""
30
+ """Submit a callable to be executed with bounded concurrency.
31
+
32
+ This method blocks if the work queue is full (at the bound limit) until
33
+ space becomes available. Compatible with concurrent.futures.Executor.submit().
34
+
35
+ Args:
36
+ fn: The callable to execute.
37
+ *args: Positional arguments to pass to the callable.
38
+ **kwargs: Keyword arguments to pass to the callable.
39
+
40
+ Returns:
41
+ concurrent.futures.Future: A Future representing the pending execution.
42
+
43
+ Raises:
44
+ Exception: Any exception raised during submission is re-raised after
45
+ releasing the semaphore.
46
+ """
33
47
  self.semaphore.acquire()
34
48
  try:
35
49
  future = self.executor.submit(fn, *args, **kwargs)
@@ -40,8 +54,13 @@ class BoundedExecutor:
40
54
  future.add_done_callback(lambda _: self.semaphore.release())
41
55
  return future
42
56
 
43
- """See concurrent.futures.Executor#shutdown"""
44
-
45
57
  def shutdown(self, wait: bool = True) -> None:
46
- """Shutdown the executor, optionally waiting for pending tasks to complete."""
58
+ """Shutdown the executor, optionally waiting for pending tasks to complete.
59
+
60
+ Compatible with concurrent.futures.Executor.shutdown().
61
+
62
+ Args:
63
+ wait: If True, blocks until all pending tasks complete. If False,
64
+ returns immediately without waiting. Defaults to True.
65
+ """
47
66
  self.executor.shutdown(wait)
arize/ml/casting.py CHANGED
@@ -1,4 +1,3 @@
1
- # type: ignore[pb2]
2
1
  """Type casting utilities for ML model data conversion."""
3
2
 
4
3
  from __future__ import annotations
@@ -14,8 +13,8 @@ from arize.ml.types import (
14
13
  Schema,
15
14
  TypedColumns,
16
15
  TypedValue,
16
+ _normalize_column_names,
17
17
  )
18
- from arize.utils.types import is_list_of
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  import pandas as pd
@@ -25,7 +24,11 @@ class CastingError(Exception):
25
24
  """Raised when type casting fails for a value."""
26
25
 
27
26
  def __str__(self) -> str:
28
- """Return a human-readable error message."""
27
+ """Return a human-readable error message.
28
+
29
+ Returns:
30
+ str: The formatted error message describing the casting failure.
31
+ """
29
32
  return self.error_message()
30
33
 
31
34
  def __init__(self, error_msg: str, typed_value: TypedValue) -> None:
@@ -39,7 +42,11 @@ class CastingError(Exception):
39
42
  self.typed_value = typed_value
40
43
 
41
44
  def error_message(self) -> str:
42
- """Return the error message for this exception."""
45
+ """Return the error message for this exception.
46
+
47
+ Returns:
48
+ str: Detailed error message including the value, its type, target type, and failure reason.
49
+ """
43
50
  return (
44
51
  f"Failed to cast value {self.typed_value.value} of type {type(self.typed_value.value)} "
45
52
  f"to type {self.typed_value.type}. "
@@ -51,14 +58,18 @@ class ColumnCastingError(Exception):
51
58
  """Raised when type casting fails for a column."""
52
59
 
53
60
  def __str__(self) -> str:
54
- """Return a human-readable error message."""
61
+ """Return a human-readable error message.
62
+
63
+ Returns:
64
+ str: The formatted error message describing the column casting failure.
65
+ """
55
66
  return self.error_message()
56
67
 
57
68
  def __init__(
58
69
  self,
59
70
  error_msg: str,
60
- attempted_columns: str,
61
- attempted_type: TypedColumns,
71
+ attempted_columns: list[str],
72
+ attempted_type: str,
62
73
  ) -> None:
63
74
  """Initialize the exception with column casting context.
64
75
 
@@ -72,7 +83,11 @@ class ColumnCastingError(Exception):
72
83
  self.attempted_casting_type = attempted_type
73
84
 
74
85
  def error_message(self) -> str:
75
- """Return the error message for this exception."""
86
+ """Return the error message for this exception.
87
+
88
+ Returns:
89
+ str: Detailed error message including the target type, affected columns, and failure reason.
90
+ """
76
91
  return (
77
92
  f"Failed to cast to type {self.attempted_casting_type} "
78
93
  f"for columns: {log_a_list(self.attempted_casting_columns, 'and')}. "
@@ -84,7 +99,11 @@ class InvalidTypedColumnsError(Exception):
84
99
  """Raised when typed columns are invalid or incorrectly specified."""
85
100
 
86
101
  def __str__(self) -> str:
87
- """Return a human-readable error message."""
102
+ """Return a human-readable error message.
103
+
104
+ Returns:
105
+ str: The formatted error message describing the invalid typed columns.
106
+ """
88
107
  return self.error_message()
89
108
 
90
109
  def __init__(self, field_name: str, reason: str) -> None:
@@ -98,7 +117,11 @@ class InvalidTypedColumnsError(Exception):
98
117
  self.reason = reason
99
118
 
100
119
  def error_message(self) -> str:
101
- """Return the error message for this exception."""
120
+ """Return the error message for this exception.
121
+
122
+ Returns:
123
+ str: Error message describing which field has invalid typed columns and why.
124
+ """
102
125
  return f"The {self.field_name} TypedColumns object {self.reason}."
103
126
 
104
127
 
@@ -106,7 +129,11 @@ class InvalidSchemaFieldTypeError(Exception):
106
129
  """Raised when schema field has invalid or unexpected type."""
107
130
 
108
131
  def __str__(self) -> str:
109
- """Return a human-readable error message."""
132
+ """Return a human-readable error message.
133
+
134
+ Returns:
135
+ str: The formatted error message describing the invalid schema field type.
136
+ """
110
137
  return self.error_message()
111
138
 
112
139
  def __init__(self, msg: str) -> None:
@@ -118,7 +145,11 @@ class InvalidSchemaFieldTypeError(Exception):
118
145
  self.msg = msg
119
146
 
120
147
  def error_message(self) -> str:
121
- """Return the error message for this exception."""
148
+ """Return the error message for this exception.
149
+
150
+ Returns:
151
+ str: The error message describing the schema field type issue.
152
+ """
122
153
  return self.msg
123
154
 
124
155
 
@@ -381,23 +412,10 @@ def _convert_schema_field_types(
381
412
  Schema: A Schema, with feature and tag column names converted to the
382
413
  List[string] format expected in downstream validation.
383
414
  """
384
- feature_column_names_list = (
415
+ feature_column_names_list = _normalize_column_names(
385
416
  schema.feature_column_names
386
- if is_list_of(schema.feature_column_names, str)
387
- else (
388
- schema.feature_column_names.get_all_column_names()
389
- if schema.feature_column_names
390
- else []
391
- )
392
- )
393
-
394
- tag_column_names_list = (
395
- schema.tag_column_names
396
- if is_list_of(schema.tag_column_names, str)
397
- else schema.tag_column_names.get_all_column_names()
398
- if schema.tag_column_names
399
- else []
400
417
  )
418
+ tag_column_names_list = _normalize_column_names(schema.tag_column_names)
401
419
 
402
420
  schema_dict = {
403
421
  "feature_column_names": feature_column_names_list,