arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +208 -77
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +269 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +390 -286
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a21.dist-info/RECORD +0 -146
- arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
arize/types.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
|
1
|
+
"""Common type definitions and data models used across the Arize SDK."""
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import logging
|
|
3
5
|
import math
|
|
6
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
4
7
|
from dataclasses import asdict, dataclass, replace
|
|
5
8
|
from datetime import datetime
|
|
6
9
|
from decimal import Decimal
|
|
7
10
|
from enum import Enum, unique
|
|
8
11
|
from itertools import chain
|
|
9
12
|
from typing import (
|
|
10
|
-
Dict,
|
|
11
|
-
Iterable,
|
|
12
|
-
List,
|
|
13
13
|
NamedTuple,
|
|
14
|
-
|
|
15
|
-
Set,
|
|
16
|
-
Tuple,
|
|
14
|
+
Self,
|
|
17
15
|
TypeVar,
|
|
18
16
|
)
|
|
19
17
|
|
|
@@ -48,6 +46,8 @@ logger = logging.getLogger(__name__)
|
|
|
48
46
|
|
|
49
47
|
@unique
|
|
50
48
|
class ModelTypes(Enum):
|
|
49
|
+
"""Enum representing supported model types in Arize."""
|
|
50
|
+
|
|
51
51
|
NUMERIC = 1
|
|
52
52
|
SCORE_CATEGORICAL = 2
|
|
53
53
|
RANKING = 3
|
|
@@ -58,7 +58,8 @@ class ModelTypes(Enum):
|
|
|
58
58
|
MULTI_CLASS = 8
|
|
59
59
|
|
|
60
60
|
@classmethod
|
|
61
|
-
def list_types(cls):
|
|
61
|
+
def list_types(cls) -> list[str]:
|
|
62
|
+
"""Return a list of all type names in this enum."""
|
|
62
63
|
return [t.name for t in cls]
|
|
63
64
|
|
|
64
65
|
|
|
@@ -70,7 +71,10 @@ CATEGORICAL_MODEL_TYPES = [
|
|
|
70
71
|
|
|
71
72
|
|
|
72
73
|
class DocEnum(Enum):
|
|
73
|
-
|
|
74
|
+
"""Enum subclass supporting inline documentation for enum members."""
|
|
75
|
+
|
|
76
|
+
def __new__(cls, value: object, doc: str | None = None) -> Self:
|
|
77
|
+
"""Create a new enum instance with optional documentation."""
|
|
74
78
|
self = object.__new__(
|
|
75
79
|
cls
|
|
76
80
|
) # calling super().__new__(value) here would fail
|
|
@@ -80,13 +84,13 @@ class DocEnum(Enum):
|
|
|
80
84
|
return self
|
|
81
85
|
|
|
82
86
|
def __repr__(self) -> str:
|
|
87
|
+
"""Return a string representation including documentation."""
|
|
83
88
|
return f"{self.name} metrics include: {self.__doc__}"
|
|
84
89
|
|
|
85
90
|
|
|
86
91
|
@unique
|
|
87
92
|
class Metrics(DocEnum):
|
|
88
|
-
"""
|
|
89
|
-
Metric groupings, used for validation of schema columns in log() call.
|
|
93
|
+
"""Metric groupings, used for validation of schema columns in log() call.
|
|
90
94
|
|
|
91
95
|
See docstring descriptions of the Enum with __doc__ or __repr__(), e.g.:
|
|
92
96
|
Metrics.RANKING.__doc__
|
|
@@ -105,6 +109,8 @@ class Metrics(DocEnum):
|
|
|
105
109
|
|
|
106
110
|
@unique
|
|
107
111
|
class Environments(Enum):
|
|
112
|
+
"""Enum representing deployment environments for models."""
|
|
113
|
+
|
|
108
114
|
TRAINING = 1
|
|
109
115
|
VALIDATION = 2
|
|
110
116
|
PRODUCTION = 3
|
|
@@ -114,11 +120,18 @@ class Environments(Enum):
|
|
|
114
120
|
|
|
115
121
|
@dataclass
|
|
116
122
|
class EmbeddingColumnNames:
|
|
123
|
+
"""Column names for embedding feature data."""
|
|
124
|
+
|
|
117
125
|
vector_column_name: str = ""
|
|
118
126
|
data_column_name: str | None = None
|
|
119
127
|
link_to_data_column_name: str | None = None
|
|
120
128
|
|
|
121
|
-
def __post_init__(self):
|
|
129
|
+
def __post_init__(self) -> None:
|
|
130
|
+
"""Validate that vector column name is specified.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
ValueError: If vector_column_name is empty.
|
|
134
|
+
"""
|
|
122
135
|
if not self.vector_column_name:
|
|
123
136
|
raise ValueError(
|
|
124
137
|
"embedding_features require a vector to be specified. You can "
|
|
@@ -126,7 +139,8 @@ class EmbeddingColumnNames:
|
|
|
126
139
|
"(from arize.pandas.embeddings) if you do not have them"
|
|
127
140
|
)
|
|
128
141
|
|
|
129
|
-
def __iter__(self):
|
|
142
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
143
|
+
"""Iterate over the embedding column names."""
|
|
130
144
|
return iter(
|
|
131
145
|
(
|
|
132
146
|
self.vector_column_name,
|
|
@@ -137,14 +151,16 @@ class EmbeddingColumnNames:
|
|
|
137
151
|
|
|
138
152
|
|
|
139
153
|
class Embedding(NamedTuple):
|
|
140
|
-
vector
|
|
141
|
-
|
|
154
|
+
"""Container for embedding vector data with optional raw data and links."""
|
|
155
|
+
|
|
156
|
+
vector: list[float]
|
|
157
|
+
data: str | list[str] | None = None
|
|
142
158
|
link_to_data: str | None = None
|
|
143
159
|
|
|
144
160
|
def validate(self, emb_name: str | int | float) -> None:
|
|
145
|
-
"""
|
|
146
|
-
|
|
147
|
-
|
|
161
|
+
"""Validates that the embedding object passed is of the correct format.
|
|
162
|
+
|
|
163
|
+
Ensures validations are passed for vector, data, and link_to_data fields.
|
|
148
164
|
|
|
149
165
|
Arguments:
|
|
150
166
|
---------
|
|
@@ -167,19 +183,16 @@ class Embedding(NamedTuple):
|
|
|
167
183
|
if self.link_to_data is not None:
|
|
168
184
|
self._validate_embedding_link_to_data(emb_name, self.link_to_data)
|
|
169
185
|
|
|
170
|
-
return
|
|
186
|
+
return
|
|
171
187
|
|
|
172
188
|
def _validate_embedding_vector(
|
|
173
189
|
self,
|
|
174
190
|
emb_name: str | int | float,
|
|
175
191
|
) -> None:
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
pandas Series)
|
|
181
|
-
2. List must not be empty
|
|
182
|
-
3. Elements in list must be floats
|
|
192
|
+
"""Validates that the embedding vector passed is of the correct format.
|
|
193
|
+
|
|
194
|
+
Requirements: 1) Type must be list or convertible to list (like numpy arrays,
|
|
195
|
+
pandas Series), 2) List must not be empty, 3) Elements in list must be floats.
|
|
183
196
|
|
|
184
197
|
Arguments:
|
|
185
198
|
---------
|
|
@@ -209,11 +222,11 @@ class Embedding(NamedTuple):
|
|
|
209
222
|
|
|
210
223
|
@staticmethod
|
|
211
224
|
def _validate_embedding_data(
|
|
212
|
-
emb_name: str | int | float, data: str |
|
|
225
|
+
emb_name: str | int | float, data: str | list[str]
|
|
213
226
|
) -> None:
|
|
214
|
-
"""
|
|
215
|
-
|
|
216
|
-
|
|
227
|
+
"""Validates that the embedding raw data field is of the correct format.
|
|
228
|
+
|
|
229
|
+
Requirement: Must be string or list of strings (NLP case).
|
|
217
230
|
|
|
218
231
|
Arguments:
|
|
219
232
|
---------
|
|
@@ -247,7 +260,7 @@ class Embedding(NamedTuple):
|
|
|
247
260
|
f"Embedding data field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
|
|
248
261
|
f"Found {character_count}."
|
|
249
262
|
)
|
|
250
|
-
|
|
263
|
+
if character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
|
|
251
264
|
logger.warning(
|
|
252
265
|
get_truncation_warning_message(
|
|
253
266
|
"Embedding raw data fields",
|
|
@@ -259,9 +272,9 @@ class Embedding(NamedTuple):
|
|
|
259
272
|
def _validate_embedding_link_to_data(
|
|
260
273
|
emb_name: str | int | float, link_to_data: str
|
|
261
274
|
) -> None:
|
|
262
|
-
"""
|
|
263
|
-
|
|
264
|
-
|
|
275
|
+
"""Validates that the embedding link to data field is of the correct format.
|
|
276
|
+
|
|
277
|
+
Requirement: Must be string.
|
|
265
278
|
|
|
266
279
|
Arguments:
|
|
267
280
|
---------
|
|
@@ -282,13 +295,11 @@ class Embedding(NamedTuple):
|
|
|
282
295
|
|
|
283
296
|
@staticmethod
|
|
284
297
|
def _is_valid_iterable(
|
|
285
|
-
data: str |
|
|
298
|
+
data: str | list[str] | list[float] | np.ndarray,
|
|
286
299
|
) -> bool:
|
|
287
|
-
"""
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
2. numpy array or
|
|
291
|
-
3. pandas Series
|
|
300
|
+
"""Validates that the input data field is of the correct iterable type.
|
|
301
|
+
|
|
302
|
+
Accepted types: 1) List, 2) numpy array, or 3) pandas Series.
|
|
292
303
|
|
|
293
304
|
Arguments:
|
|
294
305
|
---------
|
|
@@ -327,12 +338,15 @@ class Embedding(NamedTuple):
|
|
|
327
338
|
|
|
328
339
|
|
|
329
340
|
class LLMRunMetadata(NamedTuple):
|
|
341
|
+
"""Metadata for LLM execution including token counts and latency."""
|
|
342
|
+
|
|
330
343
|
total_token_count: int | None = None
|
|
331
344
|
prompt_token_count: int | None = None
|
|
332
345
|
response_token_count: int | None = None
|
|
333
346
|
response_latency_ms: int | float | None = None
|
|
334
347
|
|
|
335
348
|
def validate(self) -> None:
|
|
349
|
+
"""Validate the field values and constraints."""
|
|
336
350
|
allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
|
|
337
351
|
if not isinstance(self.total_token_count, allowed_types):
|
|
338
352
|
raise InvalidValueType(
|
|
@@ -361,9 +375,9 @@ class LLMRunMetadata(NamedTuple):
|
|
|
361
375
|
|
|
362
376
|
|
|
363
377
|
class ObjectDetectionColumnNames(NamedTuple):
|
|
364
|
-
"""
|
|
365
|
-
|
|
366
|
-
actual schema parameter.
|
|
378
|
+
"""Used to log object detection prediction and actual values.
|
|
379
|
+
|
|
380
|
+
These values are assigned to the prediction or actual schema parameter.
|
|
367
381
|
|
|
368
382
|
Arguments:
|
|
369
383
|
---------
|
|
@@ -385,9 +399,9 @@ class ObjectDetectionColumnNames(NamedTuple):
|
|
|
385
399
|
|
|
386
400
|
|
|
387
401
|
class SemanticSegmentationColumnNames(NamedTuple):
|
|
388
|
-
"""
|
|
389
|
-
|
|
390
|
-
actual schema parameter.
|
|
402
|
+
"""Used to log semantic segmentation prediction and actual values.
|
|
403
|
+
|
|
404
|
+
These values are assigned to the prediction or actual schema parameter.
|
|
391
405
|
|
|
392
406
|
Arguments:
|
|
393
407
|
---------
|
|
@@ -405,8 +419,7 @@ class SemanticSegmentationColumnNames(NamedTuple):
|
|
|
405
419
|
|
|
406
420
|
|
|
407
421
|
class InstanceSegmentationPredictionColumnNames(NamedTuple):
|
|
408
|
-
"""
|
|
409
|
-
Used to log instance segmentation prediction values that are assigned to the prediction schema parameter.
|
|
422
|
+
"""Used to log instance segmentation prediction values for the prediction schema parameter.
|
|
410
423
|
|
|
411
424
|
Arguments:
|
|
412
425
|
---------
|
|
@@ -433,8 +446,7 @@ class InstanceSegmentationPredictionColumnNames(NamedTuple):
|
|
|
433
446
|
|
|
434
447
|
|
|
435
448
|
class InstanceSegmentationActualColumnNames(NamedTuple):
|
|
436
|
-
"""
|
|
437
|
-
Used to log instance segmentation actual values that are assigned to the actual schema parameter.
|
|
449
|
+
"""Used to log instance segmentation actual values that are assigned to the actual schema parameter.
|
|
438
450
|
|
|
439
451
|
Arguments:
|
|
440
452
|
---------
|
|
@@ -455,12 +467,15 @@ class InstanceSegmentationActualColumnNames(NamedTuple):
|
|
|
455
467
|
|
|
456
468
|
|
|
457
469
|
class ObjectDetectionLabel(NamedTuple):
|
|
458
|
-
|
|
459
|
-
|
|
470
|
+
"""Label data for object detection tasks with bounding boxes and categories."""
|
|
471
|
+
|
|
472
|
+
bounding_boxes_coordinates: list[list[float]]
|
|
473
|
+
categories: list[str]
|
|
460
474
|
# Actual Object Detection Labels won't have scores
|
|
461
|
-
scores:
|
|
475
|
+
scores: list[float] | None = None
|
|
462
476
|
|
|
463
|
-
def validate(self, prediction_or_actual: str):
|
|
477
|
+
def validate(self, prediction_or_actual: str) -> None:
|
|
478
|
+
"""Validate the object detection label fields and constraints."""
|
|
464
479
|
# Validate bounding boxes
|
|
465
480
|
self._validate_bounding_boxes_coordinates()
|
|
466
481
|
# Validate categories
|
|
@@ -470,7 +485,7 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
470
485
|
# Validate we have the same number of bounding boxes, categories and scores
|
|
471
486
|
self._validate_count_match()
|
|
472
487
|
|
|
473
|
-
def _validate_bounding_boxes_coordinates(self):
|
|
488
|
+
def _validate_bounding_boxes_coordinates(self) -> None:
|
|
474
489
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
475
490
|
raise TypeError(
|
|
476
491
|
"Object Detection Label bounding boxes must be a list of lists of floats"
|
|
@@ -478,14 +493,14 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
478
493
|
for coordinates in self.bounding_boxes_coordinates:
|
|
479
494
|
_validate_bounding_box_coordinates(coordinates)
|
|
480
495
|
|
|
481
|
-
def _validate_categories(self):
|
|
496
|
+
def _validate_categories(self) -> None:
|
|
482
497
|
# Allows for categories as empty strings
|
|
483
498
|
if not is_list_of(self.categories, str):
|
|
484
499
|
raise TypeError(
|
|
485
500
|
"Object Detection Label categories must be a list of strings"
|
|
486
501
|
)
|
|
487
502
|
|
|
488
|
-
def _validate_scores(self, prediction_or_actual: str):
|
|
503
|
+
def _validate_scores(self, prediction_or_actual: str) -> None:
|
|
489
504
|
if self.scores is None:
|
|
490
505
|
if prediction_or_actual == "prediction":
|
|
491
506
|
raise ValueError(
|
|
@@ -507,7 +522,7 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
507
522
|
f"{self.scores}"
|
|
508
523
|
)
|
|
509
524
|
|
|
510
|
-
def _validate_count_match(self):
|
|
525
|
+
def _validate_count_match(self) -> None:
|
|
511
526
|
n_bounding_boxes = len(self.bounding_boxes_coordinates)
|
|
512
527
|
if n_bounding_boxes == 0:
|
|
513
528
|
raise ValueError(
|
|
@@ -534,10 +549,13 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
534
549
|
|
|
535
550
|
|
|
536
551
|
class SemanticSegmentationLabel(NamedTuple):
|
|
537
|
-
|
|
538
|
-
categories: List[str]
|
|
552
|
+
"""Label data for semantic segmentation with polygon coordinates and categories."""
|
|
539
553
|
|
|
540
|
-
|
|
554
|
+
polygon_coordinates: list[list[float]]
|
|
555
|
+
categories: list[str]
|
|
556
|
+
|
|
557
|
+
def validate(self) -> None:
|
|
558
|
+
"""Validate the field values and constraints."""
|
|
541
559
|
# Validate polygon coordinates
|
|
542
560
|
self._validate_polygon_coordinates()
|
|
543
561
|
# Validate categories
|
|
@@ -545,17 +563,17 @@ class SemanticSegmentationLabel(NamedTuple):
|
|
|
545
563
|
# Validate we have the same number of polygon coordinates and categories
|
|
546
564
|
self._validate_count_match()
|
|
547
565
|
|
|
548
|
-
def _validate_polygon_coordinates(self):
|
|
566
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
549
567
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
550
568
|
|
|
551
|
-
def _validate_categories(self):
|
|
569
|
+
def _validate_categories(self) -> None:
|
|
552
570
|
# Allows for categories as empty strings
|
|
553
571
|
if not is_list_of(self.categories, str):
|
|
554
572
|
raise TypeError(
|
|
555
573
|
"Semantic Segmentation Label categories must be a list of strings"
|
|
556
574
|
)
|
|
557
575
|
|
|
558
|
-
def _validate_count_match(self):
|
|
576
|
+
def _validate_count_match(self) -> None:
|
|
559
577
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
560
578
|
if n_polygon_coordinates == 0:
|
|
561
579
|
raise ValueError(
|
|
@@ -573,12 +591,15 @@ class SemanticSegmentationLabel(NamedTuple):
|
|
|
573
591
|
|
|
574
592
|
|
|
575
593
|
class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
594
|
+
"""Prediction label for instance segmentation with polygons and category information."""
|
|
595
|
+
|
|
596
|
+
polygon_coordinates: list[list[float]]
|
|
597
|
+
categories: list[str]
|
|
598
|
+
scores: list[float] | None = None
|
|
599
|
+
bounding_boxes_coordinates: list[list[float]] | None = None
|
|
580
600
|
|
|
581
|
-
def validate(self):
|
|
601
|
+
def validate(self) -> None:
|
|
602
|
+
"""Validate the field values and constraints."""
|
|
582
603
|
# Validate polygon coordinates
|
|
583
604
|
self._validate_polygon_coordinates()
|
|
584
605
|
# Validate categories
|
|
@@ -590,17 +611,17 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
590
611
|
# Validate we have the same number of polygon coordinates and categories
|
|
591
612
|
self._validate_count_match()
|
|
592
613
|
|
|
593
|
-
def _validate_polygon_coordinates(self):
|
|
614
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
594
615
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
595
616
|
|
|
596
|
-
def _validate_categories(self):
|
|
617
|
+
def _validate_categories(self) -> None:
|
|
597
618
|
# Allows for categories as empty strings
|
|
598
619
|
if not is_list_of(self.categories, str):
|
|
599
620
|
raise TypeError(
|
|
600
621
|
"Instance Segmentation Prediction Label categories must be a list of strings"
|
|
601
622
|
)
|
|
602
623
|
|
|
603
|
-
def _validate_scores(self):
|
|
624
|
+
def _validate_scores(self) -> None:
|
|
604
625
|
if self.scores is not None:
|
|
605
626
|
if not is_list_of(self.scores, float):
|
|
606
627
|
raise TypeError(
|
|
@@ -613,7 +634,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
613
634
|
f"{self.scores}"
|
|
614
635
|
)
|
|
615
636
|
|
|
616
|
-
def _validate_bounding_boxes(self):
|
|
637
|
+
def _validate_bounding_boxes(self) -> None:
|
|
617
638
|
if self.bounding_boxes_coordinates is not None:
|
|
618
639
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
619
640
|
raise TypeError(
|
|
@@ -622,7 +643,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
622
643
|
for coordinates in self.bounding_boxes_coordinates:
|
|
623
644
|
_validate_bounding_box_coordinates(coordinates)
|
|
624
645
|
|
|
625
|
-
def _validate_count_match(self):
|
|
646
|
+
def _validate_count_match(self) -> None:
|
|
626
647
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
627
648
|
if n_polygon_coordinates == 0:
|
|
628
649
|
raise ValueError(
|
|
@@ -657,11 +678,14 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
657
678
|
|
|
658
679
|
|
|
659
680
|
class InstanceSegmentationActualLabel(NamedTuple):
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
681
|
+
"""Actual label for instance segmentation with polygon coordinates and categories."""
|
|
682
|
+
|
|
683
|
+
polygon_coordinates: list[list[float]]
|
|
684
|
+
categories: list[str]
|
|
685
|
+
bounding_boxes_coordinates: list[list[float]] | None = None
|
|
663
686
|
|
|
664
|
-
def validate(self):
|
|
687
|
+
def validate(self) -> None:
|
|
688
|
+
"""Validate the field values and constraints."""
|
|
665
689
|
# Validate polygon coordinates
|
|
666
690
|
self._validate_polygon_coordinates()
|
|
667
691
|
# Validate categories
|
|
@@ -671,17 +695,17 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
671
695
|
# Validate we have the same number of polygon coordinates and categories
|
|
672
696
|
self._validate_count_match()
|
|
673
697
|
|
|
674
|
-
def _validate_polygon_coordinates(self):
|
|
698
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
675
699
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
676
700
|
|
|
677
|
-
def _validate_categories(self):
|
|
701
|
+
def _validate_categories(self) -> None:
|
|
678
702
|
# Allows for categories as empty strings
|
|
679
703
|
if not is_list_of(self.categories, str):
|
|
680
704
|
raise TypeError(
|
|
681
705
|
"Instance Segmentation Actual Label categories must be a list of strings"
|
|
682
706
|
)
|
|
683
707
|
|
|
684
|
-
def _validate_bounding_boxes(self):
|
|
708
|
+
def _validate_bounding_boxes(self) -> None:
|
|
685
709
|
if self.bounding_boxes_coordinates is not None:
|
|
686
710
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
687
711
|
raise TypeError(
|
|
@@ -690,7 +714,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
690
714
|
for coordinates in self.bounding_boxes_coordinates:
|
|
691
715
|
_validate_bounding_box_coordinates(coordinates)
|
|
692
716
|
|
|
693
|
-
def _validate_count_match(self):
|
|
717
|
+
def _validate_count_match(self) -> None:
|
|
694
718
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
695
719
|
if n_polygon_coordinates == 0:
|
|
696
720
|
raise ValueError(
|
|
@@ -717,8 +741,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
717
741
|
|
|
718
742
|
|
|
719
743
|
class MultiClassPredictionLabel(NamedTuple):
|
|
720
|
-
"""
|
|
721
|
-
Used to log multi class prediction label
|
|
744
|
+
"""Used to log multi class prediction label.
|
|
722
745
|
|
|
723
746
|
Arguments:
|
|
724
747
|
---------
|
|
@@ -729,15 +752,16 @@ class MultiClassPredictionLabel(NamedTuple):
|
|
|
729
752
|
|
|
730
753
|
"""
|
|
731
754
|
|
|
732
|
-
prediction_scores:
|
|
733
|
-
threshold_scores:
|
|
755
|
+
prediction_scores: dict[str, float | int]
|
|
756
|
+
threshold_scores: dict[str, float | int] | None = None
|
|
734
757
|
|
|
735
|
-
def validate(self):
|
|
758
|
+
def validate(self) -> None:
|
|
759
|
+
"""Validate the field values and constraints."""
|
|
736
760
|
# Validate scores
|
|
737
761
|
self._validate_prediction_scores()
|
|
738
762
|
self._validate_threshold_scores()
|
|
739
763
|
|
|
740
|
-
def _validate_prediction_scores(self):
|
|
764
|
+
def _validate_prediction_scores(self) -> None:
|
|
741
765
|
# prediction dictionary validations
|
|
742
766
|
if not is_dict_of(
|
|
743
767
|
self.prediction_scores,
|
|
@@ -778,7 +802,7 @@ class MultiClassPredictionLabel(NamedTuple):
|
|
|
778
802
|
"invalid. All scores (values in dictionary) must be between 0 and 1, inclusive."
|
|
779
803
|
)
|
|
780
804
|
|
|
781
|
-
def _validate_threshold_scores(self):
|
|
805
|
+
def _validate_threshold_scores(self) -> None:
|
|
782
806
|
if self.threshold_scores is None or len(self.threshold_scores) == 0:
|
|
783
807
|
return
|
|
784
808
|
if not is_dict_of(
|
|
@@ -822,8 +846,7 @@ class MultiClassPredictionLabel(NamedTuple):
|
|
|
822
846
|
|
|
823
847
|
|
|
824
848
|
class MultiClassActualLabel(NamedTuple):
|
|
825
|
-
"""
|
|
826
|
-
Used to log multi class actual label
|
|
849
|
+
"""Used to log multi class actual label.
|
|
827
850
|
|
|
828
851
|
Arguments:
|
|
829
852
|
---------
|
|
@@ -833,13 +856,14 @@ class MultiClassActualLabel(NamedTuple):
|
|
|
833
856
|
|
|
834
857
|
"""
|
|
835
858
|
|
|
836
|
-
actual_scores:
|
|
859
|
+
actual_scores: dict[str, float | int]
|
|
837
860
|
|
|
838
|
-
def validate(self):
|
|
861
|
+
def validate(self) -> None:
|
|
862
|
+
"""Validate the field values and constraints."""
|
|
839
863
|
# Validate scores
|
|
840
864
|
self._validate_actual_scores()
|
|
841
865
|
|
|
842
|
-
def _validate_actual_scores(self):
|
|
866
|
+
def _validate_actual_scores(self) -> None:
|
|
843
867
|
if not is_dict_of(
|
|
844
868
|
self.actual_scores,
|
|
845
869
|
key_allowed_types=str,
|
|
@@ -879,12 +903,15 @@ class MultiClassActualLabel(NamedTuple):
|
|
|
879
903
|
|
|
880
904
|
|
|
881
905
|
class RankingPredictionLabel(NamedTuple):
|
|
906
|
+
"""Prediction label for ranking tasks with group and rank information."""
|
|
907
|
+
|
|
882
908
|
group_id: str
|
|
883
909
|
rank: int
|
|
884
910
|
score: float | None = None
|
|
885
911
|
label: str | None = None
|
|
886
912
|
|
|
887
|
-
def validate(self):
|
|
913
|
+
def validate(self) -> None:
|
|
914
|
+
"""Validate the field values and constraints."""
|
|
888
915
|
# Validate existence of required fields: prediction_group_id and rank
|
|
889
916
|
if self.group_id is None or self.rank is None:
|
|
890
917
|
raise ValueError(
|
|
@@ -901,7 +928,7 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
901
928
|
if self.score is not None:
|
|
902
929
|
self._validate_score()
|
|
903
930
|
|
|
904
|
-
def _validate_group_id(self):
|
|
931
|
+
def _validate_group_id(self) -> None:
|
|
905
932
|
if not isinstance(self.group_id, str):
|
|
906
933
|
raise TypeError("Prediction Group ID must be a string")
|
|
907
934
|
if not (1 <= len(self.group_id) <= 36):
|
|
@@ -909,7 +936,7 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
909
936
|
f"Prediction Group ID must have length between 1 and 36. Found {len(self.group_id)}"
|
|
910
937
|
)
|
|
911
938
|
|
|
912
|
-
def _validate_rank(self):
|
|
939
|
+
def _validate_rank(self) -> None:
|
|
913
940
|
if not isinstance(self.rank, int):
|
|
914
941
|
raise TypeError("Prediction Rank must be an int")
|
|
915
942
|
if not (1 <= self.rank <= 100):
|
|
@@ -917,22 +944,25 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
917
944
|
f"Prediction Rank must be between 1 and 100, inclusive. Found {self.rank}"
|
|
918
945
|
)
|
|
919
946
|
|
|
920
|
-
def _validate_label(self):
|
|
947
|
+
def _validate_label(self) -> None:
|
|
921
948
|
if not isinstance(self.label, str):
|
|
922
949
|
raise TypeError("Prediction Label must be a str")
|
|
923
950
|
if self.label == "":
|
|
924
951
|
raise ValueError("Prediction Label must not be an empty string.")
|
|
925
952
|
|
|
926
|
-
def _validate_score(self):
|
|
953
|
+
def _validate_score(self) -> None:
|
|
927
954
|
if not isinstance(self.score, (float, int)):
|
|
928
955
|
raise TypeError("Prediction Score must be a float or an int")
|
|
929
956
|
|
|
930
957
|
|
|
931
958
|
class RankingActualLabel(NamedTuple):
|
|
932
|
-
|
|
959
|
+
"""Actual label for ranking tasks with relevance information."""
|
|
960
|
+
|
|
961
|
+
relevance_labels: list[str] | None = None
|
|
933
962
|
relevance_score: float | None = None
|
|
934
963
|
|
|
935
|
-
def validate(self):
|
|
964
|
+
def validate(self) -> None:
|
|
965
|
+
"""Validate the field values and constraints."""
|
|
936
966
|
# Validate relevance_labels type
|
|
937
967
|
if self.relevance_labels is not None:
|
|
938
968
|
self._validate_relevance_labels(self.relevance_labels)
|
|
@@ -941,7 +971,7 @@ class RankingActualLabel(NamedTuple):
|
|
|
941
971
|
self._validate_relevance_score(self.relevance_score)
|
|
942
972
|
|
|
943
973
|
@staticmethod
|
|
944
|
-
def _validate_relevance_labels(relevance_labels:
|
|
974
|
+
def _validate_relevance_labels(relevance_labels: list[str]) -> None:
|
|
945
975
|
if not is_list_of(relevance_labels, str):
|
|
946
976
|
raise TypeError("Actual Relevance Labels must be a list of strings")
|
|
947
977
|
if any(label == "" for label in relevance_labels):
|
|
@@ -950,17 +980,20 @@ class RankingActualLabel(NamedTuple):
|
|
|
950
980
|
)
|
|
951
981
|
|
|
952
982
|
@staticmethod
|
|
953
|
-
def _validate_relevance_score(relevance_score: float):
|
|
983
|
+
def _validate_relevance_score(relevance_score: float) -> None:
|
|
954
984
|
if not isinstance(relevance_score, (float, int)):
|
|
955
985
|
raise TypeError("Actual Relevance score must be a float or an int")
|
|
956
986
|
|
|
957
987
|
|
|
958
988
|
@dataclass
|
|
959
989
|
class PromptTemplateColumnNames:
|
|
990
|
+
"""Column names for prompt template configuration in LLM schemas."""
|
|
991
|
+
|
|
960
992
|
template_column_name: str | None = None
|
|
961
993
|
template_version_column_name: str | None = None
|
|
962
994
|
|
|
963
|
-
def __iter__(self):
|
|
995
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
996
|
+
"""Iterate over the prompt template column names."""
|
|
964
997
|
return iter(
|
|
965
998
|
(self.template_column_name, self.template_version_column_name)
|
|
966
999
|
)
|
|
@@ -968,21 +1001,27 @@ class PromptTemplateColumnNames:
|
|
|
968
1001
|
|
|
969
1002
|
@dataclass
|
|
970
1003
|
class LLMConfigColumnNames:
|
|
1004
|
+
"""Column names for LLM configuration parameters in schemas."""
|
|
1005
|
+
|
|
971
1006
|
model_column_name: str | None = None
|
|
972
1007
|
params_column_name: str | None = None
|
|
973
1008
|
|
|
974
|
-
def __iter__(self):
|
|
1009
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
1010
|
+
"""Iterate over the LLM config column names."""
|
|
975
1011
|
return iter((self.model_column_name, self.params_column_name))
|
|
976
1012
|
|
|
977
1013
|
|
|
978
1014
|
@dataclass
|
|
979
1015
|
class LLMRunMetadataColumnNames:
|
|
1016
|
+
"""Column names for LLM run metadata fields in schemas."""
|
|
1017
|
+
|
|
980
1018
|
total_token_count_column_name: str | None = None
|
|
981
1019
|
prompt_token_count_column_name: str | None = None
|
|
982
1020
|
response_token_count_column_name: str | None = None
|
|
983
1021
|
response_latency_ms_column_name: str | None = None
|
|
984
1022
|
|
|
985
|
-
def __iter__(self):
|
|
1023
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
1024
|
+
"""Iterate over the LLM run metadata column names."""
|
|
986
1025
|
return iter(
|
|
987
1026
|
(
|
|
988
1027
|
self.total_token_count_column_name,
|
|
@@ -1011,11 +1050,19 @@ class LLMRunMetadataColumnNames:
|
|
|
1011
1050
|
#
|
|
1012
1051
|
@dataclass
|
|
1013
1052
|
class SimilarityReference:
|
|
1053
|
+
"""Reference to a prediction for similarity search operations."""
|
|
1054
|
+
|
|
1014
1055
|
prediction_id: str
|
|
1015
1056
|
reference_column_name: str
|
|
1016
1057
|
prediction_timestamp: datetime | None = None
|
|
1017
1058
|
|
|
1018
|
-
def __post_init__(self):
|
|
1059
|
+
def __post_init__(self) -> None:
|
|
1060
|
+
"""Validate similarity reference fields after initialization.
|
|
1061
|
+
|
|
1062
|
+
Raises:
|
|
1063
|
+
ValueError: If prediction_id or reference_column_name is empty.
|
|
1064
|
+
TypeError: If prediction_timestamp is not a datetime object.
|
|
1065
|
+
"""
|
|
1019
1066
|
if self.prediction_id == "":
|
|
1020
1067
|
raise ValueError("prediction id cannot be empty")
|
|
1021
1068
|
if self.reference_column_name == "":
|
|
@@ -1028,11 +1075,20 @@ class SimilarityReference:
|
|
|
1028
1075
|
|
|
1029
1076
|
@dataclass
|
|
1030
1077
|
class SimilaritySearchParams:
|
|
1031
|
-
|
|
1078
|
+
"""Parameters for configuring similarity search operations."""
|
|
1079
|
+
|
|
1080
|
+
references: list[SimilarityReference]
|
|
1032
1081
|
search_column_name: str
|
|
1033
1082
|
threshold: float = 0
|
|
1034
1083
|
|
|
1035
|
-
def __post_init__(self):
|
|
1084
|
+
def __post_init__(self) -> None:
|
|
1085
|
+
"""Validate similarity search parameters after initialization.
|
|
1086
|
+
|
|
1087
|
+
Raises:
|
|
1088
|
+
ValueError: If references list is invalid, search_column_name is
|
|
1089
|
+
empty, or threshold is out of range.
|
|
1090
|
+
TypeError: If any reference is not a SimilarityReference instance.
|
|
1091
|
+
"""
|
|
1036
1092
|
if (
|
|
1037
1093
|
not self.references
|
|
1038
1094
|
or len(self.references) <= 0
|
|
@@ -1054,23 +1110,28 @@ class SimilaritySearchParams:
|
|
|
1054
1110
|
|
|
1055
1111
|
@dataclass(frozen=True)
|
|
1056
1112
|
class BaseSchema:
|
|
1057
|
-
|
|
1113
|
+
"""Base class for all schema definitions with immutable fields."""
|
|
1114
|
+
|
|
1115
|
+
def replace(self, **changes: object) -> Self:
|
|
1116
|
+
"""Return a new instance with specified fields replaced."""
|
|
1058
1117
|
return replace(self, **changes)
|
|
1059
1118
|
|
|
1060
|
-
def asdict(self) ->
|
|
1119
|
+
def asdict(self) -> dict[str, str]:
|
|
1120
|
+
"""Convert the schema to a dictionary."""
|
|
1061
1121
|
return asdict(self)
|
|
1062
1122
|
|
|
1063
|
-
def get_used_columns(self) ->
|
|
1123
|
+
def get_used_columns(self) -> set[str]:
|
|
1124
|
+
"""Return the set of column names used in this schema."""
|
|
1064
1125
|
return set(self.get_used_columns_counts().keys())
|
|
1065
1126
|
|
|
1066
|
-
def get_used_columns_counts(self) ->
|
|
1127
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1128
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1067
1129
|
raise NotImplementedError()
|
|
1068
1130
|
|
|
1069
1131
|
|
|
1070
1132
|
@dataclass(frozen=True)
|
|
1071
1133
|
class TypedColumns:
|
|
1072
|
-
"""
|
|
1073
|
-
Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
|
|
1134
|
+
"""Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
|
|
1074
1135
|
|
|
1075
1136
|
Usage:
|
|
1076
1137
|
------
|
|
@@ -1101,30 +1162,31 @@ class TypedColumns:
|
|
|
1101
1162
|
|
|
1102
1163
|
"""
|
|
1103
1164
|
|
|
1104
|
-
inferred:
|
|
1105
|
-
to_str:
|
|
1106
|
-
to_int:
|
|
1107
|
-
to_float:
|
|
1165
|
+
inferred: list[str] | None = None
|
|
1166
|
+
to_str: list[str] | None = None
|
|
1167
|
+
to_int: list[str] | None = None
|
|
1168
|
+
to_float: list[str] | None = None
|
|
1108
1169
|
|
|
1109
|
-
def get_all_column_names(self) ->
|
|
1170
|
+
def get_all_column_names(self) -> list[str]:
|
|
1171
|
+
"""Return all column names across all conversion lists."""
|
|
1110
1172
|
return list(chain.from_iterable(filter(None, self.__dict__.values())))
|
|
1111
1173
|
|
|
1112
|
-
def has_duplicate_columns(self) ->
|
|
1174
|
+
def has_duplicate_columns(self) -> tuple[bool, set[str]]:
|
|
1175
|
+
"""Check for duplicate columns and return (has_duplicates, duplicate_set)."""
|
|
1113
1176
|
# True if there are duplicates within a field's list or across fields.
|
|
1114
1177
|
# Return a set of the duplicate column names.
|
|
1115
1178
|
cols = self.get_all_column_names()
|
|
1116
|
-
duplicates =
|
|
1179
|
+
duplicates = {x for x in cols if cols.count(x) > 1}
|
|
1117
1180
|
return len(duplicates) > 0, duplicates
|
|
1118
1181
|
|
|
1119
1182
|
def is_empty(self) -> bool:
|
|
1183
|
+
"""Return True if no columns are configured for conversion."""
|
|
1120
1184
|
return not self.get_all_column_names()
|
|
1121
1185
|
|
|
1122
1186
|
|
|
1123
1187
|
@dataclass(frozen=True)
|
|
1124
1188
|
class Schema(BaseSchema):
|
|
1125
|
-
"""
|
|
1126
|
-
Used to organize and map column names containing model data within your Pandas dataframe to
|
|
1127
|
-
Arize.
|
|
1189
|
+
"""Used to organize and map column names containing model data within your Pandas dataframe to Arize.
|
|
1128
1190
|
|
|
1129
1191
|
Arguments:
|
|
1130
1192
|
---------
|
|
@@ -1215,15 +1277,15 @@ class Schema(BaseSchema):
|
|
|
1215
1277
|
"""
|
|
1216
1278
|
|
|
1217
1279
|
prediction_id_column_name: str | None = None
|
|
1218
|
-
feature_column_names:
|
|
1219
|
-
tag_column_names:
|
|
1280
|
+
feature_column_names: list[str] | TypedColumns | None = None
|
|
1281
|
+
tag_column_names: list[str] | TypedColumns | None = None
|
|
1220
1282
|
timestamp_column_name: str | None = None
|
|
1221
1283
|
prediction_label_column_name: str | None = None
|
|
1222
1284
|
prediction_score_column_name: str | None = None
|
|
1223
1285
|
actual_label_column_name: str | None = None
|
|
1224
1286
|
actual_score_column_name: str | None = None
|
|
1225
|
-
shap_values_column_names:
|
|
1226
|
-
embedding_feature_column_names:
|
|
1287
|
+
shap_values_column_names: dict[str, str] | None = None
|
|
1288
|
+
embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
|
|
1227
1289
|
None # type:ignore
|
|
1228
1290
|
)
|
|
1229
1291
|
prediction_group_id_column_name: str | None = None
|
|
@@ -1242,7 +1304,7 @@ class Schema(BaseSchema):
|
|
|
1242
1304
|
prompt_template_column_names: PromptTemplateColumnNames | None = None
|
|
1243
1305
|
llm_config_column_names: LLMConfigColumnNames | None = None
|
|
1244
1306
|
llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
|
|
1245
|
-
retrieved_document_ids_column_name:
|
|
1307
|
+
retrieved_document_ids_column_name: list[str] | None = None
|
|
1246
1308
|
multi_class_threshold_scores_column_name: str | None = None
|
|
1247
1309
|
semantic_segmentation_prediction_column_names: (
|
|
1248
1310
|
SemanticSegmentationColumnNames | None
|
|
@@ -1257,7 +1319,8 @@ class Schema(BaseSchema):
|
|
|
1257
1319
|
InstanceSegmentationActualColumnNames | None
|
|
1258
1320
|
) = None
|
|
1259
1321
|
|
|
1260
|
-
def get_used_columns_counts(self) ->
|
|
1322
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1323
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1261
1324
|
columns_used_counts = {}
|
|
1262
1325
|
|
|
1263
1326
|
for field in self.__dataclass_fields__:
|
|
@@ -1364,6 +1427,7 @@ class Schema(BaseSchema):
|
|
|
1364
1427
|
return columns_used_counts
|
|
1365
1428
|
|
|
1366
1429
|
def has_prediction_columns(self) -> bool:
|
|
1430
|
+
"""Return True if prediction columns are configured."""
|
|
1367
1431
|
prediction_cols = (
|
|
1368
1432
|
self.prediction_label_column_name,
|
|
1369
1433
|
self.prediction_score_column_name,
|
|
@@ -1377,6 +1441,7 @@ class Schema(BaseSchema):
|
|
|
1377
1441
|
return any(col is not None for col in prediction_cols)
|
|
1378
1442
|
|
|
1379
1443
|
def has_actual_columns(self) -> bool:
|
|
1444
|
+
"""Return True if actual label columns are configured."""
|
|
1380
1445
|
actual_cols = (
|
|
1381
1446
|
self.actual_label_column_name,
|
|
1382
1447
|
self.actual_score_column_name,
|
|
@@ -1389,13 +1454,16 @@ class Schema(BaseSchema):
|
|
|
1389
1454
|
return any(col is not None for col in actual_cols)
|
|
1390
1455
|
|
|
1391
1456
|
def has_feature_importance_columns(self) -> bool:
|
|
1457
|
+
"""Return True if feature importance columns are configured."""
|
|
1392
1458
|
feature_importance_cols = (self.shap_values_column_names,)
|
|
1393
1459
|
return any(col is not None for col in feature_importance_cols)
|
|
1394
1460
|
|
|
1395
1461
|
def has_typed_columns(self) -> bool:
|
|
1462
|
+
"""Return True if typed columns are configured."""
|
|
1396
1463
|
return any(self.typed_column_fields())
|
|
1397
1464
|
|
|
1398
|
-
def typed_column_fields(self) ->
|
|
1465
|
+
def typed_column_fields(self) -> set[str]:
|
|
1466
|
+
"""Return the set of field names with typed columns."""
|
|
1399
1467
|
return {
|
|
1400
1468
|
field
|
|
1401
1469
|
for field in self.__dataclass_fields__
|
|
@@ -1403,9 +1471,9 @@ class Schema(BaseSchema):
|
|
|
1403
1471
|
}
|
|
1404
1472
|
|
|
1405
1473
|
def is_delayed(self) -> bool:
|
|
1406
|
-
"""
|
|
1407
|
-
|
|
1408
|
-
|
|
1474
|
+
"""Check if the schema has inherently latent information.
|
|
1475
|
+
|
|
1476
|
+
Determines this based on the columns provided by the user.
|
|
1409
1477
|
|
|
1410
1478
|
Returns:
|
|
1411
1479
|
bool: True if the schema is "delayed", i.e., does not possess prediction
|
|
@@ -1418,11 +1486,14 @@ class Schema(BaseSchema):
|
|
|
1418
1486
|
|
|
1419
1487
|
@dataclass(frozen=True)
|
|
1420
1488
|
class CorpusSchema(BaseSchema):
|
|
1489
|
+
"""Schema for corpus data with document identification and content columns."""
|
|
1490
|
+
|
|
1421
1491
|
document_id_column_name: str | None = None
|
|
1422
1492
|
document_version_column_name: str | None = None
|
|
1423
1493
|
document_text_embedding_column_names: EmbeddingColumnNames | None = None
|
|
1424
1494
|
|
|
1425
|
-
def get_used_columns_counts(self) ->
|
|
1495
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1496
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1426
1497
|
columns_used_counts = {}
|
|
1427
1498
|
|
|
1428
1499
|
if self.document_id_column_name is not None:
|
|
@@ -1459,6 +1530,8 @@ class CorpusSchema(BaseSchema):
|
|
|
1459
1530
|
|
|
1460
1531
|
@unique
|
|
1461
1532
|
class ArizeTypes(Enum):
|
|
1533
|
+
"""Enum representing supported data types in Arize platform."""
|
|
1534
|
+
|
|
1462
1535
|
STR = 0
|
|
1463
1536
|
FLOAT = 1
|
|
1464
1537
|
INT = 2
|
|
@@ -1466,11 +1539,21 @@ class ArizeTypes(Enum):
|
|
|
1466
1539
|
|
|
1467
1540
|
@dataclass(frozen=True)
|
|
1468
1541
|
class TypedValue:
|
|
1542
|
+
"""Container for a value with its associated Arize type."""
|
|
1543
|
+
|
|
1469
1544
|
type: ArizeTypes
|
|
1470
1545
|
value: str | bool | float | int
|
|
1471
1546
|
|
|
1472
1547
|
|
|
1473
1548
|
def is_json_str(s: str) -> bool:
|
|
1549
|
+
"""Check if a string is valid JSON.
|
|
1550
|
+
|
|
1551
|
+
Args:
|
|
1552
|
+
s: The string to validate.
|
|
1553
|
+
|
|
1554
|
+
Returns:
|
|
1555
|
+
True if the string is valid JSON, False otherwise.
|
|
1556
|
+
"""
|
|
1474
1557
|
try:
|
|
1475
1558
|
json.loads(s)
|
|
1476
1559
|
except ValueError:
|
|
@@ -1484,25 +1567,51 @@ T = TypeVar("T", bound=type)
|
|
|
1484
1567
|
|
|
1485
1568
|
|
|
1486
1569
|
def is_array_of(arr: Sequence[object], tp: T) -> bool:
|
|
1570
|
+
"""Check if a value is a numpy array with all elements of a specific type.
|
|
1571
|
+
|
|
1572
|
+
Args:
|
|
1573
|
+
arr: The sequence to check.
|
|
1574
|
+
tp: The expected type for all elements.
|
|
1575
|
+
|
|
1576
|
+
Returns:
|
|
1577
|
+
True if arr is a numpy array and all elements are of type tp.
|
|
1578
|
+
"""
|
|
1487
1579
|
return isinstance(arr, np.ndarray) and all(isinstance(x, tp) for x in arr)
|
|
1488
1580
|
|
|
1489
1581
|
|
|
1490
1582
|
def is_list_of(lst: Sequence[object], tp: T) -> bool:
|
|
1583
|
+
"""Check if a value is a list with all elements of a specific type.
|
|
1584
|
+
|
|
1585
|
+
Args:
|
|
1586
|
+
lst: The sequence to check.
|
|
1587
|
+
tp: The expected type for all elements.
|
|
1588
|
+
|
|
1589
|
+
Returns:
|
|
1590
|
+
True if lst is a list and all elements are of type tp.
|
|
1591
|
+
"""
|
|
1491
1592
|
return isinstance(lst, list) and all(isinstance(x, tp) for x in lst)
|
|
1492
1593
|
|
|
1493
1594
|
|
|
1494
1595
|
def is_iterable_of(lst: Sequence[object], tp: T) -> bool:
|
|
1596
|
+
"""Check if a value is an iterable with all elements of a specific type.
|
|
1597
|
+
|
|
1598
|
+
Args:
|
|
1599
|
+
lst: The sequence to check.
|
|
1600
|
+
tp: The expected type for all elements.
|
|
1601
|
+
|
|
1602
|
+
Returns:
|
|
1603
|
+
True if lst is an iterable and all elements are of type tp.
|
|
1604
|
+
"""
|
|
1495
1605
|
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)
|
|
1496
1606
|
|
|
1497
1607
|
|
|
1498
1608
|
def is_dict_of(
|
|
1499
|
-
d:
|
|
1609
|
+
d: dict[object, object],
|
|
1500
1610
|
key_allowed_types: T,
|
|
1501
1611
|
value_allowed_types: T = (),
|
|
1502
1612
|
value_list_allowed_types: T = (),
|
|
1503
1613
|
) -> bool:
|
|
1504
|
-
"""
|
|
1505
|
-
Method to check types are valid for dictionary.
|
|
1614
|
+
"""Method to check types are valid for dictionary.
|
|
1506
1615
|
|
|
1507
1616
|
Arguments:
|
|
1508
1617
|
---------
|
|
@@ -1535,7 +1644,7 @@ def is_dict_of(
|
|
|
1535
1644
|
)
|
|
1536
1645
|
|
|
1537
1646
|
|
|
1538
|
-
def _count_characters_raw_data(data: str |
|
|
1647
|
+
def _count_characters_raw_data(data: str | list[str]) -> int:
|
|
1539
1648
|
character_count = 0
|
|
1540
1649
|
if isinstance(data, str):
|
|
1541
1650
|
character_count = len(data)
|
|
@@ -1551,8 +1660,14 @@ def _count_characters_raw_data(data: str | List[str]) -> int:
|
|
|
1551
1660
|
|
|
1552
1661
|
|
|
1553
1662
|
def add_to_column_count_dictionary(
|
|
1554
|
-
column_dictionary:
|
|
1555
|
-
):
|
|
1663
|
+
column_dictionary: dict[str, int], col: str | None
|
|
1664
|
+
) -> None:
|
|
1665
|
+
"""Increment the count for a column name in a dictionary.
|
|
1666
|
+
|
|
1667
|
+
Args:
|
|
1668
|
+
column_dictionary: Dictionary mapping column names to counts.
|
|
1669
|
+
col: The column name to increment, or None to skip.
|
|
1670
|
+
"""
|
|
1556
1671
|
if col:
|
|
1557
1672
|
if col in column_dictionary:
|
|
1558
1673
|
column_dictionary[col] += 1
|
|
@@ -1560,7 +1675,9 @@ def add_to_column_count_dictionary(
|
|
|
1560
1675
|
column_dictionary[col] = 1
|
|
1561
1676
|
|
|
1562
1677
|
|
|
1563
|
-
def _validate_bounding_box_coordinates(
|
|
1678
|
+
def _validate_bounding_box_coordinates(
|
|
1679
|
+
bounding_box_coordinates: list[float],
|
|
1680
|
+
) -> None:
|
|
1564
1681
|
if not is_list_of(bounding_box_coordinates, float):
|
|
1565
1682
|
raise TypeError(
|
|
1566
1683
|
"Each bounding box's coordinates must be a lists of floats"
|
|
@@ -1586,10 +1703,12 @@ def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
|
|
|
1586
1703
|
f"top-left. Found {bounding_box_coordinates}"
|
|
1587
1704
|
)
|
|
1588
1705
|
|
|
1589
|
-
return
|
|
1706
|
+
return
|
|
1590
1707
|
|
|
1591
1708
|
|
|
1592
|
-
def _validate_polygon_coordinates(
|
|
1709
|
+
def _validate_polygon_coordinates(
|
|
1710
|
+
polygon_coordinates: list[list[float]],
|
|
1711
|
+
) -> None:
|
|
1593
1712
|
if not is_list_of(polygon_coordinates, list):
|
|
1594
1713
|
raise TypeError("Polygon coordinates must be a list of lists of floats")
|
|
1595
1714
|
for coordinates in polygon_coordinates:
|
|
@@ -1651,27 +1770,41 @@ def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
|
|
|
1651
1770
|
f"{coordinates}"
|
|
1652
1771
|
)
|
|
1653
1772
|
|
|
1654
|
-
return
|
|
1773
|
+
return
|
|
1655
1774
|
|
|
1656
1775
|
|
|
1657
|
-
def segments_intersect(
|
|
1658
|
-
|
|
1659
|
-
|
|
1776
|
+
def segments_intersect(
|
|
1777
|
+
p1: tuple[float, float],
|
|
1778
|
+
p2: tuple[float, float],
|
|
1779
|
+
p3: tuple[float, float],
|
|
1780
|
+
p4: tuple[float, float],
|
|
1781
|
+
) -> bool:
|
|
1782
|
+
"""Check if two line segments intersect.
|
|
1660
1783
|
|
|
1661
1784
|
Args:
|
|
1662
|
-
p1
|
|
1663
|
-
|
|
1785
|
+
p1: First endpoint of the first line segment (x,y)
|
|
1786
|
+
p2: Second endpoint of the first line segment (x,y)
|
|
1787
|
+
p3: First endpoint of the second line segment (x,y)
|
|
1788
|
+
p4: Second endpoint of the second line segment (x,y)
|
|
1664
1789
|
|
|
1665
1790
|
Returns:
|
|
1666
1791
|
True if the line segments intersect, False otherwise
|
|
1667
1792
|
"""
|
|
1668
1793
|
|
|
1669
1794
|
# Function to calculate direction
|
|
1670
|
-
def orientation(
|
|
1795
|
+
def orientation(
|
|
1796
|
+
p: tuple[float, float],
|
|
1797
|
+
q: tuple[float, float],
|
|
1798
|
+
r: tuple[float, float],
|
|
1799
|
+
) -> float:
|
|
1671
1800
|
return (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1])
|
|
1672
1801
|
|
|
1673
1802
|
# Function to check if point q is on segment pr
|
|
1674
|
-
def on_segment(
|
|
1803
|
+
def on_segment(
|
|
1804
|
+
p: tuple[float, float],
|
|
1805
|
+
q: tuple[float, float],
|
|
1806
|
+
r: tuple[float, float],
|
|
1807
|
+
) -> bool:
|
|
1675
1808
|
return (
|
|
1676
1809
|
q[0] <= max(p[0], r[0])
|
|
1677
1810
|
and q[0] >= min(p[0], r[0])
|
|
@@ -1703,17 +1836,20 @@ def segments_intersect(p1, p2, p3, p4):
|
|
|
1703
1836
|
|
|
1704
1837
|
@unique
|
|
1705
1838
|
class StatusCodes(Enum):
|
|
1839
|
+
"""Enum representing status codes for operations and responses."""
|
|
1840
|
+
|
|
1706
1841
|
UNSET = 0
|
|
1707
1842
|
OK = 1
|
|
1708
1843
|
ERROR = 2
|
|
1709
1844
|
|
|
1710
1845
|
@classmethod
|
|
1711
|
-
def list_codes(cls):
|
|
1846
|
+
def list_codes(cls) -> list[str]:
|
|
1847
|
+
"""Return a list of all status code names."""
|
|
1712
1848
|
return [t.name for t in cls]
|
|
1713
1849
|
|
|
1714
1850
|
|
|
1715
|
-
def convert_element(value):
|
|
1716
|
-
"""Converts scalar or array to python native"""
|
|
1851
|
+
def convert_element(value: object) -> object:
|
|
1852
|
+
"""Converts scalar or array to python native."""
|
|
1717
1853
|
val = getattr(value, "tolist", lambda: value)()
|
|
1718
1854
|
# Check if it's a list since elements from pd indices are converted to a
|
|
1719
1855
|
# scalar whereas pd series/dataframe elements are converted to list of 1
|
|
@@ -1734,7 +1870,7 @@ PredictionLabelTypes = (
|
|
|
1734
1870
|
| bool
|
|
1735
1871
|
| int
|
|
1736
1872
|
| float
|
|
1737
|
-
|
|
|
1873
|
+
| tuple[str, float]
|
|
1738
1874
|
| ObjectDetectionLabel
|
|
1739
1875
|
| RankingPredictionLabel
|
|
1740
1876
|
| MultiClassPredictionLabel
|
|
@@ -1745,7 +1881,7 @@ ActualLabelTypes = (
|
|
|
1745
1881
|
| bool
|
|
1746
1882
|
| int
|
|
1747
1883
|
| float
|
|
1748
|
-
|
|
|
1884
|
+
| tuple[str, float]
|
|
1749
1885
|
| ObjectDetectionLabel
|
|
1750
1886
|
| RankingActualLabel
|
|
1751
1887
|
| MultiClassActualLabel
|