arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +208 -77
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +269 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +390 -286
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a21.dist-info/RECORD +0 -146
  166. arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
@@ -1,12 +1,15 @@
1
+ """Batch validation logic for ML model predictions and actuals."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- import datetime
4
5
  import logging
5
6
  import math
7
+ from datetime import datetime, timedelta, timezone
6
8
  from itertools import chain
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
8
10
 
9
11
  import numpy as np
12
+ import pandas as pd
10
13
  import pyarrow as pa
11
14
 
12
15
  from arize.constants.ml import (
@@ -55,23 +58,22 @@ from arize.types import (
55
58
  segments_intersect,
56
59
  )
57
60
 
58
- if TYPE_CHECKING:
59
- import pandas as pd
60
-
61
-
62
61
  logger = logging.getLogger(__name__)
63
62
 
64
63
 
65
64
  class Validator:
65
+ """Validator for batch data with schema and dataframe validation methods."""
66
+
66
67
  @staticmethod
67
68
  def validate_required_checks(
68
69
  dataframe: pd.DataFrame,
69
70
  model_id: str,
70
71
  environment: Environments,
71
72
  schema: BaseSchema,
72
- model_version: Optional[str] = None,
73
- batch_id: Optional[str] = None,
74
- ) -> List[err.ValidationError]:
73
+ model_version: str | None = None,
74
+ batch_id: str | None = None,
75
+ ) -> list[err.ValidationError]:
76
+ """Validate required checks for schema, environment, and DataFrame structure."""
75
77
  general_checks = chain(
76
78
  Validator._check_valid_schema_type(schema, environment),
77
79
  Validator._check_field_convertible_to_str(
@@ -87,7 +89,7 @@ class Validator:
87
89
  schema, CorpusSchema
88
90
  ):
89
91
  return list(general_checks)
90
- elif isinstance(schema, Schema):
92
+ if isinstance(schema, Schema):
91
93
  return list(
92
94
  chain(
93
95
  general_checks,
@@ -108,10 +110,11 @@ class Validator:
108
110
  model_type: ModelTypes,
109
111
  environment: Environments,
110
112
  schema: BaseSchema,
111
- metric_families: Optional[List[Metrics]] = None,
112
- model_version: Optional[str] = None,
113
- batch_id: Optional[str] = None,
114
- ) -> List[err.ValidationError]:
113
+ metric_families: list[Metrics] | None = None,
114
+ model_version: str | None = None,
115
+ batch_id: str | None = None,
116
+ ) -> list[err.ValidationError]:
117
+ """Validate parameters including model type, environment, and schema consistency."""
115
118
  # general checks
116
119
  general_checks = chain(
117
120
  Validator._check_column_names_for_empty_strings(schema),
@@ -125,7 +128,7 @@ class Validator:
125
128
  )
126
129
  if isinstance(schema, CorpusSchema):
127
130
  return list(general_checks)
128
- elif isinstance(schema, Schema):
131
+ if isinstance(schema, Schema):
129
132
  general_checks = chain(
130
133
  general_checks,
131
134
  Validator._check_existence_prediction_id_column_delayed_schema(
@@ -153,7 +156,7 @@ class Validator:
153
156
  ),
154
157
  )
155
158
  return list(chain(general_checks, num_checks))
156
- elif model_type in CATEGORICAL_MODEL_TYPES:
159
+ if model_type in CATEGORICAL_MODEL_TYPES:
157
160
  sc_checks = chain(
158
161
  Validator._check_existence_preprod_pred_act_score_or_label(
159
162
  schema, environment
@@ -166,7 +169,7 @@ class Validator:
166
169
  ),
167
170
  )
168
171
  return list(chain(general_checks, sc_checks))
169
- elif model_type == ModelTypes.GENERATIVE_LLM:
172
+ if model_type == ModelTypes.GENERATIVE_LLM:
170
173
  gllm_checks = chain(
171
174
  Validator._check_existence_preprod_act(schema, environment),
172
175
  Validator._check_missing_object_detection_columns(
@@ -177,7 +180,7 @@ class Validator:
177
180
  ),
178
181
  )
179
182
  return list(chain(general_checks, gllm_checks))
180
- elif model_type == ModelTypes.RANKING:
183
+ if model_type == ModelTypes.RANKING:
181
184
  r_checks = chain(
182
185
  Validator._check_existence_group_id_rank_category_relevance(
183
186
  schema
@@ -190,7 +193,7 @@ class Validator:
190
193
  ),
191
194
  )
192
195
  return list(chain(general_checks, r_checks))
193
- elif model_type == ModelTypes.OBJECT_DETECTION:
196
+ if model_type == ModelTypes.OBJECT_DETECTION:
194
197
  od_checks = chain(
195
198
  Validator._check_exactly_one_cv_column_type(
196
199
  schema, environment
@@ -203,7 +206,7 @@ class Validator:
203
206
  ),
204
207
  )
205
208
  return list(chain(general_checks, od_checks))
206
- elif model_type == ModelTypes.MULTI_CLASS:
209
+ if model_type == ModelTypes.MULTI_CLASS:
207
210
  multi_class_checks = chain(
208
211
  Validator._check_existing_multi_class_columns(schema),
209
212
  Validator._check_missing_non_multi_class_columns(
@@ -218,8 +221,11 @@ class Validator:
218
221
  model_type: ModelTypes,
219
222
  schema: BaseSchema,
220
223
  pyarrow_schema: pa.Schema,
221
- ) -> List[err.ValidationError]:
222
- column_types = dict(zip(pyarrow_schema.names, pyarrow_schema.types))
224
+ ) -> list[err.ValidationError]:
225
+ """Validate column data types against expected types for the schema."""
226
+ column_types = dict(
227
+ zip(pyarrow_schema.names, pyarrow_schema.types, strict=True)
228
+ )
223
229
 
224
230
  if isinstance(schema, CorpusSchema):
225
231
  return list(
@@ -227,7 +233,7 @@ class Validator:
227
233
  Validator._check_type_document_columns(schema, column_types)
228
234
  )
229
235
  )
230
- elif isinstance(schema, Schema):
236
+ if isinstance(schema, Schema):
231
237
  general_checks = chain(
232
238
  Validator._check_type_prediction_id(schema, column_types),
233
239
  Validator._check_type_timestamp(schema, column_types),
@@ -271,7 +277,7 @@ class Validator:
271
277
  ),
272
278
  )
273
279
  return list(chain(general_checks, gllm_checks))
274
- elif model_type == ModelTypes.RANKING:
280
+ if model_type == ModelTypes.RANKING:
275
281
  r_checks = chain(
276
282
  Validator._check_type_prediction_group_id(
277
283
  schema, column_types
@@ -285,7 +291,7 @@ class Validator:
285
291
  ),
286
292
  )
287
293
  return list(chain(general_checks, r_checks))
288
- elif model_type == ModelTypes.OBJECT_DETECTION:
294
+ if model_type == ModelTypes.OBJECT_DETECTION:
289
295
  od_checks = chain(
290
296
  Validator._check_type_image_segment_coordinates(
291
297
  schema, column_types
@@ -298,7 +304,7 @@ class Validator:
298
304
  ),
299
305
  )
300
306
  return list(chain(general_checks, od_checks))
301
- elif model_type == ModelTypes.MULTI_CLASS:
307
+ if model_type == ModelTypes.MULTI_CLASS:
302
308
  multi_class_checks = chain(
303
309
  Validator._check_type_multi_class_pred_threshold_act_scores(
304
310
  schema, column_types
@@ -315,7 +321,8 @@ class Validator:
315
321
  environment: Environments,
316
322
  schema: BaseSchema,
317
323
  model_type: ModelTypes,
318
- ) -> List[err.ValidationError]:
324
+ ) -> list[err.ValidationError]:
325
+ """Validate data values including ranges, formats, and consistency checks."""
319
326
  # ASSUMPTION: at this point the param and type checks should have passed.
320
327
  # This function may crash if that is not true, e.g. if columns are missing
321
328
  # or are of the wrong types.
@@ -338,7 +345,7 @@ class Validator:
338
345
  ),
339
346
  )
340
347
  )
341
- elif isinstance(schema, Schema):
348
+ if isinstance(schema, Schema):
342
349
  general_checks = chain(
343
350
  general_checks,
344
351
  Validator._check_value_timestamp(dataframe, schema),
@@ -429,21 +436,21 @@ class Validator:
429
436
  return list(general_checks)
430
437
  return []
431
438
 
432
- # ----------------------
433
- # Minimum requred checks
434
- # ----------------------
439
+ # -----------------------
440
+ # Minimum required checks
441
+ # -----------------------
435
442
  @staticmethod
436
443
  def _check_column_names_for_empty_strings(
437
444
  schema: BaseSchema,
438
- ) -> List[err.InvalidColumnNameEmptyString]:
445
+ ) -> list[err.InvalidColumnNameEmptyString]:
439
446
  if "" in schema.get_used_columns():
440
447
  return [err.InvalidColumnNameEmptyString()]
441
448
  return []
442
449
 
443
450
  @staticmethod
444
451
  def _check_field_convertible_to_str(
445
- model_id, model_version, batch_id
446
- ) -> List[err.InvalidFieldTypeConversion]:
452
+ model_id: object, model_version: object, batch_id: object
453
+ ) -> list[err.InvalidFieldTypeConversion]:
447
454
  # converting to a set first makes the checks run a lot faster
448
455
  wrong_fields = []
449
456
  if model_id is not None and not isinstance(model_id, str):
@@ -469,7 +476,7 @@ class Validator:
469
476
  @staticmethod
470
477
  def _check_field_type_embedding_features_column_names(
471
478
  schema: Schema,
472
- ) -> List[err.InvalidFieldTypeEmbeddingFeatures]:
479
+ ) -> list[err.InvalidFieldTypeEmbeddingFeatures]:
473
480
  if schema.embedding_feature_column_names is not None:
474
481
  if not isinstance(schema.embedding_feature_column_names, dict):
475
482
  return [err.InvalidFieldTypeEmbeddingFeatures()]
@@ -483,7 +490,7 @@ class Validator:
483
490
  @staticmethod
484
491
  def _check_field_type_prompt_response(
485
492
  schema: Schema,
486
- ) -> List[err.InvalidFieldTypePromptResponse]:
493
+ ) -> list[err.InvalidFieldTypePromptResponse]:
487
494
  errors = []
488
495
  if schema.prompt_column_names is not None and not isinstance(
489
496
  schema.prompt_column_names, (str, EmbeddingColumnNames)
@@ -502,7 +509,7 @@ class Validator:
502
509
  @staticmethod
503
510
  def _check_field_type_prompt_templates(
504
511
  schema: Schema,
505
- ) -> List[err.InvalidFieldTypePromptTemplates]:
512
+ ) -> list[err.InvalidFieldTypePromptTemplates]:
506
513
  if schema.prompt_template_column_names is not None and not isinstance(
507
514
  schema.prompt_template_column_names, PromptTemplateColumnNames
508
515
  ):
@@ -513,7 +520,7 @@ class Validator:
513
520
  def _check_field_type_llm_config(
514
521
  dataframe: pd.DataFrame,
515
522
  schema: Schema,
516
- ) -> List[Union[err.InvalidFieldTypeLlmConfig, err.InvalidTypeColumns]]:
523
+ ) -> list[err.InvalidFieldTypeLlmConfig | err.InvalidTypeColumns]:
517
524
  if schema.llm_config_column_names is None:
518
525
  return []
519
526
  if not isinstance(schema.llm_config_column_names, LLMConfigColumnNames):
@@ -548,7 +555,7 @@ class Validator:
548
555
  @staticmethod
549
556
  def _check_invalid_index(
550
557
  dataframe: pd.DataFrame,
551
- ) -> List[err.InvalidDataFrameIndex]:
558
+ ) -> list[err.InvalidDataFrameIndex]:
552
559
  if (dataframe.index != dataframe.reset_index(drop=True).index).any():
553
560
  return [err.InvalidDataFrameIndex()]
554
561
  return []
@@ -560,9 +567,9 @@ class Validator:
560
567
  @staticmethod
561
568
  def _check_model_type_and_metrics(
562
569
  model_type: ModelTypes,
563
- metric_families: Optional[List[Metrics]],
570
+ metric_families: list[Metrics] | None,
564
571
  schema: Schema,
565
- ) -> List[err.ValidationError]:
572
+ ) -> list[err.ValidationError]:
566
573
  if metric_families is None:
567
574
  return []
568
575
 
@@ -606,10 +613,10 @@ class Validator:
606
613
  @staticmethod
607
614
  def _check_model_mapping_combinations(
608
615
  model_type: ModelTypes,
609
- metric_families: List[Metrics],
616
+ metric_families: list[Metrics],
610
617
  schema: Schema,
611
- required_columns_map: List[Dict[str, Any]],
612
- ) -> Tuple[bool, List[str], List[List[str]]]:
618
+ required_columns_map: list[dict[str, Any]],
619
+ ) -> tuple[bool, list[str], list[list[str]]]:
613
620
  missing_columns = []
614
621
  for item in required_columns_map:
615
622
  if model_type.name.lower() == item.get("external_model_type"):
@@ -625,10 +632,10 @@ class Validator:
625
632
  metric_combinations.append(
626
633
  [metric.upper() for metric in metrics_list]
627
634
  )
628
- if set(metrics_list) == set(
635
+ if set(metrics_list) == {
629
636
  metric_family.name.lower()
630
637
  for metric_family in metric_families
631
- ):
638
+ }:
632
639
  # This is a valid combination of model type + metrics.
633
640
  # Now validate that required columns are in the schema.
634
641
  is_valid_combination = True
@@ -665,10 +672,10 @@ class Validator:
665
672
  @staticmethod
666
673
  def _check_existence_prediction_id_column_delayed_schema(
667
674
  schema: Schema, model_type: ModelTypes
668
- ) -> List[err.MissingPredictionIdColumnForDelayedRecords]:
675
+ ) -> list[err.MissingPredictionIdColumnForDelayedRecords]:
669
676
  if schema.prediction_id_column_name is not None:
670
677
  return []
671
- # TODO: Revise logic once predicion_label column addition (for generative models)
678
+ # TODO: Revise logic once prediction_label column addition (for generative models)
672
679
  # is moved to beginning of log function
673
680
  if schema.is_delayed() and model_type is not ModelTypes.GENERATIVE_LLM:
674
681
  # We skip GENERATIVE model types since they are assigned a default
@@ -696,12 +703,12 @@ class Validator:
696
703
  def _check_missing_columns(
697
704
  dataframe: pd.DataFrame,
698
705
  schema: BaseSchema,
699
- ) -> List[err.MissingColumns]:
706
+ ) -> list[err.MissingColumns]:
700
707
  if isinstance(schema, CorpusSchema):
701
708
  return Validator._check_missing_columns_corpus_schema(
702
709
  dataframe, schema
703
710
  )
704
- elif isinstance(schema, Schema):
711
+ if isinstance(schema, Schema):
705
712
  return Validator._check_missing_columns_schema(dataframe, schema)
706
713
  return []
707
714
 
@@ -709,7 +716,7 @@ class Validator:
709
716
  def _check_missing_columns_schema(
710
717
  dataframe: pd.DataFrame,
711
718
  schema: Schema,
712
- ) -> List[err.MissingColumns]:
719
+ ) -> list[err.MissingColumns]:
713
720
  # converting to a set first makes the checks run a lot faster
714
721
  existing_columns = set(dataframe.columns)
715
722
  missing_columns = []
@@ -721,9 +728,13 @@ class Validator:
721
728
  missing_columns.append(col)
722
729
 
723
730
  if schema.feature_column_names is not None:
724
- for col in schema.feature_column_names:
725
- if col not in existing_columns:
726
- missing_columns.append(col)
731
+ missing_columns.extend(
732
+ [
733
+ col
734
+ for col in schema.feature_column_names
735
+ if col not in existing_columns
736
+ ]
737
+ )
727
738
 
728
739
  if schema.embedding_feature_column_names is not None:
729
740
  for (
@@ -752,44 +763,76 @@ class Validator:
752
763
  )
753
764
 
754
765
  if schema.tag_column_names is not None:
755
- for col in schema.tag_column_names:
756
- if col not in existing_columns:
757
- missing_columns.append(col)
766
+ missing_columns.extend(
767
+ [
768
+ col
769
+ for col in schema.tag_column_names
770
+ if col not in existing_columns
771
+ ]
772
+ )
758
773
 
759
774
  if schema.shap_values_column_names is not None:
760
- for col in schema.shap_values_column_names.values():
761
- if col not in existing_columns:
762
- missing_columns.append(col)
775
+ missing_columns.extend(
776
+ [
777
+ col
778
+ for col in schema.shap_values_column_names.values()
779
+ if col not in existing_columns
780
+ ]
781
+ )
763
782
 
764
783
  if schema.object_detection_prediction_column_names is not None:
765
- for col in schema.object_detection_prediction_column_names:
766
- if col is not None and col not in existing_columns:
767
- missing_columns.append(col)
784
+ missing_columns.extend(
785
+ [
786
+ col
787
+ for col in schema.object_detection_prediction_column_names
788
+ if col is not None and col not in existing_columns
789
+ ]
790
+ )
768
791
 
769
792
  if schema.object_detection_actual_column_names is not None:
770
- for col in schema.object_detection_actual_column_names:
771
- if col is not None and col not in existing_columns:
772
- missing_columns.append(col)
793
+ missing_columns.extend(
794
+ [
795
+ col
796
+ for col in schema.object_detection_actual_column_names
797
+ if col is not None and col not in existing_columns
798
+ ]
799
+ )
773
800
 
774
801
  if schema.semantic_segmentation_prediction_column_names is not None:
775
- for col in schema.semantic_segmentation_prediction_column_names:
776
- if col is not None and col not in existing_columns:
777
- missing_columns.append(col)
802
+ missing_columns.extend(
803
+ [
804
+ col
805
+ for col in schema.semantic_segmentation_prediction_column_names
806
+ if col is not None and col not in existing_columns
807
+ ]
808
+ )
778
809
 
779
810
  if schema.semantic_segmentation_actual_column_names is not None:
780
- for col in schema.semantic_segmentation_actual_column_names:
781
- if col is not None and col not in existing_columns:
782
- missing_columns.append(col)
811
+ missing_columns.extend(
812
+ [
813
+ col
814
+ for col in schema.semantic_segmentation_actual_column_names
815
+ if col is not None and col not in existing_columns
816
+ ]
817
+ )
783
818
 
784
819
  if schema.instance_segmentation_prediction_column_names is not None:
785
- for col in schema.instance_segmentation_prediction_column_names:
786
- if col is not None and col not in existing_columns:
787
- missing_columns.append(col)
820
+ missing_columns.extend(
821
+ [
822
+ col
823
+ for col in schema.instance_segmentation_prediction_column_names
824
+ if col is not None and col not in existing_columns
825
+ ]
826
+ )
788
827
 
789
828
  if schema.instance_segmentation_actual_column_names is not None:
790
- for col in schema.instance_segmentation_actual_column_names:
791
- if col is not None and col not in existing_columns:
792
- missing_columns.append(col)
829
+ missing_columns.extend(
830
+ [
831
+ col
832
+ for col in schema.instance_segmentation_actual_column_names
833
+ if col is not None and col not in existing_columns
834
+ ]
835
+ )
793
836
 
794
837
  if schema.prompt_column_names is not None:
795
838
  if isinstance(schema.prompt_column_names, str):
@@ -838,14 +881,22 @@ class Validator:
838
881
  )
839
882
 
840
883
  if schema.prompt_template_column_names is not None:
841
- for col in schema.prompt_template_column_names:
842
- if col is not None and col not in existing_columns:
843
- missing_columns.append(col)
884
+ missing_columns.extend(
885
+ [
886
+ col
887
+ for col in schema.prompt_template_column_names
888
+ if col is not None and col not in existing_columns
889
+ ]
890
+ )
844
891
 
845
892
  if schema.llm_config_column_names is not None:
846
- for col in schema.llm_config_column_names:
847
- if col is not None and col not in existing_columns:
848
- missing_columns.append(col)
893
+ missing_columns.extend(
894
+ [
895
+ col
896
+ for col in schema.llm_config_column_names
897
+ if col is not None and col not in existing_columns
898
+ ]
899
+ )
849
900
 
850
901
  if missing_columns:
851
902
  return [err.MissingColumns(missing_columns)]
@@ -855,7 +906,7 @@ class Validator:
855
906
  def _check_missing_columns_corpus_schema(
856
907
  dataframe: pd.DataFrame,
857
908
  schema: CorpusSchema,
858
- ) -> List[err.MissingColumns]:
909
+ ) -> list[err.MissingColumns]:
859
910
  # converting to a set first makes the checks run a lot faster
860
911
  existing_columns = set(dataframe.columns)
861
912
  missing_columns = []
@@ -912,7 +963,7 @@ class Validator:
912
963
  def _check_valid_schema_type(
913
964
  schema: BaseSchema,
914
965
  environment: Environments,
915
- ) -> List[err.InvalidSchemaType]:
966
+ ) -> list[err.InvalidSchemaType]:
916
967
  if environment == Environments.CORPUS and not (
917
968
  isinstance(schema, CorpusSchema)
918
969
  ):
@@ -934,7 +985,7 @@ class Validator:
934
985
  @staticmethod
935
986
  def _check_invalid_shap_suffix(
936
987
  schema: Schema,
937
- ) -> List[err.InvalidShapSuffix]:
988
+ ) -> list[err.InvalidShapSuffix]:
938
989
  invalid_column_names = set()
939
990
 
940
991
  if schema.feature_column_names is not None:
@@ -970,10 +1021,10 @@ class Validator:
970
1021
  def _check_reserved_columns(
971
1022
  schema: BaseSchema,
972
1023
  model_type: ModelTypes,
973
- ) -> List[err.ReservedColumns]:
1024
+ ) -> list[err.ReservedColumns]:
974
1025
  if isinstance(schema, CorpusSchema):
975
1026
  return []
976
- elif isinstance(schema, Schema):
1027
+ if isinstance(schema, Schema):
977
1028
  reserved_columns = []
978
1029
  column_counts = schema.get_used_columns_counts()
979
1030
  if model_type == ModelTypes.GENERATIVE_LLM:
@@ -1079,8 +1130,8 @@ class Validator:
1079
1130
 
1080
1131
  @staticmethod
1081
1132
  def _check_invalid_model_id(
1082
- model_id: Optional[str],
1083
- ) -> List[err.InvalidModelId]:
1133
+ model_id: str | None,
1134
+ ) -> list[err.InvalidModelId]:
1084
1135
  # assume it's been coerced to string beforehand
1085
1136
  if (not isinstance(model_id, str)) or len(model_id.strip()) == 0:
1086
1137
  return [err.InvalidModelId()]
@@ -1088,8 +1139,8 @@ class Validator:
1088
1139
 
1089
1140
  @staticmethod
1090
1141
  def _check_invalid_model_version(
1091
- model_version: Optional[str] = None,
1092
- ) -> List[err.InvalidModelVersion]:
1142
+ model_version: str | None = None,
1143
+ ) -> list[err.InvalidModelVersion]:
1093
1144
  if model_version is None:
1094
1145
  return []
1095
1146
  if (
@@ -1102,9 +1153,9 @@ class Validator:
1102
1153
 
1103
1154
  @staticmethod
1104
1155
  def _check_invalid_batch_id(
1105
- batch_id: Optional[str],
1156
+ batch_id: str | None,
1106
1157
  environment: Environments,
1107
- ) -> List[err.InvalidBatchId]:
1158
+ ) -> list[err.InvalidBatchId]:
1108
1159
  # assume it's been coerced to string beforehand
1109
1160
  if environment in (Environments.VALIDATION,) and (
1110
1161
  (not isinstance(batch_id, str)) or len(batch_id.strip()) == 0
@@ -1115,7 +1166,7 @@ class Validator:
1115
1166
  @staticmethod
1116
1167
  def _check_invalid_model_type(
1117
1168
  model_type: ModelTypes,
1118
- ) -> List[err.InvalidModelType]:
1169
+ ) -> list[err.InvalidModelType]:
1119
1170
  if model_type in (mt for mt in ModelTypes):
1120
1171
  return []
1121
1172
  return [err.InvalidModelType()]
@@ -1123,7 +1174,7 @@ class Validator:
1123
1174
  @staticmethod
1124
1175
  def _check_invalid_environment(
1125
1176
  environment: Environments,
1126
- ) -> List[err.InvalidEnvironment]:
1177
+ ) -> list[err.InvalidEnvironment]:
1127
1178
  if environment in (env for env in Environments):
1128
1179
  return []
1129
1180
  return [err.InvalidEnvironment()]
@@ -1132,7 +1183,7 @@ class Validator:
1132
1183
  def _check_existence_preprod_pred_act_score_or_label(
1133
1184
  schema: Schema,
1134
1185
  environment: Environments,
1135
- ) -> List[err.MissingPreprodPredActNumericAndCategorical]:
1186
+ ) -> list[err.MissingPreprodPredActNumericAndCategorical]:
1136
1187
  if environment in (Environments.VALIDATION, Environments.TRAINING) and (
1137
1188
  (
1138
1189
  schema.prediction_label_column_name is None
@@ -1149,7 +1200,7 @@ class Validator:
1149
1200
  @staticmethod
1150
1201
  def _check_exactly_one_cv_column_type(
1151
1202
  schema: Schema, environment: Environments
1152
- ) -> List[Union[err.MultipleCVPredAct, err.MissingCVPredAct]]:
1203
+ ) -> list[err.MultipleCVPredAct | err.MissingCVPredAct]:
1153
1204
  # Checks that the required prediction/actual columns are given in the schema depending on
1154
1205
  # the environment, for object detection models. There should be exactly one of
1155
1206
  # object detection, semantic segmentation, or instance segmentation columns.
@@ -1180,7 +1231,7 @@ class Validator:
1180
1231
 
1181
1232
  if cv_types_count == 0:
1182
1233
  return [err.MissingCVPredAct(environment)]
1183
- elif cv_types_count > 1:
1234
+ if cv_types_count > 1:
1184
1235
  return [err.MultipleCVPredAct(environment)]
1185
1236
 
1186
1237
  elif environment in (
@@ -1213,7 +1264,7 @@ class Validator:
1213
1264
 
1214
1265
  if cv_types_count == 0:
1215
1266
  return [err.MissingCVPredAct(environment)]
1216
- elif cv_types_count > 1:
1267
+ if cv_types_count > 1:
1217
1268
  return [err.MultipleCVPredAct(environment)]
1218
1269
 
1219
1270
  return []
@@ -1221,9 +1272,9 @@ class Validator:
1221
1272
  @staticmethod
1222
1273
  def _check_missing_object_detection_columns(
1223
1274
  schema: Schema, model_type: ModelTypes
1224
- ) -> List[err.InvalidPredActCVColumnNamesForModelType]:
1275
+ ) -> list[err.InvalidPredActCVColumnNamesForModelType]:
1225
1276
  # Checks that models that are not Object Detection models don't have, in the schema, the
1226
- # object detection, semantic segmentation, or instance segmentation dedicated prediciton/actual
1277
+ # object detection, semantic segmentation, or instance segmentation dedicated prediction/actual
1227
1278
  # column names
1228
1279
  if (
1229
1280
  schema.object_detection_prediction_column_names is not None
@@ -1239,7 +1290,7 @@ class Validator:
1239
1290
  @staticmethod
1240
1291
  def _check_missing_non_object_detection_columns(
1241
1292
  schema: Schema, model_type: ModelTypes
1242
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1293
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1243
1294
  # Checks that object detection models don't have, in the schema, the columns reserved for
1244
1295
  # other model types
1245
1296
  columns_to_check = (
@@ -1253,10 +1304,7 @@ class Validator:
1253
1304
  schema.relevance_score_column_name,
1254
1305
  schema.relevance_labels_column_name,
1255
1306
  )
1256
- wrong_cols = []
1257
- for col in columns_to_check:
1258
- if col is not None:
1259
- wrong_cols.append(col)
1307
+ wrong_cols = [col for col in columns_to_check if col is not None]
1260
1308
  if wrong_cols:
1261
1309
  allowed_cols = [
1262
1310
  "object_detection_prediction_column_names",
@@ -1276,7 +1324,7 @@ class Validator:
1276
1324
  @staticmethod
1277
1325
  def _check_missing_multi_class_columns(
1278
1326
  schema: Schema, model_type: ModelTypes
1279
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1327
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1280
1328
  # Checks that models that are not Multi Class models don't have, in the schema, the
1281
1329
  # multi class dedicated threshold column
1282
1330
  if (
@@ -1295,7 +1343,7 @@ class Validator:
1295
1343
  @staticmethod
1296
1344
  def _check_existing_multi_class_columns(
1297
1345
  schema: Schema,
1298
- ) -> List[err.MissingReqPredActColumnNamesForMultiClass]:
1346
+ ) -> list[err.MissingReqPredActColumnNamesForMultiClass]:
1299
1347
  # Checks that models that are Multi Class models have, in the schema, the
1300
1348
  # required prediction score or actual score columns
1301
1349
  if (
@@ -1311,7 +1359,7 @@ class Validator:
1311
1359
  @staticmethod
1312
1360
  def _check_missing_non_multi_class_columns(
1313
1361
  schema: Schema, model_type: ModelTypes
1314
- ) -> List[err.InvalidPredActColumnNamesForModelType]:
1362
+ ) -> list[err.InvalidPredActColumnNamesForModelType]:
1315
1363
  # Checks that multi class models don't have, in the schema, the columns reserved for
1316
1364
  # other model types
1317
1365
  columns_to_check = (
@@ -1329,10 +1377,7 @@ class Validator:
1329
1377
  schema.instance_segmentation_prediction_column_names,
1330
1378
  schema.instance_segmentation_actual_column_names,
1331
1379
  )
1332
- wrong_cols = []
1333
- for col in columns_to_check:
1334
- if col is not None:
1335
- wrong_cols.append(col)
1380
+ wrong_cols = [col for col in columns_to_check if col is not None]
1336
1381
  if wrong_cols:
1337
1382
  allowed_cols = [
1338
1383
  "prediction_score_column_name",
@@ -1350,7 +1395,7 @@ class Validator:
1350
1395
  def _check_existence_preprod_act(
1351
1396
  schema: Schema,
1352
1397
  environment: Environments,
1353
- ) -> List[err.MissingPreprodAct]:
1398
+ ) -> list[err.MissingPreprodAct]:
1354
1399
  if environment in (Environments.VALIDATION, Environments.TRAINING) and (
1355
1400
  schema.actual_label_column_name is None
1356
1401
  ):
@@ -1360,7 +1405,7 @@ class Validator:
1360
1405
  @staticmethod
1361
1406
  def _check_existence_group_id_rank_category_relevance(
1362
1407
  schema: Schema,
1363
- ) -> List[err.MissingRequiredColumnsForRankingModel]:
1408
+ ) -> list[err.MissingRequiredColumnsForRankingModel]:
1364
1409
  # prediction_group_id and rank columns are required as ranking prediction columns.
1365
1410
  ranking_prediction_cols = (
1366
1411
  schema.prediction_label_column_name,
@@ -1384,7 +1429,7 @@ class Validator:
1384
1429
  @staticmethod
1385
1430
  def _check_dataframe_for_duplicate_columns(
1386
1431
  schema: BaseSchema, dataframe: pd.DataFrame
1387
- ) -> List[err.DuplicateColumnsInDataframe]:
1432
+ ) -> list[err.DuplicateColumnsInDataframe]:
1388
1433
  # Get the columns used in the schema
1389
1434
  schema_col_used = schema.get_used_columns()
1390
1435
  # Get the duplicated column names from the dataframe
@@ -1400,7 +1445,7 @@ class Validator:
1400
1445
  @staticmethod
1401
1446
  def _check_invalid_number_of_embeddings(
1402
1447
  schema: Schema,
1403
- ) -> List[err.InvalidNumberOfEmbeddings]:
1448
+ ) -> list[err.InvalidNumberOfEmbeddings]:
1404
1449
  if schema.embedding_feature_column_names is not None:
1405
1450
  number_of_embeddings = len(schema.embedding_feature_column_names)
1406
1451
  if number_of_embeddings > MAX_NUMBER_OF_EMBEDDINGS:
@@ -1413,8 +1458,8 @@ class Validator:
1413
1458
 
1414
1459
  @staticmethod
1415
1460
  def _check_type_prediction_id(
1416
- schema: Schema, column_types: Dict[str, Any]
1417
- ) -> List[err.InvalidType]:
1461
+ schema: Schema, column_types: dict[str, Any]
1462
+ ) -> list[err.InvalidType]:
1418
1463
  col = schema.prediction_id_column_name
1419
1464
  if col in column_types:
1420
1465
  # should mirror server side
@@ -1437,8 +1482,8 @@ class Validator:
1437
1482
 
1438
1483
  @staticmethod
1439
1484
  def _check_type_timestamp(
1440
- schema: Schema, column_types: Dict[str, Any]
1441
- ) -> List[err.InvalidType]:
1485
+ schema: Schema, column_types: dict[str, Any]
1486
+ ) -> list[err.InvalidType]:
1442
1487
  col = schema.timestamp_column_name
1443
1488
  if col in column_types:
1444
1489
  # should mirror server side
@@ -1464,8 +1509,8 @@ class Validator:
1464
1509
 
1465
1510
  @staticmethod
1466
1511
  def _check_type_features(
1467
- schema: Schema, column_types: Dict[str, Any]
1468
- ) -> List[err.InvalidTypeFeatures]:
1512
+ schema: Schema, column_types: dict[str, Any]
1513
+ ) -> list[err.InvalidTypeFeatures]:
1469
1514
  if schema.feature_column_names is not None:
1470
1515
  # should mirror server side
1471
1516
  allowed_datatypes = (
@@ -1480,13 +1525,12 @@ class Validator:
1480
1525
  pa.null(),
1481
1526
  pa.list_(pa.string()),
1482
1527
  )
1483
- wrong_type_cols = []
1484
- for col in schema.feature_column_names:
1485
- if (
1486
- col in column_types
1487
- and column_types[col] not in allowed_datatypes
1488
- ):
1489
- wrong_type_cols.append(col)
1528
+ wrong_type_cols = [
1529
+ col
1530
+ for col in schema.feature_column_names
1531
+ if col in column_types
1532
+ and column_types[col] not in allowed_datatypes
1533
+ ]
1490
1534
  if wrong_type_cols:
1491
1535
  return [
1492
1536
  err.InvalidTypeFeatures(
@@ -1504,8 +1548,8 @@ class Validator:
1504
1548
 
1505
1549
  @staticmethod
1506
1550
  def _check_type_embedding_features(
1507
- schema: Schema, column_types: Dict[str, Any]
1508
- ) -> List[err.InvalidTypeFeatures]:
1551
+ schema: Schema, column_types: dict[str, Any]
1552
+ ) -> list[err.InvalidTypeFeatures]:
1509
1553
  if schema.embedding_feature_column_names is not None:
1510
1554
  # should mirror server side
1511
1555
  allowed_vector_datatypes = (
@@ -1580,8 +1624,8 @@ class Validator:
1580
1624
 
1581
1625
  @staticmethod
1582
1626
  def _check_type_tags(
1583
- schema: Schema, column_types: Dict[str, Any]
1584
- ) -> List[err.InvalidTypeTags]:
1627
+ schema: Schema, column_types: dict[str, Any]
1628
+ ) -> list[err.InvalidTypeTags]:
1585
1629
  if schema.tag_column_names is not None:
1586
1630
  # should mirror server side
1587
1631
  allowed_datatypes = (
@@ -1595,13 +1639,12 @@ class Validator:
1595
1639
  pa.int8(),
1596
1640
  pa.null(),
1597
1641
  )
1598
- wrong_type_cols = []
1599
- for col in schema.tag_column_names:
1600
- if (
1601
- col in column_types
1602
- and column_types[col] not in allowed_datatypes
1603
- ):
1604
- wrong_type_cols.append(col)
1642
+ wrong_type_cols = [
1643
+ col
1644
+ for col in schema.tag_column_names
1645
+ if col in column_types
1646
+ and column_types[col] not in allowed_datatypes
1647
+ ]
1605
1648
  if wrong_type_cols:
1606
1649
  return [
1607
1650
  err.InvalidTypeTags(
@@ -1612,8 +1655,8 @@ class Validator:
1612
1655
 
1613
1656
  @staticmethod
1614
1657
  def _check_type_shap_values(
1615
- schema: Schema, column_types: Dict[str, Any]
1616
- ) -> List[err.InvalidTypeShapValues]:
1658
+ schema: Schema, column_types: dict[str, Any]
1659
+ ) -> list[err.InvalidTypeShapValues]:
1617
1660
  if schema.shap_values_column_names is not None:
1618
1661
  # should mirror server side
1619
1662
  allowed_datatypes = (
@@ -1622,13 +1665,12 @@ class Validator:
1622
1665
  pa.float32(),
1623
1666
  pa.int32(),
1624
1667
  )
1625
- wrong_type_cols = []
1626
- for _, col in schema.shap_values_column_names.items():
1627
- if (
1628
- col in column_types
1629
- and column_types[col] not in allowed_datatypes
1630
- ):
1631
- wrong_type_cols.append(col)
1668
+ wrong_type_cols = [
1669
+ col
1670
+ for col in schema.shap_values_column_names.values()
1671
+ if col in column_types
1672
+ and column_types[col] not in allowed_datatypes
1673
+ ]
1632
1674
  if wrong_type_cols:
1633
1675
  return [
1634
1676
  err.InvalidTypeShapValues(
@@ -1639,8 +1681,8 @@ class Validator:
1639
1681
 
1640
1682
  @staticmethod
1641
1683
  def _check_type_pred_act_labels(
1642
- model_type: ModelTypes, schema: Schema, column_types: Dict[str, Any]
1643
- ) -> List[err.InvalidType]:
1684
+ model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
1685
+ ) -> list[err.InvalidType]:
1644
1686
  errors = []
1645
1687
  columns = (
1646
1688
  ("Prediction labels", schema.prediction_label_column_name),
@@ -1703,8 +1745,8 @@ class Validator:
1703
1745
 
1704
1746
  @staticmethod
1705
1747
  def _check_type_pred_act_scores(
1706
- model_type: ModelTypes, schema: Schema, column_types: Dict[str, Any]
1707
- ) -> List[err.InvalidType]:
1748
+ model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
1749
+ ) -> list[err.InvalidType]:
1708
1750
  errors = []
1709
1751
  columns = (
1710
1752
  ("Prediction scores", schema.prediction_score_column_name),
@@ -1743,13 +1785,14 @@ class Validator:
1743
1785
 
1744
1786
  @staticmethod
1745
1787
  def _check_type_multi_class_pred_threshold_act_scores(
1746
- schema: Schema, column_types: Dict[str, Any]
1747
- ) -> List[err.InvalidType]:
1748
- """
1749
- Check type for prediction / threshold / actual scores for multiclass model
1750
- Expect the scores to be a list of pyarrow structs that contains field "class_name" and field "score
1751
- Where class_name is a string and score is a number
1752
- Example: '[{"class_name": "class1", "score": 0.1}, {"class_name": "class2", "score": 0.2}, ...]'
1788
+ schema: Schema, column_types: dict[str, Any]
1789
+ ) -> list[err.InvalidType]:
1790
+ """Check type for prediction / threshold / actual scores for multiclass model.
1791
+
1792
+ Expect the scores to be a list of pyarrow structs that contains field
1793
+ "class_name" and field "score", where class_name is a string and score
1794
+ is a number.
1795
+ Example: '[{"class_name": "class1", "score": 0.1}, ...]'
1753
1796
  """
1754
1797
  errors = []
1755
1798
  columns = (
@@ -1802,8 +1845,8 @@ class Validator:
1802
1845
 
1803
1846
  @staticmethod
1804
1847
  def _check_type_prompt_response(
1805
- schema: Schema, column_types: Dict[str, Any]
1806
- ) -> List[err.InvalidTypeColumns]:
1848
+ schema: Schema, column_types: dict[str, Any]
1849
+ ) -> list[err.InvalidTypeColumns]:
1807
1850
  fields_to_check = []
1808
1851
  if schema.prompt_column_names is not None:
1809
1852
  fields_to_check.append(schema.prompt_column_names)
@@ -1872,8 +1915,8 @@ class Validator:
1872
1915
 
1873
1916
  @staticmethod
1874
1917
  def _check_type_llm_prompt_templates(
1875
- schema: Schema, column_types: Dict[str, Any]
1876
- ) -> List[err.InvalidTypeColumns]:
1918
+ schema: Schema, column_types: dict[str, Any]
1919
+ ) -> list[err.InvalidTypeColumns]:
1877
1920
  if schema.prompt_template_column_names is None:
1878
1921
  return []
1879
1922
 
@@ -1913,8 +1956,8 @@ class Validator:
1913
1956
 
1914
1957
  @staticmethod
1915
1958
  def _check_type_llm_config(
1916
- schema: Schema, column_types: Dict[str, Any]
1917
- ) -> List[err.InvalidTypeColumns]:
1959
+ schema: Schema, column_types: dict[str, Any]
1960
+ ) -> list[err.InvalidTypeColumns]:
1918
1961
  if schema.llm_config_column_names is None:
1919
1962
  return []
1920
1963
 
@@ -1950,8 +1993,8 @@ class Validator:
1950
1993
 
1951
1994
  @staticmethod
1952
1995
  def _check_type_llm_run_metadata(
1953
- schema: Schema, column_types: Dict[str, Any]
1954
- ) -> List[err.InvalidTypeColumns]:
1996
+ schema: Schema, column_types: dict[str, Any]
1997
+ ) -> list[err.InvalidTypeColumns]:
1955
1998
  if schema.llm_run_metadata_column_names is None:
1956
1999
  return []
1957
2000
 
@@ -2023,8 +2066,8 @@ class Validator:
2023
2066
 
2024
2067
  @staticmethod
2025
2068
  def _check_type_retrieved_document_ids(
2026
- schema: Schema, column_types: Dict[str, Any]
2027
- ) -> List[err.InvalidType]:
2069
+ schema: Schema, column_types: dict[str, Any]
2070
+ ) -> list[err.InvalidType]:
2028
2071
  col = schema.retrieved_document_ids_column_name
2029
2072
  if col in column_types:
2030
2073
  # should mirror server side
@@ -2044,8 +2087,8 @@ class Validator:
2044
2087
 
2045
2088
  @staticmethod
2046
2089
  def _check_type_image_segment_coordinates(
2047
- schema: Schema, column_types: Dict[str, Any]
2048
- ) -> List[err.InvalidTypeColumns]:
2090
+ schema: Schema, column_types: dict[str, Any]
2091
+ ) -> list[err.InvalidTypeColumns]:
2049
2092
  # should mirror server side
2050
2093
  allowed_coordinate_types = (
2051
2094
  pa.list_(pa.list_(pa.float64())),
@@ -2090,9 +2133,8 @@ class Validator:
2090
2133
  wrong_type_cols.append(coord_col)
2091
2134
 
2092
2135
  if schema.instance_segmentation_prediction_column_names is not None:
2093
- polygons_coord_col = (
2094
- schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
2095
- )
2136
+ inst_seg_pred = schema.instance_segmentation_prediction_column_names
2137
+ polygons_coord_col = inst_seg_pred.polygon_coordinates_column_name
2096
2138
  if (
2097
2139
  polygons_coord_col in column_types
2098
2140
  and column_types[polygons_coord_col]
@@ -2101,7 +2143,7 @@ class Validator:
2101
2143
  wrong_type_cols.append(polygons_coord_col)
2102
2144
 
2103
2145
  bbox_coord_col = (
2104
- schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
2146
+ inst_seg_pred.bounding_boxes_coordinates_column_name
2105
2147
  )
2106
2148
  if (
2107
2149
  bbox_coord_col in column_types
@@ -2110,9 +2152,8 @@ class Validator:
2110
2152
  wrong_type_cols.append(bbox_coord_col)
2111
2153
 
2112
2154
  if schema.instance_segmentation_actual_column_names is not None:
2113
- coord_col = (
2114
- schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
2115
- )
2155
+ inst_seg_actual = schema.instance_segmentation_actual_column_names
2156
+ coord_col = inst_seg_actual.polygon_coordinates_column_name
2116
2157
  if (
2117
2158
  coord_col in column_types
2118
2159
  and column_types[coord_col] not in allowed_coordinate_types
@@ -2120,7 +2161,7 @@ class Validator:
2120
2161
  wrong_type_cols.append(coord_col)
2121
2162
 
2122
2163
  bbox_coord_col = (
2123
- schema.instance_segmentation_actual_column_names.bounding_boxes_coordinates_column_name # noqa: E501
2164
+ inst_seg_actual.bounding_boxes_coordinates_column_name
2124
2165
  )
2125
2166
  if (
2126
2167
  bbox_coord_col in column_types
@@ -2141,8 +2182,8 @@ class Validator:
2141
2182
 
2142
2183
  @staticmethod
2143
2184
  def _check_type_image_segment_categories(
2144
- schema: Schema, column_types: Dict[str, Any]
2145
- ) -> List[err.InvalidTypeColumns]:
2185
+ schema: Schema, column_types: dict[str, Any]
2186
+ ) -> list[err.InvalidTypeColumns]:
2146
2187
  # should mirror server side
2147
2188
  allowed_category_datatypes = (
2148
2189
  pa.list_(pa.string()),
@@ -2210,8 +2251,8 @@ class Validator:
2210
2251
 
2211
2252
  @staticmethod
2212
2253
  def _check_type_image_segment_scores(
2213
- schema: Schema, column_types: Dict[str, Any]
2214
- ) -> List[err.InvalidTypeColumns]:
2254
+ schema: Schema, column_types: dict[str, Any]
2255
+ ) -> list[err.InvalidTypeColumns]:
2215
2256
  # should mirror server side
2216
2257
  allowed_score_datatypes = (
2217
2258
  pa.list_(pa.float64()),
@@ -2270,7 +2311,7 @@ class Validator:
2270
2311
  @staticmethod
2271
2312
  def _check_embedding_vectors_dimensionality(
2272
2313
  dataframe: pd.DataFrame, schema: Schema
2273
- ) -> List[err.ValidationError]:
2314
+ ) -> list[err.ValidationError]:
2274
2315
  if schema.embedding_feature_column_names is None:
2275
2316
  return []
2276
2317
 
@@ -2300,7 +2341,7 @@ class Validator:
2300
2341
  @staticmethod
2301
2342
  def _check_embedding_raw_data_characters(
2302
2343
  dataframe: pd.DataFrame, schema: Schema
2303
- ) -> List[err.ValidationError]:
2344
+ ) -> list[err.ValidationError]:
2304
2345
  if schema.embedding_feature_column_names is None:
2305
2346
  return []
2306
2347
 
@@ -2322,7 +2363,7 @@ class Validator:
2322
2363
  invalid_long_string_data_cols
2323
2364
  )
2324
2365
  ]
2325
- elif truncated_long_string_data_cols:
2366
+ if truncated_long_string_data_cols:
2326
2367
  logger.warning(
2327
2368
  get_truncation_warning_message(
2328
2369
  "Embedding raw data fields",
@@ -2334,7 +2375,7 @@ class Validator:
2334
2375
  @staticmethod
2335
2376
  def _check_value_rank(
2336
2377
  dataframe: pd.DataFrame, schema: Schema
2337
- ) -> List[err.InvalidRankValue]:
2378
+ ) -> list[err.InvalidRankValue]:
2338
2379
  col = schema.rank_column_name
2339
2380
  lbound, ubound = (1, 100)
2340
2381
 
@@ -2346,11 +2387,11 @@ class Validator:
2346
2387
 
2347
2388
  @staticmethod
2348
2389
  def _check_id_field_str_length(
2349
- dataframe: pd.DataFrame, schema_name: str, id_col_name: Optional[str]
2350
- ) -> List[err.ValidationError]:
2351
- """
2352
- Require prediction_id to be a string of length between MIN_PREDICTION_ID_LEN
2353
- and MAX_PREDICTION_ID_LEN
2390
+ dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
2391
+ ) -> list[err.ValidationError]:
2392
+ """Require prediction_id to be a string of length between MIN and MAX.
2393
+
2394
+ Between MIN_PREDICTION_ID_LEN and MAX_PREDICTION_ID_LEN.
2354
2395
  """
2355
2396
  # We check whether the column name can be None is allowed in `Validator.validate_params`
2356
2397
  if id_col_name is None:
@@ -2380,11 +2421,11 @@ class Validator:
2380
2421
 
2381
2422
  @staticmethod
2382
2423
  def _check_document_id_field_str_length(
2383
- dataframe: pd.DataFrame, schema_name: str, id_col_name: Optional[str]
2384
- ) -> List[err.ValidationError]:
2385
- """
2386
- Require document id to be a string of length between MIN_DOCUMENT_ID_LEN
2387
- and MAX_DOCUMENT_ID_LEN
2424
+ dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
2425
+ ) -> list[err.ValidationError]:
2426
+ """Require document id to be a string of length between MIN and MAX.
2427
+
2428
+ Between MIN_DOCUMENT_ID_LEN and MAX_DOCUMENT_ID_LEN.
2388
2429
  """
2389
2430
  # We check whether the column name can be None is allowed in `Validator.validate_params`
2390
2431
  if id_col_name is None:
@@ -2433,7 +2474,7 @@ class Validator:
2433
2474
  @staticmethod
2434
2475
  def _check_value_tag(
2435
2476
  dataframe: pd.DataFrame, schema: Schema
2436
- ) -> List[err.InvalidTagLength]:
2477
+ ) -> list[err.InvalidTagLength]:
2437
2478
  if schema.tag_column_names is None:
2438
2479
  return []
2439
2480
 
@@ -2459,7 +2500,7 @@ class Validator:
2459
2500
  truncated_tag_cols.append(col)
2460
2501
  if wrong_tag_cols:
2461
2502
  return [err.InvalidTagLength(wrong_tag_cols)]
2462
- elif truncated_tag_cols:
2503
+ if truncated_tag_cols:
2463
2504
  logger.warning(
2464
2505
  get_truncation_warning_message(
2465
2506
  "tags", MAX_TAG_LENGTH_TRUNCATION
@@ -2470,9 +2511,7 @@ class Validator:
2470
2511
  @staticmethod
2471
2512
  def _check_value_ranking_category(
2472
2513
  dataframe: pd.DataFrame, schema: Schema
2473
- ) -> List[
2474
- Union[err.InvalidValueMissingValue, err.InvalidRankingCategoryValue]
2475
- ]:
2514
+ ) -> list[err.InvalidValueMissingValue | err.InvalidRankingCategoryValue]:
2476
2515
  if schema.relevance_labels_column_name is not None:
2477
2516
  col = schema.relevance_labels_column_name
2478
2517
  elif schema.attributions_column_name is not None:
@@ -2503,7 +2542,7 @@ class Validator:
2503
2542
  @staticmethod
2504
2543
  def _check_length_multi_class_maps(
2505
2544
  dataframe: pd.DataFrame, schema: Schema
2506
- ) -> List[err.InvalidNumClassesMultiClassMap]:
2545
+ ) -> list[err.InvalidNumClassesMultiClassMap]:
2507
2546
  # each entry in column is a list of dictionaries mapping class names and scores
2508
2547
  # validate length of list of dictionaries for each column
2509
2548
  invalid_cols = {}
@@ -2540,15 +2579,13 @@ class Validator:
2540
2579
  @staticmethod
2541
2580
  def _check_classes_and_scores_values_in_multi_class_maps(
2542
2581
  dataframe: pd.DataFrame, schema: Schema
2543
- ) -> List[
2544
- Union[
2545
- err.InvalidMultiClassClassNameLength,
2546
- err.InvalidMultiClassActScoreValue,
2547
- err.InvalidMultiClassPredScoreValue,
2548
- ]
2582
+ ) -> list[
2583
+ err.InvalidMultiClassClassNameLength
2584
+ | err.InvalidMultiClassActScoreValue
2585
+ | err.InvalidMultiClassPredScoreValue
2549
2586
  ]:
2550
- """
2551
- Validate the class names and score values of dictionaries:
2587
+ """Validate the class names and score values of dictionaries.
2588
+
2552
2589
  - class name length
2553
2590
  - valid actual score
2554
2591
  - valid prediction / threshold score
@@ -2624,11 +2661,12 @@ class Validator:
2624
2661
  @staticmethod
2625
2662
  def _check_each_multi_class_pred_has_threshold(
2626
2663
  dataframe: pd.DataFrame, schema: Schema
2627
- ) -> List[err.InvalidMultiClassThresholdClasses]:
2628
- """
2629
- For Multi Class, if threshold scores col is included in schema and dataframe,
2630
- validate for each prediction score received, the associated threshold score
2631
- for that class was received
2664
+ ) -> list[err.InvalidMultiClassThresholdClasses]:
2665
+ """Validate threshold scores for Multi Class models.
2666
+
2667
+ If threshold scores column is included in schema and dataframe, validate that
2668
+ for each prediction score received, the associated threshold score for that
2669
+ class was also received.
2632
2670
  """
2633
2671
  threshold_col = schema.multi_class_threshold_scores_column_name
2634
2672
  if threshold_col is None:
@@ -2657,10 +2695,10 @@ class Validator:
2657
2695
  def _check_value_timestamp(
2658
2696
  dataframe: pd.DataFrame,
2659
2697
  schema: Schema,
2660
- ) -> List[Union[err.InvalidValueMissingValue, err.InvalidValueTimestamp]]:
2698
+ ) -> list[err.InvalidValueMissingValue | err.InvalidValueTimestamp]:
2661
2699
  # Due to the timing difference between checking this here and the data finally
2662
2700
  # hitting the same check on server side, there's a some chance for a false
2663
- # result, i.e. the check here suceeeds but the same check on server side fails.
2701
+ # result, i.e. the check here succeeds but the same check on server side fails.
2664
2702
  col = schema.timestamp_column_name
2665
2703
  if col is not None and col in dataframe.columns:
2666
2704
  # When a timestamp column has Date and NaN, pyarrow will be fine, but
@@ -2673,19 +2711,15 @@ class Validator:
2673
2711
  )
2674
2712
  ]
2675
2713
 
2676
- now_t = datetime.datetime.now()
2714
+ now_t = datetime.now(tz=timezone.utc)
2677
2715
  lbound, ubound = (
2678
2716
  (
2679
2717
  now_t
2680
- - datetime.timedelta(
2681
- days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365
2682
- )
2718
+ - timedelta(days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365)
2683
2719
  ).timestamp(),
2684
2720
  (
2685
2721
  now_t
2686
- + datetime.timedelta(
2687
- days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365
2688
- )
2722
+ + timedelta(days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365)
2689
2723
  ).timestamp(),
2690
2724
  )
2691
2725
  # faster than pyarrow compute
@@ -2767,7 +2801,7 @@ class Validator:
2767
2801
  @staticmethod
2768
2802
  def _check_invalid_missing_values(
2769
2803
  dataframe: pd.DataFrame, schema: BaseSchema, model_type: ModelTypes
2770
- ) -> List[err.InvalidValueMissingValue]:
2804
+ ) -> list[err.InvalidValueMissingValue]:
2771
2805
  errors = []
2772
2806
  columns = ()
2773
2807
  if isinstance(schema, CorpusSchema):
@@ -2814,7 +2848,7 @@ class Validator:
2814
2848
  environment: Environments,
2815
2849
  schema: Schema,
2816
2850
  model_type: ModelTypes,
2817
- ) -> List[err.InvalidRecord]:
2851
+ ) -> list[err.InvalidRecord]:
2818
2852
  if environment in (Environments.VALIDATION, Environments.TRAINING):
2819
2853
  return []
2820
2854
 
@@ -2858,11 +2892,11 @@ class Validator:
2858
2892
  environment: Environments,
2859
2893
  schema: Schema,
2860
2894
  model_type: ModelTypes,
2861
- ) -> List[err.InvalidRecord]:
2862
- """
2863
- Validates there's not a single row in the dataframe with pred_label, pred_score all
2864
- evaluates to null OR with actual_label, actual_score all evaluates to null and returns
2865
- errors if either of the two cases exists
2895
+ ) -> list[err.InvalidRecord]:
2896
+ """Validates there's not a single row in the dataframe with all nulls.
2897
+
2898
+ Returns errors if any row has all of pred_label and pred_score evaluating to
2899
+ null, OR all of actual_label and actual_score evaluating to null.
2866
2900
  """
2867
2901
  if environment == Environments.PRODUCTION:
2868
2902
  return []
@@ -2905,21 +2939,23 @@ class Validator:
2905
2939
 
2906
2940
  @staticmethod
2907
2941
  def _check_invalid_record_helper(
2908
- dataframe: pd.DataFrame, column_names: List[Optional[str]]
2909
- ) -> List[err.InvalidRecord]:
2910
- """
2911
- This function checks that there are no null values in a subset of columns,
2912
- returning an error if so. The column subset is computed from the input list of
2913
- columns `column_names` that are not None and that are present in the dataframe
2942
+ dataframe: pd.DataFrame, column_names: list[str | None]
2943
+ ) -> list[err.InvalidRecord]:
2944
+ """Check that there are no null values in a subset of columns.
2945
+
2946
+ The column subset is computed from the input list of columns `column_names`
2947
+ that are not None and that are present in the dataframe. Returns an error if
2948
+ null values are found.
2914
2949
 
2915
2950
  Returns:
2916
2951
  List[err.InvalidRecord]: An error expressing the rows that are problematic
2917
2952
 
2918
2953
  """
2919
- columns_subset = []
2920
- for col in column_names:
2921
- if col is not None and col in dataframe.columns:
2922
- columns_subset.append(col)
2954
+ columns_subset = [
2955
+ col
2956
+ for col in column_names
2957
+ if col is not None and col in dataframe.columns
2958
+ ]
2923
2959
  if len(columns_subset) == 0:
2924
2960
  return []
2925
2961
  null_filter = dataframe[columns_subset].isnull().all(axis=1)
@@ -2930,8 +2966,8 @@ class Validator:
2930
2966
 
2931
2967
  @staticmethod
2932
2968
  def _check_type_prediction_group_id(
2933
- schema: Schema, column_types: Dict[str, Any]
2934
- ) -> List[err.InvalidType]:
2969
+ schema: Schema, column_types: dict[str, Any]
2970
+ ) -> list[err.InvalidType]:
2935
2971
  col = schema.prediction_group_id_column_name
2936
2972
  if col in column_types:
2937
2973
  # should mirror server side
@@ -2954,8 +2990,8 @@ class Validator:
2954
2990
 
2955
2991
  @staticmethod
2956
2992
  def _check_type_rank(
2957
- schema: Schema, column_types: Dict[str, Any]
2958
- ) -> List[err.InvalidType]:
2993
+ schema: Schema, column_types: dict[str, Any]
2994
+ ) -> list[err.InvalidType]:
2959
2995
  col = schema.rank_column_name
2960
2996
  if col in column_types:
2961
2997
  allowed_datatypes = (
@@ -2976,8 +3012,8 @@ class Validator:
2976
3012
 
2977
3013
  @staticmethod
2978
3014
  def _check_type_ranking_category(
2979
- schema: Schema, column_types: Dict[str, Any]
2980
- ) -> List[err.InvalidType]:
3015
+ schema: Schema, column_types: dict[str, Any]
3016
+ ) -> list[err.InvalidType]:
2981
3017
  if schema.relevance_labels_column_name is not None:
2982
3018
  col = schema.relevance_labels_column_name
2983
3019
  elif schema.attributions_column_name is not None:
@@ -2999,7 +3035,7 @@ class Validator:
2999
3035
  @staticmethod
3000
3036
  def _check_value_bounding_boxes_coordinates(
3001
3037
  dataframe: pd.DataFrame, schema: Schema
3002
- ) -> List[err.InvalidBoundingBoxesCoordinates]:
3038
+ ) -> list[err.InvalidBoundingBoxesCoordinates]:
3003
3039
  errors = []
3004
3040
  if schema.object_detection_prediction_column_names is not None:
3005
3041
  coords_col_name = schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
@@ -3020,7 +3056,7 @@ class Validator:
3020
3056
  @staticmethod
3021
3057
  def _check_value_bounding_boxes_categories(
3022
3058
  dataframe: pd.DataFrame, schema: Schema
3023
- ) -> List[err.InvalidBoundingBoxesCategories]:
3059
+ ) -> list[err.InvalidBoundingBoxesCategories]:
3024
3060
  errors = []
3025
3061
  if schema.object_detection_prediction_column_names is not None:
3026
3062
  cat_col_name = schema.object_detection_prediction_column_names.categories_column_name
@@ -3041,7 +3077,7 @@ class Validator:
3041
3077
  @staticmethod
3042
3078
  def _check_value_bounding_boxes_scores(
3043
3079
  dataframe: pd.DataFrame, schema: Schema
3044
- ) -> List[err.InvalidBoundingBoxesScores]:
3080
+ ) -> list[err.InvalidBoundingBoxesScores]:
3045
3081
  errors = []
3046
3082
  if schema.object_detection_prediction_column_names is not None:
3047
3083
  sc_col_name = schema.object_detection_prediction_column_names.scores_column_name
@@ -3066,7 +3102,7 @@ class Validator:
3066
3102
  @staticmethod
3067
3103
  def _check_value_semantic_segmentation_polygon_coordinates(
3068
3104
  dataframe: pd.DataFrame, schema: Schema
3069
- ) -> List[err.InvalidPolygonCoordinates]:
3105
+ ) -> list[err.InvalidPolygonCoordinates]:
3070
3106
  errors = []
3071
3107
  if schema.semantic_segmentation_prediction_column_names is not None:
3072
3108
  coords_col_name = schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
@@ -3076,7 +3112,7 @@ class Validator:
3076
3112
  if error is not None:
3077
3113
  errors.append(error)
3078
3114
  if schema.semantic_segmentation_actual_column_names is not None:
3079
- coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
3115
+ coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
3080
3116
  error = _check_value_polygon_coordinates_helper(
3081
3117
  dataframe[coords_col_name]
3082
3118
  )
@@ -3087,7 +3123,7 @@ class Validator:
3087
3123
  @staticmethod
3088
3124
  def _check_value_semantic_segmentation_polygon_categories(
3089
3125
  dataframe: pd.DataFrame, schema: Schema
3090
- ) -> List[err.InvalidPolygonCategories]:
3126
+ ) -> list[err.InvalidPolygonCategories]:
3091
3127
  errors = []
3092
3128
  if schema.semantic_segmentation_prediction_column_names is not None:
3093
3129
  cat_col_name = schema.semantic_segmentation_prediction_column_names.categories_column_name
@@ -3108,7 +3144,7 @@ class Validator:
3108
3144
  @staticmethod
3109
3145
  def _check_value_instance_segmentation_polygon_coordinates(
3110
3146
  dataframe: pd.DataFrame, schema: Schema
3111
- ) -> List[err.InvalidPolygonCoordinates]:
3147
+ ) -> list[err.InvalidPolygonCoordinates]:
3112
3148
  errors = []
3113
3149
  if schema.instance_segmentation_prediction_column_names is not None:
3114
3150
  coords_col_name = schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
@@ -3118,7 +3154,7 @@ class Validator:
3118
3154
  if error is not None:
3119
3155
  errors.append(error)
3120
3156
  if schema.instance_segmentation_actual_column_names is not None:
3121
- coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name # noqa: E501
3157
+ coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
3122
3158
  error = _check_value_polygon_coordinates_helper(
3123
3159
  dataframe[coords_col_name]
3124
3160
  )
@@ -3129,7 +3165,7 @@ class Validator:
3129
3165
  @staticmethod
3130
3166
  def _check_value_instance_segmentation_polygon_categories(
3131
3167
  dataframe: pd.DataFrame, schema: Schema
3132
- ) -> List[err.InvalidPolygonCategories]:
3168
+ ) -> list[err.InvalidPolygonCategories]:
3133
3169
  errors = []
3134
3170
  if schema.instance_segmentation_prediction_column_names is not None:
3135
3171
  cat_col_name = schema.instance_segmentation_prediction_column_names.categories_column_name
@@ -3150,7 +3186,7 @@ class Validator:
3150
3186
  @staticmethod
3151
3187
  def _check_value_instance_segmentation_polygon_scores(
3152
3188
  dataframe: pd.DataFrame, schema: Schema
3153
- ) -> List[err.InvalidPolygonScores]:
3189
+ ) -> list[err.InvalidPolygonScores]:
3154
3190
  errors = []
3155
3191
  if schema.instance_segmentation_prediction_column_names is not None:
3156
3192
  sc_col_name = schema.instance_segmentation_prediction_column_names.scores_column_name
@@ -3165,7 +3201,7 @@ class Validator:
3165
3201
  @staticmethod
3166
3202
  def _check_value_instance_segmentation_bbox_coordinates(
3167
3203
  dataframe: pd.DataFrame, schema: Schema
3168
- ) -> List[err.InvalidBoundingBoxesCoordinates]:
3204
+ ) -> list[err.InvalidBoundingBoxesCoordinates]:
3169
3205
  errors = []
3170
3206
  if schema.instance_segmentation_prediction_column_names is not None:
3171
3207
  coords_col_name = schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
@@ -3188,7 +3224,7 @@ class Validator:
3188
3224
  @staticmethod
3189
3225
  def _check_value_prompt_response(
3190
3226
  dataframe: pd.DataFrame, schema: Schema
3191
- ) -> List[err.ValidationError]:
3227
+ ) -> list[err.ValidationError]:
3192
3228
  vector_cols_to_check = []
3193
3229
  text_cols_to_check = []
3194
3230
  if isinstance(schema.prompt_column_names, str):
@@ -3253,7 +3289,7 @@ class Validator:
3253
3289
  @staticmethod
3254
3290
  def _check_value_llm_model_name(
3255
3291
  dataframe: pd.DataFrame, schema: Schema
3256
- ) -> List[err.InvalidStringLengthInColumn]:
3292
+ ) -> list[err.InvalidStringLengthInColumn]:
3257
3293
  if schema.llm_config_column_names is None:
3258
3294
  return []
3259
3295
  col = schema.llm_config_column_names.model_column_name
@@ -3270,7 +3306,7 @@ class Validator:
3270
3306
  max_length=MAX_LLM_MODEL_NAME_LENGTH,
3271
3307
  )
3272
3308
  ]
3273
- elif max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
3309
+ if max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
3274
3310
  logger.warning(
3275
3311
  get_truncation_warning_message(
3276
3312
  "LLM model names", MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION
@@ -3281,7 +3317,7 @@ class Validator:
3281
3317
  @staticmethod
3282
3318
  def _check_value_llm_prompt_template(
3283
3319
  dataframe: pd.DataFrame, schema: Schema
3284
- ) -> List[err.InvalidStringLengthInColumn]:
3320
+ ) -> list[err.InvalidStringLengthInColumn]:
3285
3321
  if schema.prompt_template_column_names is None:
3286
3322
  return []
3287
3323
  col = schema.prompt_template_column_names.template_column_name
@@ -3298,7 +3334,7 @@ class Validator:
3298
3334
  max_length=MAX_PROMPT_TEMPLATE_LENGTH,
3299
3335
  )
3300
3336
  ]
3301
- elif max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
3337
+ if max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
3302
3338
  logger.warning(
3303
3339
  get_truncation_warning_message(
3304
3340
  "prompt templates",
@@ -3310,7 +3346,7 @@ class Validator:
3310
3346
  @staticmethod
3311
3347
  def _check_value_llm_prompt_template_version(
3312
3348
  dataframe: pd.DataFrame, schema: Schema
3313
- ) -> List[err.InvalidStringLengthInColumn]:
3349
+ ) -> list[err.InvalidStringLengthInColumn]:
3314
3350
  if schema.prompt_template_column_names is None:
3315
3351
  return []
3316
3352
  col = schema.prompt_template_column_names.template_version_column_name
@@ -3327,7 +3363,7 @@ class Validator:
3327
3363
  max_length=MAX_PROMPT_TEMPLATE_VERSION_LENGTH,
3328
3364
  )
3329
3365
  ]
3330
- elif max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
3366
+ if max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
3331
3367
  logger.warning(
3332
3368
  get_truncation_warning_message(
3333
3369
  "prompt template versions",
@@ -3338,8 +3374,8 @@ class Validator:
3338
3374
 
3339
3375
  @staticmethod
3340
3376
  def _check_type_document_columns(
3341
- schema: CorpusSchema, column_types: Dict[str, Any]
3342
- ) -> List[err.InvalidTypeColumns]:
3377
+ schema: CorpusSchema, column_types: dict[str, Any]
3378
+ ) -> list[err.InvalidTypeColumns]:
3343
3379
  invalid_types = []
3344
3380
  # Check document id
3345
3381
  col = schema.document_id_column_name
@@ -3424,16 +3460,15 @@ class Validator:
3424
3460
  return []
3425
3461
 
3426
3462
 
3427
- def _check_value_string_length_helper(x):
3463
+ def _check_value_string_length_helper(x: object) -> int:
3428
3464
  if isinstance(x, str):
3429
3465
  return len(x)
3430
- else:
3431
- return 0
3466
+ return 0
3432
3467
 
3433
3468
 
3434
3469
  def _check_value_vector_dimensionality_helper(
3435
- dataframe: pd.DataFrame, cols_to_check: List[str]
3436
- ) -> Tuple[List[str], List[str]]:
3470
+ dataframe: pd.DataFrame, cols_to_check: list[str]
3471
+ ) -> tuple[list[str], list[str]]:
3437
3472
  invalid_low_dimensionality_vector_cols = []
3438
3473
  invalid_high_dimensionality_vector_cols = []
3439
3474
  for col in cols_to_check:
@@ -3452,8 +3487,8 @@ def _check_value_vector_dimensionality_helper(
3452
3487
 
3453
3488
 
3454
3489
  def _check_value_raw_data_length_helper(
3455
- dataframe: pd.DataFrame, cols_to_check: List[str]
3456
- ) -> Tuple[List[str], List[str]]:
3490
+ dataframe: pd.DataFrame, cols_to_check: list[str]
3491
+ ) -> tuple[list[str], list[str]]:
3457
3492
  invalid_long_string_data_cols = []
3458
3493
  truncated_long_string_data_cols = []
3459
3494
  for col in cols_to_check:
@@ -3469,7 +3504,7 @@ def _check_value_raw_data_length_helper(
3469
3504
  )
3470
3505
  except TypeError as exc:
3471
3506
  e = TypeError(f"Cannot validate the column '{col}'. " + str(exc))
3472
- logger.error(e)
3507
+ logger.exception(e)
3473
3508
  raise e from exc
3474
3509
  if max_data_len > MAX_RAW_DATA_CHARACTERS:
3475
3510
  invalid_long_string_data_cols.append(col)
@@ -3480,8 +3515,8 @@ def _check_value_raw_data_length_helper(
3480
3515
 
3481
3516
  def _check_value_bounding_boxes_coordinates_helper(
3482
3517
  coordinates_col: pd.Series,
3483
- ) -> Union[err.InvalidBoundingBoxesCoordinates, None]:
3484
- def check(boxes):
3518
+ ) -> err.InvalidBoundingBoxesCoordinates | None:
3519
+ def check(boxes: object) -> None:
3485
3520
  # We allow for zero boxes. None coordinates list is not allowed (will break following tests:
3486
3521
  # 'NoneType is not iterable')
3487
3522
  if boxes is None:
@@ -3502,7 +3537,9 @@ def _check_value_bounding_boxes_coordinates_helper(
3502
3537
  return None
3503
3538
 
3504
3539
 
3505
- def _box_coordinates_wrong_format(box_coords):
3540
+ def _box_coordinates_wrong_format(
3541
+ box_coords: object,
3542
+ ) -> err.InvalidBoundingBoxesCoordinates | None:
3506
3543
  if (
3507
3544
  # Coordinates should be a collection of 4 floats
3508
3545
  len(box_coords) != 4
@@ -3516,12 +3553,13 @@ def _box_coordinates_wrong_format(box_coords):
3516
3553
  return err.InvalidBoundingBoxesCoordinates(
3517
3554
  reason="boxes_coordinates_wrong_format"
3518
3555
  )
3556
+ return None
3519
3557
 
3520
3558
 
3521
3559
  def _check_value_bounding_boxes_categories_helper(
3522
3560
  categories_col: pd.Series,
3523
- ) -> Union[err.InvalidBoundingBoxesCategories, None]:
3524
- def check(categories):
3561
+ ) -> err.InvalidBoundingBoxesCategories | None:
3562
+ def check(categories: object) -> None:
3525
3563
  # We allow for zero boxes. None category list is not allowed (will break following tests:
3526
3564
  # 'NoneType is not iterable')
3527
3565
  if categories is None:
@@ -3542,8 +3580,8 @@ def _check_value_bounding_boxes_categories_helper(
3542
3580
 
3543
3581
  def _check_value_bounding_boxes_scores_helper(
3544
3582
  scores_col: pd.Series,
3545
- ) -> Union[err.InvalidBoundingBoxesScores, None]:
3546
- def check(scores):
3583
+ ) -> err.InvalidBoundingBoxesScores | None:
3584
+ def check(scores: object) -> None:
3547
3585
  # We allow for zero boxes. None confidence score list is not allowed (will break following tests:
3548
3586
  # 'NoneType is not iterable')
3549
3587
  if scores is None:
@@ -3562,9 +3600,10 @@ def _check_value_bounding_boxes_scores_helper(
3562
3600
  return None
3563
3601
 
3564
3602
 
3565
- def _polygon_coordinates_wrong_format(polygon_coords):
3566
- """
3567
- Check if polygon coordinates are valid.
3603
+ def _polygon_coordinates_wrong_format(
3604
+ polygon_coords: object,
3605
+ ) -> err.InvalidPolygonCoordinates | None:
3606
+ """Check if polygon coordinates are valid.
3568
3607
 
3569
3608
  Validates:
3570
3609
  - Has at least 3 vertices (6 coordinates)
@@ -3610,9 +3649,9 @@ def _polygon_coordinates_wrong_format(polygon_coords):
3610
3649
 
3611
3650
  # Check for self-intersections
3612
3651
  # We need to check if any two non-adjacent edges intersect
3613
- edges = []
3614
- for i in range(len(points)):
3615
- edges.append((points[i], points[(i + 1) % len(points)]))
3652
+ edges = [
3653
+ (points[i], points[(i + 1) % len(points)]) for i in range(len(points))
3654
+ ]
3616
3655
 
3617
3656
  for i in range(len(edges)):
3618
3657
  for j in range(i + 2, len(edges)):
@@ -3634,8 +3673,8 @@ def _polygon_coordinates_wrong_format(polygon_coords):
3634
3673
 
3635
3674
  def _check_value_polygon_coordinates_helper(
3636
3675
  coordinates_col: pd.Series,
3637
- ) -> Union[err.InvalidPolygonCoordinates, None]:
3638
- def check(polygons):
3676
+ ) -> err.InvalidPolygonCoordinates | None:
3677
+ def check(polygons: object) -> None:
3639
3678
  # We allow for zero polygons. None coordinates list is not allowed (will break following tests:
3640
3679
  # 'NoneType is not iterable')
3641
3680
  if polygons is None:
@@ -3658,8 +3697,8 @@ def _check_value_polygon_coordinates_helper(
3658
3697
 
3659
3698
  def _check_value_polygon_categories_helper(
3660
3699
  categories_col: pd.Series,
3661
- ) -> Union[err.InvalidPolygonCategories, None]:
3662
- def check(categories):
3700
+ ) -> err.InvalidPolygonCategories | None:
3701
+ def check(categories: object) -> None:
3663
3702
  # We allow for zero boxes. None category list is not allowed (will break following tests:
3664
3703
  # 'NoneType is not iterable')
3665
3704
  if categories is None:
@@ -3678,8 +3717,8 @@ def _check_value_polygon_categories_helper(
3678
3717
 
3679
3718
  def _check_value_polygon_scores_helper(
3680
3719
  scores_col: pd.Series,
3681
- ) -> Union[err.InvalidPolygonScores, None]:
3682
- def check(scores):
3720
+ ) -> err.InvalidPolygonScores | None:
3721
+ def check(scores: object) -> None:
3683
3722
  # We allow for zero boxes. None confidence score list is not allowed (will break following tests:
3684
3723
  # 'NoneType is not iterable')
3685
3724
  if scores is None:
@@ -3696,7 +3735,7 @@ def _check_value_polygon_scores_helper(
3696
3735
  return None
3697
3736
 
3698
3737
 
3699
- def _count_characters_raw_data(data: Union[str, List[str]]) -> int:
3738
+ def _count_characters_raw_data(data: str | list[str]) -> int:
3700
3739
  character_count = 0
3701
3740
  if isinstance(data, str):
3702
3741
  character_count = len(data)