arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +208 -77
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +269 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +390 -286
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a21.dist-info/RECORD +0 -146
- arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
|
@@ -1,12 +1,15 @@
|
|
|
1
|
+
"""Batch validation logic for ML model predictions and actuals."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
|
-
import datetime
|
|
4
5
|
import logging
|
|
5
6
|
import math
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
6
8
|
from itertools import chain
|
|
7
|
-
from typing import
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
10
13
|
import pyarrow as pa
|
|
11
14
|
|
|
12
15
|
from arize.constants.ml import (
|
|
@@ -55,23 +58,22 @@ from arize.types import (
|
|
|
55
58
|
segments_intersect,
|
|
56
59
|
)
|
|
57
60
|
|
|
58
|
-
if TYPE_CHECKING:
|
|
59
|
-
import pandas as pd
|
|
60
|
-
|
|
61
|
-
|
|
62
61
|
logger = logging.getLogger(__name__)
|
|
63
62
|
|
|
64
63
|
|
|
65
64
|
class Validator:
|
|
65
|
+
"""Validator for batch data with schema and dataframe validation methods."""
|
|
66
|
+
|
|
66
67
|
@staticmethod
|
|
67
68
|
def validate_required_checks(
|
|
68
69
|
dataframe: pd.DataFrame,
|
|
69
70
|
model_id: str,
|
|
70
71
|
environment: Environments,
|
|
71
72
|
schema: BaseSchema,
|
|
72
|
-
model_version:
|
|
73
|
-
batch_id:
|
|
74
|
-
) ->
|
|
73
|
+
model_version: str | None = None,
|
|
74
|
+
batch_id: str | None = None,
|
|
75
|
+
) -> list[err.ValidationError]:
|
|
76
|
+
"""Validate required checks for schema, environment, and DataFrame structure."""
|
|
75
77
|
general_checks = chain(
|
|
76
78
|
Validator._check_valid_schema_type(schema, environment),
|
|
77
79
|
Validator._check_field_convertible_to_str(
|
|
@@ -87,7 +89,7 @@ class Validator:
|
|
|
87
89
|
schema, CorpusSchema
|
|
88
90
|
):
|
|
89
91
|
return list(general_checks)
|
|
90
|
-
|
|
92
|
+
if isinstance(schema, Schema):
|
|
91
93
|
return list(
|
|
92
94
|
chain(
|
|
93
95
|
general_checks,
|
|
@@ -108,10 +110,11 @@ class Validator:
|
|
|
108
110
|
model_type: ModelTypes,
|
|
109
111
|
environment: Environments,
|
|
110
112
|
schema: BaseSchema,
|
|
111
|
-
metric_families:
|
|
112
|
-
model_version:
|
|
113
|
-
batch_id:
|
|
114
|
-
) ->
|
|
113
|
+
metric_families: list[Metrics] | None = None,
|
|
114
|
+
model_version: str | None = None,
|
|
115
|
+
batch_id: str | None = None,
|
|
116
|
+
) -> list[err.ValidationError]:
|
|
117
|
+
"""Validate parameters including model type, environment, and schema consistency."""
|
|
115
118
|
# general checks
|
|
116
119
|
general_checks = chain(
|
|
117
120
|
Validator._check_column_names_for_empty_strings(schema),
|
|
@@ -125,7 +128,7 @@ class Validator:
|
|
|
125
128
|
)
|
|
126
129
|
if isinstance(schema, CorpusSchema):
|
|
127
130
|
return list(general_checks)
|
|
128
|
-
|
|
131
|
+
if isinstance(schema, Schema):
|
|
129
132
|
general_checks = chain(
|
|
130
133
|
general_checks,
|
|
131
134
|
Validator._check_existence_prediction_id_column_delayed_schema(
|
|
@@ -153,7 +156,7 @@ class Validator:
|
|
|
153
156
|
),
|
|
154
157
|
)
|
|
155
158
|
return list(chain(general_checks, num_checks))
|
|
156
|
-
|
|
159
|
+
if model_type in CATEGORICAL_MODEL_TYPES:
|
|
157
160
|
sc_checks = chain(
|
|
158
161
|
Validator._check_existence_preprod_pred_act_score_or_label(
|
|
159
162
|
schema, environment
|
|
@@ -166,7 +169,7 @@ class Validator:
|
|
|
166
169
|
),
|
|
167
170
|
)
|
|
168
171
|
return list(chain(general_checks, sc_checks))
|
|
169
|
-
|
|
172
|
+
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
170
173
|
gllm_checks = chain(
|
|
171
174
|
Validator._check_existence_preprod_act(schema, environment),
|
|
172
175
|
Validator._check_missing_object_detection_columns(
|
|
@@ -177,7 +180,7 @@ class Validator:
|
|
|
177
180
|
),
|
|
178
181
|
)
|
|
179
182
|
return list(chain(general_checks, gllm_checks))
|
|
180
|
-
|
|
183
|
+
if model_type == ModelTypes.RANKING:
|
|
181
184
|
r_checks = chain(
|
|
182
185
|
Validator._check_existence_group_id_rank_category_relevance(
|
|
183
186
|
schema
|
|
@@ -190,7 +193,7 @@ class Validator:
|
|
|
190
193
|
),
|
|
191
194
|
)
|
|
192
195
|
return list(chain(general_checks, r_checks))
|
|
193
|
-
|
|
196
|
+
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
194
197
|
od_checks = chain(
|
|
195
198
|
Validator._check_exactly_one_cv_column_type(
|
|
196
199
|
schema, environment
|
|
@@ -203,7 +206,7 @@ class Validator:
|
|
|
203
206
|
),
|
|
204
207
|
)
|
|
205
208
|
return list(chain(general_checks, od_checks))
|
|
206
|
-
|
|
209
|
+
if model_type == ModelTypes.MULTI_CLASS:
|
|
207
210
|
multi_class_checks = chain(
|
|
208
211
|
Validator._check_existing_multi_class_columns(schema),
|
|
209
212
|
Validator._check_missing_non_multi_class_columns(
|
|
@@ -218,8 +221,11 @@ class Validator:
|
|
|
218
221
|
model_type: ModelTypes,
|
|
219
222
|
schema: BaseSchema,
|
|
220
223
|
pyarrow_schema: pa.Schema,
|
|
221
|
-
) ->
|
|
222
|
-
|
|
224
|
+
) -> list[err.ValidationError]:
|
|
225
|
+
"""Validate column data types against expected types for the schema."""
|
|
226
|
+
column_types = dict(
|
|
227
|
+
zip(pyarrow_schema.names, pyarrow_schema.types, strict=True)
|
|
228
|
+
)
|
|
223
229
|
|
|
224
230
|
if isinstance(schema, CorpusSchema):
|
|
225
231
|
return list(
|
|
@@ -227,7 +233,7 @@ class Validator:
|
|
|
227
233
|
Validator._check_type_document_columns(schema, column_types)
|
|
228
234
|
)
|
|
229
235
|
)
|
|
230
|
-
|
|
236
|
+
if isinstance(schema, Schema):
|
|
231
237
|
general_checks = chain(
|
|
232
238
|
Validator._check_type_prediction_id(schema, column_types),
|
|
233
239
|
Validator._check_type_timestamp(schema, column_types),
|
|
@@ -271,7 +277,7 @@ class Validator:
|
|
|
271
277
|
),
|
|
272
278
|
)
|
|
273
279
|
return list(chain(general_checks, gllm_checks))
|
|
274
|
-
|
|
280
|
+
if model_type == ModelTypes.RANKING:
|
|
275
281
|
r_checks = chain(
|
|
276
282
|
Validator._check_type_prediction_group_id(
|
|
277
283
|
schema, column_types
|
|
@@ -285,7 +291,7 @@ class Validator:
|
|
|
285
291
|
),
|
|
286
292
|
)
|
|
287
293
|
return list(chain(general_checks, r_checks))
|
|
288
|
-
|
|
294
|
+
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
289
295
|
od_checks = chain(
|
|
290
296
|
Validator._check_type_image_segment_coordinates(
|
|
291
297
|
schema, column_types
|
|
@@ -298,7 +304,7 @@ class Validator:
|
|
|
298
304
|
),
|
|
299
305
|
)
|
|
300
306
|
return list(chain(general_checks, od_checks))
|
|
301
|
-
|
|
307
|
+
if model_type == ModelTypes.MULTI_CLASS:
|
|
302
308
|
multi_class_checks = chain(
|
|
303
309
|
Validator._check_type_multi_class_pred_threshold_act_scores(
|
|
304
310
|
schema, column_types
|
|
@@ -315,7 +321,8 @@ class Validator:
|
|
|
315
321
|
environment: Environments,
|
|
316
322
|
schema: BaseSchema,
|
|
317
323
|
model_type: ModelTypes,
|
|
318
|
-
) ->
|
|
324
|
+
) -> list[err.ValidationError]:
|
|
325
|
+
"""Validate data values including ranges, formats, and consistency checks."""
|
|
319
326
|
# ASSUMPTION: at this point the param and type checks should have passed.
|
|
320
327
|
# This function may crash if that is not true, e.g. if columns are missing
|
|
321
328
|
# or are of the wrong types.
|
|
@@ -338,7 +345,7 @@ class Validator:
|
|
|
338
345
|
),
|
|
339
346
|
)
|
|
340
347
|
)
|
|
341
|
-
|
|
348
|
+
if isinstance(schema, Schema):
|
|
342
349
|
general_checks = chain(
|
|
343
350
|
general_checks,
|
|
344
351
|
Validator._check_value_timestamp(dataframe, schema),
|
|
@@ -429,21 +436,21 @@ class Validator:
|
|
|
429
436
|
return list(general_checks)
|
|
430
437
|
return []
|
|
431
438
|
|
|
432
|
-
#
|
|
433
|
-
# Minimum
|
|
434
|
-
#
|
|
439
|
+
# -----------------------
|
|
440
|
+
# Minimum required checks
|
|
441
|
+
# -----------------------
|
|
435
442
|
@staticmethod
|
|
436
443
|
def _check_column_names_for_empty_strings(
|
|
437
444
|
schema: BaseSchema,
|
|
438
|
-
) ->
|
|
445
|
+
) -> list[err.InvalidColumnNameEmptyString]:
|
|
439
446
|
if "" in schema.get_used_columns():
|
|
440
447
|
return [err.InvalidColumnNameEmptyString()]
|
|
441
448
|
return []
|
|
442
449
|
|
|
443
450
|
@staticmethod
|
|
444
451
|
def _check_field_convertible_to_str(
|
|
445
|
-
model_id, model_version, batch_id
|
|
446
|
-
) ->
|
|
452
|
+
model_id: object, model_version: object, batch_id: object
|
|
453
|
+
) -> list[err.InvalidFieldTypeConversion]:
|
|
447
454
|
# converting to a set first makes the checks run a lot faster
|
|
448
455
|
wrong_fields = []
|
|
449
456
|
if model_id is not None and not isinstance(model_id, str):
|
|
@@ -469,7 +476,7 @@ class Validator:
|
|
|
469
476
|
@staticmethod
|
|
470
477
|
def _check_field_type_embedding_features_column_names(
|
|
471
478
|
schema: Schema,
|
|
472
|
-
) ->
|
|
479
|
+
) -> list[err.InvalidFieldTypeEmbeddingFeatures]:
|
|
473
480
|
if schema.embedding_feature_column_names is not None:
|
|
474
481
|
if not isinstance(schema.embedding_feature_column_names, dict):
|
|
475
482
|
return [err.InvalidFieldTypeEmbeddingFeatures()]
|
|
@@ -483,7 +490,7 @@ class Validator:
|
|
|
483
490
|
@staticmethod
|
|
484
491
|
def _check_field_type_prompt_response(
|
|
485
492
|
schema: Schema,
|
|
486
|
-
) ->
|
|
493
|
+
) -> list[err.InvalidFieldTypePromptResponse]:
|
|
487
494
|
errors = []
|
|
488
495
|
if schema.prompt_column_names is not None and not isinstance(
|
|
489
496
|
schema.prompt_column_names, (str, EmbeddingColumnNames)
|
|
@@ -502,7 +509,7 @@ class Validator:
|
|
|
502
509
|
@staticmethod
|
|
503
510
|
def _check_field_type_prompt_templates(
|
|
504
511
|
schema: Schema,
|
|
505
|
-
) ->
|
|
512
|
+
) -> list[err.InvalidFieldTypePromptTemplates]:
|
|
506
513
|
if schema.prompt_template_column_names is not None and not isinstance(
|
|
507
514
|
schema.prompt_template_column_names, PromptTemplateColumnNames
|
|
508
515
|
):
|
|
@@ -513,7 +520,7 @@ class Validator:
|
|
|
513
520
|
def _check_field_type_llm_config(
|
|
514
521
|
dataframe: pd.DataFrame,
|
|
515
522
|
schema: Schema,
|
|
516
|
-
) ->
|
|
523
|
+
) -> list[err.InvalidFieldTypeLlmConfig | err.InvalidTypeColumns]:
|
|
517
524
|
if schema.llm_config_column_names is None:
|
|
518
525
|
return []
|
|
519
526
|
if not isinstance(schema.llm_config_column_names, LLMConfigColumnNames):
|
|
@@ -548,7 +555,7 @@ class Validator:
|
|
|
548
555
|
@staticmethod
|
|
549
556
|
def _check_invalid_index(
|
|
550
557
|
dataframe: pd.DataFrame,
|
|
551
|
-
) ->
|
|
558
|
+
) -> list[err.InvalidDataFrameIndex]:
|
|
552
559
|
if (dataframe.index != dataframe.reset_index(drop=True).index).any():
|
|
553
560
|
return [err.InvalidDataFrameIndex()]
|
|
554
561
|
return []
|
|
@@ -560,9 +567,9 @@ class Validator:
|
|
|
560
567
|
@staticmethod
|
|
561
568
|
def _check_model_type_and_metrics(
|
|
562
569
|
model_type: ModelTypes,
|
|
563
|
-
metric_families:
|
|
570
|
+
metric_families: list[Metrics] | None,
|
|
564
571
|
schema: Schema,
|
|
565
|
-
) ->
|
|
572
|
+
) -> list[err.ValidationError]:
|
|
566
573
|
if metric_families is None:
|
|
567
574
|
return []
|
|
568
575
|
|
|
@@ -606,10 +613,10 @@ class Validator:
|
|
|
606
613
|
@staticmethod
|
|
607
614
|
def _check_model_mapping_combinations(
|
|
608
615
|
model_type: ModelTypes,
|
|
609
|
-
metric_families:
|
|
616
|
+
metric_families: list[Metrics],
|
|
610
617
|
schema: Schema,
|
|
611
|
-
required_columns_map:
|
|
612
|
-
) ->
|
|
618
|
+
required_columns_map: list[dict[str, Any]],
|
|
619
|
+
) -> tuple[bool, list[str], list[list[str]]]:
|
|
613
620
|
missing_columns = []
|
|
614
621
|
for item in required_columns_map:
|
|
615
622
|
if model_type.name.lower() == item.get("external_model_type"):
|
|
@@ -625,10 +632,10 @@ class Validator:
|
|
|
625
632
|
metric_combinations.append(
|
|
626
633
|
[metric.upper() for metric in metrics_list]
|
|
627
634
|
)
|
|
628
|
-
if set(metrics_list) ==
|
|
635
|
+
if set(metrics_list) == {
|
|
629
636
|
metric_family.name.lower()
|
|
630
637
|
for metric_family in metric_families
|
|
631
|
-
|
|
638
|
+
}:
|
|
632
639
|
# This is a valid combination of model type + metrics.
|
|
633
640
|
# Now validate that required columns are in the schema.
|
|
634
641
|
is_valid_combination = True
|
|
@@ -665,10 +672,10 @@ class Validator:
|
|
|
665
672
|
@staticmethod
|
|
666
673
|
def _check_existence_prediction_id_column_delayed_schema(
|
|
667
674
|
schema: Schema, model_type: ModelTypes
|
|
668
|
-
) ->
|
|
675
|
+
) -> list[err.MissingPredictionIdColumnForDelayedRecords]:
|
|
669
676
|
if schema.prediction_id_column_name is not None:
|
|
670
677
|
return []
|
|
671
|
-
# TODO: Revise logic once
|
|
678
|
+
# TODO: Revise logic once prediction_label column addition (for generative models)
|
|
672
679
|
# is moved to beginning of log function
|
|
673
680
|
if schema.is_delayed() and model_type is not ModelTypes.GENERATIVE_LLM:
|
|
674
681
|
# We skip GENERATIVE model types since they are assigned a default
|
|
@@ -696,12 +703,12 @@ class Validator:
|
|
|
696
703
|
def _check_missing_columns(
|
|
697
704
|
dataframe: pd.DataFrame,
|
|
698
705
|
schema: BaseSchema,
|
|
699
|
-
) ->
|
|
706
|
+
) -> list[err.MissingColumns]:
|
|
700
707
|
if isinstance(schema, CorpusSchema):
|
|
701
708
|
return Validator._check_missing_columns_corpus_schema(
|
|
702
709
|
dataframe, schema
|
|
703
710
|
)
|
|
704
|
-
|
|
711
|
+
if isinstance(schema, Schema):
|
|
705
712
|
return Validator._check_missing_columns_schema(dataframe, schema)
|
|
706
713
|
return []
|
|
707
714
|
|
|
@@ -709,7 +716,7 @@ class Validator:
|
|
|
709
716
|
def _check_missing_columns_schema(
|
|
710
717
|
dataframe: pd.DataFrame,
|
|
711
718
|
schema: Schema,
|
|
712
|
-
) ->
|
|
719
|
+
) -> list[err.MissingColumns]:
|
|
713
720
|
# converting to a set first makes the checks run a lot faster
|
|
714
721
|
existing_columns = set(dataframe.columns)
|
|
715
722
|
missing_columns = []
|
|
@@ -721,9 +728,13 @@ class Validator:
|
|
|
721
728
|
missing_columns.append(col)
|
|
722
729
|
|
|
723
730
|
if schema.feature_column_names is not None:
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
731
|
+
missing_columns.extend(
|
|
732
|
+
[
|
|
733
|
+
col
|
|
734
|
+
for col in schema.feature_column_names
|
|
735
|
+
if col not in existing_columns
|
|
736
|
+
]
|
|
737
|
+
)
|
|
727
738
|
|
|
728
739
|
if schema.embedding_feature_column_names is not None:
|
|
729
740
|
for (
|
|
@@ -752,44 +763,76 @@ class Validator:
|
|
|
752
763
|
)
|
|
753
764
|
|
|
754
765
|
if schema.tag_column_names is not None:
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
766
|
+
missing_columns.extend(
|
|
767
|
+
[
|
|
768
|
+
col
|
|
769
|
+
for col in schema.tag_column_names
|
|
770
|
+
if col not in existing_columns
|
|
771
|
+
]
|
|
772
|
+
)
|
|
758
773
|
|
|
759
774
|
if schema.shap_values_column_names is not None:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
775
|
+
missing_columns.extend(
|
|
776
|
+
[
|
|
777
|
+
col
|
|
778
|
+
for col in schema.shap_values_column_names.values()
|
|
779
|
+
if col not in existing_columns
|
|
780
|
+
]
|
|
781
|
+
)
|
|
763
782
|
|
|
764
783
|
if schema.object_detection_prediction_column_names is not None:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
784
|
+
missing_columns.extend(
|
|
785
|
+
[
|
|
786
|
+
col
|
|
787
|
+
for col in schema.object_detection_prediction_column_names
|
|
788
|
+
if col is not None and col not in existing_columns
|
|
789
|
+
]
|
|
790
|
+
)
|
|
768
791
|
|
|
769
792
|
if schema.object_detection_actual_column_names is not None:
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
793
|
+
missing_columns.extend(
|
|
794
|
+
[
|
|
795
|
+
col
|
|
796
|
+
for col in schema.object_detection_actual_column_names
|
|
797
|
+
if col is not None and col not in existing_columns
|
|
798
|
+
]
|
|
799
|
+
)
|
|
773
800
|
|
|
774
801
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
802
|
+
missing_columns.extend(
|
|
803
|
+
[
|
|
804
|
+
col
|
|
805
|
+
for col in schema.semantic_segmentation_prediction_column_names
|
|
806
|
+
if col is not None and col not in existing_columns
|
|
807
|
+
]
|
|
808
|
+
)
|
|
778
809
|
|
|
779
810
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
811
|
+
missing_columns.extend(
|
|
812
|
+
[
|
|
813
|
+
col
|
|
814
|
+
for col in schema.semantic_segmentation_actual_column_names
|
|
815
|
+
if col is not None and col not in existing_columns
|
|
816
|
+
]
|
|
817
|
+
)
|
|
783
818
|
|
|
784
819
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
820
|
+
missing_columns.extend(
|
|
821
|
+
[
|
|
822
|
+
col
|
|
823
|
+
for col in schema.instance_segmentation_prediction_column_names
|
|
824
|
+
if col is not None and col not in existing_columns
|
|
825
|
+
]
|
|
826
|
+
)
|
|
788
827
|
|
|
789
828
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
829
|
+
missing_columns.extend(
|
|
830
|
+
[
|
|
831
|
+
col
|
|
832
|
+
for col in schema.instance_segmentation_actual_column_names
|
|
833
|
+
if col is not None and col not in existing_columns
|
|
834
|
+
]
|
|
835
|
+
)
|
|
793
836
|
|
|
794
837
|
if schema.prompt_column_names is not None:
|
|
795
838
|
if isinstance(schema.prompt_column_names, str):
|
|
@@ -838,14 +881,22 @@ class Validator:
|
|
|
838
881
|
)
|
|
839
882
|
|
|
840
883
|
if schema.prompt_template_column_names is not None:
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
884
|
+
missing_columns.extend(
|
|
885
|
+
[
|
|
886
|
+
col
|
|
887
|
+
for col in schema.prompt_template_column_names
|
|
888
|
+
if col is not None and col not in existing_columns
|
|
889
|
+
]
|
|
890
|
+
)
|
|
844
891
|
|
|
845
892
|
if schema.llm_config_column_names is not None:
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
893
|
+
missing_columns.extend(
|
|
894
|
+
[
|
|
895
|
+
col
|
|
896
|
+
for col in schema.llm_config_column_names
|
|
897
|
+
if col is not None and col not in existing_columns
|
|
898
|
+
]
|
|
899
|
+
)
|
|
849
900
|
|
|
850
901
|
if missing_columns:
|
|
851
902
|
return [err.MissingColumns(missing_columns)]
|
|
@@ -855,7 +906,7 @@ class Validator:
|
|
|
855
906
|
def _check_missing_columns_corpus_schema(
|
|
856
907
|
dataframe: pd.DataFrame,
|
|
857
908
|
schema: CorpusSchema,
|
|
858
|
-
) ->
|
|
909
|
+
) -> list[err.MissingColumns]:
|
|
859
910
|
# converting to a set first makes the checks run a lot faster
|
|
860
911
|
existing_columns = set(dataframe.columns)
|
|
861
912
|
missing_columns = []
|
|
@@ -912,7 +963,7 @@ class Validator:
|
|
|
912
963
|
def _check_valid_schema_type(
|
|
913
964
|
schema: BaseSchema,
|
|
914
965
|
environment: Environments,
|
|
915
|
-
) ->
|
|
966
|
+
) -> list[err.InvalidSchemaType]:
|
|
916
967
|
if environment == Environments.CORPUS and not (
|
|
917
968
|
isinstance(schema, CorpusSchema)
|
|
918
969
|
):
|
|
@@ -934,7 +985,7 @@ class Validator:
|
|
|
934
985
|
@staticmethod
|
|
935
986
|
def _check_invalid_shap_suffix(
|
|
936
987
|
schema: Schema,
|
|
937
|
-
) ->
|
|
988
|
+
) -> list[err.InvalidShapSuffix]:
|
|
938
989
|
invalid_column_names = set()
|
|
939
990
|
|
|
940
991
|
if schema.feature_column_names is not None:
|
|
@@ -970,10 +1021,10 @@ class Validator:
|
|
|
970
1021
|
def _check_reserved_columns(
|
|
971
1022
|
schema: BaseSchema,
|
|
972
1023
|
model_type: ModelTypes,
|
|
973
|
-
) ->
|
|
1024
|
+
) -> list[err.ReservedColumns]:
|
|
974
1025
|
if isinstance(schema, CorpusSchema):
|
|
975
1026
|
return []
|
|
976
|
-
|
|
1027
|
+
if isinstance(schema, Schema):
|
|
977
1028
|
reserved_columns = []
|
|
978
1029
|
column_counts = schema.get_used_columns_counts()
|
|
979
1030
|
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
@@ -1079,8 +1130,8 @@ class Validator:
|
|
|
1079
1130
|
|
|
1080
1131
|
@staticmethod
|
|
1081
1132
|
def _check_invalid_model_id(
|
|
1082
|
-
model_id:
|
|
1083
|
-
) ->
|
|
1133
|
+
model_id: str | None,
|
|
1134
|
+
) -> list[err.InvalidModelId]:
|
|
1084
1135
|
# assume it's been coerced to string beforehand
|
|
1085
1136
|
if (not isinstance(model_id, str)) or len(model_id.strip()) == 0:
|
|
1086
1137
|
return [err.InvalidModelId()]
|
|
@@ -1088,8 +1139,8 @@ class Validator:
|
|
|
1088
1139
|
|
|
1089
1140
|
@staticmethod
|
|
1090
1141
|
def _check_invalid_model_version(
|
|
1091
|
-
model_version:
|
|
1092
|
-
) ->
|
|
1142
|
+
model_version: str | None = None,
|
|
1143
|
+
) -> list[err.InvalidModelVersion]:
|
|
1093
1144
|
if model_version is None:
|
|
1094
1145
|
return []
|
|
1095
1146
|
if (
|
|
@@ -1102,9 +1153,9 @@ class Validator:
|
|
|
1102
1153
|
|
|
1103
1154
|
@staticmethod
|
|
1104
1155
|
def _check_invalid_batch_id(
|
|
1105
|
-
batch_id:
|
|
1156
|
+
batch_id: str | None,
|
|
1106
1157
|
environment: Environments,
|
|
1107
|
-
) ->
|
|
1158
|
+
) -> list[err.InvalidBatchId]:
|
|
1108
1159
|
# assume it's been coerced to string beforehand
|
|
1109
1160
|
if environment in (Environments.VALIDATION,) and (
|
|
1110
1161
|
(not isinstance(batch_id, str)) or len(batch_id.strip()) == 0
|
|
@@ -1115,7 +1166,7 @@ class Validator:
|
|
|
1115
1166
|
@staticmethod
|
|
1116
1167
|
def _check_invalid_model_type(
|
|
1117
1168
|
model_type: ModelTypes,
|
|
1118
|
-
) ->
|
|
1169
|
+
) -> list[err.InvalidModelType]:
|
|
1119
1170
|
if model_type in (mt for mt in ModelTypes):
|
|
1120
1171
|
return []
|
|
1121
1172
|
return [err.InvalidModelType()]
|
|
@@ -1123,7 +1174,7 @@ class Validator:
|
|
|
1123
1174
|
@staticmethod
|
|
1124
1175
|
def _check_invalid_environment(
|
|
1125
1176
|
environment: Environments,
|
|
1126
|
-
) ->
|
|
1177
|
+
) -> list[err.InvalidEnvironment]:
|
|
1127
1178
|
if environment in (env for env in Environments):
|
|
1128
1179
|
return []
|
|
1129
1180
|
return [err.InvalidEnvironment()]
|
|
@@ -1132,7 +1183,7 @@ class Validator:
|
|
|
1132
1183
|
def _check_existence_preprod_pred_act_score_or_label(
|
|
1133
1184
|
schema: Schema,
|
|
1134
1185
|
environment: Environments,
|
|
1135
|
-
) ->
|
|
1186
|
+
) -> list[err.MissingPreprodPredActNumericAndCategorical]:
|
|
1136
1187
|
if environment in (Environments.VALIDATION, Environments.TRAINING) and (
|
|
1137
1188
|
(
|
|
1138
1189
|
schema.prediction_label_column_name is None
|
|
@@ -1149,7 +1200,7 @@ class Validator:
|
|
|
1149
1200
|
@staticmethod
|
|
1150
1201
|
def _check_exactly_one_cv_column_type(
|
|
1151
1202
|
schema: Schema, environment: Environments
|
|
1152
|
-
) ->
|
|
1203
|
+
) -> list[err.MultipleCVPredAct | err.MissingCVPredAct]:
|
|
1153
1204
|
# Checks that the required prediction/actual columns are given in the schema depending on
|
|
1154
1205
|
# the environment, for object detection models. There should be exactly one of
|
|
1155
1206
|
# object detection, semantic segmentation, or instance segmentation columns.
|
|
@@ -1180,7 +1231,7 @@ class Validator:
|
|
|
1180
1231
|
|
|
1181
1232
|
if cv_types_count == 0:
|
|
1182
1233
|
return [err.MissingCVPredAct(environment)]
|
|
1183
|
-
|
|
1234
|
+
if cv_types_count > 1:
|
|
1184
1235
|
return [err.MultipleCVPredAct(environment)]
|
|
1185
1236
|
|
|
1186
1237
|
elif environment in (
|
|
@@ -1213,7 +1264,7 @@ class Validator:
|
|
|
1213
1264
|
|
|
1214
1265
|
if cv_types_count == 0:
|
|
1215
1266
|
return [err.MissingCVPredAct(environment)]
|
|
1216
|
-
|
|
1267
|
+
if cv_types_count > 1:
|
|
1217
1268
|
return [err.MultipleCVPredAct(environment)]
|
|
1218
1269
|
|
|
1219
1270
|
return []
|
|
@@ -1221,9 +1272,9 @@ class Validator:
|
|
|
1221
1272
|
@staticmethod
|
|
1222
1273
|
def _check_missing_object_detection_columns(
|
|
1223
1274
|
schema: Schema, model_type: ModelTypes
|
|
1224
|
-
) ->
|
|
1275
|
+
) -> list[err.InvalidPredActCVColumnNamesForModelType]:
|
|
1225
1276
|
# Checks that models that are not Object Detection models don't have, in the schema, the
|
|
1226
|
-
# object detection, semantic segmentation, or instance segmentation dedicated
|
|
1277
|
+
# object detection, semantic segmentation, or instance segmentation dedicated prediction/actual
|
|
1227
1278
|
# column names
|
|
1228
1279
|
if (
|
|
1229
1280
|
schema.object_detection_prediction_column_names is not None
|
|
@@ -1239,7 +1290,7 @@ class Validator:
|
|
|
1239
1290
|
@staticmethod
|
|
1240
1291
|
def _check_missing_non_object_detection_columns(
|
|
1241
1292
|
schema: Schema, model_type: ModelTypes
|
|
1242
|
-
) ->
|
|
1293
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1243
1294
|
# Checks that object detection models don't have, in the schema, the columns reserved for
|
|
1244
1295
|
# other model types
|
|
1245
1296
|
columns_to_check = (
|
|
@@ -1253,10 +1304,7 @@ class Validator:
|
|
|
1253
1304
|
schema.relevance_score_column_name,
|
|
1254
1305
|
schema.relevance_labels_column_name,
|
|
1255
1306
|
)
|
|
1256
|
-
wrong_cols = []
|
|
1257
|
-
for col in columns_to_check:
|
|
1258
|
-
if col is not None:
|
|
1259
|
-
wrong_cols.append(col)
|
|
1307
|
+
wrong_cols = [col for col in columns_to_check if col is not None]
|
|
1260
1308
|
if wrong_cols:
|
|
1261
1309
|
allowed_cols = [
|
|
1262
1310
|
"object_detection_prediction_column_names",
|
|
@@ -1276,7 +1324,7 @@ class Validator:
|
|
|
1276
1324
|
@staticmethod
|
|
1277
1325
|
def _check_missing_multi_class_columns(
|
|
1278
1326
|
schema: Schema, model_type: ModelTypes
|
|
1279
|
-
) ->
|
|
1327
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1280
1328
|
# Checks that models that are not Multi Class models don't have, in the schema, the
|
|
1281
1329
|
# multi class dedicated threshold column
|
|
1282
1330
|
if (
|
|
@@ -1295,7 +1343,7 @@ class Validator:
|
|
|
1295
1343
|
@staticmethod
|
|
1296
1344
|
def _check_existing_multi_class_columns(
|
|
1297
1345
|
schema: Schema,
|
|
1298
|
-
) ->
|
|
1346
|
+
) -> list[err.MissingReqPredActColumnNamesForMultiClass]:
|
|
1299
1347
|
# Checks that models that are Multi Class models have, in the schema, the
|
|
1300
1348
|
# required prediction score or actual score columns
|
|
1301
1349
|
if (
|
|
@@ -1311,7 +1359,7 @@ class Validator:
|
|
|
1311
1359
|
@staticmethod
|
|
1312
1360
|
def _check_missing_non_multi_class_columns(
|
|
1313
1361
|
schema: Schema, model_type: ModelTypes
|
|
1314
|
-
) ->
|
|
1362
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1315
1363
|
# Checks that multi class models don't have, in the schema, the columns reserved for
|
|
1316
1364
|
# other model types
|
|
1317
1365
|
columns_to_check = (
|
|
@@ -1329,10 +1377,7 @@ class Validator:
|
|
|
1329
1377
|
schema.instance_segmentation_prediction_column_names,
|
|
1330
1378
|
schema.instance_segmentation_actual_column_names,
|
|
1331
1379
|
)
|
|
1332
|
-
wrong_cols = []
|
|
1333
|
-
for col in columns_to_check:
|
|
1334
|
-
if col is not None:
|
|
1335
|
-
wrong_cols.append(col)
|
|
1380
|
+
wrong_cols = [col for col in columns_to_check if col is not None]
|
|
1336
1381
|
if wrong_cols:
|
|
1337
1382
|
allowed_cols = [
|
|
1338
1383
|
"prediction_score_column_name",
|
|
@@ -1350,7 +1395,7 @@ class Validator:
|
|
|
1350
1395
|
def _check_existence_preprod_act(
|
|
1351
1396
|
schema: Schema,
|
|
1352
1397
|
environment: Environments,
|
|
1353
|
-
) ->
|
|
1398
|
+
) -> list[err.MissingPreprodAct]:
|
|
1354
1399
|
if environment in (Environments.VALIDATION, Environments.TRAINING) and (
|
|
1355
1400
|
schema.actual_label_column_name is None
|
|
1356
1401
|
):
|
|
@@ -1360,7 +1405,7 @@ class Validator:
|
|
|
1360
1405
|
@staticmethod
|
|
1361
1406
|
def _check_existence_group_id_rank_category_relevance(
|
|
1362
1407
|
schema: Schema,
|
|
1363
|
-
) ->
|
|
1408
|
+
) -> list[err.MissingRequiredColumnsForRankingModel]:
|
|
1364
1409
|
# prediction_group_id and rank columns are required as ranking prediction columns.
|
|
1365
1410
|
ranking_prediction_cols = (
|
|
1366
1411
|
schema.prediction_label_column_name,
|
|
@@ -1384,7 +1429,7 @@ class Validator:
|
|
|
1384
1429
|
@staticmethod
|
|
1385
1430
|
def _check_dataframe_for_duplicate_columns(
|
|
1386
1431
|
schema: BaseSchema, dataframe: pd.DataFrame
|
|
1387
|
-
) ->
|
|
1432
|
+
) -> list[err.DuplicateColumnsInDataframe]:
|
|
1388
1433
|
# Get the columns used in the schema
|
|
1389
1434
|
schema_col_used = schema.get_used_columns()
|
|
1390
1435
|
# Get the duplicated column names from the dataframe
|
|
@@ -1400,7 +1445,7 @@ class Validator:
|
|
|
1400
1445
|
@staticmethod
|
|
1401
1446
|
def _check_invalid_number_of_embeddings(
|
|
1402
1447
|
schema: Schema,
|
|
1403
|
-
) ->
|
|
1448
|
+
) -> list[err.InvalidNumberOfEmbeddings]:
|
|
1404
1449
|
if schema.embedding_feature_column_names is not None:
|
|
1405
1450
|
number_of_embeddings = len(schema.embedding_feature_column_names)
|
|
1406
1451
|
if number_of_embeddings > MAX_NUMBER_OF_EMBEDDINGS:
|
|
@@ -1413,8 +1458,8 @@ class Validator:
|
|
|
1413
1458
|
|
|
1414
1459
|
@staticmethod
|
|
1415
1460
|
def _check_type_prediction_id(
|
|
1416
|
-
schema: Schema, column_types:
|
|
1417
|
-
) ->
|
|
1461
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1462
|
+
) -> list[err.InvalidType]:
|
|
1418
1463
|
col = schema.prediction_id_column_name
|
|
1419
1464
|
if col in column_types:
|
|
1420
1465
|
# should mirror server side
|
|
@@ -1437,8 +1482,8 @@ class Validator:
|
|
|
1437
1482
|
|
|
1438
1483
|
@staticmethod
|
|
1439
1484
|
def _check_type_timestamp(
|
|
1440
|
-
schema: Schema, column_types:
|
|
1441
|
-
) ->
|
|
1485
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1486
|
+
) -> list[err.InvalidType]:
|
|
1442
1487
|
col = schema.timestamp_column_name
|
|
1443
1488
|
if col in column_types:
|
|
1444
1489
|
# should mirror server side
|
|
@@ -1464,8 +1509,8 @@ class Validator:
|
|
|
1464
1509
|
|
|
1465
1510
|
@staticmethod
|
|
1466
1511
|
def _check_type_features(
|
|
1467
|
-
schema: Schema, column_types:
|
|
1468
|
-
) ->
|
|
1512
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1513
|
+
) -> list[err.InvalidTypeFeatures]:
|
|
1469
1514
|
if schema.feature_column_names is not None:
|
|
1470
1515
|
# should mirror server side
|
|
1471
1516
|
allowed_datatypes = (
|
|
@@ -1480,13 +1525,12 @@ class Validator:
|
|
|
1480
1525
|
pa.null(),
|
|
1481
1526
|
pa.list_(pa.string()),
|
|
1482
1527
|
)
|
|
1483
|
-
wrong_type_cols = [
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
wrong_type_cols.append(col)
|
|
1528
|
+
wrong_type_cols = [
|
|
1529
|
+
col
|
|
1530
|
+
for col in schema.feature_column_names
|
|
1531
|
+
if col in column_types
|
|
1532
|
+
and column_types[col] not in allowed_datatypes
|
|
1533
|
+
]
|
|
1490
1534
|
if wrong_type_cols:
|
|
1491
1535
|
return [
|
|
1492
1536
|
err.InvalidTypeFeatures(
|
|
@@ -1504,8 +1548,8 @@ class Validator:
|
|
|
1504
1548
|
|
|
1505
1549
|
@staticmethod
|
|
1506
1550
|
def _check_type_embedding_features(
|
|
1507
|
-
schema: Schema, column_types:
|
|
1508
|
-
) ->
|
|
1551
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1552
|
+
) -> list[err.InvalidTypeFeatures]:
|
|
1509
1553
|
if schema.embedding_feature_column_names is not None:
|
|
1510
1554
|
# should mirror server side
|
|
1511
1555
|
allowed_vector_datatypes = (
|
|
@@ -1580,8 +1624,8 @@ class Validator:
|
|
|
1580
1624
|
|
|
1581
1625
|
@staticmethod
|
|
1582
1626
|
def _check_type_tags(
|
|
1583
|
-
schema: Schema, column_types:
|
|
1584
|
-
) ->
|
|
1627
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1628
|
+
) -> list[err.InvalidTypeTags]:
|
|
1585
1629
|
if schema.tag_column_names is not None:
|
|
1586
1630
|
# should mirror server side
|
|
1587
1631
|
allowed_datatypes = (
|
|
@@ -1595,13 +1639,12 @@ class Validator:
|
|
|
1595
1639
|
pa.int8(),
|
|
1596
1640
|
pa.null(),
|
|
1597
1641
|
)
|
|
1598
|
-
wrong_type_cols = [
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
wrong_type_cols.append(col)
|
|
1642
|
+
wrong_type_cols = [
|
|
1643
|
+
col
|
|
1644
|
+
for col in schema.tag_column_names
|
|
1645
|
+
if col in column_types
|
|
1646
|
+
and column_types[col] not in allowed_datatypes
|
|
1647
|
+
]
|
|
1605
1648
|
if wrong_type_cols:
|
|
1606
1649
|
return [
|
|
1607
1650
|
err.InvalidTypeTags(
|
|
@@ -1612,8 +1655,8 @@ class Validator:
|
|
|
1612
1655
|
|
|
1613
1656
|
@staticmethod
|
|
1614
1657
|
def _check_type_shap_values(
|
|
1615
|
-
schema: Schema, column_types:
|
|
1616
|
-
) ->
|
|
1658
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1659
|
+
) -> list[err.InvalidTypeShapValues]:
|
|
1617
1660
|
if schema.shap_values_column_names is not None:
|
|
1618
1661
|
# should mirror server side
|
|
1619
1662
|
allowed_datatypes = (
|
|
@@ -1622,13 +1665,12 @@ class Validator:
|
|
|
1622
1665
|
pa.float32(),
|
|
1623
1666
|
pa.int32(),
|
|
1624
1667
|
)
|
|
1625
|
-
wrong_type_cols = [
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
wrong_type_cols.append(col)
|
|
1668
|
+
wrong_type_cols = [
|
|
1669
|
+
col
|
|
1670
|
+
for col in schema.shap_values_column_names.values()
|
|
1671
|
+
if col in column_types
|
|
1672
|
+
and column_types[col] not in allowed_datatypes
|
|
1673
|
+
]
|
|
1632
1674
|
if wrong_type_cols:
|
|
1633
1675
|
return [
|
|
1634
1676
|
err.InvalidTypeShapValues(
|
|
@@ -1639,8 +1681,8 @@ class Validator:
|
|
|
1639
1681
|
|
|
1640
1682
|
@staticmethod
|
|
1641
1683
|
def _check_type_pred_act_labels(
|
|
1642
|
-
model_type: ModelTypes, schema: Schema, column_types:
|
|
1643
|
-
) ->
|
|
1684
|
+
model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
|
|
1685
|
+
) -> list[err.InvalidType]:
|
|
1644
1686
|
errors = []
|
|
1645
1687
|
columns = (
|
|
1646
1688
|
("Prediction labels", schema.prediction_label_column_name),
|
|
@@ -1703,8 +1745,8 @@ class Validator:
|
|
|
1703
1745
|
|
|
1704
1746
|
@staticmethod
|
|
1705
1747
|
def _check_type_pred_act_scores(
|
|
1706
|
-
model_type: ModelTypes, schema: Schema, column_types:
|
|
1707
|
-
) ->
|
|
1748
|
+
model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
|
|
1749
|
+
) -> list[err.InvalidType]:
|
|
1708
1750
|
errors = []
|
|
1709
1751
|
columns = (
|
|
1710
1752
|
("Prediction scores", schema.prediction_score_column_name),
|
|
@@ -1743,13 +1785,14 @@ class Validator:
|
|
|
1743
1785
|
|
|
1744
1786
|
@staticmethod
|
|
1745
1787
|
def _check_type_multi_class_pred_threshold_act_scores(
|
|
1746
|
-
schema: Schema, column_types:
|
|
1747
|
-
) ->
|
|
1748
|
-
"""
|
|
1749
|
-
|
|
1750
|
-
Expect the scores to be a list of pyarrow structs that contains field
|
|
1751
|
-
|
|
1752
|
-
|
|
1788
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1789
|
+
) -> list[err.InvalidType]:
|
|
1790
|
+
"""Check type for prediction / threshold / actual scores for multiclass model.
|
|
1791
|
+
|
|
1792
|
+
Expect the scores to be a list of pyarrow structs that contains field
|
|
1793
|
+
"class_name" and field "score", where class_name is a string and score
|
|
1794
|
+
is a number.
|
|
1795
|
+
Example: '[{"class_name": "class1", "score": 0.1}, ...]'
|
|
1753
1796
|
"""
|
|
1754
1797
|
errors = []
|
|
1755
1798
|
columns = (
|
|
@@ -1802,8 +1845,8 @@ class Validator:
|
|
|
1802
1845
|
|
|
1803
1846
|
@staticmethod
|
|
1804
1847
|
def _check_type_prompt_response(
|
|
1805
|
-
schema: Schema, column_types:
|
|
1806
|
-
) ->
|
|
1848
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1849
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1807
1850
|
fields_to_check = []
|
|
1808
1851
|
if schema.prompt_column_names is not None:
|
|
1809
1852
|
fields_to_check.append(schema.prompt_column_names)
|
|
@@ -1872,8 +1915,8 @@ class Validator:
|
|
|
1872
1915
|
|
|
1873
1916
|
@staticmethod
|
|
1874
1917
|
def _check_type_llm_prompt_templates(
|
|
1875
|
-
schema: Schema, column_types:
|
|
1876
|
-
) ->
|
|
1918
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1919
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1877
1920
|
if schema.prompt_template_column_names is None:
|
|
1878
1921
|
return []
|
|
1879
1922
|
|
|
@@ -1913,8 +1956,8 @@ class Validator:
|
|
|
1913
1956
|
|
|
1914
1957
|
@staticmethod
|
|
1915
1958
|
def _check_type_llm_config(
|
|
1916
|
-
schema: Schema, column_types:
|
|
1917
|
-
) ->
|
|
1959
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1960
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1918
1961
|
if schema.llm_config_column_names is None:
|
|
1919
1962
|
return []
|
|
1920
1963
|
|
|
@@ -1950,8 +1993,8 @@ class Validator:
|
|
|
1950
1993
|
|
|
1951
1994
|
@staticmethod
|
|
1952
1995
|
def _check_type_llm_run_metadata(
|
|
1953
|
-
schema: Schema, column_types:
|
|
1954
|
-
) ->
|
|
1996
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1997
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1955
1998
|
if schema.llm_run_metadata_column_names is None:
|
|
1956
1999
|
return []
|
|
1957
2000
|
|
|
@@ -2023,8 +2066,8 @@ class Validator:
|
|
|
2023
2066
|
|
|
2024
2067
|
@staticmethod
|
|
2025
2068
|
def _check_type_retrieved_document_ids(
|
|
2026
|
-
schema: Schema, column_types:
|
|
2027
|
-
) ->
|
|
2069
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2070
|
+
) -> list[err.InvalidType]:
|
|
2028
2071
|
col = schema.retrieved_document_ids_column_name
|
|
2029
2072
|
if col in column_types:
|
|
2030
2073
|
# should mirror server side
|
|
@@ -2044,8 +2087,8 @@ class Validator:
|
|
|
2044
2087
|
|
|
2045
2088
|
@staticmethod
|
|
2046
2089
|
def _check_type_image_segment_coordinates(
|
|
2047
|
-
schema: Schema, column_types:
|
|
2048
|
-
) ->
|
|
2090
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2091
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2049
2092
|
# should mirror server side
|
|
2050
2093
|
allowed_coordinate_types = (
|
|
2051
2094
|
pa.list_(pa.list_(pa.float64())),
|
|
@@ -2090,9 +2133,8 @@ class Validator:
|
|
|
2090
2133
|
wrong_type_cols.append(coord_col)
|
|
2091
2134
|
|
|
2092
2135
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
)
|
|
2136
|
+
inst_seg_pred = schema.instance_segmentation_prediction_column_names
|
|
2137
|
+
polygons_coord_col = inst_seg_pred.polygon_coordinates_column_name
|
|
2096
2138
|
if (
|
|
2097
2139
|
polygons_coord_col in column_types
|
|
2098
2140
|
and column_types[polygons_coord_col]
|
|
@@ -2101,7 +2143,7 @@ class Validator:
|
|
|
2101
2143
|
wrong_type_cols.append(polygons_coord_col)
|
|
2102
2144
|
|
|
2103
2145
|
bbox_coord_col = (
|
|
2104
|
-
|
|
2146
|
+
inst_seg_pred.bounding_boxes_coordinates_column_name
|
|
2105
2147
|
)
|
|
2106
2148
|
if (
|
|
2107
2149
|
bbox_coord_col in column_types
|
|
@@ -2110,9 +2152,8 @@ class Validator:
|
|
|
2110
2152
|
wrong_type_cols.append(bbox_coord_col)
|
|
2111
2153
|
|
|
2112
2154
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
)
|
|
2155
|
+
inst_seg_actual = schema.instance_segmentation_actual_column_names
|
|
2156
|
+
coord_col = inst_seg_actual.polygon_coordinates_column_name
|
|
2116
2157
|
if (
|
|
2117
2158
|
coord_col in column_types
|
|
2118
2159
|
and column_types[coord_col] not in allowed_coordinate_types
|
|
@@ -2120,7 +2161,7 @@ class Validator:
|
|
|
2120
2161
|
wrong_type_cols.append(coord_col)
|
|
2121
2162
|
|
|
2122
2163
|
bbox_coord_col = (
|
|
2123
|
-
|
|
2164
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
2124
2165
|
)
|
|
2125
2166
|
if (
|
|
2126
2167
|
bbox_coord_col in column_types
|
|
@@ -2141,8 +2182,8 @@ class Validator:
|
|
|
2141
2182
|
|
|
2142
2183
|
@staticmethod
|
|
2143
2184
|
def _check_type_image_segment_categories(
|
|
2144
|
-
schema: Schema, column_types:
|
|
2145
|
-
) ->
|
|
2185
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2186
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2146
2187
|
# should mirror server side
|
|
2147
2188
|
allowed_category_datatypes = (
|
|
2148
2189
|
pa.list_(pa.string()),
|
|
@@ -2210,8 +2251,8 @@ class Validator:
|
|
|
2210
2251
|
|
|
2211
2252
|
@staticmethod
|
|
2212
2253
|
def _check_type_image_segment_scores(
|
|
2213
|
-
schema: Schema, column_types:
|
|
2214
|
-
) ->
|
|
2254
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2255
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2215
2256
|
# should mirror server side
|
|
2216
2257
|
allowed_score_datatypes = (
|
|
2217
2258
|
pa.list_(pa.float64()),
|
|
@@ -2270,7 +2311,7 @@ class Validator:
|
|
|
2270
2311
|
@staticmethod
|
|
2271
2312
|
def _check_embedding_vectors_dimensionality(
|
|
2272
2313
|
dataframe: pd.DataFrame, schema: Schema
|
|
2273
|
-
) ->
|
|
2314
|
+
) -> list[err.ValidationError]:
|
|
2274
2315
|
if schema.embedding_feature_column_names is None:
|
|
2275
2316
|
return []
|
|
2276
2317
|
|
|
@@ -2300,7 +2341,7 @@ class Validator:
|
|
|
2300
2341
|
@staticmethod
|
|
2301
2342
|
def _check_embedding_raw_data_characters(
|
|
2302
2343
|
dataframe: pd.DataFrame, schema: Schema
|
|
2303
|
-
) ->
|
|
2344
|
+
) -> list[err.ValidationError]:
|
|
2304
2345
|
if schema.embedding_feature_column_names is None:
|
|
2305
2346
|
return []
|
|
2306
2347
|
|
|
@@ -2322,7 +2363,7 @@ class Validator:
|
|
|
2322
2363
|
invalid_long_string_data_cols
|
|
2323
2364
|
)
|
|
2324
2365
|
]
|
|
2325
|
-
|
|
2366
|
+
if truncated_long_string_data_cols:
|
|
2326
2367
|
logger.warning(
|
|
2327
2368
|
get_truncation_warning_message(
|
|
2328
2369
|
"Embedding raw data fields",
|
|
@@ -2334,7 +2375,7 @@ class Validator:
|
|
|
2334
2375
|
@staticmethod
|
|
2335
2376
|
def _check_value_rank(
|
|
2336
2377
|
dataframe: pd.DataFrame, schema: Schema
|
|
2337
|
-
) ->
|
|
2378
|
+
) -> list[err.InvalidRankValue]:
|
|
2338
2379
|
col = schema.rank_column_name
|
|
2339
2380
|
lbound, ubound = (1, 100)
|
|
2340
2381
|
|
|
@@ -2346,11 +2387,11 @@ class Validator:
|
|
|
2346
2387
|
|
|
2347
2388
|
@staticmethod
|
|
2348
2389
|
def _check_id_field_str_length(
|
|
2349
|
-
dataframe: pd.DataFrame, schema_name: str, id_col_name:
|
|
2350
|
-
) ->
|
|
2351
|
-
"""
|
|
2352
|
-
|
|
2353
|
-
and MAX_PREDICTION_ID_LEN
|
|
2390
|
+
dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
|
|
2391
|
+
) -> list[err.ValidationError]:
|
|
2392
|
+
"""Require prediction_id to be a string of length between MIN and MAX.
|
|
2393
|
+
|
|
2394
|
+
Between MIN_PREDICTION_ID_LEN and MAX_PREDICTION_ID_LEN.
|
|
2354
2395
|
"""
|
|
2355
2396
|
# We check whether the column name can be None is allowed in `Validator.validate_params`
|
|
2356
2397
|
if id_col_name is None:
|
|
@@ -2380,11 +2421,11 @@ class Validator:
|
|
|
2380
2421
|
|
|
2381
2422
|
@staticmethod
|
|
2382
2423
|
def _check_document_id_field_str_length(
|
|
2383
|
-
dataframe: pd.DataFrame, schema_name: str, id_col_name:
|
|
2384
|
-
) ->
|
|
2385
|
-
"""
|
|
2386
|
-
|
|
2387
|
-
and MAX_DOCUMENT_ID_LEN
|
|
2424
|
+
dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
|
|
2425
|
+
) -> list[err.ValidationError]:
|
|
2426
|
+
"""Require document id to be a string of length between MIN and MAX.
|
|
2427
|
+
|
|
2428
|
+
Between MIN_DOCUMENT_ID_LEN and MAX_DOCUMENT_ID_LEN.
|
|
2388
2429
|
"""
|
|
2389
2430
|
# We check whether the column name can be None is allowed in `Validator.validate_params`
|
|
2390
2431
|
if id_col_name is None:
|
|
@@ -2433,7 +2474,7 @@ class Validator:
|
|
|
2433
2474
|
@staticmethod
|
|
2434
2475
|
def _check_value_tag(
|
|
2435
2476
|
dataframe: pd.DataFrame, schema: Schema
|
|
2436
|
-
) ->
|
|
2477
|
+
) -> list[err.InvalidTagLength]:
|
|
2437
2478
|
if schema.tag_column_names is None:
|
|
2438
2479
|
return []
|
|
2439
2480
|
|
|
@@ -2459,7 +2500,7 @@ class Validator:
|
|
|
2459
2500
|
truncated_tag_cols.append(col)
|
|
2460
2501
|
if wrong_tag_cols:
|
|
2461
2502
|
return [err.InvalidTagLength(wrong_tag_cols)]
|
|
2462
|
-
|
|
2503
|
+
if truncated_tag_cols:
|
|
2463
2504
|
logger.warning(
|
|
2464
2505
|
get_truncation_warning_message(
|
|
2465
2506
|
"tags", MAX_TAG_LENGTH_TRUNCATION
|
|
@@ -2470,9 +2511,7 @@ class Validator:
|
|
|
2470
2511
|
@staticmethod
|
|
2471
2512
|
def _check_value_ranking_category(
|
|
2472
2513
|
dataframe: pd.DataFrame, schema: Schema
|
|
2473
|
-
) ->
|
|
2474
|
-
Union[err.InvalidValueMissingValue, err.InvalidRankingCategoryValue]
|
|
2475
|
-
]:
|
|
2514
|
+
) -> list[err.InvalidValueMissingValue | err.InvalidRankingCategoryValue]:
|
|
2476
2515
|
if schema.relevance_labels_column_name is not None:
|
|
2477
2516
|
col = schema.relevance_labels_column_name
|
|
2478
2517
|
elif schema.attributions_column_name is not None:
|
|
@@ -2503,7 +2542,7 @@ class Validator:
|
|
|
2503
2542
|
@staticmethod
|
|
2504
2543
|
def _check_length_multi_class_maps(
|
|
2505
2544
|
dataframe: pd.DataFrame, schema: Schema
|
|
2506
|
-
) ->
|
|
2545
|
+
) -> list[err.InvalidNumClassesMultiClassMap]:
|
|
2507
2546
|
# each entry in column is a list of dictionaries mapping class names and scores
|
|
2508
2547
|
# validate length of list of dictionaries for each column
|
|
2509
2548
|
invalid_cols = {}
|
|
@@ -2540,15 +2579,13 @@ class Validator:
|
|
|
2540
2579
|
@staticmethod
|
|
2541
2580
|
def _check_classes_and_scores_values_in_multi_class_maps(
|
|
2542
2581
|
dataframe: pd.DataFrame, schema: Schema
|
|
2543
|
-
) ->
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
|
|
2547
|
-
err.InvalidMultiClassPredScoreValue,
|
|
2548
|
-
]
|
|
2582
|
+
) -> list[
|
|
2583
|
+
err.InvalidMultiClassClassNameLength
|
|
2584
|
+
| err.InvalidMultiClassActScoreValue
|
|
2585
|
+
| err.InvalidMultiClassPredScoreValue
|
|
2549
2586
|
]:
|
|
2550
|
-
"""
|
|
2551
|
-
|
|
2587
|
+
"""Validate the class names and score values of dictionaries.
|
|
2588
|
+
|
|
2552
2589
|
- class name length
|
|
2553
2590
|
- valid actual score
|
|
2554
2591
|
- valid prediction / threshold score
|
|
@@ -2624,11 +2661,12 @@ class Validator:
|
|
|
2624
2661
|
@staticmethod
|
|
2625
2662
|
def _check_each_multi_class_pred_has_threshold(
|
|
2626
2663
|
dataframe: pd.DataFrame, schema: Schema
|
|
2627
|
-
) ->
|
|
2628
|
-
"""
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
for
|
|
2664
|
+
) -> list[err.InvalidMultiClassThresholdClasses]:
|
|
2665
|
+
"""Validate threshold scores for Multi Class models.
|
|
2666
|
+
|
|
2667
|
+
If threshold scores column is included in schema and dataframe, validate that
|
|
2668
|
+
for each prediction score received, the associated threshold score for that
|
|
2669
|
+
class was also received.
|
|
2632
2670
|
"""
|
|
2633
2671
|
threshold_col = schema.multi_class_threshold_scores_column_name
|
|
2634
2672
|
if threshold_col is None:
|
|
@@ -2657,10 +2695,10 @@ class Validator:
|
|
|
2657
2695
|
def _check_value_timestamp(
|
|
2658
2696
|
dataframe: pd.DataFrame,
|
|
2659
2697
|
schema: Schema,
|
|
2660
|
-
) ->
|
|
2698
|
+
) -> list[err.InvalidValueMissingValue | err.InvalidValueTimestamp]:
|
|
2661
2699
|
# Due to the timing difference between checking this here and the data finally
|
|
2662
2700
|
# hitting the same check on server side, there's a some chance for a false
|
|
2663
|
-
# result, i.e. the check here
|
|
2701
|
+
# result, i.e. the check here succeeds but the same check on server side fails.
|
|
2664
2702
|
col = schema.timestamp_column_name
|
|
2665
2703
|
if col is not None and col in dataframe.columns:
|
|
2666
2704
|
# When a timestamp column has Date and NaN, pyarrow will be fine, but
|
|
@@ -2673,19 +2711,15 @@ class Validator:
|
|
|
2673
2711
|
)
|
|
2674
2712
|
]
|
|
2675
2713
|
|
|
2676
|
-
now_t = datetime.
|
|
2714
|
+
now_t = datetime.now(tz=timezone.utc)
|
|
2677
2715
|
lbound, ubound = (
|
|
2678
2716
|
(
|
|
2679
2717
|
now_t
|
|
2680
|
-
-
|
|
2681
|
-
days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365
|
|
2682
|
-
)
|
|
2718
|
+
- timedelta(days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365)
|
|
2683
2719
|
).timestamp(),
|
|
2684
2720
|
(
|
|
2685
2721
|
now_t
|
|
2686
|
-
+
|
|
2687
|
-
days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365
|
|
2688
|
-
)
|
|
2722
|
+
+ timedelta(days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365)
|
|
2689
2723
|
).timestamp(),
|
|
2690
2724
|
)
|
|
2691
2725
|
# faster than pyarrow compute
|
|
@@ -2767,7 +2801,7 @@ class Validator:
|
|
|
2767
2801
|
@staticmethod
|
|
2768
2802
|
def _check_invalid_missing_values(
|
|
2769
2803
|
dataframe: pd.DataFrame, schema: BaseSchema, model_type: ModelTypes
|
|
2770
|
-
) ->
|
|
2804
|
+
) -> list[err.InvalidValueMissingValue]:
|
|
2771
2805
|
errors = []
|
|
2772
2806
|
columns = ()
|
|
2773
2807
|
if isinstance(schema, CorpusSchema):
|
|
@@ -2814,7 +2848,7 @@ class Validator:
|
|
|
2814
2848
|
environment: Environments,
|
|
2815
2849
|
schema: Schema,
|
|
2816
2850
|
model_type: ModelTypes,
|
|
2817
|
-
) ->
|
|
2851
|
+
) -> list[err.InvalidRecord]:
|
|
2818
2852
|
if environment in (Environments.VALIDATION, Environments.TRAINING):
|
|
2819
2853
|
return []
|
|
2820
2854
|
|
|
@@ -2858,11 +2892,11 @@ class Validator:
|
|
|
2858
2892
|
environment: Environments,
|
|
2859
2893
|
schema: Schema,
|
|
2860
2894
|
model_type: ModelTypes,
|
|
2861
|
-
) ->
|
|
2862
|
-
"""
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
|
|
2895
|
+
) -> list[err.InvalidRecord]:
|
|
2896
|
+
"""Validates there's not a single row in the dataframe with all nulls.
|
|
2897
|
+
|
|
2898
|
+
Returns errors if any row has all of pred_label and pred_score evaluating to
|
|
2899
|
+
null, OR all of actual_label and actual_score evaluating to null.
|
|
2866
2900
|
"""
|
|
2867
2901
|
if environment == Environments.PRODUCTION:
|
|
2868
2902
|
return []
|
|
@@ -2905,21 +2939,23 @@ class Validator:
|
|
|
2905
2939
|
|
|
2906
2940
|
@staticmethod
|
|
2907
2941
|
def _check_invalid_record_helper(
|
|
2908
|
-
dataframe: pd.DataFrame, column_names:
|
|
2909
|
-
) ->
|
|
2910
|
-
"""
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
|
|
2942
|
+
dataframe: pd.DataFrame, column_names: list[str | None]
|
|
2943
|
+
) -> list[err.InvalidRecord]:
|
|
2944
|
+
"""Check that there are no null values in a subset of columns.
|
|
2945
|
+
|
|
2946
|
+
The column subset is computed from the input list of columns `column_names`
|
|
2947
|
+
that are not None and that are present in the dataframe. Returns an error if
|
|
2948
|
+
null values are found.
|
|
2914
2949
|
|
|
2915
2950
|
Returns:
|
|
2916
2951
|
List[err.InvalidRecord]: An error expressing the rows that are problematic
|
|
2917
2952
|
|
|
2918
2953
|
"""
|
|
2919
|
-
columns_subset = [
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
2954
|
+
columns_subset = [
|
|
2955
|
+
col
|
|
2956
|
+
for col in column_names
|
|
2957
|
+
if col is not None and col in dataframe.columns
|
|
2958
|
+
]
|
|
2923
2959
|
if len(columns_subset) == 0:
|
|
2924
2960
|
return []
|
|
2925
2961
|
null_filter = dataframe[columns_subset].isnull().all(axis=1)
|
|
@@ -2930,8 +2966,8 @@ class Validator:
|
|
|
2930
2966
|
|
|
2931
2967
|
@staticmethod
|
|
2932
2968
|
def _check_type_prediction_group_id(
|
|
2933
|
-
schema: Schema, column_types:
|
|
2934
|
-
) ->
|
|
2969
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2970
|
+
) -> list[err.InvalidType]:
|
|
2935
2971
|
col = schema.prediction_group_id_column_name
|
|
2936
2972
|
if col in column_types:
|
|
2937
2973
|
# should mirror server side
|
|
@@ -2954,8 +2990,8 @@ class Validator:
|
|
|
2954
2990
|
|
|
2955
2991
|
@staticmethod
|
|
2956
2992
|
def _check_type_rank(
|
|
2957
|
-
schema: Schema, column_types:
|
|
2958
|
-
) ->
|
|
2993
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2994
|
+
) -> list[err.InvalidType]:
|
|
2959
2995
|
col = schema.rank_column_name
|
|
2960
2996
|
if col in column_types:
|
|
2961
2997
|
allowed_datatypes = (
|
|
@@ -2976,8 +3012,8 @@ class Validator:
|
|
|
2976
3012
|
|
|
2977
3013
|
@staticmethod
|
|
2978
3014
|
def _check_type_ranking_category(
|
|
2979
|
-
schema: Schema, column_types:
|
|
2980
|
-
) ->
|
|
3015
|
+
schema: Schema, column_types: dict[str, Any]
|
|
3016
|
+
) -> list[err.InvalidType]:
|
|
2981
3017
|
if schema.relevance_labels_column_name is not None:
|
|
2982
3018
|
col = schema.relevance_labels_column_name
|
|
2983
3019
|
elif schema.attributions_column_name is not None:
|
|
@@ -2999,7 +3035,7 @@ class Validator:
|
|
|
2999
3035
|
@staticmethod
|
|
3000
3036
|
def _check_value_bounding_boxes_coordinates(
|
|
3001
3037
|
dataframe: pd.DataFrame, schema: Schema
|
|
3002
|
-
) ->
|
|
3038
|
+
) -> list[err.InvalidBoundingBoxesCoordinates]:
|
|
3003
3039
|
errors = []
|
|
3004
3040
|
if schema.object_detection_prediction_column_names is not None:
|
|
3005
3041
|
coords_col_name = schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
|
|
@@ -3020,7 +3056,7 @@ class Validator:
|
|
|
3020
3056
|
@staticmethod
|
|
3021
3057
|
def _check_value_bounding_boxes_categories(
|
|
3022
3058
|
dataframe: pd.DataFrame, schema: Schema
|
|
3023
|
-
) ->
|
|
3059
|
+
) -> list[err.InvalidBoundingBoxesCategories]:
|
|
3024
3060
|
errors = []
|
|
3025
3061
|
if schema.object_detection_prediction_column_names is not None:
|
|
3026
3062
|
cat_col_name = schema.object_detection_prediction_column_names.categories_column_name
|
|
@@ -3041,7 +3077,7 @@ class Validator:
|
|
|
3041
3077
|
@staticmethod
|
|
3042
3078
|
def _check_value_bounding_boxes_scores(
|
|
3043
3079
|
dataframe: pd.DataFrame, schema: Schema
|
|
3044
|
-
) ->
|
|
3080
|
+
) -> list[err.InvalidBoundingBoxesScores]:
|
|
3045
3081
|
errors = []
|
|
3046
3082
|
if schema.object_detection_prediction_column_names is not None:
|
|
3047
3083
|
sc_col_name = schema.object_detection_prediction_column_names.scores_column_name
|
|
@@ -3066,7 +3102,7 @@ class Validator:
|
|
|
3066
3102
|
@staticmethod
|
|
3067
3103
|
def _check_value_semantic_segmentation_polygon_coordinates(
|
|
3068
3104
|
dataframe: pd.DataFrame, schema: Schema
|
|
3069
|
-
) ->
|
|
3105
|
+
) -> list[err.InvalidPolygonCoordinates]:
|
|
3070
3106
|
errors = []
|
|
3071
3107
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
3072
3108
|
coords_col_name = schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
|
|
@@ -3076,7 +3112,7 @@ class Validator:
|
|
|
3076
3112
|
if error is not None:
|
|
3077
3113
|
errors.append(error)
|
|
3078
3114
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
3079
|
-
coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3115
|
+
coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3080
3116
|
error = _check_value_polygon_coordinates_helper(
|
|
3081
3117
|
dataframe[coords_col_name]
|
|
3082
3118
|
)
|
|
@@ -3087,7 +3123,7 @@ class Validator:
|
|
|
3087
3123
|
@staticmethod
|
|
3088
3124
|
def _check_value_semantic_segmentation_polygon_categories(
|
|
3089
3125
|
dataframe: pd.DataFrame, schema: Schema
|
|
3090
|
-
) ->
|
|
3126
|
+
) -> list[err.InvalidPolygonCategories]:
|
|
3091
3127
|
errors = []
|
|
3092
3128
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
3093
3129
|
cat_col_name = schema.semantic_segmentation_prediction_column_names.categories_column_name
|
|
@@ -3108,7 +3144,7 @@ class Validator:
|
|
|
3108
3144
|
@staticmethod
|
|
3109
3145
|
def _check_value_instance_segmentation_polygon_coordinates(
|
|
3110
3146
|
dataframe: pd.DataFrame, schema: Schema
|
|
3111
|
-
) ->
|
|
3147
|
+
) -> list[err.InvalidPolygonCoordinates]:
|
|
3112
3148
|
errors = []
|
|
3113
3149
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3114
3150
|
coords_col_name = schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
|
|
@@ -3118,7 +3154,7 @@ class Validator:
|
|
|
3118
3154
|
if error is not None:
|
|
3119
3155
|
errors.append(error)
|
|
3120
3156
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
3121
|
-
coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3157
|
+
coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3122
3158
|
error = _check_value_polygon_coordinates_helper(
|
|
3123
3159
|
dataframe[coords_col_name]
|
|
3124
3160
|
)
|
|
@@ -3129,7 +3165,7 @@ class Validator:
|
|
|
3129
3165
|
@staticmethod
|
|
3130
3166
|
def _check_value_instance_segmentation_polygon_categories(
|
|
3131
3167
|
dataframe: pd.DataFrame, schema: Schema
|
|
3132
|
-
) ->
|
|
3168
|
+
) -> list[err.InvalidPolygonCategories]:
|
|
3133
3169
|
errors = []
|
|
3134
3170
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3135
3171
|
cat_col_name = schema.instance_segmentation_prediction_column_names.categories_column_name
|
|
@@ -3150,7 +3186,7 @@ class Validator:
|
|
|
3150
3186
|
@staticmethod
|
|
3151
3187
|
def _check_value_instance_segmentation_polygon_scores(
|
|
3152
3188
|
dataframe: pd.DataFrame, schema: Schema
|
|
3153
|
-
) ->
|
|
3189
|
+
) -> list[err.InvalidPolygonScores]:
|
|
3154
3190
|
errors = []
|
|
3155
3191
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3156
3192
|
sc_col_name = schema.instance_segmentation_prediction_column_names.scores_column_name
|
|
@@ -3165,7 +3201,7 @@ class Validator:
|
|
|
3165
3201
|
@staticmethod
|
|
3166
3202
|
def _check_value_instance_segmentation_bbox_coordinates(
|
|
3167
3203
|
dataframe: pd.DataFrame, schema: Schema
|
|
3168
|
-
) ->
|
|
3204
|
+
) -> list[err.InvalidBoundingBoxesCoordinates]:
|
|
3169
3205
|
errors = []
|
|
3170
3206
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3171
3207
|
coords_col_name = schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
|
|
@@ -3188,7 +3224,7 @@ class Validator:
|
|
|
3188
3224
|
@staticmethod
|
|
3189
3225
|
def _check_value_prompt_response(
|
|
3190
3226
|
dataframe: pd.DataFrame, schema: Schema
|
|
3191
|
-
) ->
|
|
3227
|
+
) -> list[err.ValidationError]:
|
|
3192
3228
|
vector_cols_to_check = []
|
|
3193
3229
|
text_cols_to_check = []
|
|
3194
3230
|
if isinstance(schema.prompt_column_names, str):
|
|
@@ -3253,7 +3289,7 @@ class Validator:
|
|
|
3253
3289
|
@staticmethod
|
|
3254
3290
|
def _check_value_llm_model_name(
|
|
3255
3291
|
dataframe: pd.DataFrame, schema: Schema
|
|
3256
|
-
) ->
|
|
3292
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3257
3293
|
if schema.llm_config_column_names is None:
|
|
3258
3294
|
return []
|
|
3259
3295
|
col = schema.llm_config_column_names.model_column_name
|
|
@@ -3270,7 +3306,7 @@ class Validator:
|
|
|
3270
3306
|
max_length=MAX_LLM_MODEL_NAME_LENGTH,
|
|
3271
3307
|
)
|
|
3272
3308
|
]
|
|
3273
|
-
|
|
3309
|
+
if max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
|
|
3274
3310
|
logger.warning(
|
|
3275
3311
|
get_truncation_warning_message(
|
|
3276
3312
|
"LLM model names", MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION
|
|
@@ -3281,7 +3317,7 @@ class Validator:
|
|
|
3281
3317
|
@staticmethod
|
|
3282
3318
|
def _check_value_llm_prompt_template(
|
|
3283
3319
|
dataframe: pd.DataFrame, schema: Schema
|
|
3284
|
-
) ->
|
|
3320
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3285
3321
|
if schema.prompt_template_column_names is None:
|
|
3286
3322
|
return []
|
|
3287
3323
|
col = schema.prompt_template_column_names.template_column_name
|
|
@@ -3298,7 +3334,7 @@ class Validator:
|
|
|
3298
3334
|
max_length=MAX_PROMPT_TEMPLATE_LENGTH,
|
|
3299
3335
|
)
|
|
3300
3336
|
]
|
|
3301
|
-
|
|
3337
|
+
if max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
|
|
3302
3338
|
logger.warning(
|
|
3303
3339
|
get_truncation_warning_message(
|
|
3304
3340
|
"prompt templates",
|
|
@@ -3310,7 +3346,7 @@ class Validator:
|
|
|
3310
3346
|
@staticmethod
|
|
3311
3347
|
def _check_value_llm_prompt_template_version(
|
|
3312
3348
|
dataframe: pd.DataFrame, schema: Schema
|
|
3313
|
-
) ->
|
|
3349
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3314
3350
|
if schema.prompt_template_column_names is None:
|
|
3315
3351
|
return []
|
|
3316
3352
|
col = schema.prompt_template_column_names.template_version_column_name
|
|
@@ -3327,7 +3363,7 @@ class Validator:
|
|
|
3327
3363
|
max_length=MAX_PROMPT_TEMPLATE_VERSION_LENGTH,
|
|
3328
3364
|
)
|
|
3329
3365
|
]
|
|
3330
|
-
|
|
3366
|
+
if max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
|
|
3331
3367
|
logger.warning(
|
|
3332
3368
|
get_truncation_warning_message(
|
|
3333
3369
|
"prompt template versions",
|
|
@@ -3338,8 +3374,8 @@ class Validator:
|
|
|
3338
3374
|
|
|
3339
3375
|
@staticmethod
|
|
3340
3376
|
def _check_type_document_columns(
|
|
3341
|
-
schema: CorpusSchema, column_types:
|
|
3342
|
-
) ->
|
|
3377
|
+
schema: CorpusSchema, column_types: dict[str, Any]
|
|
3378
|
+
) -> list[err.InvalidTypeColumns]:
|
|
3343
3379
|
invalid_types = []
|
|
3344
3380
|
# Check document id
|
|
3345
3381
|
col = schema.document_id_column_name
|
|
@@ -3424,16 +3460,15 @@ class Validator:
|
|
|
3424
3460
|
return []
|
|
3425
3461
|
|
|
3426
3462
|
|
|
3427
|
-
def _check_value_string_length_helper(x):
|
|
3463
|
+
def _check_value_string_length_helper(x: object) -> int:
|
|
3428
3464
|
if isinstance(x, str):
|
|
3429
3465
|
return len(x)
|
|
3430
|
-
|
|
3431
|
-
return 0
|
|
3466
|
+
return 0
|
|
3432
3467
|
|
|
3433
3468
|
|
|
3434
3469
|
def _check_value_vector_dimensionality_helper(
|
|
3435
|
-
dataframe: pd.DataFrame, cols_to_check:
|
|
3436
|
-
) ->
|
|
3470
|
+
dataframe: pd.DataFrame, cols_to_check: list[str]
|
|
3471
|
+
) -> tuple[list[str], list[str]]:
|
|
3437
3472
|
invalid_low_dimensionality_vector_cols = []
|
|
3438
3473
|
invalid_high_dimensionality_vector_cols = []
|
|
3439
3474
|
for col in cols_to_check:
|
|
@@ -3452,8 +3487,8 @@ def _check_value_vector_dimensionality_helper(
|
|
|
3452
3487
|
|
|
3453
3488
|
|
|
3454
3489
|
def _check_value_raw_data_length_helper(
|
|
3455
|
-
dataframe: pd.DataFrame, cols_to_check:
|
|
3456
|
-
) ->
|
|
3490
|
+
dataframe: pd.DataFrame, cols_to_check: list[str]
|
|
3491
|
+
) -> tuple[list[str], list[str]]:
|
|
3457
3492
|
invalid_long_string_data_cols = []
|
|
3458
3493
|
truncated_long_string_data_cols = []
|
|
3459
3494
|
for col in cols_to_check:
|
|
@@ -3469,7 +3504,7 @@ def _check_value_raw_data_length_helper(
|
|
|
3469
3504
|
)
|
|
3470
3505
|
except TypeError as exc:
|
|
3471
3506
|
e = TypeError(f"Cannot validate the column '{col}'. " + str(exc))
|
|
3472
|
-
logger.
|
|
3507
|
+
logger.exception(e)
|
|
3473
3508
|
raise e from exc
|
|
3474
3509
|
if max_data_len > MAX_RAW_DATA_CHARACTERS:
|
|
3475
3510
|
invalid_long_string_data_cols.append(col)
|
|
@@ -3480,8 +3515,8 @@ def _check_value_raw_data_length_helper(
|
|
|
3480
3515
|
|
|
3481
3516
|
def _check_value_bounding_boxes_coordinates_helper(
|
|
3482
3517
|
coordinates_col: pd.Series,
|
|
3483
|
-
) ->
|
|
3484
|
-
def check(boxes):
|
|
3518
|
+
) -> err.InvalidBoundingBoxesCoordinates | None:
|
|
3519
|
+
def check(boxes: object) -> None:
|
|
3485
3520
|
# We allow for zero boxes. None coordinates list is not allowed (will break following tests:
|
|
3486
3521
|
# 'NoneType is not iterable')
|
|
3487
3522
|
if boxes is None:
|
|
@@ -3502,7 +3537,9 @@ def _check_value_bounding_boxes_coordinates_helper(
|
|
|
3502
3537
|
return None
|
|
3503
3538
|
|
|
3504
3539
|
|
|
3505
|
-
def _box_coordinates_wrong_format(
|
|
3540
|
+
def _box_coordinates_wrong_format(
|
|
3541
|
+
box_coords: object,
|
|
3542
|
+
) -> err.InvalidBoundingBoxesCoordinates | None:
|
|
3506
3543
|
if (
|
|
3507
3544
|
# Coordinates should be a collection of 4 floats
|
|
3508
3545
|
len(box_coords) != 4
|
|
@@ -3516,12 +3553,13 @@ def _box_coordinates_wrong_format(box_coords):
|
|
|
3516
3553
|
return err.InvalidBoundingBoxesCoordinates(
|
|
3517
3554
|
reason="boxes_coordinates_wrong_format"
|
|
3518
3555
|
)
|
|
3556
|
+
return None
|
|
3519
3557
|
|
|
3520
3558
|
|
|
3521
3559
|
def _check_value_bounding_boxes_categories_helper(
|
|
3522
3560
|
categories_col: pd.Series,
|
|
3523
|
-
) ->
|
|
3524
|
-
def check(categories):
|
|
3561
|
+
) -> err.InvalidBoundingBoxesCategories | None:
|
|
3562
|
+
def check(categories: object) -> None:
|
|
3525
3563
|
# We allow for zero boxes. None category list is not allowed (will break following tests:
|
|
3526
3564
|
# 'NoneType is not iterable')
|
|
3527
3565
|
if categories is None:
|
|
@@ -3542,8 +3580,8 @@ def _check_value_bounding_boxes_categories_helper(
|
|
|
3542
3580
|
|
|
3543
3581
|
def _check_value_bounding_boxes_scores_helper(
|
|
3544
3582
|
scores_col: pd.Series,
|
|
3545
|
-
) ->
|
|
3546
|
-
def check(scores):
|
|
3583
|
+
) -> err.InvalidBoundingBoxesScores | None:
|
|
3584
|
+
def check(scores: object) -> None:
|
|
3547
3585
|
# We allow for zero boxes. None confidence score list is not allowed (will break following tests:
|
|
3548
3586
|
# 'NoneType is not iterable')
|
|
3549
3587
|
if scores is None:
|
|
@@ -3562,9 +3600,10 @@ def _check_value_bounding_boxes_scores_helper(
|
|
|
3562
3600
|
return None
|
|
3563
3601
|
|
|
3564
3602
|
|
|
3565
|
-
def _polygon_coordinates_wrong_format(
|
|
3566
|
-
|
|
3567
|
-
|
|
3603
|
+
def _polygon_coordinates_wrong_format(
|
|
3604
|
+
polygon_coords: object,
|
|
3605
|
+
) -> err.InvalidPolygonCoordinates | None:
|
|
3606
|
+
"""Check if polygon coordinates are valid.
|
|
3568
3607
|
|
|
3569
3608
|
Validates:
|
|
3570
3609
|
- Has at least 3 vertices (6 coordinates)
|
|
@@ -3610,9 +3649,9 @@ def _polygon_coordinates_wrong_format(polygon_coords):
|
|
|
3610
3649
|
|
|
3611
3650
|
# Check for self-intersections
|
|
3612
3651
|
# We need to check if any two non-adjacent edges intersect
|
|
3613
|
-
edges = [
|
|
3614
|
-
|
|
3615
|
-
|
|
3652
|
+
edges = [
|
|
3653
|
+
(points[i], points[(i + 1) % len(points)]) for i in range(len(points))
|
|
3654
|
+
]
|
|
3616
3655
|
|
|
3617
3656
|
for i in range(len(edges)):
|
|
3618
3657
|
for j in range(i + 2, len(edges)):
|
|
@@ -3634,8 +3673,8 @@ def _polygon_coordinates_wrong_format(polygon_coords):
|
|
|
3634
3673
|
|
|
3635
3674
|
def _check_value_polygon_coordinates_helper(
|
|
3636
3675
|
coordinates_col: pd.Series,
|
|
3637
|
-
) ->
|
|
3638
|
-
def check(polygons):
|
|
3676
|
+
) -> err.InvalidPolygonCoordinates | None:
|
|
3677
|
+
def check(polygons: object) -> None:
|
|
3639
3678
|
# We allow for zero polygons. None coordinates list is not allowed (will break following tests:
|
|
3640
3679
|
# 'NoneType is not iterable')
|
|
3641
3680
|
if polygons is None:
|
|
@@ -3658,8 +3697,8 @@ def _check_value_polygon_coordinates_helper(
|
|
|
3658
3697
|
|
|
3659
3698
|
def _check_value_polygon_categories_helper(
|
|
3660
3699
|
categories_col: pd.Series,
|
|
3661
|
-
) ->
|
|
3662
|
-
def check(categories):
|
|
3700
|
+
) -> err.InvalidPolygonCategories | None:
|
|
3701
|
+
def check(categories: object) -> None:
|
|
3663
3702
|
# We allow for zero boxes. None category list is not allowed (will break following tests:
|
|
3664
3703
|
# 'NoneType is not iterable')
|
|
3665
3704
|
if categories is None:
|
|
@@ -3678,8 +3717,8 @@ def _check_value_polygon_categories_helper(
|
|
|
3678
3717
|
|
|
3679
3718
|
def _check_value_polygon_scores_helper(
|
|
3680
3719
|
scores_col: pd.Series,
|
|
3681
|
-
) ->
|
|
3682
|
-
def check(scores):
|
|
3720
|
+
) -> err.InvalidPolygonScores | None:
|
|
3721
|
+
def check(scores: object) -> None:
|
|
3683
3722
|
# We allow for zero boxes. None confidence score list is not allowed (will break following tests:
|
|
3684
3723
|
# 'NoneType is not iterable')
|
|
3685
3724
|
if scores is None:
|
|
@@ -3696,7 +3735,7 @@ def _check_value_polygon_scores_helper(
|
|
|
3696
3735
|
return None
|
|
3697
3736
|
|
|
3698
3737
|
|
|
3699
|
-
def _count_characters_raw_data(data:
|
|
3738
|
+
def _count_characters_raw_data(data: str | list[str]) -> int:
|
|
3700
3739
|
character_count = 0
|
|
3701
3740
|
if isinstance(data, str):
|
|
3702
3741
|
character_count = len(data)
|