arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,12 +1,15 @@
1
+ """Batch validation logic for ML model predictions and actuals."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- import datetime
4
5
  import logging
5
6
  import math
7
+ from datetime import datetime, timedelta, timezone
6
8
  from itertools import chain
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
8
10
 
9
11
  import numpy as np
12
+ import pandas as pd
10
13
  import pyarrow as pa
11
14
 
12
15
  from arize.constants.ml import (
@@ -37,8 +40,8 @@ from arize.constants.ml import (
37
40
  MODEL_MAPPING_CONFIG,
38
41
  )
39
42
  from arize.logging import get_truncation_warning_message
40
- from arize.models.batch_validation import errors as err
41
- from arize.types import (
43
+ from arize.ml.batch_validation import errors as err
44
+ from arize.ml.types import (
42
45
  CATEGORICAL_MODEL_TYPES,
43
46
  NUMERIC_MODEL_TYPES,
44
47
  BaseSchema,
@@ -50,28 +53,29 @@ from arize.types import (
50
53
  ModelTypes,
51
54
  PromptTemplateColumnNames,
52
55
  Schema,
56
+ segments_intersect,
57
+ )
58
+ from arize.utils.types import (
53
59
  is_dict_of,
54
60
  is_iterable_of,
55
- segments_intersect,
56
61
  )
57
62
 
58
- if TYPE_CHECKING:
59
- import pandas as pd
60
-
61
-
62
63
  logger = logging.getLogger(__name__)
63
64
 
64
65
 
65
66
  class Validator:
67
+ """Validator for batch data with schema and dataframe validation methods."""
68
+
66
69
  @staticmethod
67
70
  def validate_required_checks(
68
71
  dataframe: pd.DataFrame,
69
72
  model_id: str,
70
73
  environment: Environments,
71
74
  schema: BaseSchema,
72
- model_version: Optional[str] = None,
73
- batch_id: Optional[str] = None,
74
- ) -> List[err.ValidationError]:
75
+ model_version: str | None = None,
76
+ batch_id: str | None = None,
77
+ ) -> list[err.ValidationError]:
78
+ """Validate required checks for schema, environment, and DataFrame structure."""
75
79
  general_checks = chain(
76
80
  Validator._check_valid_schema_type(schema, environment),
77
81
  Validator._check_field_convertible_to_str(
@@ -87,7 +91,7 @@ class Validator:
87
91
  schema, CorpusSchema
88
92
  ):
89
93
  return list(general_checks)
90
- elif isinstance(schema, Schema):
94
+ if isinstance(schema, Schema):
91
95
  return list(
92
96
  chain(
93
97
  general_checks,
@@ -108,10 +112,11 @@ class Validator:
108
112
  model_type: ModelTypes,
109
113
  environment: Environments,
110
114
  schema: BaseSchema,
111
- metric_families: Optional[List[Metrics]] = None,
112
- model_version: Optional[str] = None,
113
- batch_id: Optional[str] = None,
114
- ) -> List[err.ValidationError]:
115
+ metric_families: list[Metrics] | None = None,
116
+ model_version: str | None = None,
117
+ batch_id: str | None = None,
118
+ ) -> list[err.ValidationError]:
119
+ """Validate parameters including model type, environment, and schema consistency."""
115
120
  # general checks
116
121
  general_checks = chain(
117
122
  Validator._check_column_names_for_empty_strings(schema),
@@ -125,7 +130,7 @@ class Validator:
125
130
  )
126
131
  if isinstance(schema, CorpusSchema):
127
132
  return list(general_checks)
128
- elif isinstance(schema, Schema):
133
+ if isinstance(schema, Schema):
129
134
  general_checks = chain(
130
135
  general_checks,
131
136
  Validator._check_existence_prediction_id_column_delayed_schema(
@@ -153,7 +158,7 @@ class Validator:
153
158
  ),
154
159
  )
155
160
  return list(chain(general_checks, num_checks))
156
- elif model_type in CATEGORICAL_MODEL_TYPES:
161
+ if model_type in CATEGORICAL_MODEL_TYPES:
157
162
  sc_checks = chain(
158
163
  Validator._check_existence_preprod_pred_act_score_or_label(
159
164
  schema, environment
@@ -166,7 +171,7 @@ class Validator:
166
171
  ),
167
172
  )
168
173
  return list(chain(general_checks, sc_checks))
169
- elif model_type == ModelTypes.GENERATIVE_LLM:
174
+ if model_type == ModelTypes.GENERATIVE_LLM:
170
175
  gllm_checks = chain(
171
176
  Validator._check_existence_preprod_act(schema, environment),
172
177
  Validator._check_missing_object_detection_columns(
@@ -177,7 +182,7 @@ class Validator:
177
182
  ),
178
183
  )
179
184
  return list(chain(general_checks, gllm_checks))
180
- elif model_type == ModelTypes.RANKING:
185
+ if model_type == ModelTypes.RANKING:
181
186
  r_checks = chain(
182
187
  Validator._check_existence_group_id_rank_category_relevance(
183
188
  schema
@@ -190,7 +195,7 @@ class Validator:
190
195
  ),
191
196
  )
192
197
  return list(chain(general_checks, r_checks))
193
- elif model_type == ModelTypes.OBJECT_DETECTION:
198
+ if model_type == ModelTypes.OBJECT_DETECTION:
194
199
  od_checks = chain(
195
200
  Validator._check_exactly_one_cv_column_type(
196
201
  schema, environment
@@ -203,7 +208,7 @@ class Validator:
203
208
  ),
204
209
  )
205
210
  return list(chain(general_checks, od_checks))
206
- elif model_type == ModelTypes.MULTI_CLASS:
211
+ if model_type == ModelTypes.MULTI_CLASS:
207
212
  multi_class_checks = chain(
208
213
  Validator._check_existing_multi_class_columns(schema),
209
214
  Validator._check_missing_non_multi_class_columns(
@@ -218,8 +223,11 @@ class Validator:
218
223
  model_type: ModelTypes,
219
224
  schema: BaseSchema,
220
225
  pyarrow_schema: pa.Schema,
221
- ) -> List[err.ValidationError]:
222
- column_types = dict(zip(pyarrow_schema.names, pyarrow_schema.types))
226
+ ) -> list[err.ValidationError]:
227
+ """Validate column data types against expected types for the schema."""
228
+ column_types = dict(
229
+ zip(pyarrow_schema.names, pyarrow_schema.types, strict=True)
230
+ )
223
231
 
224
232
  if isinstance(schema, CorpusSchema):
225
233
  return list(
@@ -227,7 +235,7 @@ class Validator:
227
235
  Validator._check_type_document_columns(schema, column_types)
228
236
  )
229
237
  )
230
- elif isinstance(schema, Schema):
238
+ if isinstance(schema, Schema):
231
239
  general_checks = chain(
232
240
  Validator._check_type_prediction_id(schema, column_types),
233
241
  Validator._check_type_timestamp(schema, column_types),
@@ -271,7 +279,7 @@ class Validator:
271
279
  ),
272
280
  )
273
281
  return list(chain(general_checks, gllm_checks))
274
- elif model_type == ModelTypes.RANKING:
282
+ if model_type == ModelTypes.RANKING:
275
283
  r_checks = chain(
276
284
  Validator._check_type_prediction_group_id(
277
285
  schema, column_types
@@ -285,7 +293,7 @@ class Validator:
285
293
  ),
286
294
  )
287
295
  return list(chain(general_checks, r_checks))
288
- elif model_type == ModelTypes.OBJECT_DETECTION:
296
+ if model_type == ModelTypes.OBJECT_DETECTION:
289
297
  od_checks = chain(
290
298
  Validator._check_type_image_segment_coordinates(
291
299
  schema, column_types
@@ -298,7 +306,7 @@ class Validator:
298
306
  ),
299
307
  )
300
308
  return list(chain(general_checks, od_checks))
301
- elif model_type == ModelTypes.MULTI_CLASS:
309
+ if model_type == ModelTypes.MULTI_CLASS:
302
310
  multi_class_checks = chain(
303
311
  Validator._check_type_multi_class_pred_threshold_act_scores(
304
312
  schema, column_types
@@ -315,7 +323,8 @@ class Validator:
315
323
  environment: Environments,
316
324
  schema: BaseSchema,
317
325
  model_type: ModelTypes,
318
- ) -> List[err.ValidationError]:
326
+ ) -> list[err.ValidationError]:
327
+ """Validate data values including ranges, formats, and consistency checks."""
319
328
  # ASSUMPTION: at this point the param and type checks should have passed.
320
329
  # This function may crash if that is not true, e.g. if columns are missing
321
330
  # or are of the wrong types.
@@ -338,7 +347,7 @@ class Validator:
338
347
  ),
339
348
  )
340
349
  )
341
- elif isinstance(schema, Schema):
350
+ if isinstance(schema, Schema):
342
351
  general_checks = chain(
343
352
  general_checks,
344
353
  Validator._check_value_timestamp(dataframe, schema),
@@ -429,21 +438,21 @@ class Validator:
429
438
  return list(general_checks)
430
439
  return []
431
440
 
432
- # ----------------------
433
- # Minimum requred checks
434
- # ----------------------
441
+ # -----------------------
442
+ # Minimum required checks
443
+ # -----------------------
435
444
  @staticmethod
436
445
  def _check_column_names_for_empty_strings(
437
446
  schema: BaseSchema,
438
- ) -> List[err.InvalidColumnNameEmptyString]:
447
+ ) -> list[err.InvalidColumnNameEmptyString]:
439
448
  if "" in schema.get_used_columns():
440
449
  return [err.InvalidColumnNameEmptyString()]
441
450
  return []
442
451
 
443
452
  @staticmethod
444
453
  def _check_field_convertible_to_str(
445
- model_id, model_version, batch_id
446
- ) -> List[err.InvalidFieldTypeConversion]:
454
+ model_id: object, model_version: object, batch_id: object
455
+ ) -> list[err.InvalidFieldTypeConversion]:
447
456
  # converting to a set first makes the checks run a lot faster
448
457
  wrong_fields = []
449
458
  if model_id is not None and not isinstance(model_id, str):
@@ -469,7 +478,7 @@ class Validator:
469
478
  @staticmethod
470
479
  def _check_field_type_embedding_features_column_names(
471
480
  schema: Schema,
472
- ) -> List[err.InvalidFieldTypeEmbeddingFeatures]:
481
+ ) -> list[err.InvalidFieldTypeEmbeddingFeatures]:
473
482
  if schema.embedding_feature_column_names is not None:
474
483
  if not isinstance(schema.embedding_feature_column_names, dict):
475
484
  return [err.InvalidFieldTypeEmbeddingFeatures()]
@@ -483,7 +492,7 @@ class Validator:
483
492
  @staticmethod
484
493
  def _check_field_type_prompt_response(
485
494
  schema: Schema,
486
- ) -> List[err.InvalidFieldTypePromptResponse]:
495
+ ) -> list[err.InvalidFieldTypePromptResponse]:
487
496
  errors = []
488
497
  if schema.prompt_column_names is not None and not isinstance(
489
498
  schema.prompt_column_names, (str, EmbeddingColumnNames)
@@ -502,7 +511,7 @@ class Validator:
502
511
  @staticmethod
503
512
  def _check_field_type_prompt_templates(
504
513
  schema: Schema,
505
- ) -> List[err.InvalidFieldTypePromptTemplates]:
514
+ ) -> list[err.InvalidFieldTypePromptTemplates]:
506
515
  if schema.prompt_template_column_names is not None and not isinstance(
507
516
  schema.prompt_template_column_names, PromptTemplateColumnNames
508
517
  ):
@@ -513,7 +522,7 @@ class Validator:
513
522
  def _check_field_type_llm_config(
514
523
  dataframe: pd.DataFrame,
515
524
  schema: Schema,
516
- ) -> List[Union[err.InvalidFieldTypeLlmConfig, err.InvalidTypeColumns]]:
525
+ ) -> list[err.InvalidFieldTypeLlmConfig | err.InvalidTypeColumns]:
517
526
  if schema.llm_config_column_names is None:
518
527
  return []
519
528
  if not isinstance(schema.llm_config_column_names, LLMConfigColumnNames):
@@ -548,7 +557,7 @@ class Validator:
548
557
  @staticmethod
549
558
  def _check_invalid_index(
550
559
  dataframe: pd.DataFrame,
551
- ) -> List[err.InvalidDataFrameIndex]:
560
+ ) -> list[err.InvalidDataFrameIndex]:
552
561
  if (dataframe.index != dataframe.reset_index(drop=True).index).any():
553
562
  return [err.InvalidDataFrameIndex()]
554
563
  return []
@@ -560,9 +569,9 @@ class Validator:
560
569
  @staticmethod
561
570
  def _check_model_type_and_metrics(
562
571
  model_type: ModelTypes,
563
- metric_families: Optional[List[Metrics]],
572
+ metric_families: list[Metrics] | None,
564
573
  schema: Schema,
565
- ) -> List[err.ValidationError]:
574
+ ) -> list[err.ValidationError]:
566
575
  if metric_families is None:
567
576
  return []
568
577
 
@@ -606,10 +615,10 @@ class Validator:
606
615
  @staticmethod
607
616
  def _check_model_mapping_combinations(
608
617
  model_type: ModelTypes,
609
- metric_families: List[Metrics],
618
+ metric_families: list[Metrics],
610
619
  schema: Schema,
611
- required_columns_map: List[Dict[str, Any]],
612
- ) -> Tuple[bool, List[str], List[List[str]]]:
620
+ required_columns_map: list[dict[str, Any]],
621
+ ) -> tuple[bool, list[str], list[list[str]]]:
613
622
  missing_columns = []
614
623
  for item in required_columns_map:
615
624
  if model_type.name.lower() == item.get("external_model_type"):
@@ -625,10 +634,10 @@ class Validator:
625
634
  metric_combinations.append(
626
635
  [metric.upper() for metric in metrics_list]
627
636
  )
628
- if set(metrics_list) == set(
637
+ if set(metrics_list) == {
629
638
  metric_family.name.lower()
630
639
  for metric_family in metric_families
631
- ):
640
+ }:
632
641
  # This is a valid combination of model type + metrics.
633
642
  # Now validate that required columns are in the schema.
634
643
  is_valid_combination = True
@@ -665,10 +674,10 @@ class Validator:
665
674
  @staticmethod
666
675
  def _check_existence_prediction_id_column_delayed_schema(
667
676
  schema: Schema, model_type: ModelTypes
668
- ) -> List[err.MissingPredictionIdColumnForDelayedRecords]:
677
+ ) -> list[err.MissingPredictionIdColumnForDelayedRecords]:
669
678
  if schema.prediction_id_column_name is not None:
670
679
  return []
671
- # TODO: Revise logic once predicion_label column addition (for generative models)
680
+ # TODO: Revise logic once prediction_label column addition (for generative models)
672
681
  # is moved to beginning of log function
673
682
  if schema.is_delayed() and model_type is not ModelTypes.GENERATIVE_LLM:
674
683
  # We skip GENERATIVE model types since they are assigned a default
@@ -696,12 +705,12 @@ class Validator:
696
705
  def _check_missing_columns(
697
706
  dataframe: pd.DataFrame,
698
707
  schema: BaseSchema,
699
- ) -> List[err.MissingColumns]:
708
+ ) -> list[err.MissingColumns]:
700
709
  if isinstance(schema, CorpusSchema):
701
710
  return Validator._check_missing_columns_corpus_schema(
702
711
  dataframe, schema
703
712
  )
704
- elif isinstance(schema, Schema):
713
+ if isinstance(schema, Schema):
705
714
  return Validator._check_missing_columns_schema(dataframe, schema)
706
715
  return []
707
716
 
@@ -709,7 +718,7 @@ class Validator:
709
718
  def _check_missing_columns_schema(
710
719
  dataframe: pd.DataFrame,
711
720
  schema: Schema,
712
- ) -> List[err.MissingColumns]:
721
+ ) -> list[err.MissingColumns]:
713
722
  # converting to a set first makes the checks run a lot faster
714
723
  existing_columns = set(dataframe.columns)
715
724
  missing_columns = []
@@ -721,9 +730,13 @@ class Validator:
721
730
  missing_columns.append(col)
722
731
 
723
732
  if schema.feature_column_names is not None:
724
- for col in schema.feature_column_names:
725
- if col not in existing_columns:
726
- missing_columns.append(col)
733
+ missing_columns.extend(
734
+ [
735
+ col
736
+ for col in schema.feature_column_names
737
+ if col not in existing_columns
738
+ ]
739
+ )
727
740
 
728
741
  if schema.embedding_feature_column_names is not None:
729
742
  for (
@@ -752,44 +765,76 @@ class Validator:
752
765
  )
753
766
 
754
767
  if schema.tag_column_names is not None:
755
- for col in schema.tag_column_names:
756
- if col not in existing_columns:
757
- missing_columns.append(col)
768
+ missing_columns.extend(
769
+ [
770
+ col
771
+ for col in schema.tag_column_names
772
+ if col not in existing_columns
773
+ ]
774
+ )
758
775
 
759
776
  if schema.shap_values_column_names is not None:
760
- for col in schema.shap_values_column_names.values():
761
- if col not in existing_columns:
762
- missing_columns.append(col)
777
+ missing_columns.extend(
778
+ [
779
+ col
780
+ for col in schema.shap_values_column_names.values()
781
+ if col not in existing_columns
782
+ ]
783
+ )
763
784
 
764
785
  if schema.object_detection_prediction_column_names is not None:
765
- for col in schema.object_detection_prediction_column_names:
766
- if col is not None and col not in existing_columns:
767
- missing_columns.append(col)
786
+ missing_columns.extend(
787
+ [
788
+ col
789
+ for col in schema.object_detection_prediction_column_names
790
+ if col is not None and col not in existing_columns
791
+ ]
792
+ )
768
793
 
769
794
  if schema.object_detection_actual_column_names is not None:
770
- for col in schema.object_detection_actual_column_names:
771
- if col is not None and col not in existing_columns:
772
- missing_columns.append(col)
795
+ missing_columns.extend(
796
+ [
797
+ col
798
+ for col in schema.object_detection_actual_column_names
799
+ if col is not None and col not in existing_columns
800
+ ]
801
+ )
773
802
 
774
803
  if schema.semantic_segmentation_prediction_column_names is not None:
775
- for col in schema.semantic_segmentation_prediction_column_names:
776
- if col is not None and col not in existing_columns:
777
- missing_columns.append(col)
804
+ missing_columns.extend(
805
+ [
806
+ col
807
+ for col in schema.semantic_segmentation_prediction_column_names
808
+ if col is not None and col not in existing_columns
809
+ ]
810
+ )
778
811
 
779
812
  if schema.semantic_segmentation_actual_column_names is not None:
780
- for col in schema.semantic_segmentation_actual_column_names:
781
- if col is not None and col not in existing_columns:
782
- missing_columns.append(col)
813
+ missing_columns.extend(
814
+ [
815
+ col
816
+ for col in schema.semantic_segmentation_actual_column_names
817
+ if col is not None and col not in existing_columns
818
+ ]
819
+ )
783
820
 
784
821
  if schema.instance_segmentation_prediction_column_names is not None:
785
- for col in schema.instance_segmentation_prediction_column_names:
786
- if col is not None and col not in existing_columns:
787
- missing_columns.append(col)
822
+ missing_columns.extend(
823
+ [
824
+ col
825
+ for col in schema.instance_segmentation_prediction_column_names
826
+ if col is not None and col not in existing_columns
827
+ ]
828
+ )
788
829
 
789
830
  if schema.instance_segmentation_actual_column_names is not None:
790
- for col in schema.instance_segmentation_actual_column_names:
791
- if col is not None and col not in existing_columns:
792
- missing_columns.append(col)
831
+ missing_columns.extend(
832
+ [
833
+ col
834
+ for col in schema.instance_segmentation_actual_column_names
835
+ if col is not None and col not in existing_columns
836
+ ]
837
+ )
793
838
 
794
839
  if schema.prompt_column_names is not None:
795
840
  if isinstance(schema.prompt_column_names, str):
@@ -838,14 +883,22 @@ class Validator:
838
883
  )
839
884
 
840
885
  if schema.prompt_template_column_names is not None:
841
- for col in schema.prompt_template_column_names:
842
- if col is not None and col not in existing_columns:
843
- missing_columns.append(col)
886
+ missing_columns.extend(
887
+ [
888
+ col
889
+ for col in schema.prompt_template_column_names
890
+ if col is not None and col not in existing_columns
891
+ ]
892
+ )
844
893
 
845
894
  if schema.llm_config_column_names is not None:
846
- for col in schema.llm_config_column_names:
847
- if col is not None and col not in existing_columns:
848
- missing_columns.append(col)
895
+ missing_columns.extend(
896
+ [
897
+ col
898
+ for col in schema.llm_config_column_names
899
+ if col is not None and col not in existing_columns
900
+ ]
901
+ )
849
902
 
850
903
  if missing_columns:
851
904
  return [err.MissingColumns(missing_columns)]
@@ -855,7 +908,7 @@ class Validator:
855
908
  def _check_missing_columns_corpus_schema(
856
909
  dataframe: pd.DataFrame,
857
910
  schema: CorpusSchema,
858
- ) -> List[err.MissingColumns]:
911
+ ) -> list[err.MissingColumns]:
859
912
  # converting to a set first makes the checks run a lot faster
860
913
  existing_columns = set(dataframe.columns)
861
914
  missing_columns = []
@@ -912,7 +965,7 @@ class Validator:
912
965
  def _check_valid_schema_type(
913
966
  schema: BaseSchema,
914
967
  environment: Environments,
915
- ) -> List[err.InvalidSchemaType]:
968
+ ) -> list[err.InvalidSchemaType]:
916
969
  if environment == Environments.CORPUS and not (
917
970
  isinstance(schema, CorpusSchema)
918
971
  ):
@@ -934,7 +987,7 @@ class Validator:
934
987
  @staticmethod
935
988
  def _check_invalid_shap_suffix(
936
989
  schema: Schema,
937
- ) -> List[err.InvalidShapSuffix]:
990
+ ) -> list[err.InvalidShapSuffix]:
938
991
  invalid_column_names = set()
939
992
 
940
993
  if schema.feature_column_names is not None:
@@ -970,10 +1023,10 @@ class Validator:
970
1023
  def _check_reserved_columns(
971
1024
  schema: BaseSchema,
972
1025
  model_type: ModelTypes,
973
- ) -> List[err.ReservedColumns]:
1026
+ ) -> list[err.ReservedColumns]:
974
1027
  if isinstance(schema, CorpusSchema):
975
1028
  return []
976
- elif isinstance(schema, Schema):
1029
+ if isinstance(schema, Schema):
977
1030
  reserved_columns = []
978
1031
  column_counts = schema.get_used_columns_counts()
979
1032
  if model_type == ModelTypes.GENERATIVE_LLM:
@@ -1079,8 +1132,8 @@ class Validator:
1079
1132
 
1080
1133
  @staticmethod
1081
1134
  def _check_invalid_model_id(
1082
- model_id: Optional[str],
1083
- ) -> List[err.InvalidModelId]:
1135
+ model_id: str | None,
1136
+ ) -> list[err.InvalidModelId]:
1084
1137
  # assume it's been coerced to string beforehand
1085
1138
  if (not isinstance(model_id, str)) or len(model_id.strip()) == 0:
1086
1139
  return [err.InvalidModelId()]
@@ -1088,8 +1141,8 @@ class Validator:
1088
1141
 
1089
1142
  @staticmethod
1090
1143
  def _check_invalid_model_version(
1091
- model_version: Optional[str] = None,
1092
- ) -> List[err.InvalidModelVersion]:
1144
+ model_version: str | None = None,
1145
+ ) -> list[err.InvalidModelVersion]:
1093
1146
  if model_version is None:
1094
1147
  return []
1095
1148
  if (
@@ -1102,9 +1155,9 @@ class Validator:
1102
1155
 
1103
1156
  @staticmethod
1104
1157
  def _check_invalid_batch_id(
1105
- batch_id: Optional[str],
1158
+ batch_id: str | None,
1106
1159
  environment: Environments,
1107
- ) -> List[err.InvalidBatchId]:
1160
+ ) -> list[err.InvalidBatchId]:
1108
1161
  # assume it's been coerced to string beforehand
1109
1162
  if environment in (Environments.VALIDATION,) and (
1110
1163
  (not isinstance(batch_id, str)) or len(batch_id.strip()) == 0
@@ -1115,7 +1168,7 @@ class Validator:
1115
1168
  @staticmethod
1116
1169
  def _check_invalid_model_type(
1117
1170
  model_type: ModelTypes,
1118
- ) -> List[err.InvalidModelType]:
1171
+ ) -> list[err.InvalidModelType]:
1119
1172
  if model_type in (mt for mt in ModelTypes):
1120
1173
  return []
1121
1174
  return [err.InvalidModelType()]
@@ -1123,7 +1176,7 @@ class Validator:
1123
1176
  @staticmethod
1124
1177
  def _check_invalid_environment(
1125
1178
  environment: Environments,
1126
- ) -> List[err.InvalidEnvironment]:
1179
+ ) -> list[err.InvalidEnvironment]:
1127
1180
  if environment in (env for env in Environments):
1128
1181
  return []
1129
1182
  return [err.InvalidEnvironment()]
@@ -1132,7 +1185,7 @@ class Validator:
1132
1185
  def _check_existence_preprod_pred_act_score_or_label(
1133
1186
  schema: Schema,
1134
1187
  environment: Environments,
1135
- ) -> List[err.MissingPreprodPredActNumericAndCategorical]:
1188
+ ) -> list[err.MissingPreprodPredActNumericAndCategorical]:
1136
1189
  if environment in (Environments.VALIDATION, Environments.TRAINING) and (
1137
1190
  (
1138
1191
  schema.prediction_label_column_name is None
@@ -1149,7 +1202,7 @@ class Validator:
1149
1202
  @staticmethod
1150
1203
  def _check_exactly_one_cv_column_type(
1151
1204
  schema: Schema, environment: Environments
1152
- ) -> List[Union[err.MultipleCVPredAct, err.MissingCVPredAct]]:
1205
+ ) -> list[err.MultipleCVPredAct | err.MissingCVPredAct]:
1153
1206
  # Checks that the required prediction/actual columns are given in the schema depending on
1154
1207
  # the environment, for object detection models. There should be exactly one of
1155
1208
  # object detection, semantic segmentation, or instance segmentation columns.
@@ -1180,7 +1233,7 @@ class Validator:
1180
1233
 
1181
1234
  if cv_types_count == 0:
1182
1235
  return [err.MissingCVPredAct(environment)]
1183
- elif cv_types_count > 1:
1236
+ if cv_types_count > 1:
1184
1237
  return [err.MultipleCVPredAct(environment)]
1185
1238
 
1186
1239
  elif environment in (
@@ -1213,7 +1266,7 @@ class Validator:
1213
1266
 
1214
1267
  if cv_types_count == 0:
1215
1268
  return [err.MissingCVPredAct(environment)]
1216
- elif cv_types_count > 1:
1269
+ if cv_types_count > 1:
1217
1270
  return [err.MultipleCVPredAct(environment)]
1218
1271
 
1219
1272
  return []
@@ -1221,9 +1274,9 @@ class Validator:
1221
1274
  @staticmethod
1222
1275
  def _check_missing_object_detection_columns(
1223
1276
  schema: Schema, model_type: ModelTypes
1224
- ) -> List[err.InvalidPredActCVColumnNamesForModelType]:
1277
+ ) -> list[err.InvalidPredActCVColumnNamesForModelType]:
1225
1278
  # Checks that models that are not Object Detection models don't have, in the schema, the
1226
- # object detection, semantic segmentation, or instance segmentation dedicated prediciton/actual
1279
+ # object detection, semantic segmentation, or instance segmentation dedicated prediction/actual
1227
1280
  # column names
1228
1281
  if (
1229
1282
  schema.object_detection_prediction_column_names is not None
@@ -1239,7 +1292,7 @@ class Validator:
1239
1292
  @staticmethod
1240
1293
  def _check_missing_non_object_detection_columns(
1241
1294
  schema: Schema, model_type: ModelTypes
1242
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1295
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1243
1296
  # Checks that object detection models don't have, in the schema, the columns reserved for
1244
1297
  # other model types
1245
1298
  columns_to_check = (
@@ -1253,10 +1306,7 @@ class Validator:
1253
1306
  schema.relevance_score_column_name,
1254
1307
  schema.relevance_labels_column_name,
1255
1308
  )
1256
- wrong_cols = []
1257
- for col in columns_to_check:
1258
- if col is not None:
1259
- wrong_cols.append(col)
1309
+ wrong_cols = [col for col in columns_to_check if col is not None]
1260
1310
  if wrong_cols:
1261
1311
  allowed_cols = [
1262
1312
  "object_detection_prediction_column_names",
@@ -1276,7 +1326,7 @@ class Validator:
1276
1326
  @staticmethod
1277
1327
  def _check_missing_multi_class_columns(
1278
1328
  schema: Schema, model_type: ModelTypes
1279
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1329
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1280
1330
  # Checks that models that are not Multi Class models don't have, in the schema, the
1281
1331
  # multi class dedicated threshold column
1282
1332
  if (
@@ -1295,7 +1345,7 @@ class Validator:
1295
1345
  @staticmethod
1296
1346
  def _check_existing_multi_class_columns(
1297
1347
  schema: Schema,
1298
- ) -> List[err.MissingReqPredActColumnNamesForMultiClass]:
1348
+ ) -> list[err.MissingReqPredActColumnNamesForMultiClass]:
1299
1349
  # Checks that models that are Multi Class models have, in the schema, the
1300
1350
  # required prediction score or actual score columns
1301
1351
  if (
@@ -1311,7 +1361,7 @@ class Validator:
1311
1361
  @staticmethod
1312
1362
  def _check_missing_non_multi_class_columns(
1313
1363
  schema: Schema, model_type: ModelTypes
1314
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1364
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1315
1365
  # Checks that multi class models don't have, in the schema, the columns reserved for
1316
1366
  # other model types
1317
1367
  columns_to_check = (
@@ -1329,10 +1379,7 @@ class Validator:
1329
1379
  schema.instance_segmentation_prediction_column_names,
1330
1380
  schema.instance_segmentation_actual_column_names,
1331
1381
  )
1332
- wrong_cols = []
1333
- for col in columns_to_check:
1334
- if col is not None:
1335
- wrong_cols.append(col)
1382
+ wrong_cols = [col for col in columns_to_check if col is not None]
1336
1383
  if wrong_cols:
1337
1384
  allowed_cols = [
1338
1385
  "prediction_score_column_name",
@@ -1350,7 +1397,7 @@ class Validator:
1350
1397
  def _check_existence_preprod_act(
1351
1398
  schema: Schema,
1352
1399
  environment: Environments,
1353
- ) -> List[err.MissingPreprodAct]:
1400
+ ) -> list[err.MissingPreprodAct]:
1354
1401
  if environment in (Environments.VALIDATION, Environments.TRAINING) and (
1355
1402
  schema.actual_label_column_name is None
1356
1403
  ):
@@ -1360,7 +1407,7 @@ class Validator:
1360
1407
  @staticmethod
1361
1408
  def _check_existence_group_id_rank_category_relevance(
1362
1409
  schema: Schema,
1363
- ) -> List[err.MissingRequiredColumnsForRankingModel]:
1410
+ ) -> list[err.MissingRequiredColumnsForRankingModel]:
1364
1411
  # prediction_group_id and rank columns are required as ranking prediction columns.
1365
1412
  ranking_prediction_cols = (
1366
1413
  schema.prediction_label_column_name,
@@ -1384,7 +1431,7 @@ class Validator:
1384
1431
  @staticmethod
1385
1432
  def _check_dataframe_for_duplicate_columns(
1386
1433
  schema: BaseSchema, dataframe: pd.DataFrame
1387
- ) -> List[err.DuplicateColumnsInDataframe]:
1434
+ ) -> list[err.DuplicateColumnsInDataframe]:
1388
1435
  # Get the columns used in the schema
1389
1436
  schema_col_used = schema.get_used_columns()
1390
1437
  # Get the duplicated column names from the dataframe
@@ -1400,7 +1447,7 @@ class Validator:
1400
1447
  @staticmethod
1401
1448
  def _check_invalid_number_of_embeddings(
1402
1449
  schema: Schema,
1403
- ) -> List[err.InvalidNumberOfEmbeddings]:
1450
+ ) -> list[err.InvalidNumberOfEmbeddings]:
1404
1451
  if schema.embedding_feature_column_names is not None:
1405
1452
  number_of_embeddings = len(schema.embedding_feature_column_names)
1406
1453
  if number_of_embeddings > MAX_NUMBER_OF_EMBEDDINGS:
@@ -1413,8 +1460,8 @@ class Validator:
1413
1460
 
1414
1461
  @staticmethod
1415
1462
  def _check_type_prediction_id(
1416
- schema: Schema, column_types: Dict[str, Any]
1417
- ) -> List[err.InvalidType]:
1463
+ schema: Schema, column_types: dict[str, Any]
1464
+ ) -> list[err.InvalidType]:
1418
1465
  col = schema.prediction_id_column_name
1419
1466
  if col in column_types:
1420
1467
  # should mirror server side
@@ -1437,8 +1484,8 @@ class Validator:
1437
1484
 
1438
1485
  @staticmethod
1439
1486
  def _check_type_timestamp(
1440
- schema: Schema, column_types: Dict[str, Any]
1441
- ) -> List[err.InvalidType]:
1487
+ schema: Schema, column_types: dict[str, Any]
1488
+ ) -> list[err.InvalidType]:
1442
1489
  col = schema.timestamp_column_name
1443
1490
  if col in column_types:
1444
1491
  # should mirror server side
@@ -1464,8 +1511,8 @@ class Validator:
1464
1511
 
1465
1512
  @staticmethod
1466
1513
  def _check_type_features(
1467
- schema: Schema, column_types: Dict[str, Any]
1468
- ) -> List[err.InvalidTypeFeatures]:
1514
+ schema: Schema, column_types: dict[str, Any]
1515
+ ) -> list[err.InvalidTypeFeatures]:
1469
1516
  if schema.feature_column_names is not None:
1470
1517
  # should mirror server side
1471
1518
  allowed_datatypes = (
@@ -1480,13 +1527,12 @@ class Validator:
1480
1527
  pa.null(),
1481
1528
  pa.list_(pa.string()),
1482
1529
  )
1483
- wrong_type_cols = []
1484
- for col in schema.feature_column_names:
1485
- if (
1486
- col in column_types
1487
- and column_types[col] not in allowed_datatypes
1488
- ):
1489
- wrong_type_cols.append(col)
1530
+ wrong_type_cols = [
1531
+ col
1532
+ for col in schema.feature_column_names
1533
+ if col in column_types
1534
+ and column_types[col] not in allowed_datatypes
1535
+ ]
1490
1536
  if wrong_type_cols:
1491
1537
  return [
1492
1538
  err.InvalidTypeFeatures(
@@ -1504,8 +1550,8 @@ class Validator:
1504
1550
 
1505
1551
  @staticmethod
1506
1552
  def _check_type_embedding_features(
1507
- schema: Schema, column_types: Dict[str, Any]
1508
- ) -> List[err.InvalidTypeFeatures]:
1553
+ schema: Schema, column_types: dict[str, Any]
1554
+ ) -> list[err.InvalidTypeFeatures]:
1509
1555
  if schema.embedding_feature_column_names is not None:
1510
1556
  # should mirror server side
1511
1557
  allowed_vector_datatypes = (
@@ -1580,8 +1626,8 @@ class Validator:
1580
1626
 
1581
1627
  @staticmethod
1582
1628
  def _check_type_tags(
1583
- schema: Schema, column_types: Dict[str, Any]
1584
- ) -> List[err.InvalidTypeTags]:
1629
+ schema: Schema, column_types: dict[str, Any]
1630
+ ) -> list[err.InvalidTypeTags]:
1585
1631
  if schema.tag_column_names is not None:
1586
1632
  # should mirror server side
1587
1633
  allowed_datatypes = (
@@ -1595,13 +1641,12 @@ class Validator:
1595
1641
  pa.int8(),
1596
1642
  pa.null(),
1597
1643
  )
1598
- wrong_type_cols = []
1599
- for col in schema.tag_column_names:
1600
- if (
1601
- col in column_types
1602
- and column_types[col] not in allowed_datatypes
1603
- ):
1604
- wrong_type_cols.append(col)
1644
+ wrong_type_cols = [
1645
+ col
1646
+ for col in schema.tag_column_names
1647
+ if col in column_types
1648
+ and column_types[col] not in allowed_datatypes
1649
+ ]
1605
1650
  if wrong_type_cols:
1606
1651
  return [
1607
1652
  err.InvalidTypeTags(
@@ -1612,8 +1657,8 @@ class Validator:
1612
1657
 
1613
1658
  @staticmethod
1614
1659
  def _check_type_shap_values(
1615
- schema: Schema, column_types: Dict[str, Any]
1616
- ) -> List[err.InvalidTypeShapValues]:
1660
+ schema: Schema, column_types: dict[str, Any]
1661
+ ) -> list[err.InvalidTypeShapValues]:
1617
1662
  if schema.shap_values_column_names is not None:
1618
1663
  # should mirror server side
1619
1664
  allowed_datatypes = (
@@ -1622,13 +1667,12 @@ class Validator:
1622
1667
  pa.float32(),
1623
1668
  pa.int32(),
1624
1669
  )
1625
- wrong_type_cols = []
1626
- for _, col in schema.shap_values_column_names.items():
1627
- if (
1628
- col in column_types
1629
- and column_types[col] not in allowed_datatypes
1630
- ):
1631
- wrong_type_cols.append(col)
1670
+ wrong_type_cols = [
1671
+ col
1672
+ for col in schema.shap_values_column_names.values()
1673
+ if col in column_types
1674
+ and column_types[col] not in allowed_datatypes
1675
+ ]
1632
1676
  if wrong_type_cols:
1633
1677
  return [
1634
1678
  err.InvalidTypeShapValues(
@@ -1639,8 +1683,8 @@ class Validator:
1639
1683
 
1640
1684
  @staticmethod
1641
1685
  def _check_type_pred_act_labels(
1642
- model_type: ModelTypes, schema: Schema, column_types: Dict[str, Any]
1643
- ) -> List[err.InvalidType]:
1686
+ model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
1687
+ ) -> list[err.InvalidType]:
1644
1688
  errors = []
1645
1689
  columns = (
1646
1690
  ("Prediction labels", schema.prediction_label_column_name),
@@ -1703,8 +1747,8 @@ class Validator:
1703
1747
 
1704
1748
  @staticmethod
1705
1749
  def _check_type_pred_act_scores(
1706
- model_type: ModelTypes, schema: Schema, column_types: Dict[str, Any]
1707
- ) -> List[err.InvalidType]:
1750
+ model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
1751
+ ) -> list[err.InvalidType]:
1708
1752
  errors = []
1709
1753
  columns = (
1710
1754
  ("Prediction scores", schema.prediction_score_column_name),
@@ -1743,13 +1787,14 @@ class Validator:
1743
1787
 
1744
1788
  @staticmethod
1745
1789
  def _check_type_multi_class_pred_threshold_act_scores(
1746
- schema: Schema, column_types: Dict[str, Any]
1747
- ) -> List[err.InvalidType]:
1748
- """
1749
- Check type for prediction / threshold / actual scores for multiclass model
1750
- Expect the scores to be a list of pyarrow structs that contains field "class_name" and field "score
1751
- Where class_name is a string and score is a number
1752
- Example: '[{"class_name": "class1", "score": 0.1}, {"class_name": "class2", "score": 0.2}, ...]'
1790
+ schema: Schema, column_types: dict[str, Any]
1791
+ ) -> list[err.InvalidType]:
1792
+ """Check type for prediction / threshold / actual scores for multiclass model.
1793
+
1794
+ Expect the scores to be a list of pyarrow structs that contains field
1795
+ "class_name" and field "score", where class_name is a string and score
1796
+ is a number.
1797
+ Example: '[{"class_name": "class1", "score": 0.1}, ...]'
1753
1798
  """
1754
1799
  errors = []
1755
1800
  columns = (
@@ -1802,8 +1847,8 @@ class Validator:
1802
1847
 
1803
1848
  @staticmethod
1804
1849
  def _check_type_prompt_response(
1805
- schema: Schema, column_types: Dict[str, Any]
1806
- ) -> List[err.InvalidTypeColumns]:
1850
+ schema: Schema, column_types: dict[str, Any]
1851
+ ) -> list[err.InvalidTypeColumns]:
1807
1852
  fields_to_check = []
1808
1853
  if schema.prompt_column_names is not None:
1809
1854
  fields_to_check.append(schema.prompt_column_names)
@@ -1872,8 +1917,8 @@ class Validator:
1872
1917
 
1873
1918
  @staticmethod
1874
1919
  def _check_type_llm_prompt_templates(
1875
- schema: Schema, column_types: Dict[str, Any]
1876
- ) -> List[err.InvalidTypeColumns]:
1920
+ schema: Schema, column_types: dict[str, Any]
1921
+ ) -> list[err.InvalidTypeColumns]:
1877
1922
  if schema.prompt_template_column_names is None:
1878
1923
  return []
1879
1924
 
@@ -1913,8 +1958,8 @@ class Validator:
1913
1958
 
1914
1959
  @staticmethod
1915
1960
  def _check_type_llm_config(
1916
- schema: Schema, column_types: Dict[str, Any]
1917
- ) -> List[err.InvalidTypeColumns]:
1961
+ schema: Schema, column_types: dict[str, Any]
1962
+ ) -> list[err.InvalidTypeColumns]:
1918
1963
  if schema.llm_config_column_names is None:
1919
1964
  return []
1920
1965
 
@@ -1950,8 +1995,8 @@ class Validator:
1950
1995
 
1951
1996
  @staticmethod
1952
1997
  def _check_type_llm_run_metadata(
1953
- schema: Schema, column_types: Dict[str, Any]
1954
- ) -> List[err.InvalidTypeColumns]:
1998
+ schema: Schema, column_types: dict[str, Any]
1999
+ ) -> list[err.InvalidTypeColumns]:
1955
2000
  if schema.llm_run_metadata_column_names is None:
1956
2001
  return []
1957
2002
 
@@ -2023,8 +2068,8 @@ class Validator:
2023
2068
 
2024
2069
  @staticmethod
2025
2070
  def _check_type_retrieved_document_ids(
2026
- schema: Schema, column_types: Dict[str, Any]
2027
- ) -> List[err.InvalidType]:
2071
+ schema: Schema, column_types: dict[str, Any]
2072
+ ) -> list[err.InvalidType]:
2028
2073
  col = schema.retrieved_document_ids_column_name
2029
2074
  if col in column_types:
2030
2075
  # should mirror server side
@@ -2044,8 +2089,8 @@ class Validator:
2044
2089
 
2045
2090
  @staticmethod
2046
2091
  def _check_type_image_segment_coordinates(
2047
- schema: Schema, column_types: Dict[str, Any]
2048
- ) -> List[err.InvalidTypeColumns]:
2092
+ schema: Schema, column_types: dict[str, Any]
2093
+ ) -> list[err.InvalidTypeColumns]:
2049
2094
  # should mirror server side
2050
2095
  allowed_coordinate_types = (
2051
2096
  pa.list_(pa.list_(pa.float64())),
@@ -2090,9 +2135,8 @@ class Validator:
2090
2135
  wrong_type_cols.append(coord_col)
2091
2136
 
2092
2137
  if schema.instance_segmentation_prediction_column_names is not None:
2093
- polygons_coord_col = (
2094
- schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
2095
- )
2138
+ inst_seg_pred = schema.instance_segmentation_prediction_column_names
2139
+ polygons_coord_col = inst_seg_pred.polygon_coordinates_column_name
2096
2140
  if (
2097
2141
  polygons_coord_col in column_types
2098
2142
  and column_types[polygons_coord_col]
@@ -2101,7 +2145,7 @@ class Validator:
2101
2145
  wrong_type_cols.append(polygons_coord_col)
2102
2146
 
2103
2147
  bbox_coord_col = (
2104
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
2148
+ inst_seg_pred.bounding_boxes_coordinates_column_name
2105
2149
  )
2106
2150
  if (
2107
2151
  bbox_coord_col in column_types
@@ -2110,9 +2154,8 @@ class Validator:
2110
2154
  wrong_type_cols.append(bbox_coord_col)
2111
2155
 
2112
2156
  if schema.instance_segmentation_actual_column_names is not None:
2113
- coord_col = (
2114
- schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
2115
- )
2157
+ inst_seg_actual = schema.instance_segmentation_actual_column_names
2158
+ coord_col = inst_seg_actual.polygon_coordinates_column_name
2116
2159
  if (
2117
2160
  coord_col in column_types
2118
2161
  and column_types[coord_col] not in allowed_coordinate_types
@@ -2120,7 +2163,7 @@ class Validator:
2120
2163
  wrong_type_cols.append(coord_col)
2121
2164
 
2122
2165
  bbox_coord_col = (
2123
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name # noqa: E501
2166
+ inst_seg_actual.bounding_boxes_coordinates_column_name
2124
2167
  )
2125
2168
  if (
2126
2169
  bbox_coord_col in column_types
@@ -2141,8 +2184,8 @@ class Validator:
2141
2184
 
2142
2185
  @staticmethod
2143
2186
  def _check_type_image_segment_categories(
2144
- schema: Schema, column_types: Dict[str, Any]
2145
- ) -> List[err.InvalidTypeColumns]:
2187
+ schema: Schema, column_types: dict[str, Any]
2188
+ ) -> list[err.InvalidTypeColumns]:
2146
2189
  # should mirror server side
2147
2190
  allowed_category_datatypes = (
2148
2191
  pa.list_(pa.string()),
@@ -2210,8 +2253,8 @@ class Validator:
2210
2253
 
2211
2254
  @staticmethod
2212
2255
  def _check_type_image_segment_scores(
2213
- schema: Schema, column_types: Dict[str, Any]
2214
- ) -> List[err.InvalidTypeColumns]:
2256
+ schema: Schema, column_types: dict[str, Any]
2257
+ ) -> list[err.InvalidTypeColumns]:
2215
2258
  # should mirror server side
2216
2259
  allowed_score_datatypes = (
2217
2260
  pa.list_(pa.float64()),
@@ -2270,7 +2313,7 @@ class Validator:
2270
2313
  @staticmethod
2271
2314
  def _check_embedding_vectors_dimensionality(
2272
2315
  dataframe: pd.DataFrame, schema: Schema
2273
- ) -> List[err.ValidationError]:
2316
+ ) -> list[err.ValidationError]:
2274
2317
  if schema.embedding_feature_column_names is None:
2275
2318
  return []
2276
2319
 
@@ -2300,7 +2343,7 @@ class Validator:
2300
2343
  @staticmethod
2301
2344
  def _check_embedding_raw_data_characters(
2302
2345
  dataframe: pd.DataFrame, schema: Schema
2303
- ) -> List[err.ValidationError]:
2346
+ ) -> list[err.ValidationError]:
2304
2347
  if schema.embedding_feature_column_names is None:
2305
2348
  return []
2306
2349
 
@@ -2322,7 +2365,7 @@ class Validator:
2322
2365
  invalid_long_string_data_cols
2323
2366
  )
2324
2367
  ]
2325
- elif truncated_long_string_data_cols:
2368
+ if truncated_long_string_data_cols:
2326
2369
  logger.warning(
2327
2370
  get_truncation_warning_message(
2328
2371
  "Embedding raw data fields",
@@ -2334,7 +2377,7 @@ class Validator:
2334
2377
  @staticmethod
2335
2378
  def _check_value_rank(
2336
2379
  dataframe: pd.DataFrame, schema: Schema
2337
- ) -> List[err.InvalidRankValue]:
2380
+ ) -> list[err.InvalidRankValue]:
2338
2381
  col = schema.rank_column_name
2339
2382
  lbound, ubound = (1, 100)
2340
2383
 
@@ -2346,11 +2389,11 @@ class Validator:
2346
2389
 
2347
2390
  @staticmethod
2348
2391
  def _check_id_field_str_length(
2349
- dataframe: pd.DataFrame, schema_name: str, id_col_name: Optional[str]
2350
- ) -> List[err.ValidationError]:
2351
- """
2352
- Require prediction_id to be a string of length between MIN_PREDICTION_ID_LEN
2353
- and MAX_PREDICTION_ID_LEN
2392
+ dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
2393
+ ) -> list[err.ValidationError]:
2394
+ """Require prediction_id to be a string of length between MIN and MAX.
2395
+
2396
+ Between MIN_PREDICTION_ID_LEN and MAX_PREDICTION_ID_LEN.
2354
2397
  """
2355
2398
  # We check whether the column name can be None is allowed in `Validator.validate_params`
2356
2399
  if id_col_name is None:
@@ -2380,11 +2423,11 @@ class Validator:
2380
2423
 
2381
2424
  @staticmethod
2382
2425
  def _check_document_id_field_str_length(
2383
- dataframe: pd.DataFrame, schema_name: str, id_col_name: Optional[str]
2384
- ) -> List[err.ValidationError]:
2385
- """
2386
- Require document id to be a string of length between MIN_DOCUMENT_ID_LEN
2387
- and MAX_DOCUMENT_ID_LEN
2426
+ dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
2427
+ ) -> list[err.ValidationError]:
2428
+ """Require document id to be a string of length between MIN and MAX.
2429
+
2430
+ Between MIN_DOCUMENT_ID_LEN and MAX_DOCUMENT_ID_LEN.
2388
2431
  """
2389
2432
  # We check whether the column name can be None is allowed in `Validator.validate_params`
2390
2433
  if id_col_name is None:
@@ -2433,7 +2476,7 @@ class Validator:
2433
2476
  @staticmethod
2434
2477
  def _check_value_tag(
2435
2478
  dataframe: pd.DataFrame, schema: Schema
2436
- ) -> List[err.InvalidTagLength]:
2479
+ ) -> list[err.InvalidTagLength]:
2437
2480
  if schema.tag_column_names is None:
2438
2481
  return []
2439
2482
 
@@ -2459,7 +2502,7 @@ class Validator:
2459
2502
  truncated_tag_cols.append(col)
2460
2503
  if wrong_tag_cols:
2461
2504
  return [err.InvalidTagLength(wrong_tag_cols)]
2462
- elif truncated_tag_cols:
2505
+ if truncated_tag_cols:
2463
2506
  logger.warning(
2464
2507
  get_truncation_warning_message(
2465
2508
  "tags", MAX_TAG_LENGTH_TRUNCATION
@@ -2470,9 +2513,7 @@ class Validator:
2470
2513
  @staticmethod
2471
2514
  def _check_value_ranking_category(
2472
2515
  dataframe: pd.DataFrame, schema: Schema
2473
- ) -> List[
2474
- Union[err.InvalidValueMissingValue, err.InvalidRankingCategoryValue]
2475
- ]:
2516
+ ) -> list[err.InvalidValueMissingValue | err.InvalidRankingCategoryValue]:
2476
2517
  if schema.relevance_labels_column_name is not None:
2477
2518
  col = schema.relevance_labels_column_name
2478
2519
  elif schema.attributions_column_name is not None:
@@ -2503,7 +2544,7 @@ class Validator:
2503
2544
  @staticmethod
2504
2545
  def _check_length_multi_class_maps(
2505
2546
  dataframe: pd.DataFrame, schema: Schema
2506
- ) -> List[err.InvalidNumClassesMultiClassMap]:
2547
+ ) -> list[err.InvalidNumClassesMultiClassMap]:
2507
2548
  # each entry in column is a list of dictionaries mapping class names and scores
2508
2549
  # validate length of list of dictionaries for each column
2509
2550
  invalid_cols = {}
@@ -2540,15 +2581,13 @@ class Validator:
2540
2581
  @staticmethod
2541
2582
  def _check_classes_and_scores_values_in_multi_class_maps(
2542
2583
  dataframe: pd.DataFrame, schema: Schema
2543
- ) -> List[
2544
- Union[
2545
- err.InvalidMultiClassClassNameLength,
2546
- err.InvalidMultiClassActScoreValue,
2547
- err.InvalidMultiClassPredScoreValue,
2548
- ]
2584
+ ) -> list[
2585
+ err.InvalidMultiClassClassNameLength
2586
+ | err.InvalidMultiClassActScoreValue
2587
+ | err.InvalidMultiClassPredScoreValue
2549
2588
  ]:
2550
- """
2551
- Validate the class names and score values of dictionaries:
2589
+ """Validate the class names and score values of dictionaries.
2590
+
2552
2591
  - class name length
2553
2592
  - valid actual score
2554
2593
  - valid prediction / threshold score
@@ -2624,11 +2663,12 @@ class Validator:
2624
2663
  @staticmethod
2625
2664
  def _check_each_multi_class_pred_has_threshold(
2626
2665
  dataframe: pd.DataFrame, schema: Schema
2627
- ) -> List[err.InvalidMultiClassThresholdClasses]:
2628
- """
2629
- For Multi Class, if threshold scores col is included in schema and dataframe,
2630
- validate for each prediction score received, the associated threshold score
2631
- for that class was received
2666
+ ) -> list[err.InvalidMultiClassThresholdClasses]:
2667
+ """Validate threshold scores for Multi Class models.
2668
+
2669
+ If threshold scores column is included in schema and dataframe, validate that
2670
+ for each prediction score received, the associated threshold score for that
2671
+ class was also received.
2632
2672
  """
2633
2673
  threshold_col = schema.multi_class_threshold_scores_column_name
2634
2674
  if threshold_col is None:
@@ -2657,10 +2697,10 @@ class Validator:
2657
2697
  def _check_value_timestamp(
2658
2698
  dataframe: pd.DataFrame,
2659
2699
  schema: Schema,
2660
- ) -> List[Union[err.InvalidValueMissingValue, err.InvalidValueTimestamp]]:
2700
+ ) -> list[err.InvalidValueMissingValue | err.InvalidValueTimestamp]:
2661
2701
  # Due to the timing difference between checking this here and the data finally
2662
2702
  # hitting the same check on server side, there's a some chance for a false
2663
- # result, i.e. the check here suceeeds but the same check on server side fails.
2703
+ # result, i.e. the check here succeeds but the same check on server side fails.
2664
2704
  col = schema.timestamp_column_name
2665
2705
  if col is not None and col in dataframe.columns:
2666
2706
  # When a timestamp column has Date and NaN, pyarrow will be fine, but
@@ -2673,19 +2713,15 @@ class Validator:
2673
2713
  )
2674
2714
  ]
2675
2715
 
2676
- now_t = datetime.datetime.now()
2716
+ now_t = datetime.now(tz=timezone.utc)
2677
2717
  lbound, ubound = (
2678
2718
  (
2679
2719
  now_t
2680
- - datetime.timedelta(
2681
- days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365
2682
- )
2720
+ - timedelta(days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365)
2683
2721
  ).timestamp(),
2684
2722
  (
2685
2723
  now_t
2686
- + datetime.timedelta(
2687
- days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365
2688
- )
2724
+ + timedelta(days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365)
2689
2725
  ).timestamp(),
2690
2726
  )
2691
2727
  # faster than pyarrow compute
@@ -2767,7 +2803,7 @@ class Validator:
2767
2803
  @staticmethod
2768
2804
  def _check_invalid_missing_values(
2769
2805
  dataframe: pd.DataFrame, schema: BaseSchema, model_type: ModelTypes
2770
- ) -> List[err.InvalidValueMissingValue]:
2806
+ ) -> list[err.InvalidValueMissingValue]:
2771
2807
  errors = []
2772
2808
  columns = ()
2773
2809
  if isinstance(schema, CorpusSchema):
@@ -2814,7 +2850,7 @@ class Validator:
2814
2850
  environment: Environments,
2815
2851
  schema: Schema,
2816
2852
  model_type: ModelTypes,
2817
- ) -> List[err.InvalidRecord]:
2853
+ ) -> list[err.InvalidRecord]:
2818
2854
  if environment in (Environments.VALIDATION, Environments.TRAINING):
2819
2855
  return []
2820
2856
 
@@ -2858,11 +2894,11 @@ class Validator:
2858
2894
  environment: Environments,
2859
2895
  schema: Schema,
2860
2896
  model_type: ModelTypes,
2861
- ) -> List[err.InvalidRecord]:
2862
- """
2863
- Validates there's not a single row in the dataframe with pred_label, pred_score all
2864
- evaluates to null OR with actual_label, actual_score all evaluates to null and returns
2865
- errors if either of the two cases exists
2897
+ ) -> list[err.InvalidRecord]:
2898
+ """Validates there's not a single row in the dataframe with all nulls.
2899
+
2900
+ Returns errors if any row has all of pred_label and pred_score evaluating to
2901
+ null, OR all of actual_label and actual_score evaluating to null.
2866
2902
  """
2867
2903
  if environment == Environments.PRODUCTION:
2868
2904
  return []
@@ -2905,21 +2941,23 @@ class Validator:
2905
2941
 
2906
2942
  @staticmethod
2907
2943
  def _check_invalid_record_helper(
2908
- dataframe: pd.DataFrame, column_names: List[Optional[str]]
2909
- ) -> List[err.InvalidRecord]:
2910
- """
2911
- This function checks that there are no null values in a subset of columns,
2912
- returning an error if so. The column subset is computed from the input list of
2913
- columns `column_names` that are not None and that are present in the dataframe
2944
+ dataframe: pd.DataFrame, column_names: list[str | None]
2945
+ ) -> list[err.InvalidRecord]:
2946
+ """Check that there are no null values in a subset of columns.
2947
+
2948
+ The column subset is computed from the input list of columns `column_names`
2949
+ that are not None and that are present in the dataframe. Returns an error if
2950
+ null values are found.
2914
2951
 
2915
2952
  Returns:
2916
2953
  List[err.InvalidRecord]: An error expressing the rows that are problematic
2917
2954
 
2918
2955
  """
2919
- columns_subset = []
2920
- for col in column_names:
2921
- if col is not None and col in dataframe.columns:
2922
- columns_subset.append(col)
2956
+ columns_subset = [
2957
+ col
2958
+ for col in column_names
2959
+ if col is not None and col in dataframe.columns
2960
+ ]
2923
2961
  if len(columns_subset) == 0:
2924
2962
  return []
2925
2963
  null_filter = dataframe[columns_subset].isnull().all(axis=1)
@@ -2930,8 +2968,8 @@ class Validator:
2930
2968
 
2931
2969
  @staticmethod
2932
2970
  def _check_type_prediction_group_id(
2933
- schema: Schema, column_types: Dict[str, Any]
2934
- ) -> List[err.InvalidType]:
2971
+ schema: Schema, column_types: dict[str, Any]
2972
+ ) -> list[err.InvalidType]:
2935
2973
  col = schema.prediction_group_id_column_name
2936
2974
  if col in column_types:
2937
2975
  # should mirror server side
@@ -2954,8 +2992,8 @@ class Validator:
2954
2992
 
2955
2993
  @staticmethod
2956
2994
  def _check_type_rank(
2957
- schema: Schema, column_types: Dict[str, Any]
2958
- ) -> List[err.InvalidType]:
2995
+ schema: Schema, column_types: dict[str, Any]
2996
+ ) -> list[err.InvalidType]:
2959
2997
  col = schema.rank_column_name
2960
2998
  if col in column_types:
2961
2999
  allowed_datatypes = (
@@ -2976,8 +3014,8 @@ class Validator:
2976
3014
 
2977
3015
  @staticmethod
2978
3016
  def _check_type_ranking_category(
2979
- schema: Schema, column_types: Dict[str, Any]
2980
- ) -> List[err.InvalidType]:
3017
+ schema: Schema, column_types: dict[str, Any]
3018
+ ) -> list[err.InvalidType]:
2981
3019
  if schema.relevance_labels_column_name is not None:
2982
3020
  col = schema.relevance_labels_column_name
2983
3021
  elif schema.attributions_column_name is not None:
@@ -2999,7 +3037,7 @@ class Validator:
2999
3037
  @staticmethod
3000
3038
  def _check_value_bounding_boxes_coordinates(
3001
3039
  dataframe: pd.DataFrame, schema: Schema
3002
- ) -> List[err.InvalidBoundingBoxesCoordinates]:
3040
+ ) -> list[err.InvalidBoundingBoxesCoordinates]:
3003
3041
  errors = []
3004
3042
  if schema.object_detection_prediction_column_names is not None:
3005
3043
  coords_col_name = schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
@@ -3020,7 +3058,7 @@ class Validator:
3020
3058
  @staticmethod
3021
3059
  def _check_value_bounding_boxes_categories(
3022
3060
  dataframe: pd.DataFrame, schema: Schema
3023
- ) -> List[err.InvalidBoundingBoxesCategories]:
3061
+ ) -> list[err.InvalidBoundingBoxesCategories]:
3024
3062
  errors = []
3025
3063
  if schema.object_detection_prediction_column_names is not None:
3026
3064
  cat_col_name = schema.object_detection_prediction_column_names.categories_column_name
@@ -3041,7 +3079,7 @@ class Validator:
3041
3079
  @staticmethod
3042
3080
  def _check_value_bounding_boxes_scores(
3043
3081
  dataframe: pd.DataFrame, schema: Schema
3044
- ) -> List[err.InvalidBoundingBoxesScores]:
3082
+ ) -> list[err.InvalidBoundingBoxesScores]:
3045
3083
  errors = []
3046
3084
  if schema.object_detection_prediction_column_names is not None:
3047
3085
  sc_col_name = schema.object_detection_prediction_column_names.scores_column_name
@@ -3066,7 +3104,7 @@ class Validator:
3066
3104
  @staticmethod
3067
3105
  def _check_value_semantic_segmentation_polygon_coordinates(
3068
3106
  dataframe: pd.DataFrame, schema: Schema
3069
- ) -> List[err.InvalidPolygonCoordinates]:
3107
+ ) -> list[err.InvalidPolygonCoordinates]:
3070
3108
  errors = []
3071
3109
  if schema.semantic_segmentation_prediction_column_names is not None:
3072
3110
  coords_col_name = schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
@@ -3076,7 +3114,7 @@ class Validator:
3076
3114
  if error is not None:
3077
3115
  errors.append(error)
3078
3116
  if schema.semantic_segmentation_actual_column_names is not None:
3079
- coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
3117
+ coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
3080
3118
  error = _check_value_polygon_coordinates_helper(
3081
3119
  dataframe[coords_col_name]
3082
3120
  )
@@ -3087,7 +3125,7 @@ class Validator:
3087
3125
  @staticmethod
3088
3126
  def _check_value_semantic_segmentation_polygon_categories(
3089
3127
  dataframe: pd.DataFrame, schema: Schema
3090
- ) -> List[err.InvalidPolygonCategories]:
3128
+ ) -> list[err.InvalidPolygonCategories]:
3091
3129
  errors = []
3092
3130
  if schema.semantic_segmentation_prediction_column_names is not None:
3093
3131
  cat_col_name = schema.semantic_segmentation_prediction_column_names.categories_column_name
@@ -3108,7 +3146,7 @@ class Validator:
3108
3146
  @staticmethod
3109
3147
  def _check_value_instance_segmentation_polygon_coordinates(
3110
3148
  dataframe: pd.DataFrame, schema: Schema
3111
- ) -> List[err.InvalidPolygonCoordinates]:
3149
+ ) -> list[err.InvalidPolygonCoordinates]:
3112
3150
  errors = []
3113
3151
  if schema.instance_segmentation_prediction_column_names is not None:
3114
3152
  coords_col_name = schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
@@ -3118,7 +3156,7 @@ class Validator:
3118
3156
  if error is not None:
3119
3157
  errors.append(error)
3120
3158
  if schema.instance_segmentation_actual_column_names is not None:
3121
- coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
3159
+ coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
3122
3160
  error = _check_value_polygon_coordinates_helper(
3123
3161
  dataframe[coords_col_name]
3124
3162
  )
@@ -3129,7 +3167,7 @@ class Validator:
3129
3167
  @staticmethod
3130
3168
  def _check_value_instance_segmentation_polygon_categories(
3131
3169
  dataframe: pd.DataFrame, schema: Schema
3132
- ) -> List[err.InvalidPolygonCategories]:
3170
+ ) -> list[err.InvalidPolygonCategories]:
3133
3171
  errors = []
3134
3172
  if schema.instance_segmentation_prediction_column_names is not None:
3135
3173
  cat_col_name = schema.instance_segmentation_prediction_column_names.categories_column_name
@@ -3150,7 +3188,7 @@ class Validator:
3150
3188
  @staticmethod
3151
3189
  def _check_value_instance_segmentation_polygon_scores(
3152
3190
  dataframe: pd.DataFrame, schema: Schema
3153
- ) -> List[err.InvalidPolygonScores]:
3191
+ ) -> list[err.InvalidPolygonScores]:
3154
3192
  errors = []
3155
3193
  if schema.instance_segmentation_prediction_column_names is not None:
3156
3194
  sc_col_name = schema.instance_segmentation_prediction_column_names.scores_column_name
@@ -3165,7 +3203,7 @@ class Validator:
3165
3203
  @staticmethod
3166
3204
  def _check_value_instance_segmentation_bbox_coordinates(
3167
3205
  dataframe: pd.DataFrame, schema: Schema
3168
- ) -> List[err.InvalidBoundingBoxesCoordinates]:
3206
+ ) -> list[err.InvalidBoundingBoxesCoordinates]:
3169
3207
  errors = []
3170
3208
  if schema.instance_segmentation_prediction_column_names is not None:
3171
3209
  coords_col_name = schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
@@ -3188,7 +3226,7 @@ class Validator:
3188
3226
  @staticmethod
3189
3227
  def _check_value_prompt_response(
3190
3228
  dataframe: pd.DataFrame, schema: Schema
3191
- ) -> List[err.ValidationError]:
3229
+ ) -> list[err.ValidationError]:
3192
3230
  vector_cols_to_check = []
3193
3231
  text_cols_to_check = []
3194
3232
  if isinstance(schema.prompt_column_names, str):
@@ -3253,7 +3291,7 @@ class Validator:
3253
3291
  @staticmethod
3254
3292
  def _check_value_llm_model_name(
3255
3293
  dataframe: pd.DataFrame, schema: Schema
3256
- ) -> List[err.InvalidStringLengthInColumn]:
3294
+ ) -> list[err.InvalidStringLengthInColumn]:
3257
3295
  if schema.llm_config_column_names is None:
3258
3296
  return []
3259
3297
  col = schema.llm_config_column_names.model_column_name
@@ -3270,7 +3308,7 @@ class Validator:
3270
3308
  max_length=MAX_LLM_MODEL_NAME_LENGTH,
3271
3309
  )
3272
3310
  ]
3273
- elif max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
3311
+ if max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
3274
3312
  logger.warning(
3275
3313
  get_truncation_warning_message(
3276
3314
  "LLM model names", MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION
@@ -3281,7 +3319,7 @@ class Validator:
3281
3319
  @staticmethod
3282
3320
  def _check_value_llm_prompt_template(
3283
3321
  dataframe: pd.DataFrame, schema: Schema
3284
- ) -> List[err.InvalidStringLengthInColumn]:
3322
+ ) -> list[err.InvalidStringLengthInColumn]:
3285
3323
  if schema.prompt_template_column_names is None:
3286
3324
  return []
3287
3325
  col = schema.prompt_template_column_names.template_column_name
@@ -3298,7 +3336,7 @@ class Validator:
3298
3336
  max_length=MAX_PROMPT_TEMPLATE_LENGTH,
3299
3337
  )
3300
3338
  ]
3301
- elif max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
3339
+ if max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
3302
3340
  logger.warning(
3303
3341
  get_truncation_warning_message(
3304
3342
  "prompt templates",
@@ -3310,7 +3348,7 @@ class Validator:
3310
3348
  @staticmethod
3311
3349
  def _check_value_llm_prompt_template_version(
3312
3350
  dataframe: pd.DataFrame, schema: Schema
3313
- ) -> List[err.InvalidStringLengthInColumn]:
3351
+ ) -> list[err.InvalidStringLengthInColumn]:
3314
3352
  if schema.prompt_template_column_names is None:
3315
3353
  return []
3316
3354
  col = schema.prompt_template_column_names.template_version_column_name
@@ -3327,7 +3365,7 @@ class Validator:
3327
3365
  max_length=MAX_PROMPT_TEMPLATE_VERSION_LENGTH,
3328
3366
  )
3329
3367
  ]
3330
- elif max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
3368
+ if max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
3331
3369
  logger.warning(
3332
3370
  get_truncation_warning_message(
3333
3371
  "prompt template versions",
@@ -3338,8 +3376,8 @@ class Validator:
3338
3376
 
3339
3377
  @staticmethod
3340
3378
  def _check_type_document_columns(
3341
- schema: CorpusSchema, column_types: Dict[str, Any]
3342
- ) -> List[err.InvalidTypeColumns]:
3379
+ schema: CorpusSchema, column_types: dict[str, Any]
3380
+ ) -> list[err.InvalidTypeColumns]:
3343
3381
  invalid_types = []
3344
3382
  # Check document id
3345
3383
  col = schema.document_id_column_name
@@ -3424,16 +3462,15 @@ class Validator:
3424
3462
  return []
3425
3463
 
3426
3464
 
3427
- def _check_value_string_length_helper(x):
3465
+ def _check_value_string_length_helper(x: object) -> int:
3428
3466
  if isinstance(x, str):
3429
3467
  return len(x)
3430
- else:
3431
- return 0
3468
+ return 0
3432
3469
 
3433
3470
 
3434
3471
  def _check_value_vector_dimensionality_helper(
3435
- dataframe: pd.DataFrame, cols_to_check: List[str]
3436
- ) -> Tuple[List[str], List[str]]:
3472
+ dataframe: pd.DataFrame, cols_to_check: list[str]
3473
+ ) -> tuple[list[str], list[str]]:
3437
3474
  invalid_low_dimensionality_vector_cols = []
3438
3475
  invalid_high_dimensionality_vector_cols = []
3439
3476
  for col in cols_to_check:
@@ -3452,8 +3489,8 @@ def _check_value_vector_dimensionality_helper(
3452
3489
 
3453
3490
 
3454
3491
  def _check_value_raw_data_length_helper(
3455
- dataframe: pd.DataFrame, cols_to_check: List[str]
3456
- ) -> Tuple[List[str], List[str]]:
3492
+ dataframe: pd.DataFrame, cols_to_check: list[str]
3493
+ ) -> tuple[list[str], list[str]]:
3457
3494
  invalid_long_string_data_cols = []
3458
3495
  truncated_long_string_data_cols = []
3459
3496
  for col in cols_to_check:
@@ -3469,7 +3506,7 @@ def _check_value_raw_data_length_helper(
3469
3506
  )
3470
3507
  except TypeError as exc:
3471
3508
  e = TypeError(f"Cannot validate the column '{col}'. " + str(exc))
3472
- logger.error(e)
3509
+ logger.exception(e)
3473
3510
  raise e from exc
3474
3511
  if max_data_len > MAX_RAW_DATA_CHARACTERS:
3475
3512
  invalid_long_string_data_cols.append(col)
@@ -3480,8 +3517,8 @@ def _check_value_raw_data_length_helper(
3480
3517
 
3481
3518
  def _check_value_bounding_boxes_coordinates_helper(
3482
3519
  coordinates_col: pd.Series,
3483
- ) -> Union[err.InvalidBoundingBoxesCoordinates, None]:
3484
- def check(boxes):
3520
+ ) -> err.InvalidBoundingBoxesCoordinates | None:
3521
+ def check(boxes: object) -> None:
3485
3522
  # We allow for zero boxes. None coordinates list is not allowed (will break following tests:
3486
3523
  # 'NoneType is not iterable')
3487
3524
  if boxes is None:
@@ -3502,7 +3539,9 @@ def _check_value_bounding_boxes_coordinates_helper(
3502
3539
  return None
3503
3540
 
3504
3541
 
3505
- def _box_coordinates_wrong_format(box_coords):
3542
+ def _box_coordinates_wrong_format(
3543
+ box_coords: object,
3544
+ ) -> err.InvalidBoundingBoxesCoordinates | None:
3506
3545
  if (
3507
3546
  # Coordinates should be a collection of 4 floats
3508
3547
  len(box_coords) != 4
@@ -3516,12 +3555,13 @@ def _box_coordinates_wrong_format(box_coords):
3516
3555
  return err.InvalidBoundingBoxesCoordinates(
3517
3556
  reason="boxes_coordinates_wrong_format"
3518
3557
  )
3558
+ return None
3519
3559
 
3520
3560
 
3521
3561
  def _check_value_bounding_boxes_categories_helper(
3522
3562
  categories_col: pd.Series,
3523
- ) -> Union[err.InvalidBoundingBoxesCategories, None]:
3524
- def check(categories):
3563
+ ) -> err.InvalidBoundingBoxesCategories | None:
3564
+ def check(categories: object) -> None:
3525
3565
  # We allow for zero boxes. None category list is not allowed (will break following tests:
3526
3566
  # 'NoneType is not iterable')
3527
3567
  if categories is None:
@@ -3542,8 +3582,8 @@ def _check_value_bounding_boxes_categories_helper(
3542
3582
 
3543
3583
  def _check_value_bounding_boxes_scores_helper(
3544
3584
  scores_col: pd.Series,
3545
- ) -> Union[err.InvalidBoundingBoxesScores, None]:
3546
- def check(scores):
3585
+ ) -> err.InvalidBoundingBoxesScores | None:
3586
+ def check(scores: object) -> None:
3547
3587
  # We allow for zero boxes. None confidence score list is not allowed (will break following tests:
3548
3588
  # 'NoneType is not iterable')
3549
3589
  if scores is None:
@@ -3562,9 +3602,10 @@ def _check_value_bounding_boxes_scores_helper(
3562
3602
  return None
3563
3603
 
3564
3604
 
3565
- def _polygon_coordinates_wrong_format(polygon_coords):
3566
- """
3567
- Check if polygon coordinates are valid.
3605
+ def _polygon_coordinates_wrong_format(
3606
+ polygon_coords: object,
3607
+ ) -> err.InvalidPolygonCoordinates | None:
3608
+ """Check if polygon coordinates are valid.
3568
3609
 
3569
3610
  Validates:
3570
3611
  - Has at least 3 vertices (6 coordinates)
@@ -3610,9 +3651,9 @@ def _polygon_coordinates_wrong_format(polygon_coords):
3610
3651
 
3611
3652
  # Check for self-intersections
3612
3653
  # We need to check if any two non-adjacent edges intersect
3613
- edges = []
3614
- for i in range(len(points)):
3615
- edges.append((points[i], points[(i + 1) % len(points)]))
3654
+ edges = [
3655
+ (points[i], points[(i + 1) % len(points)]) for i in range(len(points))
3656
+ ]
3616
3657
 
3617
3658
  for i in range(len(edges)):
3618
3659
  for j in range(i + 2, len(edges)):
@@ -3634,8 +3675,8 @@ def _polygon_coordinates_wrong_format(polygon_coords):
3634
3675
 
3635
3676
  def _check_value_polygon_coordinates_helper(
3636
3677
  coordinates_col: pd.Series,
3637
- ) -> Union[err.InvalidPolygonCoordinates, None]:
3638
- def check(polygons):
3678
+ ) -> err.InvalidPolygonCoordinates | None:
3679
+ def check(polygons: object) -> None:
3639
3680
  # We allow for zero polygons. None coordinates list is not allowed (will break following tests:
3640
3681
  # 'NoneType is not iterable')
3641
3682
  if polygons is None:
@@ -3658,8 +3699,8 @@ def _check_value_polygon_coordinates_helper(
3658
3699
 
3659
3700
  def _check_value_polygon_categories_helper(
3660
3701
  categories_col: pd.Series,
3661
- ) -> Union[err.InvalidPolygonCategories, None]:
3662
- def check(categories):
3702
+ ) -> err.InvalidPolygonCategories | None:
3703
+ def check(categories: object) -> None:
3663
3704
  # We allow for zero boxes. None category list is not allowed (will break following tests:
3664
3705
  # 'NoneType is not iterable')
3665
3706
  if categories is None:
@@ -3678,8 +3719,8 @@ def _check_value_polygon_categories_helper(
3678
3719
 
3679
3720
  def _check_value_polygon_scores_helper(
3680
3721
  scores_col: pd.Series,
3681
- ) -> Union[err.InvalidPolygonScores, None]:
3682
- def check(scores):
3722
+ ) -> err.InvalidPolygonScores | None:
3723
+ def check(scores: object) -> None:
3683
3724
  # We allow for zero boxes. None confidence score list is not allowed (will break following tests:
3684
3725
  # 'NoneType is not iterable')
3685
3726
  if scores is None:
@@ -3696,7 +3737,7 @@ def _check_value_polygon_scores_helper(
3696
3737
  return None
3697
3738
 
3698
3739
 
3699
- def _count_characters_raw_data(data: Union[str, List[str]]) -> int:
3740
+ def _count_characters_raw_data(data: str | list[str]) -> int:
3700
3741
  character_count = 0
3701
3742
  if isinstance(data, str):
3702
3743
  character_count = len(data)