arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +28 -19
- arize/_exporter/client.py +56 -37
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +207 -76
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +181 -58
- arize/config.py +324 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +304 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +43 -18
- arize/embeddings/tabular_generators.py +46 -31
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +13 -0
- arize/experiments/client.py +394 -285
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/ml/__init__.py +1 -0
- arize/ml/batch_validation/__init__.py +1 -0
- arize/{models → ml}/batch_validation/errors.py +545 -67
- arize/{models → ml}/batch_validation/validator.py +344 -303
- arize/ml/bounded_executor.py +47 -0
- arize/{models → ml}/casting.py +118 -108
- arize/{models → ml}/client.py +339 -118
- arize/{models → ml}/proto.py +97 -42
- arize/{models → ml}/stream_validation.py +43 -15
- arize/ml/surrogate_explainer/__init__.py +1 -0
- arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
- arize/{types.py → ml/types.py} +355 -354
- arize/pre_releases.py +44 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +134 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +204 -175
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +60 -37
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +81 -14
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +35 -14
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +78 -8
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +20 -3
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/utils/types.py +105 -0
- arize/version.py +3 -1
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
- arize-8.0.0b0.dist-info/RECORD +175 -0
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
- arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize/models/__init__.py +0 -0
- arize/models/batch_validation/__init__.py +0 -0
- arize/models/bounded_executor.py +0 -34
- arize/models/surrogate_explainer/__init__.py +0 -0
- arize-8.0.0a22.dist-info/RECORD +0 -146
- arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
|
@@ -1,12 +1,15 @@
|
|
|
1
|
+
"""Batch validation logic for ML model predictions and actuals."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
|
-
import datetime
|
|
4
5
|
import logging
|
|
5
6
|
import math
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
6
8
|
from itertools import chain
|
|
7
|
-
from typing import
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
10
13
|
import pyarrow as pa
|
|
11
14
|
|
|
12
15
|
from arize.constants.ml import (
|
|
@@ -37,8 +40,8 @@ from arize.constants.ml import (
|
|
|
37
40
|
MODEL_MAPPING_CONFIG,
|
|
38
41
|
)
|
|
39
42
|
from arize.logging import get_truncation_warning_message
|
|
40
|
-
from arize.
|
|
41
|
-
from arize.types import (
|
|
43
|
+
from arize.ml.batch_validation import errors as err
|
|
44
|
+
from arize.ml.types import (
|
|
42
45
|
CATEGORICAL_MODEL_TYPES,
|
|
43
46
|
NUMERIC_MODEL_TYPES,
|
|
44
47
|
BaseSchema,
|
|
@@ -50,28 +53,29 @@ from arize.types import (
|
|
|
50
53
|
ModelTypes,
|
|
51
54
|
PromptTemplateColumnNames,
|
|
52
55
|
Schema,
|
|
56
|
+
segments_intersect,
|
|
57
|
+
)
|
|
58
|
+
from arize.utils.types import (
|
|
53
59
|
is_dict_of,
|
|
54
60
|
is_iterable_of,
|
|
55
|
-
segments_intersect,
|
|
56
61
|
)
|
|
57
62
|
|
|
58
|
-
if TYPE_CHECKING:
|
|
59
|
-
import pandas as pd
|
|
60
|
-
|
|
61
|
-
|
|
62
63
|
logger = logging.getLogger(__name__)
|
|
63
64
|
|
|
64
65
|
|
|
65
66
|
class Validator:
|
|
67
|
+
"""Validator for batch data with schema and dataframe validation methods."""
|
|
68
|
+
|
|
66
69
|
@staticmethod
|
|
67
70
|
def validate_required_checks(
|
|
68
71
|
dataframe: pd.DataFrame,
|
|
69
72
|
model_id: str,
|
|
70
73
|
environment: Environments,
|
|
71
74
|
schema: BaseSchema,
|
|
72
|
-
model_version:
|
|
73
|
-
batch_id:
|
|
74
|
-
) ->
|
|
75
|
+
model_version: str | None = None,
|
|
76
|
+
batch_id: str | None = None,
|
|
77
|
+
) -> list[err.ValidationError]:
|
|
78
|
+
"""Validate required checks for schema, environment, and DataFrame structure."""
|
|
75
79
|
general_checks = chain(
|
|
76
80
|
Validator._check_valid_schema_type(schema, environment),
|
|
77
81
|
Validator._check_field_convertible_to_str(
|
|
@@ -87,7 +91,7 @@ class Validator:
|
|
|
87
91
|
schema, CorpusSchema
|
|
88
92
|
):
|
|
89
93
|
return list(general_checks)
|
|
90
|
-
|
|
94
|
+
if isinstance(schema, Schema):
|
|
91
95
|
return list(
|
|
92
96
|
chain(
|
|
93
97
|
general_checks,
|
|
@@ -108,10 +112,11 @@ class Validator:
|
|
|
108
112
|
model_type: ModelTypes,
|
|
109
113
|
environment: Environments,
|
|
110
114
|
schema: BaseSchema,
|
|
111
|
-
metric_families:
|
|
112
|
-
model_version:
|
|
113
|
-
batch_id:
|
|
114
|
-
) ->
|
|
115
|
+
metric_families: list[Metrics] | None = None,
|
|
116
|
+
model_version: str | None = None,
|
|
117
|
+
batch_id: str | None = None,
|
|
118
|
+
) -> list[err.ValidationError]:
|
|
119
|
+
"""Validate parameters including model type, environment, and schema consistency."""
|
|
115
120
|
# general checks
|
|
116
121
|
general_checks = chain(
|
|
117
122
|
Validator._check_column_names_for_empty_strings(schema),
|
|
@@ -125,7 +130,7 @@ class Validator:
|
|
|
125
130
|
)
|
|
126
131
|
if isinstance(schema, CorpusSchema):
|
|
127
132
|
return list(general_checks)
|
|
128
|
-
|
|
133
|
+
if isinstance(schema, Schema):
|
|
129
134
|
general_checks = chain(
|
|
130
135
|
general_checks,
|
|
131
136
|
Validator._check_existence_prediction_id_column_delayed_schema(
|
|
@@ -153,7 +158,7 @@ class Validator:
|
|
|
153
158
|
),
|
|
154
159
|
)
|
|
155
160
|
return list(chain(general_checks, num_checks))
|
|
156
|
-
|
|
161
|
+
if model_type in CATEGORICAL_MODEL_TYPES:
|
|
157
162
|
sc_checks = chain(
|
|
158
163
|
Validator._check_existence_preprod_pred_act_score_or_label(
|
|
159
164
|
schema, environment
|
|
@@ -166,7 +171,7 @@ class Validator:
|
|
|
166
171
|
),
|
|
167
172
|
)
|
|
168
173
|
return list(chain(general_checks, sc_checks))
|
|
169
|
-
|
|
174
|
+
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
170
175
|
gllm_checks = chain(
|
|
171
176
|
Validator._check_existence_preprod_act(schema, environment),
|
|
172
177
|
Validator._check_missing_object_detection_columns(
|
|
@@ -177,7 +182,7 @@ class Validator:
|
|
|
177
182
|
),
|
|
178
183
|
)
|
|
179
184
|
return list(chain(general_checks, gllm_checks))
|
|
180
|
-
|
|
185
|
+
if model_type == ModelTypes.RANKING:
|
|
181
186
|
r_checks = chain(
|
|
182
187
|
Validator._check_existence_group_id_rank_category_relevance(
|
|
183
188
|
schema
|
|
@@ -190,7 +195,7 @@ class Validator:
|
|
|
190
195
|
),
|
|
191
196
|
)
|
|
192
197
|
return list(chain(general_checks, r_checks))
|
|
193
|
-
|
|
198
|
+
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
194
199
|
od_checks = chain(
|
|
195
200
|
Validator._check_exactly_one_cv_column_type(
|
|
196
201
|
schema, environment
|
|
@@ -203,7 +208,7 @@ class Validator:
|
|
|
203
208
|
),
|
|
204
209
|
)
|
|
205
210
|
return list(chain(general_checks, od_checks))
|
|
206
|
-
|
|
211
|
+
if model_type == ModelTypes.MULTI_CLASS:
|
|
207
212
|
multi_class_checks = chain(
|
|
208
213
|
Validator._check_existing_multi_class_columns(schema),
|
|
209
214
|
Validator._check_missing_non_multi_class_columns(
|
|
@@ -218,8 +223,11 @@ class Validator:
|
|
|
218
223
|
model_type: ModelTypes,
|
|
219
224
|
schema: BaseSchema,
|
|
220
225
|
pyarrow_schema: pa.Schema,
|
|
221
|
-
) ->
|
|
222
|
-
|
|
226
|
+
) -> list[err.ValidationError]:
|
|
227
|
+
"""Validate column data types against expected types for the schema."""
|
|
228
|
+
column_types = dict(
|
|
229
|
+
zip(pyarrow_schema.names, pyarrow_schema.types, strict=True)
|
|
230
|
+
)
|
|
223
231
|
|
|
224
232
|
if isinstance(schema, CorpusSchema):
|
|
225
233
|
return list(
|
|
@@ -227,7 +235,7 @@ class Validator:
|
|
|
227
235
|
Validator._check_type_document_columns(schema, column_types)
|
|
228
236
|
)
|
|
229
237
|
)
|
|
230
|
-
|
|
238
|
+
if isinstance(schema, Schema):
|
|
231
239
|
general_checks = chain(
|
|
232
240
|
Validator._check_type_prediction_id(schema, column_types),
|
|
233
241
|
Validator._check_type_timestamp(schema, column_types),
|
|
@@ -271,7 +279,7 @@ class Validator:
|
|
|
271
279
|
),
|
|
272
280
|
)
|
|
273
281
|
return list(chain(general_checks, gllm_checks))
|
|
274
|
-
|
|
282
|
+
if model_type == ModelTypes.RANKING:
|
|
275
283
|
r_checks = chain(
|
|
276
284
|
Validator._check_type_prediction_group_id(
|
|
277
285
|
schema, column_types
|
|
@@ -285,7 +293,7 @@ class Validator:
|
|
|
285
293
|
),
|
|
286
294
|
)
|
|
287
295
|
return list(chain(general_checks, r_checks))
|
|
288
|
-
|
|
296
|
+
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
289
297
|
od_checks = chain(
|
|
290
298
|
Validator._check_type_image_segment_coordinates(
|
|
291
299
|
schema, column_types
|
|
@@ -298,7 +306,7 @@ class Validator:
|
|
|
298
306
|
),
|
|
299
307
|
)
|
|
300
308
|
return list(chain(general_checks, od_checks))
|
|
301
|
-
|
|
309
|
+
if model_type == ModelTypes.MULTI_CLASS:
|
|
302
310
|
multi_class_checks = chain(
|
|
303
311
|
Validator._check_type_multi_class_pred_threshold_act_scores(
|
|
304
312
|
schema, column_types
|
|
@@ -315,7 +323,8 @@ class Validator:
|
|
|
315
323
|
environment: Environments,
|
|
316
324
|
schema: BaseSchema,
|
|
317
325
|
model_type: ModelTypes,
|
|
318
|
-
) ->
|
|
326
|
+
) -> list[err.ValidationError]:
|
|
327
|
+
"""Validate data values including ranges, formats, and consistency checks."""
|
|
319
328
|
# ASSUMPTION: at this point the param and type checks should have passed.
|
|
320
329
|
# This function may crash if that is not true, e.g. if columns are missing
|
|
321
330
|
# or are of the wrong types.
|
|
@@ -338,7 +347,7 @@ class Validator:
|
|
|
338
347
|
),
|
|
339
348
|
)
|
|
340
349
|
)
|
|
341
|
-
|
|
350
|
+
if isinstance(schema, Schema):
|
|
342
351
|
general_checks = chain(
|
|
343
352
|
general_checks,
|
|
344
353
|
Validator._check_value_timestamp(dataframe, schema),
|
|
@@ -429,21 +438,21 @@ class Validator:
|
|
|
429
438
|
return list(general_checks)
|
|
430
439
|
return []
|
|
431
440
|
|
|
432
|
-
#
|
|
433
|
-
# Minimum
|
|
434
|
-
#
|
|
441
|
+
# -----------------------
|
|
442
|
+
# Minimum required checks
|
|
443
|
+
# -----------------------
|
|
435
444
|
@staticmethod
|
|
436
445
|
def _check_column_names_for_empty_strings(
|
|
437
446
|
schema: BaseSchema,
|
|
438
|
-
) ->
|
|
447
|
+
) -> list[err.InvalidColumnNameEmptyString]:
|
|
439
448
|
if "" in schema.get_used_columns():
|
|
440
449
|
return [err.InvalidColumnNameEmptyString()]
|
|
441
450
|
return []
|
|
442
451
|
|
|
443
452
|
@staticmethod
|
|
444
453
|
def _check_field_convertible_to_str(
|
|
445
|
-
model_id, model_version, batch_id
|
|
446
|
-
) ->
|
|
454
|
+
model_id: object, model_version: object, batch_id: object
|
|
455
|
+
) -> list[err.InvalidFieldTypeConversion]:
|
|
447
456
|
# converting to a set first makes the checks run a lot faster
|
|
448
457
|
wrong_fields = []
|
|
449
458
|
if model_id is not None and not isinstance(model_id, str):
|
|
@@ -469,7 +478,7 @@ class Validator:
|
|
|
469
478
|
@staticmethod
|
|
470
479
|
def _check_field_type_embedding_features_column_names(
|
|
471
480
|
schema: Schema,
|
|
472
|
-
) ->
|
|
481
|
+
) -> list[err.InvalidFieldTypeEmbeddingFeatures]:
|
|
473
482
|
if schema.embedding_feature_column_names is not None:
|
|
474
483
|
if not isinstance(schema.embedding_feature_column_names, dict):
|
|
475
484
|
return [err.InvalidFieldTypeEmbeddingFeatures()]
|
|
@@ -483,7 +492,7 @@ class Validator:
|
|
|
483
492
|
@staticmethod
|
|
484
493
|
def _check_field_type_prompt_response(
|
|
485
494
|
schema: Schema,
|
|
486
|
-
) ->
|
|
495
|
+
) -> list[err.InvalidFieldTypePromptResponse]:
|
|
487
496
|
errors = []
|
|
488
497
|
if schema.prompt_column_names is not None and not isinstance(
|
|
489
498
|
schema.prompt_column_names, (str, EmbeddingColumnNames)
|
|
@@ -502,7 +511,7 @@ class Validator:
|
|
|
502
511
|
@staticmethod
|
|
503
512
|
def _check_field_type_prompt_templates(
|
|
504
513
|
schema: Schema,
|
|
505
|
-
) ->
|
|
514
|
+
) -> list[err.InvalidFieldTypePromptTemplates]:
|
|
506
515
|
if schema.prompt_template_column_names is not None and not isinstance(
|
|
507
516
|
schema.prompt_template_column_names, PromptTemplateColumnNames
|
|
508
517
|
):
|
|
@@ -513,7 +522,7 @@ class Validator:
|
|
|
513
522
|
def _check_field_type_llm_config(
|
|
514
523
|
dataframe: pd.DataFrame,
|
|
515
524
|
schema: Schema,
|
|
516
|
-
) ->
|
|
525
|
+
) -> list[err.InvalidFieldTypeLlmConfig | err.InvalidTypeColumns]:
|
|
517
526
|
if schema.llm_config_column_names is None:
|
|
518
527
|
return []
|
|
519
528
|
if not isinstance(schema.llm_config_column_names, LLMConfigColumnNames):
|
|
@@ -548,7 +557,7 @@ class Validator:
|
|
|
548
557
|
@staticmethod
|
|
549
558
|
def _check_invalid_index(
|
|
550
559
|
dataframe: pd.DataFrame,
|
|
551
|
-
) ->
|
|
560
|
+
) -> list[err.InvalidDataFrameIndex]:
|
|
552
561
|
if (dataframe.index != dataframe.reset_index(drop=True).index).any():
|
|
553
562
|
return [err.InvalidDataFrameIndex()]
|
|
554
563
|
return []
|
|
@@ -560,9 +569,9 @@ class Validator:
|
|
|
560
569
|
@staticmethod
|
|
561
570
|
def _check_model_type_and_metrics(
|
|
562
571
|
model_type: ModelTypes,
|
|
563
|
-
metric_families:
|
|
572
|
+
metric_families: list[Metrics] | None,
|
|
564
573
|
schema: Schema,
|
|
565
|
-
) ->
|
|
574
|
+
) -> list[err.ValidationError]:
|
|
566
575
|
if metric_families is None:
|
|
567
576
|
return []
|
|
568
577
|
|
|
@@ -606,10 +615,10 @@ class Validator:
|
|
|
606
615
|
@staticmethod
|
|
607
616
|
def _check_model_mapping_combinations(
|
|
608
617
|
model_type: ModelTypes,
|
|
609
|
-
metric_families:
|
|
618
|
+
metric_families: list[Metrics],
|
|
610
619
|
schema: Schema,
|
|
611
|
-
required_columns_map:
|
|
612
|
-
) ->
|
|
620
|
+
required_columns_map: list[dict[str, Any]],
|
|
621
|
+
) -> tuple[bool, list[str], list[list[str]]]:
|
|
613
622
|
missing_columns = []
|
|
614
623
|
for item in required_columns_map:
|
|
615
624
|
if model_type.name.lower() == item.get("external_model_type"):
|
|
@@ -625,10 +634,10 @@ class Validator:
|
|
|
625
634
|
metric_combinations.append(
|
|
626
635
|
[metric.upper() for metric in metrics_list]
|
|
627
636
|
)
|
|
628
|
-
if set(metrics_list) ==
|
|
637
|
+
if set(metrics_list) == {
|
|
629
638
|
metric_family.name.lower()
|
|
630
639
|
for metric_family in metric_families
|
|
631
|
-
|
|
640
|
+
}:
|
|
632
641
|
# This is a valid combination of model type + metrics.
|
|
633
642
|
# Now validate that required columns are in the schema.
|
|
634
643
|
is_valid_combination = True
|
|
@@ -665,10 +674,10 @@ class Validator:
|
|
|
665
674
|
@staticmethod
|
|
666
675
|
def _check_existence_prediction_id_column_delayed_schema(
|
|
667
676
|
schema: Schema, model_type: ModelTypes
|
|
668
|
-
) ->
|
|
677
|
+
) -> list[err.MissingPredictionIdColumnForDelayedRecords]:
|
|
669
678
|
if schema.prediction_id_column_name is not None:
|
|
670
679
|
return []
|
|
671
|
-
# TODO: Revise logic once
|
|
680
|
+
# TODO: Revise logic once prediction_label column addition (for generative models)
|
|
672
681
|
# is moved to beginning of log function
|
|
673
682
|
if schema.is_delayed() and model_type is not ModelTypes.GENERATIVE_LLM:
|
|
674
683
|
# We skip GENERATIVE model types since they are assigned a default
|
|
@@ -696,12 +705,12 @@ class Validator:
|
|
|
696
705
|
def _check_missing_columns(
|
|
697
706
|
dataframe: pd.DataFrame,
|
|
698
707
|
schema: BaseSchema,
|
|
699
|
-
) ->
|
|
708
|
+
) -> list[err.MissingColumns]:
|
|
700
709
|
if isinstance(schema, CorpusSchema):
|
|
701
710
|
return Validator._check_missing_columns_corpus_schema(
|
|
702
711
|
dataframe, schema
|
|
703
712
|
)
|
|
704
|
-
|
|
713
|
+
if isinstance(schema, Schema):
|
|
705
714
|
return Validator._check_missing_columns_schema(dataframe, schema)
|
|
706
715
|
return []
|
|
707
716
|
|
|
@@ -709,7 +718,7 @@ class Validator:
|
|
|
709
718
|
def _check_missing_columns_schema(
|
|
710
719
|
dataframe: pd.DataFrame,
|
|
711
720
|
schema: Schema,
|
|
712
|
-
) ->
|
|
721
|
+
) -> list[err.MissingColumns]:
|
|
713
722
|
# converting to a set first makes the checks run a lot faster
|
|
714
723
|
existing_columns = set(dataframe.columns)
|
|
715
724
|
missing_columns = []
|
|
@@ -721,9 +730,13 @@ class Validator:
|
|
|
721
730
|
missing_columns.append(col)
|
|
722
731
|
|
|
723
732
|
if schema.feature_column_names is not None:
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
733
|
+
missing_columns.extend(
|
|
734
|
+
[
|
|
735
|
+
col
|
|
736
|
+
for col in schema.feature_column_names
|
|
737
|
+
if col not in existing_columns
|
|
738
|
+
]
|
|
739
|
+
)
|
|
727
740
|
|
|
728
741
|
if schema.embedding_feature_column_names is not None:
|
|
729
742
|
for (
|
|
@@ -752,44 +765,76 @@ class Validator:
|
|
|
752
765
|
)
|
|
753
766
|
|
|
754
767
|
if schema.tag_column_names is not None:
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
768
|
+
missing_columns.extend(
|
|
769
|
+
[
|
|
770
|
+
col
|
|
771
|
+
for col in schema.tag_column_names
|
|
772
|
+
if col not in existing_columns
|
|
773
|
+
]
|
|
774
|
+
)
|
|
758
775
|
|
|
759
776
|
if schema.shap_values_column_names is not None:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
777
|
+
missing_columns.extend(
|
|
778
|
+
[
|
|
779
|
+
col
|
|
780
|
+
for col in schema.shap_values_column_names.values()
|
|
781
|
+
if col not in existing_columns
|
|
782
|
+
]
|
|
783
|
+
)
|
|
763
784
|
|
|
764
785
|
if schema.object_detection_prediction_column_names is not None:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
786
|
+
missing_columns.extend(
|
|
787
|
+
[
|
|
788
|
+
col
|
|
789
|
+
for col in schema.object_detection_prediction_column_names
|
|
790
|
+
if col is not None and col not in existing_columns
|
|
791
|
+
]
|
|
792
|
+
)
|
|
768
793
|
|
|
769
794
|
if schema.object_detection_actual_column_names is not None:
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
795
|
+
missing_columns.extend(
|
|
796
|
+
[
|
|
797
|
+
col
|
|
798
|
+
for col in schema.object_detection_actual_column_names
|
|
799
|
+
if col is not None and col not in existing_columns
|
|
800
|
+
]
|
|
801
|
+
)
|
|
773
802
|
|
|
774
803
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
804
|
+
missing_columns.extend(
|
|
805
|
+
[
|
|
806
|
+
col
|
|
807
|
+
for col in schema.semantic_segmentation_prediction_column_names
|
|
808
|
+
if col is not None and col not in existing_columns
|
|
809
|
+
]
|
|
810
|
+
)
|
|
778
811
|
|
|
779
812
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
813
|
+
missing_columns.extend(
|
|
814
|
+
[
|
|
815
|
+
col
|
|
816
|
+
for col in schema.semantic_segmentation_actual_column_names
|
|
817
|
+
if col is not None and col not in existing_columns
|
|
818
|
+
]
|
|
819
|
+
)
|
|
783
820
|
|
|
784
821
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
822
|
+
missing_columns.extend(
|
|
823
|
+
[
|
|
824
|
+
col
|
|
825
|
+
for col in schema.instance_segmentation_prediction_column_names
|
|
826
|
+
if col is not None and col not in existing_columns
|
|
827
|
+
]
|
|
828
|
+
)
|
|
788
829
|
|
|
789
830
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
831
|
+
missing_columns.extend(
|
|
832
|
+
[
|
|
833
|
+
col
|
|
834
|
+
for col in schema.instance_segmentation_actual_column_names
|
|
835
|
+
if col is not None and col not in existing_columns
|
|
836
|
+
]
|
|
837
|
+
)
|
|
793
838
|
|
|
794
839
|
if schema.prompt_column_names is not None:
|
|
795
840
|
if isinstance(schema.prompt_column_names, str):
|
|
@@ -838,14 +883,22 @@ class Validator:
|
|
|
838
883
|
)
|
|
839
884
|
|
|
840
885
|
if schema.prompt_template_column_names is not None:
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
886
|
+
missing_columns.extend(
|
|
887
|
+
[
|
|
888
|
+
col
|
|
889
|
+
for col in schema.prompt_template_column_names
|
|
890
|
+
if col is not None and col not in existing_columns
|
|
891
|
+
]
|
|
892
|
+
)
|
|
844
893
|
|
|
845
894
|
if schema.llm_config_column_names is not None:
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
895
|
+
missing_columns.extend(
|
|
896
|
+
[
|
|
897
|
+
col
|
|
898
|
+
for col in schema.llm_config_column_names
|
|
899
|
+
if col is not None and col not in existing_columns
|
|
900
|
+
]
|
|
901
|
+
)
|
|
849
902
|
|
|
850
903
|
if missing_columns:
|
|
851
904
|
return [err.MissingColumns(missing_columns)]
|
|
@@ -855,7 +908,7 @@ class Validator:
|
|
|
855
908
|
def _check_missing_columns_corpus_schema(
|
|
856
909
|
dataframe: pd.DataFrame,
|
|
857
910
|
schema: CorpusSchema,
|
|
858
|
-
) ->
|
|
911
|
+
) -> list[err.MissingColumns]:
|
|
859
912
|
# converting to a set first makes the checks run a lot faster
|
|
860
913
|
existing_columns = set(dataframe.columns)
|
|
861
914
|
missing_columns = []
|
|
@@ -912,7 +965,7 @@ class Validator:
|
|
|
912
965
|
def _check_valid_schema_type(
|
|
913
966
|
schema: BaseSchema,
|
|
914
967
|
environment: Environments,
|
|
915
|
-
) ->
|
|
968
|
+
) -> list[err.InvalidSchemaType]:
|
|
916
969
|
if environment == Environments.CORPUS and not (
|
|
917
970
|
isinstance(schema, CorpusSchema)
|
|
918
971
|
):
|
|
@@ -934,7 +987,7 @@ class Validator:
|
|
|
934
987
|
@staticmethod
|
|
935
988
|
def _check_invalid_shap_suffix(
|
|
936
989
|
schema: Schema,
|
|
937
|
-
) ->
|
|
990
|
+
) -> list[err.InvalidShapSuffix]:
|
|
938
991
|
invalid_column_names = set()
|
|
939
992
|
|
|
940
993
|
if schema.feature_column_names is not None:
|
|
@@ -970,10 +1023,10 @@ class Validator:
|
|
|
970
1023
|
def _check_reserved_columns(
|
|
971
1024
|
schema: BaseSchema,
|
|
972
1025
|
model_type: ModelTypes,
|
|
973
|
-
) ->
|
|
1026
|
+
) -> list[err.ReservedColumns]:
|
|
974
1027
|
if isinstance(schema, CorpusSchema):
|
|
975
1028
|
return []
|
|
976
|
-
|
|
1029
|
+
if isinstance(schema, Schema):
|
|
977
1030
|
reserved_columns = []
|
|
978
1031
|
column_counts = schema.get_used_columns_counts()
|
|
979
1032
|
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
@@ -1079,8 +1132,8 @@ class Validator:
|
|
|
1079
1132
|
|
|
1080
1133
|
@staticmethod
|
|
1081
1134
|
def _check_invalid_model_id(
|
|
1082
|
-
model_id:
|
|
1083
|
-
) ->
|
|
1135
|
+
model_id: str | None,
|
|
1136
|
+
) -> list[err.InvalidModelId]:
|
|
1084
1137
|
# assume it's been coerced to string beforehand
|
|
1085
1138
|
if (not isinstance(model_id, str)) or len(model_id.strip()) == 0:
|
|
1086
1139
|
return [err.InvalidModelId()]
|
|
@@ -1088,8 +1141,8 @@ class Validator:
|
|
|
1088
1141
|
|
|
1089
1142
|
@staticmethod
|
|
1090
1143
|
def _check_invalid_model_version(
|
|
1091
|
-
model_version:
|
|
1092
|
-
) ->
|
|
1144
|
+
model_version: str | None = None,
|
|
1145
|
+
) -> list[err.InvalidModelVersion]:
|
|
1093
1146
|
if model_version is None:
|
|
1094
1147
|
return []
|
|
1095
1148
|
if (
|
|
@@ -1102,9 +1155,9 @@ class Validator:
|
|
|
1102
1155
|
|
|
1103
1156
|
@staticmethod
|
|
1104
1157
|
def _check_invalid_batch_id(
|
|
1105
|
-
batch_id:
|
|
1158
|
+
batch_id: str | None,
|
|
1106
1159
|
environment: Environments,
|
|
1107
|
-
) ->
|
|
1160
|
+
) -> list[err.InvalidBatchId]:
|
|
1108
1161
|
# assume it's been coerced to string beforehand
|
|
1109
1162
|
if environment in (Environments.VALIDATION,) and (
|
|
1110
1163
|
(not isinstance(batch_id, str)) or len(batch_id.strip()) == 0
|
|
@@ -1115,7 +1168,7 @@ class Validator:
|
|
|
1115
1168
|
@staticmethod
|
|
1116
1169
|
def _check_invalid_model_type(
|
|
1117
1170
|
model_type: ModelTypes,
|
|
1118
|
-
) ->
|
|
1171
|
+
) -> list[err.InvalidModelType]:
|
|
1119
1172
|
if model_type in (mt for mt in ModelTypes):
|
|
1120
1173
|
return []
|
|
1121
1174
|
return [err.InvalidModelType()]
|
|
@@ -1123,7 +1176,7 @@ class Validator:
|
|
|
1123
1176
|
@staticmethod
|
|
1124
1177
|
def _check_invalid_environment(
|
|
1125
1178
|
environment: Environments,
|
|
1126
|
-
) ->
|
|
1179
|
+
) -> list[err.InvalidEnvironment]:
|
|
1127
1180
|
if environment in (env for env in Environments):
|
|
1128
1181
|
return []
|
|
1129
1182
|
return [err.InvalidEnvironment()]
|
|
@@ -1132,7 +1185,7 @@ class Validator:
|
|
|
1132
1185
|
def _check_existence_preprod_pred_act_score_or_label(
|
|
1133
1186
|
schema: Schema,
|
|
1134
1187
|
environment: Environments,
|
|
1135
|
-
) ->
|
|
1188
|
+
) -> list[err.MissingPreprodPredActNumericAndCategorical]:
|
|
1136
1189
|
if environment in (Environments.VALIDATION, Environments.TRAINING) and (
|
|
1137
1190
|
(
|
|
1138
1191
|
schema.prediction_label_column_name is None
|
|
@@ -1149,7 +1202,7 @@ class Validator:
|
|
|
1149
1202
|
@staticmethod
|
|
1150
1203
|
def _check_exactly_one_cv_column_type(
|
|
1151
1204
|
schema: Schema, environment: Environments
|
|
1152
|
-
) ->
|
|
1205
|
+
) -> list[err.MultipleCVPredAct | err.MissingCVPredAct]:
|
|
1153
1206
|
# Checks that the required prediction/actual columns are given in the schema depending on
|
|
1154
1207
|
# the environment, for object detection models. There should be exactly one of
|
|
1155
1208
|
# object detection, semantic segmentation, or instance segmentation columns.
|
|
@@ -1180,7 +1233,7 @@ class Validator:
|
|
|
1180
1233
|
|
|
1181
1234
|
if cv_types_count == 0:
|
|
1182
1235
|
return [err.MissingCVPredAct(environment)]
|
|
1183
|
-
|
|
1236
|
+
if cv_types_count > 1:
|
|
1184
1237
|
return [err.MultipleCVPredAct(environment)]
|
|
1185
1238
|
|
|
1186
1239
|
elif environment in (
|
|
@@ -1213,7 +1266,7 @@ class Validator:
|
|
|
1213
1266
|
|
|
1214
1267
|
if cv_types_count == 0:
|
|
1215
1268
|
return [err.MissingCVPredAct(environment)]
|
|
1216
|
-
|
|
1269
|
+
if cv_types_count > 1:
|
|
1217
1270
|
return [err.MultipleCVPredAct(environment)]
|
|
1218
1271
|
|
|
1219
1272
|
return []
|
|
@@ -1221,9 +1274,9 @@ class Validator:
|
|
|
1221
1274
|
@staticmethod
|
|
1222
1275
|
def _check_missing_object_detection_columns(
|
|
1223
1276
|
schema: Schema, model_type: ModelTypes
|
|
1224
|
-
) ->
|
|
1277
|
+
) -> list[err.InvalidPredActCVColumnNamesForModelType]:
|
|
1225
1278
|
# Checks that models that are not Object Detection models don't have, in the schema, the
|
|
1226
|
-
# object detection, semantic segmentation, or instance segmentation dedicated
|
|
1279
|
+
# object detection, semantic segmentation, or instance segmentation dedicated prediction/actual
|
|
1227
1280
|
# column names
|
|
1228
1281
|
if (
|
|
1229
1282
|
schema.object_detection_prediction_column_names is not None
|
|
@@ -1239,7 +1292,7 @@ class Validator:
|
|
|
1239
1292
|
@staticmethod
|
|
1240
1293
|
def _check_missing_non_object_detection_columns(
|
|
1241
1294
|
schema: Schema, model_type: ModelTypes
|
|
1242
|
-
) ->
|
|
1295
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1243
1296
|
# Checks that object detection models don't have, in the schema, the columns reserved for
|
|
1244
1297
|
# other model types
|
|
1245
1298
|
columns_to_check = (
|
|
@@ -1253,10 +1306,7 @@ class Validator:
|
|
|
1253
1306
|
schema.relevance_score_column_name,
|
|
1254
1307
|
schema.relevance_labels_column_name,
|
|
1255
1308
|
)
|
|
1256
|
-
wrong_cols = []
|
|
1257
|
-
for col in columns_to_check:
|
|
1258
|
-
if col is not None:
|
|
1259
|
-
wrong_cols.append(col)
|
|
1309
|
+
wrong_cols = [col for col in columns_to_check if col is not None]
|
|
1260
1310
|
if wrong_cols:
|
|
1261
1311
|
allowed_cols = [
|
|
1262
1312
|
"object_detection_prediction_column_names",
|
|
@@ -1276,7 +1326,7 @@ class Validator:
|
|
|
1276
1326
|
@staticmethod
|
|
1277
1327
|
def _check_missing_multi_class_columns(
|
|
1278
1328
|
schema: Schema, model_type: ModelTypes
|
|
1279
|
-
) ->
|
|
1329
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1280
1330
|
# Checks that models that are not Multi Class models don't have, in the schema, the
|
|
1281
1331
|
# multi class dedicated threshold column
|
|
1282
1332
|
if (
|
|
@@ -1295,7 +1345,7 @@ class Validator:
|
|
|
1295
1345
|
@staticmethod
|
|
1296
1346
|
def _check_existing_multi_class_columns(
|
|
1297
1347
|
schema: Schema,
|
|
1298
|
-
) ->
|
|
1348
|
+
) -> list[err.MissingReqPredActColumnNamesForMultiClass]:
|
|
1299
1349
|
# Checks that models that are Multi Class models have, in the schema, the
|
|
1300
1350
|
# required prediction score or actual score columns
|
|
1301
1351
|
if (
|
|
@@ -1311,7 +1361,7 @@ class Validator:
|
|
|
1311
1361
|
@staticmethod
|
|
1312
1362
|
def _check_missing_non_multi_class_columns(
|
|
1313
1363
|
schema: Schema, model_type: ModelTypes
|
|
1314
|
-
) ->
|
|
1364
|
+
) -> list[err.InvalidPredActColumnNamesForModelType]:
|
|
1315
1365
|
# Checks that multi class models don't have, in the schema, the columns reserved for
|
|
1316
1366
|
# other model types
|
|
1317
1367
|
columns_to_check = (
|
|
@@ -1329,10 +1379,7 @@ class Validator:
|
|
|
1329
1379
|
schema.instance_segmentation_prediction_column_names,
|
|
1330
1380
|
schema.instance_segmentation_actual_column_names,
|
|
1331
1381
|
)
|
|
1332
|
-
wrong_cols = []
|
|
1333
|
-
for col in columns_to_check:
|
|
1334
|
-
if col is not None:
|
|
1335
|
-
wrong_cols.append(col)
|
|
1382
|
+
wrong_cols = [col for col in columns_to_check if col is not None]
|
|
1336
1383
|
if wrong_cols:
|
|
1337
1384
|
allowed_cols = [
|
|
1338
1385
|
"prediction_score_column_name",
|
|
@@ -1350,7 +1397,7 @@ class Validator:
|
|
|
1350
1397
|
def _check_existence_preprod_act(
|
|
1351
1398
|
schema: Schema,
|
|
1352
1399
|
environment: Environments,
|
|
1353
|
-
) ->
|
|
1400
|
+
) -> list[err.MissingPreprodAct]:
|
|
1354
1401
|
if environment in (Environments.VALIDATION, Environments.TRAINING) and (
|
|
1355
1402
|
schema.actual_label_column_name is None
|
|
1356
1403
|
):
|
|
@@ -1360,7 +1407,7 @@ class Validator:
|
|
|
1360
1407
|
@staticmethod
|
|
1361
1408
|
def _check_existence_group_id_rank_category_relevance(
|
|
1362
1409
|
schema: Schema,
|
|
1363
|
-
) ->
|
|
1410
|
+
) -> list[err.MissingRequiredColumnsForRankingModel]:
|
|
1364
1411
|
# prediction_group_id and rank columns are required as ranking prediction columns.
|
|
1365
1412
|
ranking_prediction_cols = (
|
|
1366
1413
|
schema.prediction_label_column_name,
|
|
@@ -1384,7 +1431,7 @@ class Validator:
|
|
|
1384
1431
|
@staticmethod
|
|
1385
1432
|
def _check_dataframe_for_duplicate_columns(
|
|
1386
1433
|
schema: BaseSchema, dataframe: pd.DataFrame
|
|
1387
|
-
) ->
|
|
1434
|
+
) -> list[err.DuplicateColumnsInDataframe]:
|
|
1388
1435
|
# Get the columns used in the schema
|
|
1389
1436
|
schema_col_used = schema.get_used_columns()
|
|
1390
1437
|
# Get the duplicated column names from the dataframe
|
|
@@ -1400,7 +1447,7 @@ class Validator:
|
|
|
1400
1447
|
@staticmethod
|
|
1401
1448
|
def _check_invalid_number_of_embeddings(
|
|
1402
1449
|
schema: Schema,
|
|
1403
|
-
) ->
|
|
1450
|
+
) -> list[err.InvalidNumberOfEmbeddings]:
|
|
1404
1451
|
if schema.embedding_feature_column_names is not None:
|
|
1405
1452
|
number_of_embeddings = len(schema.embedding_feature_column_names)
|
|
1406
1453
|
if number_of_embeddings > MAX_NUMBER_OF_EMBEDDINGS:
|
|
@@ -1413,8 +1460,8 @@ class Validator:
|
|
|
1413
1460
|
|
|
1414
1461
|
@staticmethod
|
|
1415
1462
|
def _check_type_prediction_id(
|
|
1416
|
-
schema: Schema, column_types:
|
|
1417
|
-
) ->
|
|
1463
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1464
|
+
) -> list[err.InvalidType]:
|
|
1418
1465
|
col = schema.prediction_id_column_name
|
|
1419
1466
|
if col in column_types:
|
|
1420
1467
|
# should mirror server side
|
|
@@ -1437,8 +1484,8 @@ class Validator:
|
|
|
1437
1484
|
|
|
1438
1485
|
@staticmethod
|
|
1439
1486
|
def _check_type_timestamp(
|
|
1440
|
-
schema: Schema, column_types:
|
|
1441
|
-
) ->
|
|
1487
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1488
|
+
) -> list[err.InvalidType]:
|
|
1442
1489
|
col = schema.timestamp_column_name
|
|
1443
1490
|
if col in column_types:
|
|
1444
1491
|
# should mirror server side
|
|
@@ -1464,8 +1511,8 @@ class Validator:
|
|
|
1464
1511
|
|
|
1465
1512
|
@staticmethod
|
|
1466
1513
|
def _check_type_features(
|
|
1467
|
-
schema: Schema, column_types:
|
|
1468
|
-
) ->
|
|
1514
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1515
|
+
) -> list[err.InvalidTypeFeatures]:
|
|
1469
1516
|
if schema.feature_column_names is not None:
|
|
1470
1517
|
# should mirror server side
|
|
1471
1518
|
allowed_datatypes = (
|
|
@@ -1480,13 +1527,12 @@ class Validator:
|
|
|
1480
1527
|
pa.null(),
|
|
1481
1528
|
pa.list_(pa.string()),
|
|
1482
1529
|
)
|
|
1483
|
-
wrong_type_cols = [
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
wrong_type_cols.append(col)
|
|
1530
|
+
wrong_type_cols = [
|
|
1531
|
+
col
|
|
1532
|
+
for col in schema.feature_column_names
|
|
1533
|
+
if col in column_types
|
|
1534
|
+
and column_types[col] not in allowed_datatypes
|
|
1535
|
+
]
|
|
1490
1536
|
if wrong_type_cols:
|
|
1491
1537
|
return [
|
|
1492
1538
|
err.InvalidTypeFeatures(
|
|
@@ -1504,8 +1550,8 @@ class Validator:
|
|
|
1504
1550
|
|
|
1505
1551
|
@staticmethod
|
|
1506
1552
|
def _check_type_embedding_features(
|
|
1507
|
-
schema: Schema, column_types:
|
|
1508
|
-
) ->
|
|
1553
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1554
|
+
) -> list[err.InvalidTypeFeatures]:
|
|
1509
1555
|
if schema.embedding_feature_column_names is not None:
|
|
1510
1556
|
# should mirror server side
|
|
1511
1557
|
allowed_vector_datatypes = (
|
|
@@ -1580,8 +1626,8 @@ class Validator:
|
|
|
1580
1626
|
|
|
1581
1627
|
@staticmethod
|
|
1582
1628
|
def _check_type_tags(
|
|
1583
|
-
schema: Schema, column_types:
|
|
1584
|
-
) ->
|
|
1629
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1630
|
+
) -> list[err.InvalidTypeTags]:
|
|
1585
1631
|
if schema.tag_column_names is not None:
|
|
1586
1632
|
# should mirror server side
|
|
1587
1633
|
allowed_datatypes = (
|
|
@@ -1595,13 +1641,12 @@ class Validator:
|
|
|
1595
1641
|
pa.int8(),
|
|
1596
1642
|
pa.null(),
|
|
1597
1643
|
)
|
|
1598
|
-
wrong_type_cols = [
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
wrong_type_cols.append(col)
|
|
1644
|
+
wrong_type_cols = [
|
|
1645
|
+
col
|
|
1646
|
+
for col in schema.tag_column_names
|
|
1647
|
+
if col in column_types
|
|
1648
|
+
and column_types[col] not in allowed_datatypes
|
|
1649
|
+
]
|
|
1605
1650
|
if wrong_type_cols:
|
|
1606
1651
|
return [
|
|
1607
1652
|
err.InvalidTypeTags(
|
|
@@ -1612,8 +1657,8 @@ class Validator:
|
|
|
1612
1657
|
|
|
1613
1658
|
@staticmethod
|
|
1614
1659
|
def _check_type_shap_values(
|
|
1615
|
-
schema: Schema, column_types:
|
|
1616
|
-
) ->
|
|
1660
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1661
|
+
) -> list[err.InvalidTypeShapValues]:
|
|
1617
1662
|
if schema.shap_values_column_names is not None:
|
|
1618
1663
|
# should mirror server side
|
|
1619
1664
|
allowed_datatypes = (
|
|
@@ -1622,13 +1667,12 @@ class Validator:
|
|
|
1622
1667
|
pa.float32(),
|
|
1623
1668
|
pa.int32(),
|
|
1624
1669
|
)
|
|
1625
|
-
wrong_type_cols = [
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
wrong_type_cols.append(col)
|
|
1670
|
+
wrong_type_cols = [
|
|
1671
|
+
col
|
|
1672
|
+
for col in schema.shap_values_column_names.values()
|
|
1673
|
+
if col in column_types
|
|
1674
|
+
and column_types[col] not in allowed_datatypes
|
|
1675
|
+
]
|
|
1632
1676
|
if wrong_type_cols:
|
|
1633
1677
|
return [
|
|
1634
1678
|
err.InvalidTypeShapValues(
|
|
@@ -1639,8 +1683,8 @@ class Validator:
|
|
|
1639
1683
|
|
|
1640
1684
|
@staticmethod
|
|
1641
1685
|
def _check_type_pred_act_labels(
|
|
1642
|
-
model_type: ModelTypes, schema: Schema, column_types:
|
|
1643
|
-
) ->
|
|
1686
|
+
model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
|
|
1687
|
+
) -> list[err.InvalidType]:
|
|
1644
1688
|
errors = []
|
|
1645
1689
|
columns = (
|
|
1646
1690
|
("Prediction labels", schema.prediction_label_column_name),
|
|
@@ -1703,8 +1747,8 @@ class Validator:
|
|
|
1703
1747
|
|
|
1704
1748
|
@staticmethod
|
|
1705
1749
|
def _check_type_pred_act_scores(
|
|
1706
|
-
model_type: ModelTypes, schema: Schema, column_types:
|
|
1707
|
-
) ->
|
|
1750
|
+
model_type: ModelTypes, schema: Schema, column_types: dict[str, Any]
|
|
1751
|
+
) -> list[err.InvalidType]:
|
|
1708
1752
|
errors = []
|
|
1709
1753
|
columns = (
|
|
1710
1754
|
("Prediction scores", schema.prediction_score_column_name),
|
|
@@ -1743,13 +1787,14 @@ class Validator:
|
|
|
1743
1787
|
|
|
1744
1788
|
@staticmethod
|
|
1745
1789
|
def _check_type_multi_class_pred_threshold_act_scores(
|
|
1746
|
-
schema: Schema, column_types:
|
|
1747
|
-
) ->
|
|
1748
|
-
"""
|
|
1749
|
-
|
|
1750
|
-
Expect the scores to be a list of pyarrow structs that contains field
|
|
1751
|
-
|
|
1752
|
-
|
|
1790
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1791
|
+
) -> list[err.InvalidType]:
|
|
1792
|
+
"""Check type for prediction / threshold / actual scores for multiclass model.
|
|
1793
|
+
|
|
1794
|
+
Expect the scores to be a list of pyarrow structs that contains field
|
|
1795
|
+
"class_name" and field "score", where class_name is a string and score
|
|
1796
|
+
is a number.
|
|
1797
|
+
Example: '[{"class_name": "class1", "score": 0.1}, ...]'
|
|
1753
1798
|
"""
|
|
1754
1799
|
errors = []
|
|
1755
1800
|
columns = (
|
|
@@ -1802,8 +1847,8 @@ class Validator:
|
|
|
1802
1847
|
|
|
1803
1848
|
@staticmethod
|
|
1804
1849
|
def _check_type_prompt_response(
|
|
1805
|
-
schema: Schema, column_types:
|
|
1806
|
-
) ->
|
|
1850
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1851
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1807
1852
|
fields_to_check = []
|
|
1808
1853
|
if schema.prompt_column_names is not None:
|
|
1809
1854
|
fields_to_check.append(schema.prompt_column_names)
|
|
@@ -1872,8 +1917,8 @@ class Validator:
|
|
|
1872
1917
|
|
|
1873
1918
|
@staticmethod
|
|
1874
1919
|
def _check_type_llm_prompt_templates(
|
|
1875
|
-
schema: Schema, column_types:
|
|
1876
|
-
) ->
|
|
1920
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1921
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1877
1922
|
if schema.prompt_template_column_names is None:
|
|
1878
1923
|
return []
|
|
1879
1924
|
|
|
@@ -1913,8 +1958,8 @@ class Validator:
|
|
|
1913
1958
|
|
|
1914
1959
|
@staticmethod
|
|
1915
1960
|
def _check_type_llm_config(
|
|
1916
|
-
schema: Schema, column_types:
|
|
1917
|
-
) ->
|
|
1961
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1962
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1918
1963
|
if schema.llm_config_column_names is None:
|
|
1919
1964
|
return []
|
|
1920
1965
|
|
|
@@ -1950,8 +1995,8 @@ class Validator:
|
|
|
1950
1995
|
|
|
1951
1996
|
@staticmethod
|
|
1952
1997
|
def _check_type_llm_run_metadata(
|
|
1953
|
-
schema: Schema, column_types:
|
|
1954
|
-
) ->
|
|
1998
|
+
schema: Schema, column_types: dict[str, Any]
|
|
1999
|
+
) -> list[err.InvalidTypeColumns]:
|
|
1955
2000
|
if schema.llm_run_metadata_column_names is None:
|
|
1956
2001
|
return []
|
|
1957
2002
|
|
|
@@ -2023,8 +2068,8 @@ class Validator:
|
|
|
2023
2068
|
|
|
2024
2069
|
@staticmethod
|
|
2025
2070
|
def _check_type_retrieved_document_ids(
|
|
2026
|
-
schema: Schema, column_types:
|
|
2027
|
-
) ->
|
|
2071
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2072
|
+
) -> list[err.InvalidType]:
|
|
2028
2073
|
col = schema.retrieved_document_ids_column_name
|
|
2029
2074
|
if col in column_types:
|
|
2030
2075
|
# should mirror server side
|
|
@@ -2044,8 +2089,8 @@ class Validator:
|
|
|
2044
2089
|
|
|
2045
2090
|
@staticmethod
|
|
2046
2091
|
def _check_type_image_segment_coordinates(
|
|
2047
|
-
schema: Schema, column_types:
|
|
2048
|
-
) ->
|
|
2092
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2093
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2049
2094
|
# should mirror server side
|
|
2050
2095
|
allowed_coordinate_types = (
|
|
2051
2096
|
pa.list_(pa.list_(pa.float64())),
|
|
@@ -2090,9 +2135,8 @@ class Validator:
|
|
|
2090
2135
|
wrong_type_cols.append(coord_col)
|
|
2091
2136
|
|
|
2092
2137
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
)
|
|
2138
|
+
inst_seg_pred = schema.instance_segmentation_prediction_column_names
|
|
2139
|
+
polygons_coord_col = inst_seg_pred.polygon_coordinates_column_name
|
|
2096
2140
|
if (
|
|
2097
2141
|
polygons_coord_col in column_types
|
|
2098
2142
|
and column_types[polygons_coord_col]
|
|
@@ -2101,7 +2145,7 @@ class Validator:
|
|
|
2101
2145
|
wrong_type_cols.append(polygons_coord_col)
|
|
2102
2146
|
|
|
2103
2147
|
bbox_coord_col = (
|
|
2104
|
-
|
|
2148
|
+
inst_seg_pred.bounding_boxes_coordinates_column_name
|
|
2105
2149
|
)
|
|
2106
2150
|
if (
|
|
2107
2151
|
bbox_coord_col in column_types
|
|
@@ -2110,9 +2154,8 @@ class Validator:
|
|
|
2110
2154
|
wrong_type_cols.append(bbox_coord_col)
|
|
2111
2155
|
|
|
2112
2156
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
)
|
|
2157
|
+
inst_seg_actual = schema.instance_segmentation_actual_column_names
|
|
2158
|
+
coord_col = inst_seg_actual.polygon_coordinates_column_name
|
|
2116
2159
|
if (
|
|
2117
2160
|
coord_col in column_types
|
|
2118
2161
|
and column_types[coord_col] not in allowed_coordinate_types
|
|
@@ -2120,7 +2163,7 @@ class Validator:
|
|
|
2120
2163
|
wrong_type_cols.append(coord_col)
|
|
2121
2164
|
|
|
2122
2165
|
bbox_coord_col = (
|
|
2123
|
-
|
|
2166
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
2124
2167
|
)
|
|
2125
2168
|
if (
|
|
2126
2169
|
bbox_coord_col in column_types
|
|
@@ -2141,8 +2184,8 @@ class Validator:
|
|
|
2141
2184
|
|
|
2142
2185
|
@staticmethod
|
|
2143
2186
|
def _check_type_image_segment_categories(
|
|
2144
|
-
schema: Schema, column_types:
|
|
2145
|
-
) ->
|
|
2187
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2188
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2146
2189
|
# should mirror server side
|
|
2147
2190
|
allowed_category_datatypes = (
|
|
2148
2191
|
pa.list_(pa.string()),
|
|
@@ -2210,8 +2253,8 @@ class Validator:
|
|
|
2210
2253
|
|
|
2211
2254
|
@staticmethod
|
|
2212
2255
|
def _check_type_image_segment_scores(
|
|
2213
|
-
schema: Schema, column_types:
|
|
2214
|
-
) ->
|
|
2256
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2257
|
+
) -> list[err.InvalidTypeColumns]:
|
|
2215
2258
|
# should mirror server side
|
|
2216
2259
|
allowed_score_datatypes = (
|
|
2217
2260
|
pa.list_(pa.float64()),
|
|
@@ -2270,7 +2313,7 @@ class Validator:
|
|
|
2270
2313
|
@staticmethod
|
|
2271
2314
|
def _check_embedding_vectors_dimensionality(
|
|
2272
2315
|
dataframe: pd.DataFrame, schema: Schema
|
|
2273
|
-
) ->
|
|
2316
|
+
) -> list[err.ValidationError]:
|
|
2274
2317
|
if schema.embedding_feature_column_names is None:
|
|
2275
2318
|
return []
|
|
2276
2319
|
|
|
@@ -2300,7 +2343,7 @@ class Validator:
|
|
|
2300
2343
|
@staticmethod
|
|
2301
2344
|
def _check_embedding_raw_data_characters(
|
|
2302
2345
|
dataframe: pd.DataFrame, schema: Schema
|
|
2303
|
-
) ->
|
|
2346
|
+
) -> list[err.ValidationError]:
|
|
2304
2347
|
if schema.embedding_feature_column_names is None:
|
|
2305
2348
|
return []
|
|
2306
2349
|
|
|
@@ -2322,7 +2365,7 @@ class Validator:
|
|
|
2322
2365
|
invalid_long_string_data_cols
|
|
2323
2366
|
)
|
|
2324
2367
|
]
|
|
2325
|
-
|
|
2368
|
+
if truncated_long_string_data_cols:
|
|
2326
2369
|
logger.warning(
|
|
2327
2370
|
get_truncation_warning_message(
|
|
2328
2371
|
"Embedding raw data fields",
|
|
@@ -2334,7 +2377,7 @@ class Validator:
|
|
|
2334
2377
|
@staticmethod
|
|
2335
2378
|
def _check_value_rank(
|
|
2336
2379
|
dataframe: pd.DataFrame, schema: Schema
|
|
2337
|
-
) ->
|
|
2380
|
+
) -> list[err.InvalidRankValue]:
|
|
2338
2381
|
col = schema.rank_column_name
|
|
2339
2382
|
lbound, ubound = (1, 100)
|
|
2340
2383
|
|
|
@@ -2346,11 +2389,11 @@ class Validator:
|
|
|
2346
2389
|
|
|
2347
2390
|
@staticmethod
|
|
2348
2391
|
def _check_id_field_str_length(
|
|
2349
|
-
dataframe: pd.DataFrame, schema_name: str, id_col_name:
|
|
2350
|
-
) ->
|
|
2351
|
-
"""
|
|
2352
|
-
|
|
2353
|
-
and MAX_PREDICTION_ID_LEN
|
|
2392
|
+
dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
|
|
2393
|
+
) -> list[err.ValidationError]:
|
|
2394
|
+
"""Require prediction_id to be a string of length between MIN and MAX.
|
|
2395
|
+
|
|
2396
|
+
Between MIN_PREDICTION_ID_LEN and MAX_PREDICTION_ID_LEN.
|
|
2354
2397
|
"""
|
|
2355
2398
|
# We check whether the column name can be None is allowed in `Validator.validate_params`
|
|
2356
2399
|
if id_col_name is None:
|
|
@@ -2380,11 +2423,11 @@ class Validator:
|
|
|
2380
2423
|
|
|
2381
2424
|
@staticmethod
|
|
2382
2425
|
def _check_document_id_field_str_length(
|
|
2383
|
-
dataframe: pd.DataFrame, schema_name: str, id_col_name:
|
|
2384
|
-
) ->
|
|
2385
|
-
"""
|
|
2386
|
-
|
|
2387
|
-
and MAX_DOCUMENT_ID_LEN
|
|
2426
|
+
dataframe: pd.DataFrame, schema_name: str, id_col_name: str | None
|
|
2427
|
+
) -> list[err.ValidationError]:
|
|
2428
|
+
"""Require document id to be a string of length between MIN and MAX.
|
|
2429
|
+
|
|
2430
|
+
Between MIN_DOCUMENT_ID_LEN and MAX_DOCUMENT_ID_LEN.
|
|
2388
2431
|
"""
|
|
2389
2432
|
# We check whether the column name can be None is allowed in `Validator.validate_params`
|
|
2390
2433
|
if id_col_name is None:
|
|
@@ -2433,7 +2476,7 @@ class Validator:
|
|
|
2433
2476
|
@staticmethod
|
|
2434
2477
|
def _check_value_tag(
|
|
2435
2478
|
dataframe: pd.DataFrame, schema: Schema
|
|
2436
|
-
) ->
|
|
2479
|
+
) -> list[err.InvalidTagLength]:
|
|
2437
2480
|
if schema.tag_column_names is None:
|
|
2438
2481
|
return []
|
|
2439
2482
|
|
|
@@ -2459,7 +2502,7 @@ class Validator:
|
|
|
2459
2502
|
truncated_tag_cols.append(col)
|
|
2460
2503
|
if wrong_tag_cols:
|
|
2461
2504
|
return [err.InvalidTagLength(wrong_tag_cols)]
|
|
2462
|
-
|
|
2505
|
+
if truncated_tag_cols:
|
|
2463
2506
|
logger.warning(
|
|
2464
2507
|
get_truncation_warning_message(
|
|
2465
2508
|
"tags", MAX_TAG_LENGTH_TRUNCATION
|
|
@@ -2470,9 +2513,7 @@ class Validator:
|
|
|
2470
2513
|
@staticmethod
|
|
2471
2514
|
def _check_value_ranking_category(
|
|
2472
2515
|
dataframe: pd.DataFrame, schema: Schema
|
|
2473
|
-
) ->
|
|
2474
|
-
Union[err.InvalidValueMissingValue, err.InvalidRankingCategoryValue]
|
|
2475
|
-
]:
|
|
2516
|
+
) -> list[err.InvalidValueMissingValue | err.InvalidRankingCategoryValue]:
|
|
2476
2517
|
if schema.relevance_labels_column_name is not None:
|
|
2477
2518
|
col = schema.relevance_labels_column_name
|
|
2478
2519
|
elif schema.attributions_column_name is not None:
|
|
@@ -2503,7 +2544,7 @@ class Validator:
|
|
|
2503
2544
|
@staticmethod
|
|
2504
2545
|
def _check_length_multi_class_maps(
|
|
2505
2546
|
dataframe: pd.DataFrame, schema: Schema
|
|
2506
|
-
) ->
|
|
2547
|
+
) -> list[err.InvalidNumClassesMultiClassMap]:
|
|
2507
2548
|
# each entry in column is a list of dictionaries mapping class names and scores
|
|
2508
2549
|
# validate length of list of dictionaries for each column
|
|
2509
2550
|
invalid_cols = {}
|
|
@@ -2540,15 +2581,13 @@ class Validator:
|
|
|
2540
2581
|
@staticmethod
|
|
2541
2582
|
def _check_classes_and_scores_values_in_multi_class_maps(
|
|
2542
2583
|
dataframe: pd.DataFrame, schema: Schema
|
|
2543
|
-
) ->
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
|
|
2547
|
-
err.InvalidMultiClassPredScoreValue,
|
|
2548
|
-
]
|
|
2584
|
+
) -> list[
|
|
2585
|
+
err.InvalidMultiClassClassNameLength
|
|
2586
|
+
| err.InvalidMultiClassActScoreValue
|
|
2587
|
+
| err.InvalidMultiClassPredScoreValue
|
|
2549
2588
|
]:
|
|
2550
|
-
"""
|
|
2551
|
-
|
|
2589
|
+
"""Validate the class names and score values of dictionaries.
|
|
2590
|
+
|
|
2552
2591
|
- class name length
|
|
2553
2592
|
- valid actual score
|
|
2554
2593
|
- valid prediction / threshold score
|
|
@@ -2624,11 +2663,12 @@ class Validator:
|
|
|
2624
2663
|
@staticmethod
|
|
2625
2664
|
def _check_each_multi_class_pred_has_threshold(
|
|
2626
2665
|
dataframe: pd.DataFrame, schema: Schema
|
|
2627
|
-
) ->
|
|
2628
|
-
"""
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
for
|
|
2666
|
+
) -> list[err.InvalidMultiClassThresholdClasses]:
|
|
2667
|
+
"""Validate threshold scores for Multi Class models.
|
|
2668
|
+
|
|
2669
|
+
If threshold scores column is included in schema and dataframe, validate that
|
|
2670
|
+
for each prediction score received, the associated threshold score for that
|
|
2671
|
+
class was also received.
|
|
2632
2672
|
"""
|
|
2633
2673
|
threshold_col = schema.multi_class_threshold_scores_column_name
|
|
2634
2674
|
if threshold_col is None:
|
|
@@ -2657,10 +2697,10 @@ class Validator:
|
|
|
2657
2697
|
def _check_value_timestamp(
|
|
2658
2698
|
dataframe: pd.DataFrame,
|
|
2659
2699
|
schema: Schema,
|
|
2660
|
-
) ->
|
|
2700
|
+
) -> list[err.InvalidValueMissingValue | err.InvalidValueTimestamp]:
|
|
2661
2701
|
# Due to the timing difference between checking this here and the data finally
|
|
2662
2702
|
# hitting the same check on server side, there's a some chance for a false
|
|
2663
|
-
# result, i.e. the check here
|
|
2703
|
+
# result, i.e. the check here succeeds but the same check on server side fails.
|
|
2664
2704
|
col = schema.timestamp_column_name
|
|
2665
2705
|
if col is not None and col in dataframe.columns:
|
|
2666
2706
|
# When a timestamp column has Date and NaN, pyarrow will be fine, but
|
|
@@ -2673,19 +2713,15 @@ class Validator:
|
|
|
2673
2713
|
)
|
|
2674
2714
|
]
|
|
2675
2715
|
|
|
2676
|
-
now_t = datetime.
|
|
2716
|
+
now_t = datetime.now(tz=timezone.utc)
|
|
2677
2717
|
lbound, ubound = (
|
|
2678
2718
|
(
|
|
2679
2719
|
now_t
|
|
2680
|
-
-
|
|
2681
|
-
days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365
|
|
2682
|
-
)
|
|
2720
|
+
- timedelta(days=MAX_PAST_YEARS_FROM_CURRENT_TIME * 365)
|
|
2683
2721
|
).timestamp(),
|
|
2684
2722
|
(
|
|
2685
2723
|
now_t
|
|
2686
|
-
+
|
|
2687
|
-
days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365
|
|
2688
|
-
)
|
|
2724
|
+
+ timedelta(days=MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365)
|
|
2689
2725
|
).timestamp(),
|
|
2690
2726
|
)
|
|
2691
2727
|
# faster than pyarrow compute
|
|
@@ -2767,7 +2803,7 @@ class Validator:
|
|
|
2767
2803
|
@staticmethod
|
|
2768
2804
|
def _check_invalid_missing_values(
|
|
2769
2805
|
dataframe: pd.DataFrame, schema: BaseSchema, model_type: ModelTypes
|
|
2770
|
-
) ->
|
|
2806
|
+
) -> list[err.InvalidValueMissingValue]:
|
|
2771
2807
|
errors = []
|
|
2772
2808
|
columns = ()
|
|
2773
2809
|
if isinstance(schema, CorpusSchema):
|
|
@@ -2814,7 +2850,7 @@ class Validator:
|
|
|
2814
2850
|
environment: Environments,
|
|
2815
2851
|
schema: Schema,
|
|
2816
2852
|
model_type: ModelTypes,
|
|
2817
|
-
) ->
|
|
2853
|
+
) -> list[err.InvalidRecord]:
|
|
2818
2854
|
if environment in (Environments.VALIDATION, Environments.TRAINING):
|
|
2819
2855
|
return []
|
|
2820
2856
|
|
|
@@ -2858,11 +2894,11 @@ class Validator:
|
|
|
2858
2894
|
environment: Environments,
|
|
2859
2895
|
schema: Schema,
|
|
2860
2896
|
model_type: ModelTypes,
|
|
2861
|
-
) ->
|
|
2862
|
-
"""
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
|
|
2897
|
+
) -> list[err.InvalidRecord]:
|
|
2898
|
+
"""Validates there's not a single row in the dataframe with all nulls.
|
|
2899
|
+
|
|
2900
|
+
Returns errors if any row has all of pred_label and pred_score evaluating to
|
|
2901
|
+
null, OR all of actual_label and actual_score evaluating to null.
|
|
2866
2902
|
"""
|
|
2867
2903
|
if environment == Environments.PRODUCTION:
|
|
2868
2904
|
return []
|
|
@@ -2905,21 +2941,23 @@ class Validator:
|
|
|
2905
2941
|
|
|
2906
2942
|
@staticmethod
|
|
2907
2943
|
def _check_invalid_record_helper(
|
|
2908
|
-
dataframe: pd.DataFrame, column_names:
|
|
2909
|
-
) ->
|
|
2910
|
-
"""
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
|
|
2944
|
+
dataframe: pd.DataFrame, column_names: list[str | None]
|
|
2945
|
+
) -> list[err.InvalidRecord]:
|
|
2946
|
+
"""Check that there are no null values in a subset of columns.
|
|
2947
|
+
|
|
2948
|
+
The column subset is computed from the input list of columns `column_names`
|
|
2949
|
+
that are not None and that are present in the dataframe. Returns an error if
|
|
2950
|
+
null values are found.
|
|
2914
2951
|
|
|
2915
2952
|
Returns:
|
|
2916
2953
|
List[err.InvalidRecord]: An error expressing the rows that are problematic
|
|
2917
2954
|
|
|
2918
2955
|
"""
|
|
2919
|
-
columns_subset = [
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
2956
|
+
columns_subset = [
|
|
2957
|
+
col
|
|
2958
|
+
for col in column_names
|
|
2959
|
+
if col is not None and col in dataframe.columns
|
|
2960
|
+
]
|
|
2923
2961
|
if len(columns_subset) == 0:
|
|
2924
2962
|
return []
|
|
2925
2963
|
null_filter = dataframe[columns_subset].isnull().all(axis=1)
|
|
@@ -2930,8 +2968,8 @@ class Validator:
|
|
|
2930
2968
|
|
|
2931
2969
|
@staticmethod
|
|
2932
2970
|
def _check_type_prediction_group_id(
|
|
2933
|
-
schema: Schema, column_types:
|
|
2934
|
-
) ->
|
|
2971
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2972
|
+
) -> list[err.InvalidType]:
|
|
2935
2973
|
col = schema.prediction_group_id_column_name
|
|
2936
2974
|
if col in column_types:
|
|
2937
2975
|
# should mirror server side
|
|
@@ -2954,8 +2992,8 @@ class Validator:
|
|
|
2954
2992
|
|
|
2955
2993
|
@staticmethod
|
|
2956
2994
|
def _check_type_rank(
|
|
2957
|
-
schema: Schema, column_types:
|
|
2958
|
-
) ->
|
|
2995
|
+
schema: Schema, column_types: dict[str, Any]
|
|
2996
|
+
) -> list[err.InvalidType]:
|
|
2959
2997
|
col = schema.rank_column_name
|
|
2960
2998
|
if col in column_types:
|
|
2961
2999
|
allowed_datatypes = (
|
|
@@ -2976,8 +3014,8 @@ class Validator:
|
|
|
2976
3014
|
|
|
2977
3015
|
@staticmethod
|
|
2978
3016
|
def _check_type_ranking_category(
|
|
2979
|
-
schema: Schema, column_types:
|
|
2980
|
-
) ->
|
|
3017
|
+
schema: Schema, column_types: dict[str, Any]
|
|
3018
|
+
) -> list[err.InvalidType]:
|
|
2981
3019
|
if schema.relevance_labels_column_name is not None:
|
|
2982
3020
|
col = schema.relevance_labels_column_name
|
|
2983
3021
|
elif schema.attributions_column_name is not None:
|
|
@@ -2999,7 +3037,7 @@ class Validator:
|
|
|
2999
3037
|
@staticmethod
|
|
3000
3038
|
def _check_value_bounding_boxes_coordinates(
|
|
3001
3039
|
dataframe: pd.DataFrame, schema: Schema
|
|
3002
|
-
) ->
|
|
3040
|
+
) -> list[err.InvalidBoundingBoxesCoordinates]:
|
|
3003
3041
|
errors = []
|
|
3004
3042
|
if schema.object_detection_prediction_column_names is not None:
|
|
3005
3043
|
coords_col_name = schema.object_detection_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
|
|
@@ -3020,7 +3058,7 @@ class Validator:
|
|
|
3020
3058
|
@staticmethod
|
|
3021
3059
|
def _check_value_bounding_boxes_categories(
|
|
3022
3060
|
dataframe: pd.DataFrame, schema: Schema
|
|
3023
|
-
) ->
|
|
3061
|
+
) -> list[err.InvalidBoundingBoxesCategories]:
|
|
3024
3062
|
errors = []
|
|
3025
3063
|
if schema.object_detection_prediction_column_names is not None:
|
|
3026
3064
|
cat_col_name = schema.object_detection_prediction_column_names.categories_column_name
|
|
@@ -3041,7 +3079,7 @@ class Validator:
|
|
|
3041
3079
|
@staticmethod
|
|
3042
3080
|
def _check_value_bounding_boxes_scores(
|
|
3043
3081
|
dataframe: pd.DataFrame, schema: Schema
|
|
3044
|
-
) ->
|
|
3082
|
+
) -> list[err.InvalidBoundingBoxesScores]:
|
|
3045
3083
|
errors = []
|
|
3046
3084
|
if schema.object_detection_prediction_column_names is not None:
|
|
3047
3085
|
sc_col_name = schema.object_detection_prediction_column_names.scores_column_name
|
|
@@ -3066,7 +3104,7 @@ class Validator:
|
|
|
3066
3104
|
@staticmethod
|
|
3067
3105
|
def _check_value_semantic_segmentation_polygon_coordinates(
|
|
3068
3106
|
dataframe: pd.DataFrame, schema: Schema
|
|
3069
|
-
) ->
|
|
3107
|
+
) -> list[err.InvalidPolygonCoordinates]:
|
|
3070
3108
|
errors = []
|
|
3071
3109
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
3072
3110
|
coords_col_name = schema.semantic_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
|
|
@@ -3076,7 +3114,7 @@ class Validator:
|
|
|
3076
3114
|
if error is not None:
|
|
3077
3115
|
errors.append(error)
|
|
3078
3116
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
3079
|
-
coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3117
|
+
coords_col_name = schema.semantic_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3080
3118
|
error = _check_value_polygon_coordinates_helper(
|
|
3081
3119
|
dataframe[coords_col_name]
|
|
3082
3120
|
)
|
|
@@ -3087,7 +3125,7 @@ class Validator:
|
|
|
3087
3125
|
@staticmethod
|
|
3088
3126
|
def _check_value_semantic_segmentation_polygon_categories(
|
|
3089
3127
|
dataframe: pd.DataFrame, schema: Schema
|
|
3090
|
-
) ->
|
|
3128
|
+
) -> list[err.InvalidPolygonCategories]:
|
|
3091
3129
|
errors = []
|
|
3092
3130
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
3093
3131
|
cat_col_name = schema.semantic_segmentation_prediction_column_names.categories_column_name
|
|
@@ -3108,7 +3146,7 @@ class Validator:
|
|
|
3108
3146
|
@staticmethod
|
|
3109
3147
|
def _check_value_instance_segmentation_polygon_coordinates(
|
|
3110
3148
|
dataframe: pd.DataFrame, schema: Schema
|
|
3111
|
-
) ->
|
|
3149
|
+
) -> list[err.InvalidPolygonCoordinates]:
|
|
3112
3150
|
errors = []
|
|
3113
3151
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3114
3152
|
coords_col_name = schema.instance_segmentation_prediction_column_names.polygon_coordinates_column_name # noqa: E501
|
|
@@ -3118,7 +3156,7 @@ class Validator:
|
|
|
3118
3156
|
if error is not None:
|
|
3119
3157
|
errors.append(error)
|
|
3120
3158
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
3121
|
-
coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3159
|
+
coords_col_name = schema.instance_segmentation_actual_column_names.polygon_coordinates_column_name
|
|
3122
3160
|
error = _check_value_polygon_coordinates_helper(
|
|
3123
3161
|
dataframe[coords_col_name]
|
|
3124
3162
|
)
|
|
@@ -3129,7 +3167,7 @@ class Validator:
|
|
|
3129
3167
|
@staticmethod
|
|
3130
3168
|
def _check_value_instance_segmentation_polygon_categories(
|
|
3131
3169
|
dataframe: pd.DataFrame, schema: Schema
|
|
3132
|
-
) ->
|
|
3170
|
+
) -> list[err.InvalidPolygonCategories]:
|
|
3133
3171
|
errors = []
|
|
3134
3172
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3135
3173
|
cat_col_name = schema.instance_segmentation_prediction_column_names.categories_column_name
|
|
@@ -3150,7 +3188,7 @@ class Validator:
|
|
|
3150
3188
|
@staticmethod
|
|
3151
3189
|
def _check_value_instance_segmentation_polygon_scores(
|
|
3152
3190
|
dataframe: pd.DataFrame, schema: Schema
|
|
3153
|
-
) ->
|
|
3191
|
+
) -> list[err.InvalidPolygonScores]:
|
|
3154
3192
|
errors = []
|
|
3155
3193
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3156
3194
|
sc_col_name = schema.instance_segmentation_prediction_column_names.scores_column_name
|
|
@@ -3165,7 +3203,7 @@ class Validator:
|
|
|
3165
3203
|
@staticmethod
|
|
3166
3204
|
def _check_value_instance_segmentation_bbox_coordinates(
|
|
3167
3205
|
dataframe: pd.DataFrame, schema: Schema
|
|
3168
|
-
) ->
|
|
3206
|
+
) -> list[err.InvalidBoundingBoxesCoordinates]:
|
|
3169
3207
|
errors = []
|
|
3170
3208
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
3171
3209
|
coords_col_name = schema.instance_segmentation_prediction_column_names.bounding_boxes_coordinates_column_name # noqa: E501
|
|
@@ -3188,7 +3226,7 @@ class Validator:
|
|
|
3188
3226
|
@staticmethod
|
|
3189
3227
|
def _check_value_prompt_response(
|
|
3190
3228
|
dataframe: pd.DataFrame, schema: Schema
|
|
3191
|
-
) ->
|
|
3229
|
+
) -> list[err.ValidationError]:
|
|
3192
3230
|
vector_cols_to_check = []
|
|
3193
3231
|
text_cols_to_check = []
|
|
3194
3232
|
if isinstance(schema.prompt_column_names, str):
|
|
@@ -3253,7 +3291,7 @@ class Validator:
|
|
|
3253
3291
|
@staticmethod
|
|
3254
3292
|
def _check_value_llm_model_name(
|
|
3255
3293
|
dataframe: pd.DataFrame, schema: Schema
|
|
3256
|
-
) ->
|
|
3294
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3257
3295
|
if schema.llm_config_column_names is None:
|
|
3258
3296
|
return []
|
|
3259
3297
|
col = schema.llm_config_column_names.model_column_name
|
|
@@ -3270,7 +3308,7 @@ class Validator:
|
|
|
3270
3308
|
max_length=MAX_LLM_MODEL_NAME_LENGTH,
|
|
3271
3309
|
)
|
|
3272
3310
|
]
|
|
3273
|
-
|
|
3311
|
+
if max_len > MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION:
|
|
3274
3312
|
logger.warning(
|
|
3275
3313
|
get_truncation_warning_message(
|
|
3276
3314
|
"LLM model names", MAX_LLM_MODEL_NAME_LENGTH_TRUNCATION
|
|
@@ -3281,7 +3319,7 @@ class Validator:
|
|
|
3281
3319
|
@staticmethod
|
|
3282
3320
|
def _check_value_llm_prompt_template(
|
|
3283
3321
|
dataframe: pd.DataFrame, schema: Schema
|
|
3284
|
-
) ->
|
|
3322
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3285
3323
|
if schema.prompt_template_column_names is None:
|
|
3286
3324
|
return []
|
|
3287
3325
|
col = schema.prompt_template_column_names.template_column_name
|
|
@@ -3298,7 +3336,7 @@ class Validator:
|
|
|
3298
3336
|
max_length=MAX_PROMPT_TEMPLATE_LENGTH,
|
|
3299
3337
|
)
|
|
3300
3338
|
]
|
|
3301
|
-
|
|
3339
|
+
if max_len > MAX_PROMPT_TEMPLATE_LENGTH_TRUNCATION:
|
|
3302
3340
|
logger.warning(
|
|
3303
3341
|
get_truncation_warning_message(
|
|
3304
3342
|
"prompt templates",
|
|
@@ -3310,7 +3348,7 @@ class Validator:
|
|
|
3310
3348
|
@staticmethod
|
|
3311
3349
|
def _check_value_llm_prompt_template_version(
|
|
3312
3350
|
dataframe: pd.DataFrame, schema: Schema
|
|
3313
|
-
) ->
|
|
3351
|
+
) -> list[err.InvalidStringLengthInColumn]:
|
|
3314
3352
|
if schema.prompt_template_column_names is None:
|
|
3315
3353
|
return []
|
|
3316
3354
|
col = schema.prompt_template_column_names.template_version_column_name
|
|
@@ -3327,7 +3365,7 @@ class Validator:
|
|
|
3327
3365
|
max_length=MAX_PROMPT_TEMPLATE_VERSION_LENGTH,
|
|
3328
3366
|
)
|
|
3329
3367
|
]
|
|
3330
|
-
|
|
3368
|
+
if max_len > MAX_PROMPT_TEMPLATE_VERSION_LENGTH_TRUNCATION:
|
|
3331
3369
|
logger.warning(
|
|
3332
3370
|
get_truncation_warning_message(
|
|
3333
3371
|
"prompt template versions",
|
|
@@ -3338,8 +3376,8 @@ class Validator:
|
|
|
3338
3376
|
|
|
3339
3377
|
@staticmethod
|
|
3340
3378
|
def _check_type_document_columns(
|
|
3341
|
-
schema: CorpusSchema, column_types:
|
|
3342
|
-
) ->
|
|
3379
|
+
schema: CorpusSchema, column_types: dict[str, Any]
|
|
3380
|
+
) -> list[err.InvalidTypeColumns]:
|
|
3343
3381
|
invalid_types = []
|
|
3344
3382
|
# Check document id
|
|
3345
3383
|
col = schema.document_id_column_name
|
|
@@ -3424,16 +3462,15 @@ class Validator:
|
|
|
3424
3462
|
return []
|
|
3425
3463
|
|
|
3426
3464
|
|
|
3427
|
-
def _check_value_string_length_helper(x):
|
|
3465
|
+
def _check_value_string_length_helper(x: object) -> int:
|
|
3428
3466
|
if isinstance(x, str):
|
|
3429
3467
|
return len(x)
|
|
3430
|
-
|
|
3431
|
-
return 0
|
|
3468
|
+
return 0
|
|
3432
3469
|
|
|
3433
3470
|
|
|
3434
3471
|
def _check_value_vector_dimensionality_helper(
|
|
3435
|
-
dataframe: pd.DataFrame, cols_to_check:
|
|
3436
|
-
) ->
|
|
3472
|
+
dataframe: pd.DataFrame, cols_to_check: list[str]
|
|
3473
|
+
) -> tuple[list[str], list[str]]:
|
|
3437
3474
|
invalid_low_dimensionality_vector_cols = []
|
|
3438
3475
|
invalid_high_dimensionality_vector_cols = []
|
|
3439
3476
|
for col in cols_to_check:
|
|
@@ -3452,8 +3489,8 @@ def _check_value_vector_dimensionality_helper(
|
|
|
3452
3489
|
|
|
3453
3490
|
|
|
3454
3491
|
def _check_value_raw_data_length_helper(
|
|
3455
|
-
dataframe: pd.DataFrame, cols_to_check:
|
|
3456
|
-
) ->
|
|
3492
|
+
dataframe: pd.DataFrame, cols_to_check: list[str]
|
|
3493
|
+
) -> tuple[list[str], list[str]]:
|
|
3457
3494
|
invalid_long_string_data_cols = []
|
|
3458
3495
|
truncated_long_string_data_cols = []
|
|
3459
3496
|
for col in cols_to_check:
|
|
@@ -3469,7 +3506,7 @@ def _check_value_raw_data_length_helper(
|
|
|
3469
3506
|
)
|
|
3470
3507
|
except TypeError as exc:
|
|
3471
3508
|
e = TypeError(f"Cannot validate the column '{col}'. " + str(exc))
|
|
3472
|
-
logger.
|
|
3509
|
+
logger.exception(e)
|
|
3473
3510
|
raise e from exc
|
|
3474
3511
|
if max_data_len > MAX_RAW_DATA_CHARACTERS:
|
|
3475
3512
|
invalid_long_string_data_cols.append(col)
|
|
@@ -3480,8 +3517,8 @@ def _check_value_raw_data_length_helper(
|
|
|
3480
3517
|
|
|
3481
3518
|
def _check_value_bounding_boxes_coordinates_helper(
|
|
3482
3519
|
coordinates_col: pd.Series,
|
|
3483
|
-
) ->
|
|
3484
|
-
def check(boxes):
|
|
3520
|
+
) -> err.InvalidBoundingBoxesCoordinates | None:
|
|
3521
|
+
def check(boxes: object) -> None:
|
|
3485
3522
|
# We allow for zero boxes. None coordinates list is not allowed (will break following tests:
|
|
3486
3523
|
# 'NoneType is not iterable')
|
|
3487
3524
|
if boxes is None:
|
|
@@ -3502,7 +3539,9 @@ def _check_value_bounding_boxes_coordinates_helper(
|
|
|
3502
3539
|
return None
|
|
3503
3540
|
|
|
3504
3541
|
|
|
3505
|
-
def _box_coordinates_wrong_format(
|
|
3542
|
+
def _box_coordinates_wrong_format(
|
|
3543
|
+
box_coords: object,
|
|
3544
|
+
) -> err.InvalidBoundingBoxesCoordinates | None:
|
|
3506
3545
|
if (
|
|
3507
3546
|
# Coordinates should be a collection of 4 floats
|
|
3508
3547
|
len(box_coords) != 4
|
|
@@ -3516,12 +3555,13 @@ def _box_coordinates_wrong_format(box_coords):
|
|
|
3516
3555
|
return err.InvalidBoundingBoxesCoordinates(
|
|
3517
3556
|
reason="boxes_coordinates_wrong_format"
|
|
3518
3557
|
)
|
|
3558
|
+
return None
|
|
3519
3559
|
|
|
3520
3560
|
|
|
3521
3561
|
def _check_value_bounding_boxes_categories_helper(
|
|
3522
3562
|
categories_col: pd.Series,
|
|
3523
|
-
) ->
|
|
3524
|
-
def check(categories):
|
|
3563
|
+
) -> err.InvalidBoundingBoxesCategories | None:
|
|
3564
|
+
def check(categories: object) -> None:
|
|
3525
3565
|
# We allow for zero boxes. None category list is not allowed (will break following tests:
|
|
3526
3566
|
# 'NoneType is not iterable')
|
|
3527
3567
|
if categories is None:
|
|
@@ -3542,8 +3582,8 @@ def _check_value_bounding_boxes_categories_helper(
|
|
|
3542
3582
|
|
|
3543
3583
|
def _check_value_bounding_boxes_scores_helper(
|
|
3544
3584
|
scores_col: pd.Series,
|
|
3545
|
-
) ->
|
|
3546
|
-
def check(scores):
|
|
3585
|
+
) -> err.InvalidBoundingBoxesScores | None:
|
|
3586
|
+
def check(scores: object) -> None:
|
|
3547
3587
|
# We allow for zero boxes. None confidence score list is not allowed (will break following tests:
|
|
3548
3588
|
# 'NoneType is not iterable')
|
|
3549
3589
|
if scores is None:
|
|
@@ -3562,9 +3602,10 @@ def _check_value_bounding_boxes_scores_helper(
|
|
|
3562
3602
|
return None
|
|
3563
3603
|
|
|
3564
3604
|
|
|
3565
|
-
def _polygon_coordinates_wrong_format(
|
|
3566
|
-
|
|
3567
|
-
|
|
3605
|
+
def _polygon_coordinates_wrong_format(
|
|
3606
|
+
polygon_coords: object,
|
|
3607
|
+
) -> err.InvalidPolygonCoordinates | None:
|
|
3608
|
+
"""Check if polygon coordinates are valid.
|
|
3568
3609
|
|
|
3569
3610
|
Validates:
|
|
3570
3611
|
- Has at least 3 vertices (6 coordinates)
|
|
@@ -3610,9 +3651,9 @@ def _polygon_coordinates_wrong_format(polygon_coords):
|
|
|
3610
3651
|
|
|
3611
3652
|
# Check for self-intersections
|
|
3612
3653
|
# We need to check if any two non-adjacent edges intersect
|
|
3613
|
-
edges = [
|
|
3614
|
-
|
|
3615
|
-
|
|
3654
|
+
edges = [
|
|
3655
|
+
(points[i], points[(i + 1) % len(points)]) for i in range(len(points))
|
|
3656
|
+
]
|
|
3616
3657
|
|
|
3617
3658
|
for i in range(len(edges)):
|
|
3618
3659
|
for j in range(i + 2, len(edges)):
|
|
@@ -3634,8 +3675,8 @@ def _polygon_coordinates_wrong_format(polygon_coords):
|
|
|
3634
3675
|
|
|
3635
3676
|
def _check_value_polygon_coordinates_helper(
|
|
3636
3677
|
coordinates_col: pd.Series,
|
|
3637
|
-
) ->
|
|
3638
|
-
def check(polygons):
|
|
3678
|
+
) -> err.InvalidPolygonCoordinates | None:
|
|
3679
|
+
def check(polygons: object) -> None:
|
|
3639
3680
|
# We allow for zero polygons. None coordinates list is not allowed (will break following tests:
|
|
3640
3681
|
# 'NoneType is not iterable')
|
|
3641
3682
|
if polygons is None:
|
|
@@ -3658,8 +3699,8 @@ def _check_value_polygon_coordinates_helper(
|
|
|
3658
3699
|
|
|
3659
3700
|
def _check_value_polygon_categories_helper(
|
|
3660
3701
|
categories_col: pd.Series,
|
|
3661
|
-
) ->
|
|
3662
|
-
def check(categories):
|
|
3702
|
+
) -> err.InvalidPolygonCategories | None:
|
|
3703
|
+
def check(categories: object) -> None:
|
|
3663
3704
|
# We allow for zero boxes. None category list is not allowed (will break following tests:
|
|
3664
3705
|
# 'NoneType is not iterable')
|
|
3665
3706
|
if categories is None:
|
|
@@ -3678,8 +3719,8 @@ def _check_value_polygon_categories_helper(
|
|
|
3678
3719
|
|
|
3679
3720
|
def _check_value_polygon_scores_helper(
|
|
3680
3721
|
scores_col: pd.Series,
|
|
3681
|
-
) ->
|
|
3682
|
-
def check(scores):
|
|
3722
|
+
) -> err.InvalidPolygonScores | None:
|
|
3723
|
+
def check(scores: object) -> None:
|
|
3683
3724
|
# We allow for zero boxes. None confidence score list is not allowed (will break following tests:
|
|
3684
3725
|
# 'NoneType is not iterable')
|
|
3685
3726
|
if scores is None:
|
|
@@ -3696,7 +3737,7 @@ def _check_value_polygon_scores_helper(
|
|
|
3696
3737
|
return None
|
|
3697
3738
|
|
|
3698
3739
|
|
|
3699
|
-
def _count_characters_raw_data(data:
|
|
3740
|
+
def _count_characters_raw_data(data: str | list[str]) -> int:
|
|
3700
3741
|
character_count = 0
|
|
3701
3742
|
if isinstance(data, str):
|
|
3702
3743
|
character_count = len(data)
|