arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +28 -19
- arize/_exporter/client.py +56 -37
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +207 -76
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +181 -58
- arize/config.py +324 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +304 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +43 -18
- arize/embeddings/tabular_generators.py +46 -31
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +13 -0
- arize/experiments/client.py +394 -285
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/ml/__init__.py +1 -0
- arize/ml/batch_validation/__init__.py +1 -0
- arize/{models → ml}/batch_validation/errors.py +545 -67
- arize/{models → ml}/batch_validation/validator.py +344 -303
- arize/ml/bounded_executor.py +47 -0
- arize/{models → ml}/casting.py +118 -108
- arize/{models → ml}/client.py +339 -118
- arize/{models → ml}/proto.py +97 -42
- arize/{models → ml}/stream_validation.py +43 -15
- arize/ml/surrogate_explainer/__init__.py +1 -0
- arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
- arize/{types.py → ml/types.py} +355 -354
- arize/pre_releases.py +44 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +134 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +204 -175
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +60 -37
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +81 -14
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +35 -14
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +78 -8
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +20 -3
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/utils/types.py +105 -0
- arize/version.py +3 -1
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
- arize-8.0.0b0.dist-info/RECORD +175 -0
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
- arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize/models/__init__.py +0 -0
- arize/models/batch_validation/__init__.py +0 -0
- arize/models/bounded_executor.py +0 -34
- arize/models/surrogate_explainer/__init__.py +0 -0
- arize-8.0.0a22.dist-info/RECORD +0 -146
- arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/{types.py → ml/types.py}
RENAMED
|
@@ -1,20 +1,16 @@
|
|
|
1
|
-
|
|
1
|
+
"""Common type definitions and data models used across the ML Client."""
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
4
|
import math
|
|
5
|
+
from collections.abc import Iterator
|
|
4
6
|
from dataclasses import asdict, dataclass, replace
|
|
5
7
|
from datetime import datetime
|
|
6
8
|
from decimal import Decimal
|
|
7
9
|
from enum import Enum, unique
|
|
8
10
|
from itertools import chain
|
|
9
11
|
from typing import (
|
|
10
|
-
Dict,
|
|
11
|
-
Iterable,
|
|
12
|
-
List,
|
|
13
12
|
NamedTuple,
|
|
14
|
-
|
|
15
|
-
Set,
|
|
16
|
-
Tuple,
|
|
17
|
-
TypeVar,
|
|
13
|
+
Self,
|
|
18
14
|
)
|
|
19
15
|
|
|
20
16
|
import numpy as np
|
|
@@ -42,12 +38,15 @@ from arize.exceptions.parameters import InvalidValueType
|
|
|
42
38
|
# )
|
|
43
39
|
# from arize.utils.errors import InvalidValueType
|
|
44
40
|
from arize.logging import get_truncation_warning_message
|
|
41
|
+
from arize.utils.types import is_dict_of, is_iterable_of, is_list_of
|
|
45
42
|
|
|
46
43
|
logger = logging.getLogger(__name__)
|
|
47
44
|
|
|
48
45
|
|
|
49
46
|
@unique
|
|
50
47
|
class ModelTypes(Enum):
|
|
48
|
+
"""Enum representing supported model types in Arize."""
|
|
49
|
+
|
|
51
50
|
NUMERIC = 1
|
|
52
51
|
SCORE_CATEGORICAL = 2
|
|
53
52
|
RANKING = 3
|
|
@@ -58,7 +57,8 @@ class ModelTypes(Enum):
|
|
|
58
57
|
MULTI_CLASS = 8
|
|
59
58
|
|
|
60
59
|
@classmethod
|
|
61
|
-
def list_types(cls):
|
|
60
|
+
def list_types(cls) -> list[str]:
|
|
61
|
+
"""Return a list of all type names in this enum."""
|
|
62
62
|
return [t.name for t in cls]
|
|
63
63
|
|
|
64
64
|
|
|
@@ -70,7 +70,10 @@ CATEGORICAL_MODEL_TYPES = [
|
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
class DocEnum(Enum):
|
|
73
|
-
|
|
73
|
+
"""Enum subclass supporting inline documentation for enum members."""
|
|
74
|
+
|
|
75
|
+
def __new__(cls, value: object, doc: str | None = None) -> Self:
|
|
76
|
+
"""Create a new enum instance with optional documentation."""
|
|
74
77
|
self = object.__new__(
|
|
75
78
|
cls
|
|
76
79
|
) # calling super().__new__(value) here would fail
|
|
@@ -80,13 +83,13 @@ class DocEnum(Enum):
|
|
|
80
83
|
return self
|
|
81
84
|
|
|
82
85
|
def __repr__(self) -> str:
|
|
86
|
+
"""Return a string representation including documentation."""
|
|
83
87
|
return f"{self.name} metrics include: {self.__doc__}"
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
@unique
|
|
87
91
|
class Metrics(DocEnum):
|
|
88
|
-
"""
|
|
89
|
-
Metric groupings, used for validation of schema columns in log() call.
|
|
92
|
+
"""Metric groupings, used for validation of schema columns in log() call.
|
|
90
93
|
|
|
91
94
|
See docstring descriptions of the Enum with __doc__ or __repr__(), e.g.:
|
|
92
95
|
Metrics.RANKING.__doc__
|
|
@@ -105,6 +108,8 @@ class Metrics(DocEnum):
|
|
|
105
108
|
|
|
106
109
|
@unique
|
|
107
110
|
class Environments(Enum):
|
|
111
|
+
"""Enum representing deployment environments for models."""
|
|
112
|
+
|
|
108
113
|
TRAINING = 1
|
|
109
114
|
VALIDATION = 2
|
|
110
115
|
PRODUCTION = 3
|
|
@@ -114,11 +119,18 @@ class Environments(Enum):
|
|
|
114
119
|
|
|
115
120
|
@dataclass
|
|
116
121
|
class EmbeddingColumnNames:
|
|
122
|
+
"""Column names for embedding feature data."""
|
|
123
|
+
|
|
117
124
|
vector_column_name: str = ""
|
|
118
125
|
data_column_name: str | None = None
|
|
119
126
|
link_to_data_column_name: str | None = None
|
|
120
127
|
|
|
121
|
-
def __post_init__(self):
|
|
128
|
+
def __post_init__(self) -> None:
|
|
129
|
+
"""Validate that vector column name is specified.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If vector_column_name is empty.
|
|
133
|
+
"""
|
|
122
134
|
if not self.vector_column_name:
|
|
123
135
|
raise ValueError(
|
|
124
136
|
"embedding_features require a vector to be specified. You can "
|
|
@@ -126,7 +138,8 @@ class EmbeddingColumnNames:
|
|
|
126
138
|
"(from arize.pandas.embeddings) if you do not have them"
|
|
127
139
|
)
|
|
128
140
|
|
|
129
|
-
def __iter__(self):
|
|
141
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
142
|
+
"""Iterate over the embedding column names."""
|
|
130
143
|
return iter(
|
|
131
144
|
(
|
|
132
145
|
self.vector_column_name,
|
|
@@ -137,24 +150,23 @@ class EmbeddingColumnNames:
|
|
|
137
150
|
|
|
138
151
|
|
|
139
152
|
class Embedding(NamedTuple):
|
|
140
|
-
vector
|
|
141
|
-
|
|
153
|
+
"""Container for embedding vector data with optional raw data and links."""
|
|
154
|
+
|
|
155
|
+
vector: list[float]
|
|
156
|
+
data: str | list[str] | None = None
|
|
142
157
|
link_to_data: str | None = None
|
|
143
158
|
|
|
144
159
|
def validate(self, emb_name: str | int | float) -> None:
|
|
145
|
-
"""
|
|
146
|
-
Validates that the embedding object passed is of the correct format.
|
|
147
|
-
That is, validations must be passed for vector, data & link_to_data.
|
|
160
|
+
"""Validates that the embedding object passed is of the correct format.
|
|
148
161
|
|
|
149
|
-
|
|
150
|
-
---------
|
|
151
|
-
emb_name (str, int, float): Name of the embedding feature the
|
|
152
|
-
vector belongs to
|
|
162
|
+
Ensures validations are passed for vector, data, and link_to_data fields.
|
|
153
163
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
164
|
+
Args:
|
|
165
|
+
emb_name: Name of the embedding feature the
|
|
166
|
+
vector belongs to.
|
|
157
167
|
|
|
168
|
+
Raises:
|
|
169
|
+
TypeError: If the embedding fields are of the wrong type.
|
|
158
170
|
"""
|
|
159
171
|
if self.vector is not None:
|
|
160
172
|
self._validate_embedding_vector(emb_name)
|
|
@@ -167,29 +179,23 @@ class Embedding(NamedTuple):
|
|
|
167
179
|
if self.link_to_data is not None:
|
|
168
180
|
self._validate_embedding_link_to_data(emb_name, self.link_to_data)
|
|
169
181
|
|
|
170
|
-
return
|
|
182
|
+
return
|
|
171
183
|
|
|
172
184
|
def _validate_embedding_vector(
|
|
173
185
|
self,
|
|
174
186
|
emb_name: str | int | float,
|
|
175
187
|
) -> None:
|
|
176
|
-
"""
|
|
177
|
-
Validates that the embedding vector passed is of the correct format.
|
|
178
|
-
That is:
|
|
179
|
-
1. Type must be list or convertible to list (like numpy arrays,
|
|
180
|
-
pandas Series)
|
|
181
|
-
2. List must not be empty
|
|
182
|
-
3. Elements in list must be floats
|
|
183
|
-
|
|
184
|
-
Arguments:
|
|
185
|
-
---------
|
|
186
|
-
emb_name (str, int, float): Name of the embedding feature the vector
|
|
187
|
-
belongs to
|
|
188
|
+
"""Validates that the embedding vector passed is of the correct format.
|
|
188
189
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
TypeError: If the embedding does not satisfy requirements above
|
|
190
|
+
Requirements: 1) Type must be list or convertible to list (like numpy arrays,
|
|
191
|
+
pandas Series), 2) List must not be empty, 3) Elements in list must be floats.
|
|
192
192
|
|
|
193
|
+
Args:
|
|
194
|
+
emb_name: Name of the embedding feature the vector
|
|
195
|
+
belongs to.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
TypeError: If the embedding does not satisfy requirements above.
|
|
193
199
|
"""
|
|
194
200
|
if not Embedding._is_valid_iterable(self.vector):
|
|
195
201
|
raise TypeError(
|
|
@@ -209,21 +215,19 @@ class Embedding(NamedTuple):
|
|
|
209
215
|
|
|
210
216
|
@staticmethod
|
|
211
217
|
def _validate_embedding_data(
|
|
212
|
-
emb_name: str | int | float, data: str |
|
|
218
|
+
emb_name: str | int | float, data: str | list[str]
|
|
213
219
|
) -> None:
|
|
214
|
-
"""
|
|
215
|
-
Validates that the embedding raw data field is of the correct format. That is:
|
|
216
|
-
1. Must be string or list of strings (NLP case)
|
|
220
|
+
"""Validates that the embedding raw data field is of the correct format.
|
|
217
221
|
|
|
218
|
-
|
|
219
|
-
---------
|
|
220
|
-
emb_name (str, int, float): Name of the embedding feature the vector belongs to
|
|
221
|
-
data (str, int, float): Raw data associated with the embedding feature. Typically raw text.
|
|
222
|
+
Requirement: Must be string or list of strings (NLP case).
|
|
222
223
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
224
|
+
Args:
|
|
225
|
+
emb_name: Name of the embedding feature the vector belongs to.
|
|
226
|
+
data: Raw data associated with the embedding feature.
|
|
227
|
+
Typically raw text.
|
|
226
228
|
|
|
229
|
+
Raises:
|
|
230
|
+
TypeError: If the embedding does not satisfy requirements above.
|
|
227
231
|
"""
|
|
228
232
|
# Validate that data is a string or iterable of strings
|
|
229
233
|
is_string = isinstance(data, str)
|
|
@@ -247,7 +251,7 @@ class Embedding(NamedTuple):
|
|
|
247
251
|
f"Embedding data field must not contain more than {MAX_RAW_DATA_CHARACTERS} characters. "
|
|
248
252
|
f"Found {character_count}."
|
|
249
253
|
)
|
|
250
|
-
|
|
254
|
+
if character_count > MAX_RAW_DATA_CHARACTERS_TRUNCATION:
|
|
251
255
|
logger.warning(
|
|
252
256
|
get_truncation_warning_message(
|
|
253
257
|
"Embedding raw data fields",
|
|
@@ -259,20 +263,17 @@ class Embedding(NamedTuple):
|
|
|
259
263
|
def _validate_embedding_link_to_data(
|
|
260
264
|
emb_name: str | int | float, link_to_data: str
|
|
261
265
|
) -> None:
|
|
262
|
-
"""
|
|
263
|
-
Validates that the embedding link to data field is of the correct format. That is:
|
|
264
|
-
1. Must be string
|
|
266
|
+
"""Validates that the embedding link to data field is of the correct format.
|
|
265
267
|
|
|
266
|
-
|
|
267
|
-
---------
|
|
268
|
-
emb_name (str, int, float): Name of the embedding feature the vector belongs to
|
|
269
|
-
link_to_data (str): Link to source data of embedding feature, typically an image file on
|
|
270
|
-
cloud storage
|
|
268
|
+
Requirement: Must be string.
|
|
271
269
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
270
|
+
Args:
|
|
271
|
+
emb_name: Name of the embedding feature the vector belongs to.
|
|
272
|
+
link_to_data: Link to source data of embedding feature, typically an
|
|
273
|
+
image file on cloud storage.
|
|
275
274
|
|
|
275
|
+
Raises:
|
|
276
|
+
TypeError: If the embedding does not satisfy requirements above.
|
|
276
277
|
"""
|
|
277
278
|
if not isinstance(link_to_data, str):
|
|
278
279
|
raise TypeError(
|
|
@@ -282,22 +283,18 @@ class Embedding(NamedTuple):
|
|
|
282
283
|
|
|
283
284
|
@staticmethod
|
|
284
285
|
def _is_valid_iterable(
|
|
285
|
-
data: str |
|
|
286
|
+
data: str | list[str] | list[float] | np.ndarray,
|
|
286
287
|
) -> bool:
|
|
287
|
-
"""
|
|
288
|
-
Validates that the input data field is of the correct iterable type. That is:
|
|
289
|
-
1. List or
|
|
290
|
-
2. numpy array or
|
|
291
|
-
3. pandas Series
|
|
288
|
+
"""Validates that the input data field is of the correct iterable type.
|
|
292
289
|
|
|
293
|
-
|
|
294
|
-
---------
|
|
295
|
-
data: input iterable
|
|
290
|
+
Accepted types: 1) List, 2) numpy array, or 3) pandas Series.
|
|
296
291
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
True if the data type is one of the accepted iterable types, false otherwise
|
|
292
|
+
Args:
|
|
293
|
+
data: Input iterable.
|
|
300
294
|
|
|
295
|
+
Returns:
|
|
296
|
+
True if the data type is one of the accepted iterable types,
|
|
297
|
+
false otherwise.
|
|
301
298
|
"""
|
|
302
299
|
return any(isinstance(data, t) for t in (list, np.ndarray))
|
|
303
300
|
|
|
@@ -327,12 +324,15 @@ class Embedding(NamedTuple):
|
|
|
327
324
|
|
|
328
325
|
|
|
329
326
|
class LLMRunMetadata(NamedTuple):
|
|
327
|
+
"""Metadata for LLM execution including token counts and latency."""
|
|
328
|
+
|
|
330
329
|
total_token_count: int | None = None
|
|
331
330
|
prompt_token_count: int | None = None
|
|
332
331
|
response_token_count: int | None = None
|
|
333
332
|
response_latency_ms: int | float | None = None
|
|
334
333
|
|
|
335
334
|
def validate(self) -> None:
|
|
335
|
+
"""Validate the field values and constraints."""
|
|
336
336
|
allowed_types = (int, float, np.int16, np.int32, np.float16, np.float32)
|
|
337
337
|
if not isinstance(self.total_token_count, allowed_types):
|
|
338
338
|
raise InvalidValueType(
|
|
@@ -361,22 +361,20 @@ class LLMRunMetadata(NamedTuple):
|
|
|
361
361
|
|
|
362
362
|
|
|
363
363
|
class ObjectDetectionColumnNames(NamedTuple):
|
|
364
|
-
"""
|
|
365
|
-
|
|
366
|
-
actual schema parameter.
|
|
364
|
+
"""Used to log object detection prediction and actual values.
|
|
365
|
+
|
|
366
|
+
These values are assigned to the prediction or actual schema parameter.
|
|
367
367
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
bounding_boxes_coordinates_column_name (str): Column name containing the coordinates of the
|
|
368
|
+
Args:
|
|
369
|
+
bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
|
|
371
370
|
rectangular outline that locates an object within an image or video. Pascal VOC format
|
|
372
371
|
required. The contents of this column must be a List[List[float]].
|
|
373
|
-
categories_column_name
|
|
372
|
+
categories_column_name: Column name containing the predefined classes or labels used
|
|
374
373
|
by the model to classify the detected objects. The contents of this column must be List[str].
|
|
375
|
-
scores_column_names
|
|
374
|
+
scores_column_names: Column name containing the confidence scores that the
|
|
376
375
|
model assigns to it's predictions, indicating how certain the model is that the predicted
|
|
377
376
|
class is contained within the bounding box. This argument is only applicable for prediction
|
|
378
377
|
values. The contents of this column must be List[float].
|
|
379
|
-
|
|
380
378
|
"""
|
|
381
379
|
|
|
382
380
|
bounding_boxes_coordinates_column_name: str
|
|
@@ -385,19 +383,17 @@ class ObjectDetectionColumnNames(NamedTuple):
|
|
|
385
383
|
|
|
386
384
|
|
|
387
385
|
class SemanticSegmentationColumnNames(NamedTuple):
|
|
388
|
-
"""
|
|
389
|
-
Used to log semantic segmentation prediction and actual values that are assigned to the prediction or
|
|
390
|
-
actual schema parameter.
|
|
386
|
+
"""Used to log semantic segmentation prediction and actual values.
|
|
391
387
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
388
|
+
These values are assigned to the prediction or actual schema parameter.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
polygon_coordinates_column_name: Column name containing the coordinates of the vertices
|
|
395
392
|
of the polygon mask within an image or video. The first sublist contains the
|
|
396
393
|
coordinates of the outline of the polygon. The subsequent sublists contain the coordinates
|
|
397
394
|
of any cutouts within the polygon. The contents of this column must be a List[List[float]].
|
|
398
|
-
categories_column_name
|
|
395
|
+
categories_column_name: Column name containing the predefined classes or labels used
|
|
399
396
|
by the model to classify the detected objects. The contents of this column must be List[str].
|
|
400
|
-
|
|
401
397
|
"""
|
|
402
398
|
|
|
403
399
|
polygon_coordinates_column_name: str
|
|
@@ -405,25 +401,22 @@ class SemanticSegmentationColumnNames(NamedTuple):
|
|
|
405
401
|
|
|
406
402
|
|
|
407
403
|
class InstanceSegmentationPredictionColumnNames(NamedTuple):
|
|
408
|
-
"""
|
|
409
|
-
Used to log instance segmentation prediction values that are assigned to the prediction schema parameter.
|
|
404
|
+
"""Used to log instance segmentation prediction values for the prediction schema parameter.
|
|
410
405
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
polygon_coordinates_column_name (str): Column name containing the coordinates of the vertices
|
|
406
|
+
Args:
|
|
407
|
+
polygon_coordinates_column_name: Column name containing the coordinates of the vertices
|
|
414
408
|
of the polygon mask within an image or video. The first sublist contains the
|
|
415
409
|
coordinates of the outline of the polygon. The subsequent sublists contain the coordinates
|
|
416
410
|
of any cutouts within the polygon. The contents of this column must be a List[List[float]].
|
|
417
|
-
categories_column_name
|
|
411
|
+
categories_column_name: Column name containing the predefined classes or labels used
|
|
418
412
|
by the model to classify the detected objects. The contents of this column must be List[str].
|
|
419
|
-
scores_column_name
|
|
413
|
+
scores_column_name: Column name containing the confidence scores that the
|
|
420
414
|
model assigns to it's predictions, indicating how certain the model is that the predicted
|
|
421
415
|
class is contained within the bounding box. This argument is only applicable for prediction
|
|
422
416
|
values. The contents of this column must be List[float].
|
|
423
|
-
bounding_boxes_coordinates_column_name
|
|
417
|
+
bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
|
|
424
418
|
rectangular outline that locates an object within an image or video. Pascal VOC format
|
|
425
419
|
required. The contents of this column must be a List[List[float]].
|
|
426
|
-
|
|
427
420
|
"""
|
|
428
421
|
|
|
429
422
|
polygon_coordinates_column_name: str
|
|
@@ -433,20 +426,17 @@ class InstanceSegmentationPredictionColumnNames(NamedTuple):
|
|
|
433
426
|
|
|
434
427
|
|
|
435
428
|
class InstanceSegmentationActualColumnNames(NamedTuple):
|
|
436
|
-
"""
|
|
437
|
-
Used to log instance segmentation actual values that are assigned to the actual schema parameter.
|
|
429
|
+
"""Used to log instance segmentation actual values that are assigned to the actual schema parameter.
|
|
438
430
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
polygon_coordinates_column_name (str): Column name containing the coordinates of the
|
|
431
|
+
Args:
|
|
432
|
+
polygon_coordinates_column_name: Column name containing the coordinates of the
|
|
442
433
|
polygon that locates an object within an image or video. The contents of this column
|
|
443
434
|
must be a List[List[float]].
|
|
444
|
-
categories_column_name
|
|
435
|
+
categories_column_name: Column name containing the predefined classes or labels used
|
|
445
436
|
by the model to classify the detected objects. The contents of this column must be List[str].
|
|
446
|
-
bounding_boxes_coordinates_column_name
|
|
437
|
+
bounding_boxes_coordinates_column_name: Column name containing the coordinates of the
|
|
447
438
|
rectangular outline that locates an object within an image or video. Pascal VOC format
|
|
448
439
|
required. The contents of this column must be a List[List[float]].
|
|
449
|
-
|
|
450
440
|
"""
|
|
451
441
|
|
|
452
442
|
polygon_coordinates_column_name: str
|
|
@@ -455,12 +445,15 @@ class InstanceSegmentationActualColumnNames(NamedTuple):
|
|
|
455
445
|
|
|
456
446
|
|
|
457
447
|
class ObjectDetectionLabel(NamedTuple):
|
|
458
|
-
|
|
459
|
-
|
|
448
|
+
"""Label data for object detection tasks with bounding boxes and categories."""
|
|
449
|
+
|
|
450
|
+
bounding_boxes_coordinates: list[list[float]]
|
|
451
|
+
categories: list[str]
|
|
460
452
|
# Actual Object Detection Labels won't have scores
|
|
461
|
-
scores:
|
|
453
|
+
scores: list[float] | None = None
|
|
462
454
|
|
|
463
|
-
def validate(self, prediction_or_actual: str):
|
|
455
|
+
def validate(self, prediction_or_actual: str) -> None:
|
|
456
|
+
"""Validate the object detection label fields and constraints."""
|
|
464
457
|
# Validate bounding boxes
|
|
465
458
|
self._validate_bounding_boxes_coordinates()
|
|
466
459
|
# Validate categories
|
|
@@ -470,7 +463,7 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
470
463
|
# Validate we have the same number of bounding boxes, categories and scores
|
|
471
464
|
self._validate_count_match()
|
|
472
465
|
|
|
473
|
-
def _validate_bounding_boxes_coordinates(self):
|
|
466
|
+
def _validate_bounding_boxes_coordinates(self) -> None:
|
|
474
467
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
475
468
|
raise TypeError(
|
|
476
469
|
"Object Detection Label bounding boxes must be a list of lists of floats"
|
|
@@ -478,14 +471,14 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
478
471
|
for coordinates in self.bounding_boxes_coordinates:
|
|
479
472
|
_validate_bounding_box_coordinates(coordinates)
|
|
480
473
|
|
|
481
|
-
def _validate_categories(self):
|
|
474
|
+
def _validate_categories(self) -> None:
|
|
482
475
|
# Allows for categories as empty strings
|
|
483
476
|
if not is_list_of(self.categories, str):
|
|
484
477
|
raise TypeError(
|
|
485
478
|
"Object Detection Label categories must be a list of strings"
|
|
486
479
|
)
|
|
487
480
|
|
|
488
|
-
def _validate_scores(self, prediction_or_actual: str):
|
|
481
|
+
def _validate_scores(self, prediction_or_actual: str) -> None:
|
|
489
482
|
if self.scores is None:
|
|
490
483
|
if prediction_or_actual == "prediction":
|
|
491
484
|
raise ValueError(
|
|
@@ -507,7 +500,7 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
507
500
|
f"{self.scores}"
|
|
508
501
|
)
|
|
509
502
|
|
|
510
|
-
def _validate_count_match(self):
|
|
503
|
+
def _validate_count_match(self) -> None:
|
|
511
504
|
n_bounding_boxes = len(self.bounding_boxes_coordinates)
|
|
512
505
|
if n_bounding_boxes == 0:
|
|
513
506
|
raise ValueError(
|
|
@@ -534,10 +527,13 @@ class ObjectDetectionLabel(NamedTuple):
|
|
|
534
527
|
|
|
535
528
|
|
|
536
529
|
class SemanticSegmentationLabel(NamedTuple):
|
|
537
|
-
|
|
538
|
-
categories: List[str]
|
|
530
|
+
"""Label data for semantic segmentation with polygon coordinates and categories."""
|
|
539
531
|
|
|
540
|
-
|
|
532
|
+
polygon_coordinates: list[list[float]]
|
|
533
|
+
categories: list[str]
|
|
534
|
+
|
|
535
|
+
def validate(self) -> None:
|
|
536
|
+
"""Validate the field values and constraints."""
|
|
541
537
|
# Validate polygon coordinates
|
|
542
538
|
self._validate_polygon_coordinates()
|
|
543
539
|
# Validate categories
|
|
@@ -545,17 +541,17 @@ class SemanticSegmentationLabel(NamedTuple):
|
|
|
545
541
|
# Validate we have the same number of polygon coordinates and categories
|
|
546
542
|
self._validate_count_match()
|
|
547
543
|
|
|
548
|
-
def _validate_polygon_coordinates(self):
|
|
544
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
549
545
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
550
546
|
|
|
551
|
-
def _validate_categories(self):
|
|
547
|
+
def _validate_categories(self) -> None:
|
|
552
548
|
# Allows for categories as empty strings
|
|
553
549
|
if not is_list_of(self.categories, str):
|
|
554
550
|
raise TypeError(
|
|
555
551
|
"Semantic Segmentation Label categories must be a list of strings"
|
|
556
552
|
)
|
|
557
553
|
|
|
558
|
-
def _validate_count_match(self):
|
|
554
|
+
def _validate_count_match(self) -> None:
|
|
559
555
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
560
556
|
if n_polygon_coordinates == 0:
|
|
561
557
|
raise ValueError(
|
|
@@ -573,12 +569,15 @@ class SemanticSegmentationLabel(NamedTuple):
|
|
|
573
569
|
|
|
574
570
|
|
|
575
571
|
class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
576
|
-
|
|
577
|
-
categories: List[str]
|
|
578
|
-
scores: List[float] | None = None
|
|
579
|
-
bounding_boxes_coordinates: List[List[float]] | None = None
|
|
572
|
+
"""Prediction label for instance segmentation with polygons and category information."""
|
|
580
573
|
|
|
581
|
-
|
|
574
|
+
polygon_coordinates: list[list[float]]
|
|
575
|
+
categories: list[str]
|
|
576
|
+
scores: list[float] | None = None
|
|
577
|
+
bounding_boxes_coordinates: list[list[float]] | None = None
|
|
578
|
+
|
|
579
|
+
def validate(self) -> None:
|
|
580
|
+
"""Validate the field values and constraints."""
|
|
582
581
|
# Validate polygon coordinates
|
|
583
582
|
self._validate_polygon_coordinates()
|
|
584
583
|
# Validate categories
|
|
@@ -590,17 +589,17 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
590
589
|
# Validate we have the same number of polygon coordinates and categories
|
|
591
590
|
self._validate_count_match()
|
|
592
591
|
|
|
593
|
-
def _validate_polygon_coordinates(self):
|
|
592
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
594
593
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
595
594
|
|
|
596
|
-
def _validate_categories(self):
|
|
595
|
+
def _validate_categories(self) -> None:
|
|
597
596
|
# Allows for categories as empty strings
|
|
598
597
|
if not is_list_of(self.categories, str):
|
|
599
598
|
raise TypeError(
|
|
600
599
|
"Instance Segmentation Prediction Label categories must be a list of strings"
|
|
601
600
|
)
|
|
602
601
|
|
|
603
|
-
def _validate_scores(self):
|
|
602
|
+
def _validate_scores(self) -> None:
|
|
604
603
|
if self.scores is not None:
|
|
605
604
|
if not is_list_of(self.scores, float):
|
|
606
605
|
raise TypeError(
|
|
@@ -613,7 +612,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
613
612
|
f"{self.scores}"
|
|
614
613
|
)
|
|
615
614
|
|
|
616
|
-
def _validate_bounding_boxes(self):
|
|
615
|
+
def _validate_bounding_boxes(self) -> None:
|
|
617
616
|
if self.bounding_boxes_coordinates is not None:
|
|
618
617
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
619
618
|
raise TypeError(
|
|
@@ -622,7 +621,7 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
622
621
|
for coordinates in self.bounding_boxes_coordinates:
|
|
623
622
|
_validate_bounding_box_coordinates(coordinates)
|
|
624
623
|
|
|
625
|
-
def _validate_count_match(self):
|
|
624
|
+
def _validate_count_match(self) -> None:
|
|
626
625
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
627
626
|
if n_polygon_coordinates == 0:
|
|
628
627
|
raise ValueError(
|
|
@@ -657,11 +656,14 @@ class InstanceSegmentationPredictionLabel(NamedTuple):
|
|
|
657
656
|
|
|
658
657
|
|
|
659
658
|
class InstanceSegmentationActualLabel(NamedTuple):
|
|
660
|
-
|
|
661
|
-
categories: List[str]
|
|
662
|
-
bounding_boxes_coordinates: List[List[float]] | None = None
|
|
659
|
+
"""Actual label for instance segmentation with polygon coordinates and categories."""
|
|
663
660
|
|
|
664
|
-
|
|
661
|
+
polygon_coordinates: list[list[float]]
|
|
662
|
+
categories: list[str]
|
|
663
|
+
bounding_boxes_coordinates: list[list[float]] | None = None
|
|
664
|
+
|
|
665
|
+
def validate(self) -> None:
|
|
666
|
+
"""Validate the field values and constraints."""
|
|
665
667
|
# Validate polygon coordinates
|
|
666
668
|
self._validate_polygon_coordinates()
|
|
667
669
|
# Validate categories
|
|
@@ -671,17 +673,17 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
671
673
|
# Validate we have the same number of polygon coordinates and categories
|
|
672
674
|
self._validate_count_match()
|
|
673
675
|
|
|
674
|
-
def _validate_polygon_coordinates(self):
|
|
676
|
+
def _validate_polygon_coordinates(self) -> None:
|
|
675
677
|
_validate_polygon_coordinates(self.polygon_coordinates)
|
|
676
678
|
|
|
677
|
-
def _validate_categories(self):
|
|
679
|
+
def _validate_categories(self) -> None:
|
|
678
680
|
# Allows for categories as empty strings
|
|
679
681
|
if not is_list_of(self.categories, str):
|
|
680
682
|
raise TypeError(
|
|
681
683
|
"Instance Segmentation Actual Label categories must be a list of strings"
|
|
682
684
|
)
|
|
683
685
|
|
|
684
|
-
def _validate_bounding_boxes(self):
|
|
686
|
+
def _validate_bounding_boxes(self) -> None:
|
|
685
687
|
if self.bounding_boxes_coordinates is not None:
|
|
686
688
|
if not is_list_of(self.bounding_boxes_coordinates, list):
|
|
687
689
|
raise TypeError(
|
|
@@ -690,7 +692,7 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
690
692
|
for coordinates in self.bounding_boxes_coordinates:
|
|
691
693
|
_validate_bounding_box_coordinates(coordinates)
|
|
692
694
|
|
|
693
|
-
def _validate_count_match(self):
|
|
695
|
+
def _validate_count_match(self) -> None:
|
|
694
696
|
n_polygon_coordinates = len(self.polygon_coordinates)
|
|
695
697
|
if n_polygon_coordinates == 0:
|
|
696
698
|
raise ValueError(
|
|
@@ -717,27 +719,24 @@ class InstanceSegmentationActualLabel(NamedTuple):
|
|
|
717
719
|
|
|
718
720
|
|
|
719
721
|
class MultiClassPredictionLabel(NamedTuple):
|
|
720
|
-
"""
|
|
721
|
-
Used to log multi class prediction label
|
|
722
|
+
"""Used to log multi class prediction label.
|
|
722
723
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
prediction_scores (Dict[str, Union[float, int]]): the prediction scores of the classes.
|
|
727
|
-
threshold_scores (Optional[Dict[str, Union[float, int]]]): the threshold scores of the classes.
|
|
724
|
+
Args:
|
|
725
|
+
prediction_scores: The prediction scores of the classes.
|
|
726
|
+
threshold_scores: The threshold scores of the classes.
|
|
728
727
|
Only Multi Label will have threshold scores.
|
|
729
|
-
|
|
730
728
|
"""
|
|
731
729
|
|
|
732
|
-
prediction_scores:
|
|
733
|
-
threshold_scores:
|
|
730
|
+
prediction_scores: dict[str, float | int]
|
|
731
|
+
threshold_scores: dict[str, float | int] | None = None
|
|
734
732
|
|
|
735
|
-
def validate(self):
|
|
733
|
+
def validate(self) -> None:
|
|
734
|
+
"""Validate the field values and constraints."""
|
|
736
735
|
# Validate scores
|
|
737
736
|
self._validate_prediction_scores()
|
|
738
737
|
self._validate_threshold_scores()
|
|
739
738
|
|
|
740
|
-
def _validate_prediction_scores(self):
|
|
739
|
+
def _validate_prediction_scores(self) -> None:
|
|
741
740
|
# prediction dictionary validations
|
|
742
741
|
if not is_dict_of(
|
|
743
742
|
self.prediction_scores,
|
|
@@ -778,7 +777,7 @@ class MultiClassPredictionLabel(NamedTuple):
|
|
|
778
777
|
"invalid. All scores (values in dictionary) must be between 0 and 1, inclusive."
|
|
779
778
|
)
|
|
780
779
|
|
|
781
|
-
def _validate_threshold_scores(self):
|
|
780
|
+
def _validate_threshold_scores(self) -> None:
|
|
782
781
|
if self.threshold_scores is None or len(self.threshold_scores) == 0:
|
|
783
782
|
return
|
|
784
783
|
if not is_dict_of(
|
|
@@ -822,24 +821,21 @@ class MultiClassPredictionLabel(NamedTuple):
|
|
|
822
821
|
|
|
823
822
|
|
|
824
823
|
class MultiClassActualLabel(NamedTuple):
|
|
825
|
-
"""
|
|
826
|
-
Used to log multi class actual label
|
|
827
|
-
|
|
828
|
-
Arguments:
|
|
829
|
-
---------
|
|
830
|
-
MultiClassActualLabel
|
|
831
|
-
actual_scores (Dict[str, Union[float, int]]): the actual scores of the classes.
|
|
832
|
-
Any class in actual_scores with a score of 1 will be sent to arize
|
|
824
|
+
"""Used to log multi class actual label.
|
|
833
825
|
|
|
826
|
+
Args:
|
|
827
|
+
actual_scores: The actual scores of the classes.
|
|
828
|
+
Any class in actual_scores with a score of 1 will be sent to arize.
|
|
834
829
|
"""
|
|
835
830
|
|
|
836
|
-
actual_scores:
|
|
831
|
+
actual_scores: dict[str, float | int]
|
|
837
832
|
|
|
838
|
-
def validate(self):
|
|
833
|
+
def validate(self) -> None:
|
|
834
|
+
"""Validate the field values and constraints."""
|
|
839
835
|
# Validate scores
|
|
840
836
|
self._validate_actual_scores()
|
|
841
837
|
|
|
842
|
-
def _validate_actual_scores(self):
|
|
838
|
+
def _validate_actual_scores(self) -> None:
|
|
843
839
|
if not is_dict_of(
|
|
844
840
|
self.actual_scores,
|
|
845
841
|
key_allowed_types=str,
|
|
@@ -879,12 +875,15 @@ class MultiClassActualLabel(NamedTuple):
|
|
|
879
875
|
|
|
880
876
|
|
|
881
877
|
class RankingPredictionLabel(NamedTuple):
|
|
878
|
+
"""Prediction label for ranking tasks with group and rank information."""
|
|
879
|
+
|
|
882
880
|
group_id: str
|
|
883
881
|
rank: int
|
|
884
882
|
score: float | None = None
|
|
885
883
|
label: str | None = None
|
|
886
884
|
|
|
887
|
-
def validate(self):
|
|
885
|
+
def validate(self) -> None:
|
|
886
|
+
"""Validate the field values and constraints."""
|
|
888
887
|
# Validate existence of required fields: prediction_group_id and rank
|
|
889
888
|
if self.group_id is None or self.rank is None:
|
|
890
889
|
raise ValueError(
|
|
@@ -901,7 +900,7 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
901
900
|
if self.score is not None:
|
|
902
901
|
self._validate_score()
|
|
903
902
|
|
|
904
|
-
def _validate_group_id(self):
|
|
903
|
+
def _validate_group_id(self) -> None:
|
|
905
904
|
if not isinstance(self.group_id, str):
|
|
906
905
|
raise TypeError("Prediction Group ID must be a string")
|
|
907
906
|
if not (1 <= len(self.group_id) <= 36):
|
|
@@ -909,7 +908,7 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
909
908
|
f"Prediction Group ID must have length between 1 and 36. Found {len(self.group_id)}"
|
|
910
909
|
)
|
|
911
910
|
|
|
912
|
-
def _validate_rank(self):
|
|
911
|
+
def _validate_rank(self) -> None:
|
|
913
912
|
if not isinstance(self.rank, int):
|
|
914
913
|
raise TypeError("Prediction Rank must be an int")
|
|
915
914
|
if not (1 <= self.rank <= 100):
|
|
@@ -917,22 +916,25 @@ class RankingPredictionLabel(NamedTuple):
|
|
|
917
916
|
f"Prediction Rank must be between 1 and 100, inclusive. Found {self.rank}"
|
|
918
917
|
)
|
|
919
918
|
|
|
920
|
-
def _validate_label(self):
|
|
919
|
+
def _validate_label(self) -> None:
|
|
921
920
|
if not isinstance(self.label, str):
|
|
922
921
|
raise TypeError("Prediction Label must be a str")
|
|
923
922
|
if self.label == "":
|
|
924
923
|
raise ValueError("Prediction Label must not be an empty string.")
|
|
925
924
|
|
|
926
|
-
def _validate_score(self):
|
|
925
|
+
def _validate_score(self) -> None:
|
|
927
926
|
if not isinstance(self.score, (float, int)):
|
|
928
927
|
raise TypeError("Prediction Score must be a float or an int")
|
|
929
928
|
|
|
930
929
|
|
|
931
930
|
class RankingActualLabel(NamedTuple):
|
|
932
|
-
|
|
931
|
+
"""Actual label for ranking tasks with relevance information."""
|
|
932
|
+
|
|
933
|
+
relevance_labels: list[str] | None = None
|
|
933
934
|
relevance_score: float | None = None
|
|
934
935
|
|
|
935
|
-
def validate(self):
|
|
936
|
+
def validate(self) -> None:
|
|
937
|
+
"""Validate the field values and constraints."""
|
|
936
938
|
# Validate relevance_labels type
|
|
937
939
|
if self.relevance_labels is not None:
|
|
938
940
|
self._validate_relevance_labels(self.relevance_labels)
|
|
@@ -941,7 +943,16 @@ class RankingActualLabel(NamedTuple):
|
|
|
941
943
|
self._validate_relevance_score(self.relevance_score)
|
|
942
944
|
|
|
943
945
|
@staticmethod
|
|
944
|
-
def _validate_relevance_labels(relevance_labels:
|
|
946
|
+
def _validate_relevance_labels(relevance_labels: list[str]) -> None:
|
|
947
|
+
"""Validate relevance labels.
|
|
948
|
+
|
|
949
|
+
Args:
|
|
950
|
+
relevance_labels: List of relevance labels to validate.
|
|
951
|
+
|
|
952
|
+
Raises:
|
|
953
|
+
TypeError: If relevance_labels is not a list of strings.
|
|
954
|
+
ValueError: If any label is an empty string.
|
|
955
|
+
"""
|
|
945
956
|
if not is_list_of(relevance_labels, str):
|
|
946
957
|
raise TypeError("Actual Relevance Labels must be a list of strings")
|
|
947
958
|
if any(label == "" for label in relevance_labels):
|
|
@@ -950,17 +961,28 @@ class RankingActualLabel(NamedTuple):
|
|
|
950
961
|
)
|
|
951
962
|
|
|
952
963
|
@staticmethod
|
|
953
|
-
def _validate_relevance_score(relevance_score: float):
|
|
964
|
+
def _validate_relevance_score(relevance_score: float) -> None:
|
|
965
|
+
"""Validate relevance score.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
relevance_score: Relevance score to validate.
|
|
969
|
+
|
|
970
|
+
Raises:
|
|
971
|
+
TypeError: If relevance_score is not a float or int.
|
|
972
|
+
"""
|
|
954
973
|
if not isinstance(relevance_score, (float, int)):
|
|
955
974
|
raise TypeError("Actual Relevance score must be a float or an int")
|
|
956
975
|
|
|
957
976
|
|
|
958
977
|
@dataclass
|
|
959
978
|
class PromptTemplateColumnNames:
|
|
979
|
+
"""Column names for prompt template configuration in LLM schemas."""
|
|
980
|
+
|
|
960
981
|
template_column_name: str | None = None
|
|
961
982
|
template_version_column_name: str | None = None
|
|
962
983
|
|
|
963
|
-
def __iter__(self):
|
|
984
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
985
|
+
"""Iterate over the prompt template column names."""
|
|
964
986
|
return iter(
|
|
965
987
|
(self.template_column_name, self.template_version_column_name)
|
|
966
988
|
)
|
|
@@ -968,21 +990,27 @@ class PromptTemplateColumnNames:
|
|
|
968
990
|
|
|
969
991
|
@dataclass
|
|
970
992
|
class LLMConfigColumnNames:
|
|
993
|
+
"""Column names for LLM configuration parameters in schemas."""
|
|
994
|
+
|
|
971
995
|
model_column_name: str | None = None
|
|
972
996
|
params_column_name: str | None = None
|
|
973
997
|
|
|
974
|
-
def __iter__(self):
|
|
998
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
999
|
+
"""Iterate over the LLM config column names."""
|
|
975
1000
|
return iter((self.model_column_name, self.params_column_name))
|
|
976
1001
|
|
|
977
1002
|
|
|
978
1003
|
@dataclass
|
|
979
1004
|
class LLMRunMetadataColumnNames:
|
|
1005
|
+
"""Column names for LLM run metadata fields in schemas."""
|
|
1006
|
+
|
|
980
1007
|
total_token_count_column_name: str | None = None
|
|
981
1008
|
prompt_token_count_column_name: str | None = None
|
|
982
1009
|
response_token_count_column_name: str | None = None
|
|
983
1010
|
response_latency_ms_column_name: str | None = None
|
|
984
1011
|
|
|
985
|
-
def __iter__(self):
|
|
1012
|
+
def __iter__(self) -> Iterator[str | None]:
|
|
1013
|
+
"""Iterate over the LLM run metadata column names."""
|
|
986
1014
|
return iter(
|
|
987
1015
|
(
|
|
988
1016
|
self.total_token_count_column_name,
|
|
@@ -1011,11 +1039,19 @@ class LLMRunMetadataColumnNames:
|
|
|
1011
1039
|
#
|
|
1012
1040
|
@dataclass
|
|
1013
1041
|
class SimilarityReference:
|
|
1042
|
+
"""Reference to a prediction for similarity search operations."""
|
|
1043
|
+
|
|
1014
1044
|
prediction_id: str
|
|
1015
1045
|
reference_column_name: str
|
|
1016
1046
|
prediction_timestamp: datetime | None = None
|
|
1017
1047
|
|
|
1018
|
-
def __post_init__(self):
|
|
1048
|
+
def __post_init__(self) -> None:
|
|
1049
|
+
"""Validate similarity reference fields after initialization.
|
|
1050
|
+
|
|
1051
|
+
Raises:
|
|
1052
|
+
ValueError: If prediction_id or reference_column_name is empty.
|
|
1053
|
+
TypeError: If prediction_timestamp is not a datetime object.
|
|
1054
|
+
"""
|
|
1019
1055
|
if self.prediction_id == "":
|
|
1020
1056
|
raise ValueError("prediction id cannot be empty")
|
|
1021
1057
|
if self.reference_column_name == "":
|
|
@@ -1028,11 +1064,20 @@ class SimilarityReference:
|
|
|
1028
1064
|
|
|
1029
1065
|
@dataclass
|
|
1030
1066
|
class SimilaritySearchParams:
|
|
1031
|
-
|
|
1067
|
+
"""Parameters for configuring similarity search operations."""
|
|
1068
|
+
|
|
1069
|
+
references: list[SimilarityReference]
|
|
1032
1070
|
search_column_name: str
|
|
1033
1071
|
threshold: float = 0
|
|
1034
1072
|
|
|
1035
|
-
def __post_init__(self):
|
|
1073
|
+
def __post_init__(self) -> None:
|
|
1074
|
+
"""Validate similarity search parameters after initialization.
|
|
1075
|
+
|
|
1076
|
+
Raises:
|
|
1077
|
+
ValueError: If references list is invalid, search_column_name is
|
|
1078
|
+
empty, or threshold is out of range.
|
|
1079
|
+
TypeError: If any reference is not a SimilarityReference instance.
|
|
1080
|
+
"""
|
|
1036
1081
|
if (
|
|
1037
1082
|
not self.references
|
|
1038
1083
|
or len(self.references) <= 0
|
|
@@ -1054,176 +1099,157 @@ class SimilaritySearchParams:
|
|
|
1054
1099
|
|
|
1055
1100
|
@dataclass(frozen=True)
|
|
1056
1101
|
class BaseSchema:
|
|
1057
|
-
|
|
1102
|
+
"""Base class for all schema definitions with immutable fields."""
|
|
1103
|
+
|
|
1104
|
+
def replace(self, **changes: object) -> Self:
|
|
1105
|
+
"""Return a new instance with specified fields replaced."""
|
|
1058
1106
|
return replace(self, **changes)
|
|
1059
1107
|
|
|
1060
|
-
def asdict(self) ->
|
|
1108
|
+
def asdict(self) -> dict[str, str]:
|
|
1109
|
+
"""Convert the schema to a dictionary."""
|
|
1061
1110
|
return asdict(self)
|
|
1062
1111
|
|
|
1063
|
-
def get_used_columns(self) ->
|
|
1112
|
+
def get_used_columns(self) -> set[str]:
|
|
1113
|
+
"""Return the set of column names used in this schema."""
|
|
1064
1114
|
return set(self.get_used_columns_counts().keys())
|
|
1065
1115
|
|
|
1066
|
-
def get_used_columns_counts(self) ->
|
|
1116
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1117
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1067
1118
|
raise NotImplementedError()
|
|
1068
1119
|
|
|
1069
1120
|
|
|
1070
1121
|
@dataclass(frozen=True)
|
|
1071
1122
|
class TypedColumns:
|
|
1072
|
-
"""
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
Usage:
|
|
1076
|
-
------
|
|
1077
|
-
When initializing a Schema, use TypedColumns in place of a list of string column names.
|
|
1078
|
-
e.g. feature_column_names=TypedColumns(
|
|
1079
|
-
inferred=["feature_1", "feature_2"],
|
|
1080
|
-
to_str=["feature_3"],
|
|
1081
|
-
to_int=["feature_4"]
|
|
1082
|
-
)
|
|
1123
|
+
"""Optional class used for explicit type enforcement of feature and tag columns in the dataframe.
|
|
1124
|
+
|
|
1125
|
+
When initializing a Schema, use TypedColumns in place of a list of string column names::
|
|
1083
1126
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
vs. not using TypedColumns at all.
|
|
1090
|
-
to_str (Optional[List[str]]): List of columns that should be cast to pandas StringDType.
|
|
1091
|
-
to_int (Optional[List[str]]): List of columns that should be cast to pandas Int64DType.
|
|
1092
|
-
to_float (Optional[List[str]]): List of columns that should be cast to pandas Float64DType.
|
|
1127
|
+
feature_column_names = TypedColumns(
|
|
1128
|
+
inferred=["feature_1", "feature_2"],
|
|
1129
|
+
to_str=["feature_3"],
|
|
1130
|
+
to_int=["feature_4"],
|
|
1131
|
+
)
|
|
1093
1132
|
|
|
1094
1133
|
Notes:
|
|
1095
|
-
-----
|
|
1096
1134
|
- If a TypedColumns object is included in a Schema, pandas version 1.0.0 or higher is required.
|
|
1097
1135
|
- Pandas StringDType is still considered an experimental field.
|
|
1098
1136
|
- Columns not present in any field will not be captured in the Schema.
|
|
1099
1137
|
- StringDType, Int64DType, and Float64DType are all nullable column types.
|
|
1100
|
-
|
|
1101
|
-
|
|
1138
|
+
Null values will be ingested and represented in Arize as empty values.
|
|
1102
1139
|
"""
|
|
1103
1140
|
|
|
1104
|
-
inferred:
|
|
1105
|
-
to_str:
|
|
1106
|
-
to_int:
|
|
1107
|
-
to_float:
|
|
1141
|
+
inferred: list[str] | None = None
|
|
1142
|
+
to_str: list[str] | None = None
|
|
1143
|
+
to_int: list[str] | None = None
|
|
1144
|
+
to_float: list[str] | None = None
|
|
1108
1145
|
|
|
1109
|
-
def get_all_column_names(self) ->
|
|
1146
|
+
def get_all_column_names(self) -> list[str]:
|
|
1147
|
+
"""Return all column names across all conversion lists."""
|
|
1110
1148
|
return list(chain.from_iterable(filter(None, self.__dict__.values())))
|
|
1111
1149
|
|
|
1112
|
-
def has_duplicate_columns(self) ->
|
|
1150
|
+
def has_duplicate_columns(self) -> tuple[bool, set[str]]:
|
|
1151
|
+
"""Check for duplicate columns and return (has_duplicates, duplicate_set)."""
|
|
1113
1152
|
# True if there are duplicates within a field's list or across fields.
|
|
1114
1153
|
# Return a set of the duplicate column names.
|
|
1115
1154
|
cols = self.get_all_column_names()
|
|
1116
|
-
duplicates =
|
|
1155
|
+
duplicates = {x for x in cols if cols.count(x) > 1}
|
|
1117
1156
|
return len(duplicates) > 0, duplicates
|
|
1118
1157
|
|
|
1119
1158
|
def is_empty(self) -> bool:
|
|
1159
|
+
"""Return True if no columns are configured for conversion."""
|
|
1120
1160
|
return not self.get_all_column_names()
|
|
1121
1161
|
|
|
1122
1162
|
|
|
1123
1163
|
@dataclass(frozen=True)
|
|
1124
1164
|
class Schema(BaseSchema):
|
|
1125
|
-
"""
|
|
1126
|
-
Used to organize and map column names containing model data within your Pandas dataframe to
|
|
1127
|
-
Arize.
|
|
1165
|
+
"""Used to organize and map column names containing model data within your Pandas dataframe to Arize.
|
|
1128
1166
|
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
prediction_id_column_name (str, optional): Column name for the predictions unique identifier.
|
|
1167
|
+
Args:
|
|
1168
|
+
prediction_id_column_name: Column name for the predictions unique identifier.
|
|
1132
1169
|
Unique IDs are used to match a prediction to delayed actuals or feature importances in Arize.
|
|
1133
1170
|
If prediction ids are not provided, it will default to an empty string "" and, when possible,
|
|
1134
1171
|
Arize will create a random prediction id on the server side. Prediction id must be a string column
|
|
1135
1172
|
with each row indicating a unique prediction event.
|
|
1136
|
-
feature_column_names
|
|
1173
|
+
feature_column_names: Column names for features.
|
|
1137
1174
|
The content of feature columns can be int, float, string. If TypedColumns is used,
|
|
1138
1175
|
the columns will be cast to the provided types prior to logging.
|
|
1139
|
-
tag_column_names
|
|
1176
|
+
tag_column_names: Column names for tags. The content of tag
|
|
1140
1177
|
columns can be int, float, string. If TypedColumns is used,
|
|
1141
1178
|
the columns will be cast to the provided types prior to logging.
|
|
1142
|
-
timestamp_column_name
|
|
1179
|
+
timestamp_column_name: Column name for timestamps. The content of this
|
|
1143
1180
|
column must be int Unix Timestamps in seconds.
|
|
1144
|
-
prediction_label_column_name
|
|
1181
|
+
prediction_label_column_name: Column name for categorical prediction values.
|
|
1145
1182
|
The content of this column must be convertible to string.
|
|
1146
|
-
prediction_score_column_name
|
|
1183
|
+
prediction_score_column_name: Column name for numeric prediction values. The
|
|
1147
1184
|
content of this column must be int/float or list of dictionaries mapping class names to
|
|
1148
1185
|
int/float scores in the case of MULTI_CLASS model types.
|
|
1149
|
-
actual_label_column_name
|
|
1186
|
+
actual_label_column_name: Column name for categorical ground truth values.
|
|
1150
1187
|
The content of this column must be convertible to string.
|
|
1151
|
-
actual_score_column_name
|
|
1188
|
+
actual_score_column_name: Column name for numeric ground truth values. The
|
|
1152
1189
|
content of this column must be int/float or list of dictionaries mapping class names to
|
|
1153
1190
|
int/float scores in the case of MULTI_CLASS model types.
|
|
1154
|
-
shap_values_column_names
|
|
1191
|
+
shap_values_column_names: Dictionary mapping feature column name
|
|
1155
1192
|
and corresponding SHAP feature importance column name. e.g.
|
|
1156
1193
|
{{"feat_A": "feat_A_shap", "feat_B": "feat_B_shap"}}
|
|
1157
|
-
embedding_feature_column_names
|
|
1194
|
+
embedding_feature_column_names: Dictionary
|
|
1158
1195
|
mapping embedding display names to EmbeddingColumnNames objects.
|
|
1159
|
-
prediction_group_id_column_name
|
|
1196
|
+
prediction_group_id_column_name: Column name for ranking groups or lists in
|
|
1160
1197
|
ranking models. The content of this column must be string and is limited to 128 characters.
|
|
1161
|
-
rank_column_name
|
|
1198
|
+
rank_column_name: Column name for rank of each element on the its group or
|
|
1162
1199
|
list. The content of this column must be integer between 1-100.
|
|
1163
|
-
relevance_score_column_name
|
|
1200
|
+
relevance_score_column_name: Column name for ranking model type numeric
|
|
1164
1201
|
ground truth values. The content of this column must be int/float.
|
|
1165
|
-
relevance_labels_column_name
|
|
1202
|
+
relevance_labels_column_name: Column name for ranking model type categorical
|
|
1166
1203
|
ground truth values. The content of this column must be a string.
|
|
1167
|
-
object_detection_prediction_column_names
|
|
1204
|
+
object_detection_prediction_column_names:
|
|
1168
1205
|
ObjectDetectionColumnNames object containing information defining the predicted bounding
|
|
1169
1206
|
boxes' coordinates, categories, and scores.
|
|
1170
|
-
object_detection_actual_column_names
|
|
1207
|
+
object_detection_actual_column_names:
|
|
1171
1208
|
ObjectDetectionColumnNames object containing information defining the actual bounding
|
|
1172
1209
|
boxes' coordinates, categories, and scores.
|
|
1173
|
-
prompt_column_names
|
|
1210
|
+
prompt_column_names: column names for text that is passed
|
|
1174
1211
|
to the GENERATIVE_LLM model. It accepts a string (if sending only a text column) or
|
|
1175
1212
|
EmbeddingColumnNames object containing the embedding vector data (required) and raw text
|
|
1176
1213
|
(optional) for the input text your model acts on.
|
|
1177
|
-
response_column_names
|
|
1214
|
+
response_column_names: column names for text generated by
|
|
1178
1215
|
the GENERATIVE_LLM model. It accepts a string (if sending only a text column) or
|
|
1179
1216
|
EmbeddingColumnNames object containing the embedding vector data (required) and raw text
|
|
1180
1217
|
(optional) for the text your model generates.
|
|
1181
|
-
prompt_template_column_names
|
|
1218
|
+
prompt_template_column_names: PromptTemplateColumnNames object
|
|
1182
1219
|
containing the prompt template and the prompt template version.
|
|
1183
|
-
llm_config_column_names
|
|
1220
|
+
llm_config_column_names: LLMConfigColumnNames object containing
|
|
1184
1221
|
the LLM's model name and its hyper parameters used at inference.
|
|
1185
|
-
llm_run_metadata_column_names
|
|
1222
|
+
llm_run_metadata_column_names: LLMRunMetadataColumnNames
|
|
1186
1223
|
object containing token counts and latency metrics
|
|
1187
|
-
retrieved_document_ids_column_name
|
|
1224
|
+
retrieved_document_ids_column_name: Column name for retrieved document ids.
|
|
1188
1225
|
The content of this column must be lists with entries convertible to strings.
|
|
1189
|
-
multi_class_threshold_scores_column_name
|
|
1226
|
+
multi_class_threshold_scores_column_name:
|
|
1190
1227
|
Column name for dictionary that maps class names to threshold values. The
|
|
1191
1228
|
content of this column must be dictionary of str -> int/float.
|
|
1192
|
-
semantic_segmentation_prediction_column_names
|
|
1229
|
+
semantic_segmentation_prediction_column_names:
|
|
1193
1230
|
SemanticSegmentationColumnNames object containing information defining the predicted
|
|
1194
1231
|
polygon coordinates and categories.
|
|
1195
|
-
semantic_segmentation_actual_column_names
|
|
1232
|
+
semantic_segmentation_actual_column_names:
|
|
1196
1233
|
SemanticSegmentationColumnNames object containing information defining the actual
|
|
1197
1234
|
polygon coordinates and categories.
|
|
1198
|
-
instance_segmentation_prediction_column_names
|
|
1235
|
+
instance_segmentation_prediction_column_names:
|
|
1199
1236
|
InstanceSegmentationPredictionColumnNames object containing information defining the predicted
|
|
1200
1237
|
polygon coordinates, categories, scores, and bounding box coordinates.
|
|
1201
|
-
instance_segmentation_actual_column_names
|
|
1238
|
+
instance_segmentation_actual_column_names:
|
|
1202
1239
|
InstanceSegmentationActualColumnNames object containing information defining the actual
|
|
1203
1240
|
polygon coordinates, categories, scores, and bounding box coordinates.
|
|
1204
|
-
|
|
1205
|
-
Methods:
|
|
1206
|
-
-------
|
|
1207
|
-
replace(**changes):
|
|
1208
|
-
Replaces fields of the schema
|
|
1209
|
-
asdict():
|
|
1210
|
-
Returns the schema as a dictionary. Warning: the types are not maintained, fields are
|
|
1211
|
-
converted to strings.
|
|
1212
|
-
get_used_columns():
|
|
1213
|
-
Returns a set with the unique collection of columns to be used from the dataframe.
|
|
1214
|
-
|
|
1215
1241
|
"""
|
|
1216
1242
|
|
|
1217
1243
|
prediction_id_column_name: str | None = None
|
|
1218
|
-
feature_column_names:
|
|
1219
|
-
tag_column_names:
|
|
1244
|
+
feature_column_names: list[str] | TypedColumns | None = None
|
|
1245
|
+
tag_column_names: list[str] | TypedColumns | None = None
|
|
1220
1246
|
timestamp_column_name: str | None = None
|
|
1221
1247
|
prediction_label_column_name: str | None = None
|
|
1222
1248
|
prediction_score_column_name: str | None = None
|
|
1223
1249
|
actual_label_column_name: str | None = None
|
|
1224
1250
|
actual_score_column_name: str | None = None
|
|
1225
|
-
shap_values_column_names:
|
|
1226
|
-
embedding_feature_column_names:
|
|
1251
|
+
shap_values_column_names: dict[str, str] | None = None
|
|
1252
|
+
embedding_feature_column_names: dict[str, EmbeddingColumnNames] | None = (
|
|
1227
1253
|
None # type:ignore
|
|
1228
1254
|
)
|
|
1229
1255
|
prediction_group_id_column_name: str | None = None
|
|
@@ -1242,7 +1268,7 @@ class Schema(BaseSchema):
|
|
|
1242
1268
|
prompt_template_column_names: PromptTemplateColumnNames | None = None
|
|
1243
1269
|
llm_config_column_names: LLMConfigColumnNames | None = None
|
|
1244
1270
|
llm_run_metadata_column_names: LLMRunMetadataColumnNames | None = None
|
|
1245
|
-
retrieved_document_ids_column_name:
|
|
1271
|
+
retrieved_document_ids_column_name: list[str] | None = None
|
|
1246
1272
|
multi_class_threshold_scores_column_name: str | None = None
|
|
1247
1273
|
semantic_segmentation_prediction_column_names: (
|
|
1248
1274
|
SemanticSegmentationColumnNames | None
|
|
@@ -1257,7 +1283,8 @@ class Schema(BaseSchema):
|
|
|
1257
1283
|
InstanceSegmentationActualColumnNames | None
|
|
1258
1284
|
) = None
|
|
1259
1285
|
|
|
1260
|
-
def get_used_columns_counts(self) ->
|
|
1286
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1287
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1261
1288
|
columns_used_counts = {}
|
|
1262
1289
|
|
|
1263
1290
|
for field in self.__dataclass_fields__:
|
|
@@ -1364,6 +1391,7 @@ class Schema(BaseSchema):
|
|
|
1364
1391
|
return columns_used_counts
|
|
1365
1392
|
|
|
1366
1393
|
def has_prediction_columns(self) -> bool:
|
|
1394
|
+
"""Return True if prediction columns are configured."""
|
|
1367
1395
|
prediction_cols = (
|
|
1368
1396
|
self.prediction_label_column_name,
|
|
1369
1397
|
self.prediction_score_column_name,
|
|
@@ -1377,6 +1405,7 @@ class Schema(BaseSchema):
|
|
|
1377
1405
|
return any(col is not None for col in prediction_cols)
|
|
1378
1406
|
|
|
1379
1407
|
def has_actual_columns(self) -> bool:
|
|
1408
|
+
"""Return True if actual label columns are configured."""
|
|
1380
1409
|
actual_cols = (
|
|
1381
1410
|
self.actual_label_column_name,
|
|
1382
1411
|
self.actual_score_column_name,
|
|
@@ -1389,13 +1418,16 @@ class Schema(BaseSchema):
|
|
|
1389
1418
|
return any(col is not None for col in actual_cols)
|
|
1390
1419
|
|
|
1391
1420
|
def has_feature_importance_columns(self) -> bool:
|
|
1421
|
+
"""Return True if feature importance columns are configured."""
|
|
1392
1422
|
feature_importance_cols = (self.shap_values_column_names,)
|
|
1393
1423
|
return any(col is not None for col in feature_importance_cols)
|
|
1394
1424
|
|
|
1395
1425
|
def has_typed_columns(self) -> bool:
|
|
1426
|
+
"""Return True if typed columns are configured."""
|
|
1396
1427
|
return any(self.typed_column_fields())
|
|
1397
1428
|
|
|
1398
|
-
def typed_column_fields(self) ->
|
|
1429
|
+
def typed_column_fields(self) -> set[str]:
|
|
1430
|
+
"""Return the set of field names with typed columns."""
|
|
1399
1431
|
return {
|
|
1400
1432
|
field
|
|
1401
1433
|
for field in self.__dataclass_fields__
|
|
@@ -1403,9 +1435,9 @@ class Schema(BaseSchema):
|
|
|
1403
1435
|
}
|
|
1404
1436
|
|
|
1405
1437
|
def is_delayed(self) -> bool:
|
|
1406
|
-
"""
|
|
1407
|
-
|
|
1408
|
-
|
|
1438
|
+
"""Check if the schema has inherently latent information.
|
|
1439
|
+
|
|
1440
|
+
Determines this based on the columns provided by the user.
|
|
1409
1441
|
|
|
1410
1442
|
Returns:
|
|
1411
1443
|
bool: True if the schema is "delayed", i.e., does not possess prediction
|
|
@@ -1418,11 +1450,14 @@ class Schema(BaseSchema):
|
|
|
1418
1450
|
|
|
1419
1451
|
@dataclass(frozen=True)
|
|
1420
1452
|
class CorpusSchema(BaseSchema):
|
|
1453
|
+
"""Schema for corpus data with document identification and content columns."""
|
|
1454
|
+
|
|
1421
1455
|
document_id_column_name: str | None = None
|
|
1422
1456
|
document_version_column_name: str | None = None
|
|
1423
1457
|
document_text_embedding_column_names: EmbeddingColumnNames | None = None
|
|
1424
1458
|
|
|
1425
|
-
def get_used_columns_counts(self) ->
|
|
1459
|
+
def get_used_columns_counts(self) -> dict[str, int]:
|
|
1460
|
+
"""Return a dict mapping column names to their usage count."""
|
|
1426
1461
|
columns_used_counts = {}
|
|
1427
1462
|
|
|
1428
1463
|
if self.document_id_column_name is not None:
|
|
@@ -1459,6 +1494,8 @@ class CorpusSchema(BaseSchema):
|
|
|
1459
1494
|
|
|
1460
1495
|
@unique
|
|
1461
1496
|
class ArizeTypes(Enum):
|
|
1497
|
+
"""Enum representing supported data types in Arize platform."""
|
|
1498
|
+
|
|
1462
1499
|
STR = 0
|
|
1463
1500
|
FLOAT = 1
|
|
1464
1501
|
INT = 2
|
|
@@ -1466,76 +1503,13 @@ class ArizeTypes(Enum):
|
|
|
1466
1503
|
|
|
1467
1504
|
@dataclass(frozen=True)
|
|
1468
1505
|
class TypedValue:
|
|
1506
|
+
"""Container for a value with its associated Arize type."""
|
|
1507
|
+
|
|
1469
1508
|
type: ArizeTypes
|
|
1470
1509
|
value: str | bool | float | int
|
|
1471
1510
|
|
|
1472
1511
|
|
|
1473
|
-
def
|
|
1474
|
-
try:
|
|
1475
|
-
json.loads(s)
|
|
1476
|
-
except ValueError:
|
|
1477
|
-
return False
|
|
1478
|
-
except TypeError:
|
|
1479
|
-
return False
|
|
1480
|
-
return True
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
T = TypeVar("T", bound=type)
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
def is_array_of(arr: Sequence[object], tp: T) -> bool:
|
|
1487
|
-
return isinstance(arr, np.ndarray) and all(isinstance(x, tp) for x in arr)
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
def is_list_of(lst: Sequence[object], tp: T) -> bool:
|
|
1491
|
-
return isinstance(lst, list) and all(isinstance(x, tp) for x in lst)
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
def is_iterable_of(lst: Sequence[object], tp: T) -> bool:
|
|
1495
|
-
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
def is_dict_of(
|
|
1499
|
-
d: Dict[object, object],
|
|
1500
|
-
key_allowed_types: T,
|
|
1501
|
-
value_allowed_types: T = (),
|
|
1502
|
-
value_list_allowed_types: T = (),
|
|
1503
|
-
) -> bool:
|
|
1504
|
-
"""
|
|
1505
|
-
Method to check types are valid for dictionary.
|
|
1506
|
-
|
|
1507
|
-
Arguments:
|
|
1508
|
-
---------
|
|
1509
|
-
d (Dict[object, object]): dictionary itself
|
|
1510
|
-
key_allowed_types (T): all allowed types for keys of dictionary
|
|
1511
|
-
value_allowed_types (T): all allowed types for values of dictionary
|
|
1512
|
-
value_list_allowed_types (T): if value is a list, these are the allowed
|
|
1513
|
-
types for value list
|
|
1514
|
-
|
|
1515
|
-
Returns:
|
|
1516
|
-
-------
|
|
1517
|
-
True if the data types of dictionary match the types specified by the
|
|
1518
|
-
arguments, false otherwise
|
|
1519
|
-
|
|
1520
|
-
"""
|
|
1521
|
-
if value_list_allowed_types and not isinstance(
|
|
1522
|
-
value_list_allowed_types, tuple
|
|
1523
|
-
):
|
|
1524
|
-
value_list_allowed_types = (value_list_allowed_types,)
|
|
1525
|
-
|
|
1526
|
-
return (
|
|
1527
|
-
isinstance(d, dict)
|
|
1528
|
-
and all(isinstance(k, key_allowed_types) for k in d)
|
|
1529
|
-
and all(
|
|
1530
|
-
isinstance(v, value_allowed_types)
|
|
1531
|
-
or any(is_list_of(v, t) for t in value_list_allowed_types)
|
|
1532
|
-
for v in d.values()
|
|
1533
|
-
if value_allowed_types or value_list_allowed_types
|
|
1534
|
-
)
|
|
1535
|
-
)
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
def _count_characters_raw_data(data: str | List[str]) -> int:
|
|
1512
|
+
def _count_characters_raw_data(data: str | list[str]) -> int:
|
|
1539
1513
|
character_count = 0
|
|
1540
1514
|
if isinstance(data, str):
|
|
1541
1515
|
character_count = len(data)
|
|
@@ -1551,8 +1525,14 @@ def _count_characters_raw_data(data: str | List[str]) -> int:
|
|
|
1551
1525
|
|
|
1552
1526
|
|
|
1553
1527
|
def add_to_column_count_dictionary(
|
|
1554
|
-
column_dictionary:
|
|
1555
|
-
):
|
|
1528
|
+
column_dictionary: dict[str, int], col: str | None
|
|
1529
|
+
) -> None:
|
|
1530
|
+
"""Increment the count for a column name in a dictionary.
|
|
1531
|
+
|
|
1532
|
+
Args:
|
|
1533
|
+
column_dictionary: Dictionary mapping column names to counts.
|
|
1534
|
+
col: The column name to increment, or None to skip.
|
|
1535
|
+
"""
|
|
1556
1536
|
if col:
|
|
1557
1537
|
if col in column_dictionary:
|
|
1558
1538
|
column_dictionary[col] += 1
|
|
@@ -1560,7 +1540,9 @@ def add_to_column_count_dictionary(
|
|
|
1560
1540
|
column_dictionary[col] = 1
|
|
1561
1541
|
|
|
1562
1542
|
|
|
1563
|
-
def _validate_bounding_box_coordinates(
|
|
1543
|
+
def _validate_bounding_box_coordinates(
|
|
1544
|
+
bounding_box_coordinates: list[float],
|
|
1545
|
+
) -> None:
|
|
1564
1546
|
if not is_list_of(bounding_box_coordinates, float):
|
|
1565
1547
|
raise TypeError(
|
|
1566
1548
|
"Each bounding box's coordinates must be a lists of floats"
|
|
@@ -1586,10 +1568,12 @@ def _validate_bounding_box_coordinates(bounding_box_coordinates: List[float]):
|
|
|
1586
1568
|
f"top-left. Found {bounding_box_coordinates}"
|
|
1587
1569
|
)
|
|
1588
1570
|
|
|
1589
|
-
return
|
|
1571
|
+
return
|
|
1590
1572
|
|
|
1591
1573
|
|
|
1592
|
-
def _validate_polygon_coordinates(
|
|
1574
|
+
def _validate_polygon_coordinates(
|
|
1575
|
+
polygon_coordinates: list[list[float]],
|
|
1576
|
+
) -> None:
|
|
1593
1577
|
if not is_list_of(polygon_coordinates, list):
|
|
1594
1578
|
raise TypeError("Polygon coordinates must be a list of lists of floats")
|
|
1595
1579
|
for coordinates in polygon_coordinates:
|
|
@@ -1651,27 +1635,41 @@ def _validate_polygon_coordinates(polygon_coordinates: List[List[float]]):
|
|
|
1651
1635
|
f"{coordinates}"
|
|
1652
1636
|
)
|
|
1653
1637
|
|
|
1654
|
-
return
|
|
1638
|
+
return
|
|
1655
1639
|
|
|
1656
1640
|
|
|
1657
|
-
def segments_intersect(
|
|
1658
|
-
|
|
1659
|
-
|
|
1641
|
+
def segments_intersect(
|
|
1642
|
+
p1: tuple[float, float],
|
|
1643
|
+
p2: tuple[float, float],
|
|
1644
|
+
p3: tuple[float, float],
|
|
1645
|
+
p4: tuple[float, float],
|
|
1646
|
+
) -> bool:
|
|
1647
|
+
"""Check if two line segments intersect.
|
|
1660
1648
|
|
|
1661
1649
|
Args:
|
|
1662
|
-
p1
|
|
1663
|
-
|
|
1650
|
+
p1: First endpoint of the first line segment (x,y)
|
|
1651
|
+
p2: Second endpoint of the first line segment (x,y)
|
|
1652
|
+
p3: First endpoint of the second line segment (x,y)
|
|
1653
|
+
p4: Second endpoint of the second line segment (x,y)
|
|
1664
1654
|
|
|
1665
1655
|
Returns:
|
|
1666
1656
|
True if the line segments intersect, False otherwise
|
|
1667
1657
|
"""
|
|
1668
1658
|
|
|
1669
1659
|
# Function to calculate direction
|
|
1670
|
-
def orientation(
|
|
1660
|
+
def orientation(
|
|
1661
|
+
p: tuple[float, float],
|
|
1662
|
+
q: tuple[float, float],
|
|
1663
|
+
r: tuple[float, float],
|
|
1664
|
+
) -> float:
|
|
1671
1665
|
return (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1])
|
|
1672
1666
|
|
|
1673
1667
|
# Function to check if point q is on segment pr
|
|
1674
|
-
def on_segment(
|
|
1668
|
+
def on_segment(
|
|
1669
|
+
p: tuple[float, float],
|
|
1670
|
+
q: tuple[float, float],
|
|
1671
|
+
r: tuple[float, float],
|
|
1672
|
+
) -> bool:
|
|
1675
1673
|
return (
|
|
1676
1674
|
q[0] <= max(p[0], r[0])
|
|
1677
1675
|
and q[0] >= min(p[0], r[0])
|
|
@@ -1703,17 +1701,20 @@ def segments_intersect(p1, p2, p3, p4):
|
|
|
1703
1701
|
|
|
1704
1702
|
@unique
|
|
1705
1703
|
class StatusCodes(Enum):
|
|
1704
|
+
"""Enum representing status codes for operations and responses."""
|
|
1705
|
+
|
|
1706
1706
|
UNSET = 0
|
|
1707
1707
|
OK = 1
|
|
1708
1708
|
ERROR = 2
|
|
1709
1709
|
|
|
1710
1710
|
@classmethod
|
|
1711
|
-
def list_codes(cls):
|
|
1711
|
+
def list_codes(cls) -> list[str]:
|
|
1712
|
+
"""Return a list of all status code names."""
|
|
1712
1713
|
return [t.name for t in cls]
|
|
1713
1714
|
|
|
1714
1715
|
|
|
1715
|
-
def convert_element(value):
|
|
1716
|
-
"""Converts scalar or array to python native"""
|
|
1716
|
+
def convert_element(value: object) -> object:
|
|
1717
|
+
"""Converts scalar or array to python native."""
|
|
1717
1718
|
val = getattr(value, "tolist", lambda: value)()
|
|
1718
1719
|
# Check if it's a list since elements from pd indices are converted to a
|
|
1719
1720
|
# scalar whereas pd series/dataframe elements are converted to list of 1
|
|
@@ -1734,7 +1735,7 @@ PredictionLabelTypes = (
|
|
|
1734
1735
|
| bool
|
|
1735
1736
|
| int
|
|
1736
1737
|
| float
|
|
1737
|
-
|
|
|
1738
|
+
| tuple[str, float]
|
|
1738
1739
|
| ObjectDetectionLabel
|
|
1739
1740
|
| RankingPredictionLabel
|
|
1740
1741
|
| MultiClassPredictionLabel
|
|
@@ -1745,7 +1746,7 @@ ActualLabelTypes = (
|
|
|
1745
1746
|
| bool
|
|
1746
1747
|
| int
|
|
1747
1748
|
| float
|
|
1748
|
-
|
|
|
1749
|
+
| tuple[str, float]
|
|
1749
1750
|
| ObjectDetectionLabel
|
|
1750
1751
|
| RankingActualLabel
|
|
1751
1752
|
| MultiClassActualLabel
|