arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +28 -19
- arize/_exporter/client.py +56 -37
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +207 -76
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +181 -58
- arize/config.py +324 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +304 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +43 -18
- arize/embeddings/tabular_generators.py +46 -31
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +13 -0
- arize/experiments/client.py +394 -285
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/ml/__init__.py +1 -0
- arize/ml/batch_validation/__init__.py +1 -0
- arize/{models → ml}/batch_validation/errors.py +545 -67
- arize/{models → ml}/batch_validation/validator.py +344 -303
- arize/ml/bounded_executor.py +47 -0
- arize/{models → ml}/casting.py +118 -108
- arize/{models → ml}/client.py +339 -118
- arize/{models → ml}/proto.py +97 -42
- arize/{models → ml}/stream_validation.py +43 -15
- arize/ml/surrogate_explainer/__init__.py +1 -0
- arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
- arize/{types.py → ml/types.py} +355 -354
- arize/pre_releases.py +44 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +134 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +204 -175
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +60 -37
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +81 -14
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +35 -14
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +78 -8
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +20 -3
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/utils/types.py +105 -0
- arize/version.py +3 -1
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
- arize-8.0.0b0.dist-info/RECORD +175 -0
- {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
- arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize/models/__init__.py +0 -0
- arize/models/batch_validation/__init__.py +0 -0
- arize/models/bounded_executor.py +0 -34
- arize/models/surrogate_explainer/__init__.py +0 -0
- arize-8.0.0a22.dist-info/RECORD +0 -146
- arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
arize/{models → ml}/client.py
RENAMED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
"""Client implementation for managing ML models in the Arize platform."""
|
|
2
|
+
|
|
1
3
|
# type: ignore[pb2]
|
|
2
4
|
from __future__ import annotations
|
|
3
5
|
|
|
4
6
|
import copy
|
|
5
7
|
import logging
|
|
6
8
|
import time
|
|
7
|
-
from typing import TYPE_CHECKING
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
8
10
|
|
|
11
|
+
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
9
12
|
from arize._lazy import require
|
|
10
13
|
from arize.constants.ml import (
|
|
11
14
|
LLM_RUN_METADATA_PROMPT_TOKEN_COUNT_TAG_NAME,
|
|
@@ -30,19 +33,20 @@ from arize.exceptions.parameters import (
|
|
|
30
33
|
)
|
|
31
34
|
from arize.exceptions.spaces import MissingSpaceIDError
|
|
32
35
|
from arize.logging import get_truncation_warning_message
|
|
33
|
-
from arize.
|
|
34
|
-
from arize.
|
|
35
|
-
from arize.
|
|
36
|
+
from arize.ml.bounded_executor import BoundedExecutor
|
|
37
|
+
from arize.ml.casting import cast_dictionary, cast_typed_columns
|
|
38
|
+
from arize.ml.stream_validation import (
|
|
36
39
|
validate_and_convert_prediction_id,
|
|
37
40
|
validate_label,
|
|
38
41
|
)
|
|
39
|
-
from arize.types import (
|
|
42
|
+
from arize.ml.types import (
|
|
40
43
|
CATEGORICAL_MODEL_TYPES,
|
|
41
44
|
NUMERIC_MODEL_TYPES,
|
|
42
45
|
ActualLabelTypes,
|
|
43
46
|
BaseSchema,
|
|
44
47
|
CorpusSchema,
|
|
45
48
|
Embedding,
|
|
49
|
+
EmbeddingColumnNames,
|
|
46
50
|
Environments,
|
|
47
51
|
LLMRunMetadata,
|
|
48
52
|
Metrics,
|
|
@@ -53,8 +57,8 @@ from arize.types import (
|
|
|
53
57
|
SimilaritySearchParams,
|
|
54
58
|
TypedValue,
|
|
55
59
|
convert_element,
|
|
56
|
-
is_list_of,
|
|
57
60
|
)
|
|
61
|
+
from arize.utils.types import is_list_of
|
|
58
62
|
|
|
59
63
|
if TYPE_CHECKING:
|
|
60
64
|
import concurrent.futures as cf
|
|
@@ -64,12 +68,7 @@ if TYPE_CHECKING:
|
|
|
64
68
|
import requests
|
|
65
69
|
from requests_futures.sessions import FuturesSession
|
|
66
70
|
|
|
67
|
-
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
68
71
|
from arize.config import SDKConfiguration
|
|
69
|
-
from arize.types import (
|
|
70
|
-
EmbeddingColumnNames,
|
|
71
|
-
Schema,
|
|
72
|
-
)
|
|
73
72
|
|
|
74
73
|
|
|
75
74
|
logger = logging.getLogger(__name__)
|
|
@@ -96,7 +95,18 @@ _MIMIC_EXTRA = "mimic-explainer"
|
|
|
96
95
|
|
|
97
96
|
|
|
98
97
|
class MLModelsClient:
|
|
99
|
-
|
|
98
|
+
"""Client for logging ML model predictions and actuals to Arize.
|
|
99
|
+
|
|
100
|
+
This class is primarily intended for internal use within the SDK. Users are
|
|
101
|
+
highly encouraged to access resource-specific functionality via
|
|
102
|
+
:class:`arize.ArizeClient`.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Args:
|
|
108
|
+
sdk_config: Resolved SDK configuration.
|
|
109
|
+
""" # noqa: D205, D212
|
|
100
110
|
self._sdk_config = sdk_config
|
|
101
111
|
|
|
102
112
|
# internal cache for the futures session
|
|
@@ -114,24 +124,89 @@ class MLModelsClient:
|
|
|
114
124
|
prediction_timestamp: int | None = None,
|
|
115
125
|
prediction_label: PredictionLabelTypes | None = None,
|
|
116
126
|
actual_label: ActualLabelTypes | None = None,
|
|
117
|
-
features:
|
|
127
|
+
features: dict[str, str | bool | float | int | list[str] | TypedValue]
|
|
118
128
|
| None = None,
|
|
119
|
-
embedding_features:
|
|
120
|
-
shap_values:
|
|
121
|
-
tags:
|
|
129
|
+
embedding_features: dict[str, Embedding] | None = None,
|
|
130
|
+
shap_values: dict[str, float] | None = None,
|
|
131
|
+
tags: dict[str, str | bool | float | int | TypedValue] | None = None,
|
|
122
132
|
batch_id: str | None = None,
|
|
123
133
|
prompt: str | Embedding | None = None,
|
|
124
134
|
response: str | Embedding | None = None,
|
|
125
135
|
prompt_template: str | None = None,
|
|
126
136
|
prompt_template_version: str | None = None,
|
|
127
137
|
llm_model_name: str | None = None,
|
|
128
|
-
llm_params:
|
|
138
|
+
llm_params: dict[str, str | bool | float | int] | None = None,
|
|
129
139
|
llm_run_metadata: LLMRunMetadata | None = None,
|
|
130
140
|
timeout: float | None = None,
|
|
131
141
|
) -> cf.Future:
|
|
142
|
+
"""Log a single model prediction or actual to Arize asynchronously.
|
|
143
|
+
|
|
144
|
+
This method sends a single prediction, actual, or both to Arize for ML monitoring.
|
|
145
|
+
The request is made asynchronously and returns a Future that can be used to check
|
|
146
|
+
the status or retrieve the response.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
space_id: The space ID where the model resides.
|
|
150
|
+
model_name: A unique name to identify your model in the Arize platform.
|
|
151
|
+
model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
|
|
152
|
+
RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
|
|
153
|
+
spans module instead.
|
|
154
|
+
environment: The environment this data belongs to (PRODUCTION, TRAINING, or
|
|
155
|
+
VALIDATION).
|
|
156
|
+
model_version: Optional version identifier for the model.
|
|
157
|
+
prediction_id: Unique identifier for this prediction. If not provided, one
|
|
158
|
+
will be auto-generated for PRODUCTION environment.
|
|
159
|
+
prediction_timestamp: Unix timestamp (seconds) for when the prediction was made.
|
|
160
|
+
If not provided, the current time is used. Must be within 1 year in the
|
|
161
|
+
future and 2 years in the past from the current time.
|
|
162
|
+
prediction_label: The prediction output from your model. Type depends on
|
|
163
|
+
model_type (e.g., string for categorical, float for numeric).
|
|
164
|
+
actual_label: The ground truth label. Type depends on model_type.
|
|
165
|
+
features: Dictionary of feature name to feature value. Values can be str, bool,
|
|
166
|
+
float, int, list[str], or TypedValue.
|
|
167
|
+
embedding_features: Dictionary of embedding feature name to Embedding object.
|
|
168
|
+
Maximum 50 embeddings per record. Object detection models support only 1.
|
|
169
|
+
shap_values: Dictionary of feature name to SHAP value (float) for feature
|
|
170
|
+
importance/explainability.
|
|
171
|
+
tags: Dictionary of metadata tags. Tag names cannot end with "_shap" or be
|
|
172
|
+
reserved names. Values must be under 1000 characters (warning at 100).
|
|
173
|
+
batch_id: Required for VALIDATION environment; identifies the validation batch.
|
|
174
|
+
prompt: For generative models, the prompt text or embedding sent to the model.
|
|
175
|
+
response: For generative models, the response text or embedding from the model.
|
|
176
|
+
prompt_template: Template used to generate the prompt.
|
|
177
|
+
prompt_template_version: Version identifier for the prompt template.
|
|
178
|
+
llm_model_name: Name of the LLM model used (e.g., "gpt-4").
|
|
179
|
+
llm_params: Dictionary of LLM configuration parameters (e.g., temperature,
|
|
180
|
+
max_tokens).
|
|
181
|
+
llm_run_metadata: Metadata about the LLM run including token counts and latency.
|
|
182
|
+
timeout: Maximum time (in seconds) to wait for the request to complete.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
A concurrent.futures.Future object representing the async request. Call
|
|
186
|
+
.result() to block and retrieve the Response object, or check .done() for
|
|
187
|
+
completion status.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
ValueError: If model_type is GENERATIVE_LLM, or if validation environment is
|
|
191
|
+
missing batch_id, or if training/validation environment is missing
|
|
192
|
+
prediction or actual, or if timestamp is out of range, or if no data
|
|
193
|
+
is provided (must have prediction_label, actual_label, tags, or shap_values),
|
|
194
|
+
or if tag names end with "_shap" or exceed length limits.
|
|
195
|
+
MissingSpaceIDError: If space_id is not provided or empty.
|
|
196
|
+
MissingModelNameError: If model_name is not provided or empty.
|
|
197
|
+
InvalidValueType: If features, tags, or other parameters have incorrect types.
|
|
198
|
+
InvalidNumberOfEmbeddings: If more than 50 embedding features are provided.
|
|
199
|
+
KeyError: If tag names include reserved names.
|
|
200
|
+
|
|
201
|
+
Notes:
|
|
202
|
+
- Timestamps must be within 1 year future and 2 years past from current time
|
|
203
|
+
- Tag values are truncated at 1000 characters, with warnings at 100 characters
|
|
204
|
+
- For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
|
|
205
|
+
- The Future returned can be monitored for request status asynchronously
|
|
206
|
+
"""
|
|
132
207
|
require(_STREAM_EXTRA, _STREAM_DEPS)
|
|
133
208
|
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
134
|
-
from arize.
|
|
209
|
+
from arize.ml.proto import (
|
|
135
210
|
get_pb_dictionary,
|
|
136
211
|
get_pb_label,
|
|
137
212
|
get_pb_timestamp,
|
|
@@ -179,16 +254,15 @@ class MLModelsClient:
|
|
|
179
254
|
_validate_mapping_key(feat_name, "features")
|
|
180
255
|
if is_list_of(feat_value, str):
|
|
181
256
|
continue
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
)
|
|
257
|
+
val = convert_element(feat_value)
|
|
258
|
+
if val is not None and not isinstance(
|
|
259
|
+
val, (str, bool, float, int)
|
|
260
|
+
):
|
|
261
|
+
raise InvalidValueType(
|
|
262
|
+
f"feature '{feat_name}'",
|
|
263
|
+
feat_value,
|
|
264
|
+
"one of: bool, int, float, str",
|
|
265
|
+
)
|
|
192
266
|
|
|
193
267
|
# Validate embedding_features type
|
|
194
268
|
if embedding_features:
|
|
@@ -247,7 +321,7 @@ class MLModelsClient:
|
|
|
247
321
|
f"{MAX_TAG_LENGTH}. The tag {tag_name} with value {tag_value} has "
|
|
248
322
|
f"{len(str(val))} characters."
|
|
249
323
|
)
|
|
250
|
-
|
|
324
|
+
if len(str(val)) > MAX_TAG_LENGTH_TRUNCATION:
|
|
251
325
|
logger.warning(
|
|
252
326
|
get_truncation_warning_message(
|
|
253
327
|
"tags", MAX_TAG_LENGTH_TRUNCATION
|
|
@@ -304,9 +378,7 @@ class MLModelsClient:
|
|
|
304
378
|
if embedding_features or prompt or response:
|
|
305
379
|
# NOTE: Deep copy is necessary to avoid side effects on the original input dictionary
|
|
306
380
|
combined_embedding_features = (
|
|
307
|
-
|
|
308
|
-
if embedding_features
|
|
309
|
-
else {}
|
|
381
|
+
embedding_features.copy() if embedding_features else {}
|
|
310
382
|
)
|
|
311
383
|
# Map prompt as embedding features for generative models
|
|
312
384
|
if prompt is not None:
|
|
@@ -453,8 +525,7 @@ class MLModelsClient:
|
|
|
453
525
|
indexes=None,
|
|
454
526
|
)
|
|
455
527
|
|
|
456
|
-
|
|
457
|
-
def log_batch(
|
|
528
|
+
def log(
|
|
458
529
|
self,
|
|
459
530
|
*,
|
|
460
531
|
space_id: str,
|
|
@@ -466,17 +537,69 @@ class MLModelsClient:
|
|
|
466
537
|
model_version: str = "",
|
|
467
538
|
batch_id: str = "",
|
|
468
539
|
validate: bool = True,
|
|
469
|
-
metrics_validation:
|
|
540
|
+
metrics_validation: list[Metrics] | None = None,
|
|
470
541
|
surrogate_explainability: bool = False,
|
|
471
542
|
timeout: float | None = None,
|
|
472
543
|
tmp_dir: str = "",
|
|
473
|
-
sync: bool = False,
|
|
474
544
|
) -> requests.Response:
|
|
545
|
+
"""Log a batch of model predictions and actuals to Arize from a pandas DataFrame.
|
|
546
|
+
|
|
547
|
+
This method uploads multiple records to Arize in a single batch operation using
|
|
548
|
+
Apache Arrow format for efficient transfer. The dataframe structure is defined
|
|
549
|
+
by the provided schema which maps dataframe columns to Arize data fields.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
space_id: The space ID where the model resides.
|
|
553
|
+
model_name: A unique name to identify your model in the Arize platform.
|
|
554
|
+
model_type: The type of model. Supported types: BINARY, MULTI_CLASS, REGRESSION,
|
|
555
|
+
RANKING, OBJECT_DETECTION. Note: GENERATIVE_LLM is not supported; use the
|
|
556
|
+
spans module instead.
|
|
557
|
+
dataframe: Pandas DataFrame containing the data to upload. Columns should
|
|
558
|
+
correspond to the schema field mappings.
|
|
559
|
+
schema: Schema object (Schema or CorpusSchema) that defines the mapping between
|
|
560
|
+
dataframe columns and Arize data fields (e.g., prediction_label_column_name,
|
|
561
|
+
feature_column_names, etc.).
|
|
562
|
+
environment: The environment this data belongs to (PRODUCTION, TRAINING,
|
|
563
|
+
VALIDATION, or CORPUS).
|
|
564
|
+
model_version: Optional version identifier for the model.
|
|
565
|
+
batch_id: Required for VALIDATION environment; identifies the validation batch.
|
|
566
|
+
validate: When True, performs comprehensive validation before sending data.
|
|
567
|
+
Includes checks for required fields, data types, and value constraints.
|
|
568
|
+
metrics_validation: Optional list of metric families to validate against.
|
|
569
|
+
surrogate_explainability: When True, automatically generates SHAP values using
|
|
570
|
+
MIMIC surrogate explainer. Requires the 'mimic-explainer' extra. Has no
|
|
571
|
+
effect if shap_values_column_names is already specified in schema.
|
|
572
|
+
timeout: Maximum time (in seconds) to wait for the request to complete.
|
|
573
|
+
tmp_dir: Optional temporary directory to store serialized Arrow data before
|
|
574
|
+
upload.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
A requests.Response object from the upload request. Check .status_code for
|
|
578
|
+
success (200) or error conditions.
|
|
579
|
+
|
|
580
|
+
Raises:
|
|
581
|
+
MissingSpaceIDError: If space_id is not provided or empty.
|
|
582
|
+
MissingModelNameError: If model_name is not provided or empty.
|
|
583
|
+
ValueError: If model_type is GENERATIVE_LLM, or if environment is CORPUS with
|
|
584
|
+
non-CorpusSchema, or if training/validation records are incomplete.
|
|
585
|
+
ValidationFailure: If validate=True and validation checks fail. Contains list
|
|
586
|
+
of validation error messages.
|
|
587
|
+
pa.ArrowInvalid: If the dataframe cannot be converted to Arrow format, typically
|
|
588
|
+
due to mixed types in columns not specified in the schema.
|
|
589
|
+
|
|
590
|
+
Notes:
|
|
591
|
+
- Categorical dtype columns are automatically converted to string
|
|
592
|
+
- Extraneous columns not in the schema are removed before upload
|
|
593
|
+
- Surrogate explainability requires 'mimic-explainer' extra
|
|
594
|
+
- For GENERATIVE_LLM models, use the spans module or OTEL tracing instead
|
|
595
|
+
- If logging actuals without predictions, ensure predictions were logged first
|
|
596
|
+
- Data is sent via Apache Arrow for efficient large batch transfers
|
|
597
|
+
"""
|
|
475
598
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
476
599
|
import pandas.api.types as ptypes
|
|
477
600
|
import pyarrow as pa
|
|
478
601
|
|
|
479
|
-
from arize.
|
|
602
|
+
from arize.ml.batch_validation.validator import Validator
|
|
480
603
|
from arize.utils.arrow import post_arrow_table
|
|
481
604
|
from arize.utils.dataframe import remove_extraneous_columns
|
|
482
605
|
|
|
@@ -506,8 +629,8 @@ class MLModelsClient:
|
|
|
506
629
|
# Thus we can only offer this functionality with pandas>=1.0.0.
|
|
507
630
|
try:
|
|
508
631
|
dataframe, schema = cast_typed_columns(dataframe, schema)
|
|
509
|
-
except Exception
|
|
510
|
-
logger.
|
|
632
|
+
except Exception:
|
|
633
|
+
logger.exception("Error casting typed columns")
|
|
511
634
|
raise
|
|
512
635
|
|
|
513
636
|
logger.debug("Performing required validation.")
|
|
@@ -546,7 +669,7 @@ class MLModelsClient:
|
|
|
546
669
|
|
|
547
670
|
# always validate pd.Category is not present, if yes, convert to string
|
|
548
671
|
has_cat_col = any(
|
|
549
|
-
|
|
672
|
+
ptypes.is_categorical_dtype(x) for x in dataframe.dtypes
|
|
550
673
|
)
|
|
551
674
|
if has_cat_col:
|
|
552
675
|
cat_cols = [
|
|
@@ -554,12 +677,18 @@ class MLModelsClient:
|
|
|
554
677
|
for col_name, col_cat in dataframe.dtypes.items()
|
|
555
678
|
if col_cat.name == "category"
|
|
556
679
|
]
|
|
557
|
-
cat_str_map = dict(
|
|
680
|
+
cat_str_map = dict(
|
|
681
|
+
zip(
|
|
682
|
+
cat_cols,
|
|
683
|
+
["str"] * len(cat_cols),
|
|
684
|
+
strict=True,
|
|
685
|
+
)
|
|
686
|
+
)
|
|
558
687
|
dataframe = dataframe.astype(cat_str_map)
|
|
559
688
|
|
|
560
689
|
if surrogate_explainability:
|
|
561
690
|
require(_MIMIC_EXTRA, _MIMIC_DEPS)
|
|
562
|
-
from arize.
|
|
691
|
+
from arize.ml.surrogate_explainer.mimic import Mimic
|
|
563
692
|
|
|
564
693
|
logger.debug("Running surrogate_explainability.")
|
|
565
694
|
if schema.shap_values_column_names:
|
|
@@ -588,12 +717,12 @@ class MLModelsClient:
|
|
|
588
717
|
# error conditions that we're currently not aware of.
|
|
589
718
|
pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
|
|
590
719
|
except pa.ArrowInvalid as e:
|
|
591
|
-
logger.
|
|
720
|
+
logger.exception(INVALID_ARROW_CONVERSION_MSG)
|
|
592
721
|
raise pa.ArrowInvalid(
|
|
593
|
-
f"Error converting to Arrow format: {
|
|
722
|
+
f"Error converting to Arrow format: {e!s}"
|
|
594
723
|
) from e
|
|
595
|
-
except Exception
|
|
596
|
-
logger.
|
|
724
|
+
except Exception:
|
|
725
|
+
logger.exception("Unexpected error creating Arrow table")
|
|
597
726
|
raise
|
|
598
727
|
|
|
599
728
|
if validate:
|
|
@@ -678,18 +807,53 @@ class MLModelsClient:
|
|
|
678
807
|
model_version: str = "",
|
|
679
808
|
batch_id: str = "",
|
|
680
809
|
where: str = "",
|
|
681
|
-
columns:
|
|
810
|
+
columns: list | None = None,
|
|
682
811
|
similarity_search_params: SimilaritySearchParams | None = None,
|
|
683
812
|
stream_chunk_size: int | None = None,
|
|
684
813
|
) -> pd.DataFrame:
|
|
814
|
+
"""Export model data from Arize to a pandas DataFrame.
|
|
815
|
+
|
|
816
|
+
Retrieves prediction and optional actual data for a model within a specified time
|
|
817
|
+
range and returns it as a pandas DataFrame for analysis.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
space_id: The space ID where the model resides.
|
|
821
|
+
model_name: The name of the model to export data from.
|
|
822
|
+
environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
|
|
823
|
+
start_time: Start of the time range (inclusive) as a datetime object.
|
|
824
|
+
end_time: End of the time range (inclusive) as a datetime object.
|
|
825
|
+
include_actuals: When True, includes actual labels in the export. When False,
|
|
826
|
+
only predictions are returned.
|
|
827
|
+
model_version: Optional model version to filter by. Empty string returns all
|
|
828
|
+
versions.
|
|
829
|
+
batch_id: Optional batch ID to filter by (for VALIDATION environment).
|
|
830
|
+
where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
|
|
831
|
+
columns: Optional list of column names to include. If None, all columns are
|
|
832
|
+
returned.
|
|
833
|
+
similarity_search_params: Optional parameters for embedding similarity search
|
|
834
|
+
filtering.
|
|
835
|
+
stream_chunk_size: Optional chunk size for streaming large result sets.
|
|
836
|
+
|
|
837
|
+
Returns:
|
|
838
|
+
A pandas DataFrame containing the exported data with columns for predictions,
|
|
839
|
+
actuals (if requested), features, tags, timestamps, and other model metadata.
|
|
840
|
+
|
|
841
|
+
Raises:
|
|
842
|
+
RuntimeError: If the Flight client request fails or returns no response.
|
|
843
|
+
|
|
844
|
+
Notes:
|
|
845
|
+
- Uses Apache Arrow Flight for efficient data transfer
|
|
846
|
+
- Large exports may benefit from specifying stream_chunk_size
|
|
847
|
+
- The where clause supports SQL-like filtering syntax
|
|
848
|
+
"""
|
|
685
849
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
686
850
|
from arize._exporter.client import ArizeExportClient
|
|
687
851
|
from arize._flight.client import ArizeFlightClient
|
|
688
852
|
|
|
689
853
|
with ArizeFlightClient(
|
|
690
854
|
api_key=self._sdk_config.api_key,
|
|
691
|
-
host=self._sdk_config.
|
|
692
|
-
port=self._sdk_config.
|
|
855
|
+
host=self._sdk_config.flight_host,
|
|
856
|
+
port=self._sdk_config.flight_port,
|
|
693
857
|
scheme=self._sdk_config.flight_scheme,
|
|
694
858
|
request_verify=self._sdk_config.request_verify,
|
|
695
859
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -724,18 +888,53 @@ class MLModelsClient:
|
|
|
724
888
|
model_version: str = "",
|
|
725
889
|
batch_id: str = "",
|
|
726
890
|
where: str = "",
|
|
727
|
-
columns:
|
|
891
|
+
columns: list | None = None,
|
|
728
892
|
similarity_search_params: SimilaritySearchParams | None = None,
|
|
729
893
|
stream_chunk_size: int | None = None,
|
|
730
894
|
) -> pd.DataFrame:
|
|
895
|
+
"""Export model data from Arize to a Parquet file and return as DataFrame.
|
|
896
|
+
|
|
897
|
+
Retrieves prediction and optional actual data for a model within a specified time
|
|
898
|
+
range, saves it as a Parquet file, and returns it as a pandas DataFrame.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
space_id: The space ID where the model resides.
|
|
902
|
+
model_name: The name of the model to export data from.
|
|
903
|
+
environment: The environment to export from (PRODUCTION, TRAINING, or VALIDATION).
|
|
904
|
+
start_time: Start of the time range (inclusive) as a datetime object.
|
|
905
|
+
end_time: End of the time range (inclusive) as a datetime object.
|
|
906
|
+
include_actuals: When True, includes actual labels in the export. When False,
|
|
907
|
+
only predictions are returned.
|
|
908
|
+
model_version: Optional model version to filter by. Empty string returns all
|
|
909
|
+
versions.
|
|
910
|
+
batch_id: Optional batch ID to filter by (for VALIDATION environment).
|
|
911
|
+
where: Optional SQL-like WHERE clause to filter rows (e.g., "feature_x > 0.5").
|
|
912
|
+
columns: Optional list of column names to include. If None, all columns are
|
|
913
|
+
returned.
|
|
914
|
+
similarity_search_params: Optional parameters for embedding similarity search
|
|
915
|
+
filtering.
|
|
916
|
+
stream_chunk_size: Optional chunk size for streaming large result sets.
|
|
917
|
+
|
|
918
|
+
Returns:
|
|
919
|
+
A pandas DataFrame containing the exported data. The data is also saved to a
|
|
920
|
+
Parquet file by the underlying export client.
|
|
921
|
+
|
|
922
|
+
Raises:
|
|
923
|
+
RuntimeError: If the Flight client request fails or returns no response.
|
|
924
|
+
|
|
925
|
+
Notes:
|
|
926
|
+
- Uses Apache Arrow Flight for efficient data transfer
|
|
927
|
+
- The Parquet file location is managed by the ArizeExportClient
|
|
928
|
+
- Large exports may benefit from specifying stream_chunk_size
|
|
929
|
+
"""
|
|
731
930
|
require(_BATCH_EXTRA, _BATCH_DEPS)
|
|
732
931
|
from arize._exporter.client import ArizeExportClient
|
|
733
932
|
from arize._flight.client import ArizeFlightClient
|
|
734
933
|
|
|
735
934
|
with ArizeFlightClient(
|
|
736
935
|
api_key=self._sdk_config.api_key,
|
|
737
|
-
host=self._sdk_config.
|
|
738
|
-
port=self._sdk_config.
|
|
936
|
+
host=self._sdk_config.flight_host,
|
|
937
|
+
port=self._sdk_config.flight_port,
|
|
739
938
|
scheme=self._sdk_config.flight_scheme,
|
|
740
939
|
request_verify=self._sdk_config.request_verify,
|
|
741
940
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -759,6 +958,7 @@ class MLModelsClient:
|
|
|
759
958
|
)
|
|
760
959
|
|
|
761
960
|
def _ensure_session(self) -> FuturesSession:
|
|
961
|
+
"""Lazily initialize and return the FuturesSession for async streaming requests."""
|
|
762
962
|
from requests_futures.sessions import FuturesSession
|
|
763
963
|
|
|
764
964
|
session = object.__getattribute__(self, "_session")
|
|
@@ -778,10 +978,11 @@ class MLModelsClient:
|
|
|
778
978
|
def _post(
|
|
779
979
|
self,
|
|
780
980
|
record: pb2.Record,
|
|
781
|
-
headers:
|
|
981
|
+
headers: dict[str, str],
|
|
782
982
|
timeout: float | None,
|
|
783
|
-
indexes:
|
|
784
|
-
):
|
|
983
|
+
indexes: tuple,
|
|
984
|
+
) -> object:
|
|
985
|
+
"""Post a record to Arize via async HTTP request with protobuf JSON serialization."""
|
|
785
986
|
from google.protobuf.json_format import MessageToDict
|
|
786
987
|
|
|
787
988
|
session = self._ensure_session()
|
|
@@ -801,9 +1002,10 @@ class MLModelsClient:
|
|
|
801
1002
|
return resp
|
|
802
1003
|
|
|
803
1004
|
|
|
804
|
-
def _validate_mapping_key(key_name: str, name: str):
|
|
1005
|
+
def _validate_mapping_key(key_name: str, name: str) -> None:
|
|
1006
|
+
"""Validate that a mapping key (feature/tag name) is a string and doesn't end with '_shap'."""
|
|
805
1007
|
if not isinstance(key_name, str):
|
|
806
|
-
raise
|
|
1008
|
+
raise TypeError(
|
|
807
1009
|
f"{name} dictionary key {key_name} must be named with string, type used: {type(key_name)}"
|
|
808
1010
|
)
|
|
809
1011
|
if key_name.endswith("_shap"):
|
|
@@ -813,7 +1015,8 @@ def _validate_mapping_key(key_name: str, name: str):
|
|
|
813
1015
|
return
|
|
814
1016
|
|
|
815
1017
|
|
|
816
|
-
def _is_timestamp_in_range(now: int, ts: int):
|
|
1018
|
+
def _is_timestamp_in_range(now: int, ts: int) -> bool:
|
|
1019
|
+
"""Check if a timestamp is within the acceptable range (1 year future, 2 years past)."""
|
|
817
1020
|
max_time = now + (MAX_FUTURE_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
|
|
818
1021
|
min_time = now - (MAX_PAST_YEARS_FROM_CURRENT_TIME * 365 * 24 * 60 * 60)
|
|
819
1022
|
return min_time <= ts <= max_time
|
|
@@ -826,7 +1029,8 @@ def _get_pb_schema(
|
|
|
826
1029
|
model_type: ModelTypes,
|
|
827
1030
|
environment: Environments,
|
|
828
1031
|
batch_id: str,
|
|
829
|
-
):
|
|
1032
|
+
) -> object:
|
|
1033
|
+
"""Construct a protocol buffer Schema from the user's Schema for batch logging."""
|
|
830
1034
|
s = pb2.Schema()
|
|
831
1035
|
s.constants.model_id = model_id
|
|
832
1036
|
|
|
@@ -874,48 +1078,52 @@ def _get_pb_schema(
|
|
|
874
1078
|
|
|
875
1079
|
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
876
1080
|
if schema.object_detection_prediction_column_names is not None:
|
|
877
|
-
|
|
878
|
-
|
|
1081
|
+
obj_det_pred = schema.object_detection_prediction_column_names
|
|
1082
|
+
pred_labels = (
|
|
1083
|
+
s.arrow_schema.prediction_object_detection_label_column_names
|
|
879
1084
|
)
|
|
880
|
-
|
|
881
|
-
|
|
1085
|
+
pred_labels.bboxes_coordinates_column_name = (
|
|
1086
|
+
obj_det_pred.bounding_boxes_coordinates_column_name
|
|
882
1087
|
)
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
1088
|
+
pred_labels.bboxes_categories_column_name = (
|
|
1089
|
+
obj_det_pred.categories_column_name
|
|
1090
|
+
)
|
|
1091
|
+
if obj_det_pred.scores_column_name is not None:
|
|
1092
|
+
pred_labels.bboxes_scores_column_name = (
|
|
1093
|
+
obj_det_pred.scores_column_name
|
|
889
1094
|
)
|
|
890
1095
|
|
|
891
1096
|
if schema.semantic_segmentation_prediction_column_names is not None:
|
|
892
|
-
|
|
893
|
-
|
|
1097
|
+
seg_pred_cols = schema.semantic_segmentation_prediction_column_names
|
|
1098
|
+
pred_seg_labels = s.arrow_schema.prediction_semantic_segmentation_label_column_names
|
|
1099
|
+
pred_seg_labels.polygons_coordinates_column_name = (
|
|
1100
|
+
seg_pred_cols.polygon_coordinates_column_name
|
|
894
1101
|
)
|
|
895
|
-
|
|
896
|
-
|
|
1102
|
+
pred_seg_labels.polygons_categories_column_name = (
|
|
1103
|
+
seg_pred_cols.categories_column_name
|
|
897
1104
|
)
|
|
898
1105
|
|
|
899
1106
|
if schema.instance_segmentation_prediction_column_names is not None:
|
|
900
|
-
|
|
901
|
-
schema.instance_segmentation_prediction_column_names
|
|
1107
|
+
inst_seg_pred_cols = (
|
|
1108
|
+
schema.instance_segmentation_prediction_column_names
|
|
902
1109
|
)
|
|
903
|
-
s.arrow_schema.prediction_instance_segmentation_label_column_names
|
|
904
|
-
|
|
1110
|
+
pred_inst_seg_labels = s.arrow_schema.prediction_instance_segmentation_label_column_names
|
|
1111
|
+
pred_inst_seg_labels.polygons_coordinates_column_name = (
|
|
1112
|
+
inst_seg_pred_cols.polygon_coordinates_column_name
|
|
905
1113
|
)
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
1114
|
+
pred_inst_seg_labels.polygons_categories_column_name = (
|
|
1115
|
+
inst_seg_pred_cols.categories_column_name
|
|
1116
|
+
)
|
|
1117
|
+
if inst_seg_pred_cols.scores_column_name is not None:
|
|
1118
|
+
pred_inst_seg_labels.polygons_scores_column_name = (
|
|
1119
|
+
inst_seg_pred_cols.scores_column_name
|
|
912
1120
|
)
|
|
913
1121
|
if (
|
|
914
|
-
|
|
1122
|
+
inst_seg_pred_cols.bounding_boxes_coordinates_column_name
|
|
915
1123
|
is not None
|
|
916
1124
|
):
|
|
917
|
-
|
|
918
|
-
|
|
1125
|
+
pred_inst_seg_labels.bboxes_coordinates_column_name = (
|
|
1126
|
+
inst_seg_pred_cols.bounding_boxes_coordinates_column_name
|
|
919
1127
|
)
|
|
920
1128
|
|
|
921
1129
|
if schema.prediction_score_column_name is not None:
|
|
@@ -1038,50 +1246,61 @@ def _get_pb_schema(
|
|
|
1038
1246
|
|
|
1039
1247
|
if model_type == ModelTypes.OBJECT_DETECTION:
|
|
1040
1248
|
if schema.object_detection_actual_column_names is not None:
|
|
1041
|
-
|
|
1042
|
-
|
|
1249
|
+
obj_det_actual = schema.object_detection_actual_column_names
|
|
1250
|
+
actual_labels = (
|
|
1251
|
+
s.arrow_schema.actual_object_detection_label_column_names
|
|
1043
1252
|
)
|
|
1044
|
-
|
|
1045
|
-
|
|
1253
|
+
actual_labels.bboxes_coordinates_column_name = (
|
|
1254
|
+
obj_det_actual.bounding_boxes_coordinates_column_name
|
|
1046
1255
|
)
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1256
|
+
actual_labels.bboxes_categories_column_name = (
|
|
1257
|
+
obj_det_actual.categories_column_name
|
|
1258
|
+
)
|
|
1259
|
+
if obj_det_actual.scores_column_name is not None:
|
|
1260
|
+
actual_labels.bboxes_scores_column_name = (
|
|
1261
|
+
obj_det_actual.scores_column_name
|
|
1053
1262
|
)
|
|
1054
1263
|
|
|
1055
1264
|
if schema.semantic_segmentation_actual_column_names is not None:
|
|
1056
|
-
|
|
1057
|
-
|
|
1265
|
+
sem_seg_actual = schema.semantic_segmentation_actual_column_names
|
|
1266
|
+
sem_seg_labels = (
|
|
1267
|
+
s.arrow_schema.actual_semantic_segmentation_label_column_names
|
|
1268
|
+
)
|
|
1269
|
+
sem_seg_labels.polygons_coordinates_column_name = (
|
|
1270
|
+
sem_seg_actual.polygon_coordinates_column_name
|
|
1058
1271
|
)
|
|
1059
|
-
|
|
1060
|
-
|
|
1272
|
+
sem_seg_labels.polygons_categories_column_name = (
|
|
1273
|
+
sem_seg_actual.categories_column_name
|
|
1061
1274
|
)
|
|
1062
1275
|
|
|
1063
1276
|
if schema.instance_segmentation_actual_column_names is not None:
|
|
1064
|
-
|
|
1065
|
-
|
|
1277
|
+
inst_seg_actual = schema.instance_segmentation_actual_column_names
|
|
1278
|
+
inst_seg_labels = (
|
|
1279
|
+
s.arrow_schema.actual_instance_segmentation_label_column_names
|
|
1066
1280
|
)
|
|
1067
|
-
|
|
1068
|
-
|
|
1281
|
+
inst_seg_labels.polygons_coordinates_column_name = (
|
|
1282
|
+
inst_seg_actual.polygon_coordinates_column_name
|
|
1283
|
+
)
|
|
1284
|
+
inst_seg_labels.polygons_categories_column_name = (
|
|
1285
|
+
inst_seg_actual.categories_column_name
|
|
1069
1286
|
)
|
|
1070
1287
|
if (
|
|
1071
|
-
|
|
1288
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
1072
1289
|
is not None
|
|
1073
1290
|
):
|
|
1074
|
-
|
|
1075
|
-
|
|
1291
|
+
inst_seg_labels.bboxes_coordinates_column_name = (
|
|
1292
|
+
inst_seg_actual.bounding_boxes_coordinates_column_name
|
|
1076
1293
|
)
|
|
1077
1294
|
|
|
1078
1295
|
if model_type == ModelTypes.GENERATIVE_LLM:
|
|
1079
1296
|
if schema.prompt_template_column_names is not None:
|
|
1080
|
-
|
|
1081
|
-
|
|
1297
|
+
prompt_template_names = schema.prompt_template_column_names
|
|
1298
|
+
arrow_prompt_names = s.arrow_schema.prompt_template_column_names
|
|
1299
|
+
arrow_prompt_names.template_column_name = (
|
|
1300
|
+
prompt_template_names.template_column_name
|
|
1082
1301
|
)
|
|
1083
|
-
|
|
1084
|
-
|
|
1302
|
+
arrow_prompt_names.template_version_column_name = (
|
|
1303
|
+
prompt_template_names.template_version_column_name
|
|
1085
1304
|
)
|
|
1086
1305
|
if schema.llm_config_column_names is not None:
|
|
1087
1306
|
s.arrow_schema.llm_config_column_names.model_column_name = (
|
|
@@ -1114,6 +1333,7 @@ def _get_pb_schema_corpus(
|
|
|
1114
1333
|
schema: CorpusSchema,
|
|
1115
1334
|
model_id: str,
|
|
1116
1335
|
) -> pb2.Schema:
|
|
1336
|
+
"""Construct a protocol buffer Schema from CorpusSchema for document corpus logging."""
|
|
1117
1337
|
s = pb2.Schema()
|
|
1118
1338
|
s.constants.model_id = model_id
|
|
1119
1339
|
s.constants.environment = pb2.Schema.Environment.CORPUS
|
|
@@ -1127,11 +1347,12 @@ def _get_pb_schema_corpus(
|
|
|
1127
1347
|
schema.document_version_column_name
|
|
1128
1348
|
)
|
|
1129
1349
|
if schema.document_text_embedding_column_names is not None:
|
|
1130
|
-
|
|
1131
|
-
s.arrow_schema.document_column_names.text_column_name
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1350
|
+
doc_text_emb_cols = schema.document_text_embedding_column_names
|
|
1351
|
+
doc_text_col = s.arrow_schema.document_column_names.text_column_name
|
|
1352
|
+
doc_text_col.vector_column_name = doc_text_emb_cols.vector_column_name
|
|
1353
|
+
doc_text_col.data_column_name = doc_text_emb_cols.data_column_name
|
|
1354
|
+
if doc_text_emb_cols.link_to_data_column_name is not None:
|
|
1355
|
+
doc_text_col.link_to_data_column_name = (
|
|
1356
|
+
doc_text_emb_cols.link_to_data_column_name
|
|
1357
|
+
)
|
|
1137
1358
|
return s
|