arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +9 -2
- arize/_client_factory.py +50 -0
- arize/_exporter/client.py +18 -17
- arize/_exporter/parsers/tracing_data_parser.py +9 -4
- arize/_exporter/validation.py +1 -1
- arize/_flight/client.py +37 -17
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +6 -6
- arize/_generated/api_client/api/projects_api.py +3 -3
- arize/_lazy.py +61 -10
- arize/client.py +66 -50
- arize/config.py +175 -48
- arize/constants/config.py +1 -0
- arize/constants/ml.py +9 -16
- arize/constants/spans.py +5 -10
- arize/datasets/client.py +45 -28
- arize/datasets/errors.py +1 -1
- arize/datasets/validation.py +2 -2
- arize/embeddings/auto_generator.py +16 -9
- arize/embeddings/base_generators.py +15 -9
- arize/embeddings/cv_generators.py +2 -2
- arize/embeddings/errors.py +2 -2
- arize/embeddings/nlp_generators.py +8 -8
- arize/embeddings/tabular_generators.py +6 -6
- arize/exceptions/base.py +0 -52
- arize/exceptions/config.py +22 -0
- arize/exceptions/parameters.py +1 -330
- arize/exceptions/values.py +8 -5
- arize/experiments/__init__.py +4 -0
- arize/experiments/client.py +31 -18
- arize/experiments/evaluators/base.py +12 -9
- arize/experiments/evaluators/executors.py +16 -7
- arize/experiments/evaluators/rate_limiters.py +3 -1
- arize/experiments/evaluators/types.py +9 -7
- arize/experiments/evaluators/utils.py +7 -5
- arize/experiments/functions.py +128 -58
- arize/experiments/tracing.py +4 -1
- arize/experiments/types.py +34 -31
- arize/logging.py +54 -33
- arize/ml/batch_validation/errors.py +10 -1004
- arize/ml/batch_validation/validator.py +351 -291
- arize/ml/bounded_executor.py +25 -6
- arize/ml/casting.py +51 -33
- arize/ml/client.py +43 -35
- arize/ml/proto.py +21 -22
- arize/ml/stream_validation.py +64 -27
- arize/ml/surrogate_explainer/mimic.py +18 -10
- arize/ml/types.py +27 -67
- arize/pre_releases.py +10 -6
- arize/projects/client.py +9 -4
- arize/py.typed +0 -0
- arize/regions.py +11 -11
- arize/spans/client.py +125 -31
- arize/spans/columns.py +32 -36
- arize/spans/conversion.py +12 -11
- arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
- arize/spans/validation/annotations/value_validation.py +11 -14
- arize/spans/validation/common/argument_validation.py +3 -3
- arize/spans/validation/common/dataframe_form_validation.py +7 -7
- arize/spans/validation/common/value_validation.py +11 -14
- arize/spans/validation/evals/dataframe_form_validation.py +4 -4
- arize/spans/validation/evals/evals_validation.py +6 -6
- arize/spans/validation/evals/value_validation.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +1 -1
- arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
- arize/spans/validation/metadata/value_validation.py +23 -1
- arize/spans/validation/spans/dataframe_form_validation.py +2 -2
- arize/spans/validation/spans/spans_validation.py +6 -6
- arize/utils/arrow.py +38 -2
- arize/utils/cache.py +2 -2
- arize/utils/dataframe.py +4 -4
- arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
- arize/utils/openinference_conversion.py +10 -10
- arize/utils/proto.py +0 -1
- arize/utils/types.py +6 -6
- arize/version.py +1 -1
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
arize/exceptions/parameters.py
CHANGED
|
@@ -3,159 +3,6 @@
|
|
|
3
3
|
from arize.constants.ml import MAX_NUMBER_OF_EMBEDDINGS
|
|
4
4
|
from arize.exceptions.base import ValidationError
|
|
5
5
|
|
|
6
|
-
# class MissingPredictionIdColumnForDelayedRecords(ValidationError):
|
|
7
|
-
# def __repr__(self) -> str:
|
|
8
|
-
# return "Missing_Prediction_Id_Column_For_Delayed_Records"
|
|
9
|
-
#
|
|
10
|
-
# def __init__(self, has_actual_info, has_feature_importance_info) -> None:
|
|
11
|
-
# self.has_actual_info = has_actual_info
|
|
12
|
-
# self.has_feature_importance_info = has_feature_importance_info
|
|
13
|
-
#
|
|
14
|
-
# def error_message(self) -> str:
|
|
15
|
-
# actual = "actual" if self.has_actual_info else ""
|
|
16
|
-
# feat_imp = (
|
|
17
|
-
# "feature importance" if self.has_feature_importance_info else ""
|
|
18
|
-
# )
|
|
19
|
-
# if self.has_actual_info and self.has_feature_importance_info:
|
|
20
|
-
# msg = " and ".join([actual, feat_imp])
|
|
21
|
-
# else:
|
|
22
|
-
# msg = "".join([actual, feat_imp])
|
|
23
|
-
#
|
|
24
|
-
# return (
|
|
25
|
-
# "Missing 'prediction_id_column_name'. While prediction id is optional for most cases, "
|
|
26
|
-
# "it is required when sending delayed actuals, i.e. when sending actual or feature importances "
|
|
27
|
-
# f"without predictions. In this case, {msg} information was found (without predictions). "
|
|
28
|
-
# "To learn more about delayed joins, please see the docs at "
|
|
29
|
-
# "https://docs.arize.com/arize/sending-data-guides/how-to-send-delayed-actuals"
|
|
30
|
-
# )
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
# class MissingColumns(ValidationError):
|
|
34
|
-
# def __repr__(self) -> str:
|
|
35
|
-
# return "Missing_Columns"
|
|
36
|
-
#
|
|
37
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
38
|
-
# self.missing_cols = set(cols)
|
|
39
|
-
#
|
|
40
|
-
# def error_message(self) -> str:
|
|
41
|
-
# return (
|
|
42
|
-
# "The following columns are declared in the schema "
|
|
43
|
-
# "but are not found in the dataframe: "
|
|
44
|
-
# f"{', '.join(map(str, self.missing_cols))}."
|
|
45
|
-
# )
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# class MissingRequiredColumnsMetricsValidation(ValidationError):
|
|
49
|
-
# """
|
|
50
|
-
# This error is used only for model mapping validations.
|
|
51
|
-
# """
|
|
52
|
-
#
|
|
53
|
-
# def __repr__(self) -> str:
|
|
54
|
-
# return "Missing_Columns_Required_By_Metrics_Validation"
|
|
55
|
-
#
|
|
56
|
-
# def __init__(
|
|
57
|
-
# self, model_type: ModelTypes, metrics: List[Metrics], cols: Iterable
|
|
58
|
-
# ) -> None:
|
|
59
|
-
# self.model_type = model_type
|
|
60
|
-
# self.metrics = metrics
|
|
61
|
-
# self.missing_cols = cols
|
|
62
|
-
#
|
|
63
|
-
# def error_message(self) -> str:
|
|
64
|
-
# return (
|
|
65
|
-
# f"For logging data for a {self.model_type.name} model with support for metrics "
|
|
66
|
-
# f"{', '.join(m.name for m in self.metrics)}, "
|
|
67
|
-
# f"schema must include: {', '.join(map(str, self.missing_cols))}."
|
|
68
|
-
# )
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# class ReservedColumns(ValidationError):
|
|
72
|
-
# def __repr__(self) -> str:
|
|
73
|
-
# return "Reserved_Columns"
|
|
74
|
-
#
|
|
75
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
76
|
-
# self.reserved_columns = cols
|
|
77
|
-
#
|
|
78
|
-
# def error_message(self) -> str:
|
|
79
|
-
# return (
|
|
80
|
-
# "The following columns are reserved and can only be specified "
|
|
81
|
-
# "in the proper fields of the schema: "
|
|
82
|
-
# f"{', '.join(map(str, self.reserved_columns))}."
|
|
83
|
-
# )
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
# class InvalidModelTypeAndMetricsCombination(ValidationError):
|
|
87
|
-
# """
|
|
88
|
-
# This error is used only for model mapping validations.
|
|
89
|
-
# """
|
|
90
|
-
#
|
|
91
|
-
# def __repr__(self) -> str:
|
|
92
|
-
# return "Invalid_ModelType_And_Metrics_Combination"
|
|
93
|
-
#
|
|
94
|
-
# def __init__(
|
|
95
|
-
# self,
|
|
96
|
-
# model_type: ModelTypes,
|
|
97
|
-
# metrics: List[Metrics],
|
|
98
|
-
# suggested_model_metric_combinations: List[List[str]],
|
|
99
|
-
# ) -> None:
|
|
100
|
-
# self.model_type = model_type
|
|
101
|
-
# self.metrics = metrics
|
|
102
|
-
# self.suggested_combinations = suggested_model_metric_combinations
|
|
103
|
-
#
|
|
104
|
-
# def error_message(self) -> str:
|
|
105
|
-
# valid_combos = ", or \n".join(
|
|
106
|
-
# "[" + ", ".join(combo) + "]"
|
|
107
|
-
# for combo in self.suggested_combinations
|
|
108
|
-
# )
|
|
109
|
-
# return (
|
|
110
|
-
# f"Invalid combination of model type {self.model_type.name} and metrics: "
|
|
111
|
-
# f"{', '.join(m.name for m in self.metrics)}. "
|
|
112
|
-
# f"Valid Metric combinations for this model type:\n{valid_combos}.\n"
|
|
113
|
-
# )
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# class InvalidShapSuffix(ValidationError):
|
|
117
|
-
# def __repr__(self) -> str:
|
|
118
|
-
# return "Invalid_SHAP_Suffix"
|
|
119
|
-
#
|
|
120
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
121
|
-
# self.invalid_column_names = cols
|
|
122
|
-
#
|
|
123
|
-
# def error_message(self) -> str:
|
|
124
|
-
# return (
|
|
125
|
-
# "The following features or tags must not be named with a `_shap` suffix: "
|
|
126
|
-
# f"{', '.join(map(str, self.invalid_column_names))}."
|
|
127
|
-
# )
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# class InvalidModelType(ValidationError):
|
|
131
|
-
# def __repr__(self) -> str:
|
|
132
|
-
# return "Invalid_Model_Type"
|
|
133
|
-
#
|
|
134
|
-
# def error_message(self) -> str:
|
|
135
|
-
# return (
|
|
136
|
-
# "Model type not valid. Choose one of the following: "
|
|
137
|
-
# f"{', '.join('ModelTypes.' + mt.name for mt in ModelTypes)}. "
|
|
138
|
-
# )
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
# class InvalidEnvironment(ValidationError):
|
|
142
|
-
# def __repr__(self) -> str:
|
|
143
|
-
# return "Invalid_Environment"
|
|
144
|
-
#
|
|
145
|
-
# def error_message(self) -> str:
|
|
146
|
-
# return (
|
|
147
|
-
# "Environment not valid. Choose one of the following: "
|
|
148
|
-
# f"{', '.join('Environments.' + env.name for env in Environments)}. "
|
|
149
|
-
# )
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
# class InvalidBatchId(ValidationError):
|
|
153
|
-
# def __repr__(self) -> str:
|
|
154
|
-
# return "Invalid_Batch_ID"
|
|
155
|
-
#
|
|
156
|
-
# def error_message(self) -> str:
|
|
157
|
-
# return "Batch ID must be a nonempty string if logging to validation environment."
|
|
158
|
-
|
|
159
6
|
|
|
160
7
|
class InvalidModelVersion(ValidationError):
|
|
161
8
|
"""Raised when model version is empty or invalid."""
|
|
@@ -169,14 +16,6 @@ class InvalidModelVersion(ValidationError):
|
|
|
169
16
|
return "Model version must be a nonempty string."
|
|
170
17
|
|
|
171
18
|
|
|
172
|
-
# class InvalidModelId(ValidationError):
|
|
173
|
-
# def __repr__(self) -> str:
|
|
174
|
-
# return "Invalid_Model_ID"
|
|
175
|
-
#
|
|
176
|
-
# def error_message(self) -> str:
|
|
177
|
-
# return "Model ID must be a nonempty string."
|
|
178
|
-
|
|
179
|
-
|
|
180
19
|
class InvalidProjectName(ValidationError):
|
|
181
20
|
"""Raised when project name is empty or invalid."""
|
|
182
21
|
|
|
@@ -193,174 +32,6 @@ class InvalidProjectName(ValidationError):
|
|
|
193
32
|
)
|
|
194
33
|
|
|
195
34
|
|
|
196
|
-
# class MissingPredActShap(ValidationError):
|
|
197
|
-
# def __repr__(self) -> str:
|
|
198
|
-
# return "Missing_Pred_or_Act_or_SHAP"
|
|
199
|
-
#
|
|
200
|
-
# def error_message(self) -> str:
|
|
201
|
-
# return (
|
|
202
|
-
# "The schema must specify at least one of the following: "
|
|
203
|
-
# "prediction label, actual label, or SHAP value column names"
|
|
204
|
-
# )
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
# class MissingPreprodPredAct(ValidationError):
|
|
208
|
-
# def __repr__(self) -> str:
|
|
209
|
-
# return "Missing_Preproduction_Pred_and_Act"
|
|
210
|
-
#
|
|
211
|
-
# def error_message(self) -> str:
|
|
212
|
-
# return "For logging pre-production data, the schema must specify both "
|
|
213
|
-
# "prediction and actual label columns."
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
# class MissingPreprodAct(ValidationError):
|
|
217
|
-
# def __repr__(self) -> str:
|
|
218
|
-
# return "Missing_Preproduction_Act"
|
|
219
|
-
#
|
|
220
|
-
# def error_message(self) -> str:
|
|
221
|
-
# return "For logging pre-production data, the schema must specify actual label column."
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
# class MissingPreprodPredActNumericAndCategorical(ValidationError):
|
|
225
|
-
# def __repr__(self) -> str:
|
|
226
|
-
# return "Missing_Preproduction_Pred_and_Act_Numeric_and_Categorical"
|
|
227
|
-
#
|
|
228
|
-
# def error_message(self) -> str:
|
|
229
|
-
# return (
|
|
230
|
-
# "For logging pre-production data for a numeric or a categorical model, "
|
|
231
|
-
# "the schema must specify both prediction and actual label or score columns."
|
|
232
|
-
# )
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
# class MissingRequiredColumnsForRankingModel(ValidationError):
|
|
236
|
-
# def __repr__(self) -> str:
|
|
237
|
-
# return "Missing_Required_Columns_For_Ranking_Model"
|
|
238
|
-
#
|
|
239
|
-
# def error_message(self) -> str:
|
|
240
|
-
# return (
|
|
241
|
-
# "For logging data for a ranking model, schema must specify: "
|
|
242
|
-
# "prediction_group_id_column_name and rank_column_name"
|
|
243
|
-
# )
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
# class MissingCVPredAct(ValidationError):
|
|
247
|
-
# def __repr__(self) -> str:
|
|
248
|
-
# return "Missing_CV_Prediction_or_Actual"
|
|
249
|
-
#
|
|
250
|
-
# def __init__(self, environment: Environments):
|
|
251
|
-
# self.environment = environment
|
|
252
|
-
#
|
|
253
|
-
# def error_message(self) -> str:
|
|
254
|
-
# if self.environment in (Environments.TRAINING, Environments.VALIDATION):
|
|
255
|
-
# env = "pre-production"
|
|
256
|
-
# opt = "and"
|
|
257
|
-
# elif self.environment == Environments.PRODUCTION:
|
|
258
|
-
# env = "production"
|
|
259
|
-
# opt = "or"
|
|
260
|
-
# else:
|
|
261
|
-
# raise TypeError("Invalid environment")
|
|
262
|
-
# return (
|
|
263
|
-
# f"For logging {env} data for an Object Detection model,"
|
|
264
|
-
# "the schema must specify one of: "
|
|
265
|
-
# f"('object_detection_prediction_column_names' {opt} "
|
|
266
|
-
# f"'object_detection_actual_column_names') "
|
|
267
|
-
# f"or ('semantic_segmentation_prediction_column_names' {opt} "
|
|
268
|
-
# f"'semantic_segmentation_actual_column_names') "
|
|
269
|
-
# f"or ('instance_segmentation_prediction_column_names' {opt} "
|
|
270
|
-
# f"'instance_segmentation_actual_column_names')"
|
|
271
|
-
# )
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
# class MultipleCVPredAct(ValidationError):
|
|
275
|
-
# def __repr__(self) -> str:
|
|
276
|
-
# return "Multiple_CV_Prediction_or_Actual"
|
|
277
|
-
#
|
|
278
|
-
# def __init__(self, environment: Environments):
|
|
279
|
-
# self.environment = environment
|
|
280
|
-
#
|
|
281
|
-
# def error_message(self) -> str:
|
|
282
|
-
# return (
|
|
283
|
-
# "The schema must only specify one of the following: "
|
|
284
|
-
# "'object_detection_prediction_column_names'/'object_detection_actual_column_names'"
|
|
285
|
-
# "'semantic_segmentation_prediction_column_names'/'semantic_segmentation_actual_column_names'"
|
|
286
|
-
# "'instance_segmentation_prediction_column_names'/'instance_segmentation_actual_column_names'"
|
|
287
|
-
# )
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# class InvalidPredActCVColumnNamesForModelType(ValidationError):
|
|
291
|
-
# def __repr__(self) -> str:
|
|
292
|
-
# return "Invalid_CV_Prediction_or_Actual_Column_Names_for_Model_Type"
|
|
293
|
-
#
|
|
294
|
-
# def __init__(
|
|
295
|
-
# self,
|
|
296
|
-
# invalid_model_type: ModelTypes,
|
|
297
|
-
# ) -> None:
|
|
298
|
-
# self.invalid_model_type = invalid_model_type
|
|
299
|
-
#
|
|
300
|
-
# def error_message(self) -> str:
|
|
301
|
-
# return (
|
|
302
|
-
# f"Cannot use 'object_detection_prediction_column_names' or "
|
|
303
|
-
# f"'object_detection_actual_column_names' or "
|
|
304
|
-
# f"'semantic_segmentation_prediction_column_names' or "
|
|
305
|
-
# f"'semantic_segmentation_actual_column_names' or "
|
|
306
|
-
# f"'instance_segmentation_prediction_column_names' or "
|
|
307
|
-
# f"'instance_segmentation_actual_column_names' for {self.invalid_model_type} model "
|
|
308
|
-
# f"type. They are only allowed for ModelTypes.OBJECT_DETECTION models"
|
|
309
|
-
# )
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
# class MissingReqPredActColumnNamesForMultiClass(ValidationError):
|
|
313
|
-
# def __repr__(self) -> str:
|
|
314
|
-
# return "Missing_Required_Prediction_or_Actual_Column_Names_for_Multi_Class_Model_Type"
|
|
315
|
-
#
|
|
316
|
-
# def error_message(self) -> str:
|
|
317
|
-
# return (
|
|
318
|
-
# "For logging data for a multi class model, schema must specify: "
|
|
319
|
-
# "prediction_scores_column_name and/or actual_score_column_name. "
|
|
320
|
-
# "Optionally, you may include multi_class_threshold_scores_column_name"
|
|
321
|
-
# " (must include prediction_scores_column_name)"
|
|
322
|
-
# )
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
# class InvalidPredActColumnNamesForModelType(ValidationError):
|
|
326
|
-
# def __repr__(self) -> str:
|
|
327
|
-
# return "Invalid_Prediction_or_Actual_Column_Names_for_Model_Type"
|
|
328
|
-
#
|
|
329
|
-
# def __init__(
|
|
330
|
-
# self,
|
|
331
|
-
# invalid_model_type: ModelTypes,
|
|
332
|
-
# allowed_fields: List[str],
|
|
333
|
-
# wrong_columns: List[str],
|
|
334
|
-
# ) -> None:
|
|
335
|
-
# self.invalid_model_type = invalid_model_type
|
|
336
|
-
# self.allowed_fields = allowed_fields
|
|
337
|
-
# self.wrong_columns = wrong_columns
|
|
338
|
-
#
|
|
339
|
-
# def error_message(self) -> str:
|
|
340
|
-
# allowed_col_msg = ""
|
|
341
|
-
# if self.allowed_fields is not None:
|
|
342
|
-
# allowed_col_msg = f" Allowed Schema fields are {log_a_list(self.allowed_fields, 'and')}"
|
|
343
|
-
# return (
|
|
344
|
-
# f"Invalid Schema fields for {self.invalid_model_type} model type. {allowed_col_msg}"
|
|
345
|
-
# "The following columns of your dataframe are sent as an invalid schema field: "
|
|
346
|
-
# f"{log_a_list(self.wrong_columns, 'and')}"
|
|
347
|
-
# )
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
# class DuplicateColumnsInDataframe(ValidationError):
|
|
351
|
-
# def __repr__(self) -> str:
|
|
352
|
-
# return "Duplicate_Columns_In_Dataframe"
|
|
353
|
-
#
|
|
354
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
355
|
-
# self.duplicate_cols = cols
|
|
356
|
-
#
|
|
357
|
-
# def error_message(self) -> str:
|
|
358
|
-
# return (
|
|
359
|
-
# "The following columns are present in the schema and have duplicates in the dataframe: "
|
|
360
|
-
# f"{self.duplicate_cols}. "
|
|
361
|
-
# )
|
|
362
|
-
|
|
363
|
-
|
|
364
35
|
class InvalidNumberOfEmbeddings(ValidationError):
|
|
365
36
|
"""Raised when number of embedding features exceeds the maximum allowed."""
|
|
366
37
|
|
|
@@ -390,7 +61,7 @@ class InvalidValueType(Exception):
|
|
|
390
61
|
def __init__(
|
|
391
62
|
self,
|
|
392
63
|
value_name: str,
|
|
393
|
-
value:
|
|
64
|
+
value: object,
|
|
394
65
|
correct_type: str,
|
|
395
66
|
) -> None:
|
|
396
67
|
"""Initialize the exception with value type validation context.
|
arize/exceptions/values.py
CHANGED
|
@@ -533,14 +533,15 @@ class InvalidMultiClassClassNameLength(ValidationError):
|
|
|
533
533
|
err_msg = ""
|
|
534
534
|
for col, class_names in self.invalid_col_class_name.items():
|
|
535
535
|
# limit to 10
|
|
536
|
-
|
|
536
|
+
class_names_list = (
|
|
537
537
|
list(class_names)[:10]
|
|
538
538
|
if len(class_names) > 10
|
|
539
539
|
else list(class_names)
|
|
540
540
|
)
|
|
541
541
|
err_msg += (
|
|
542
|
-
f"Found some invalid class names: {log_a_list(
|
|
543
|
-
f" names must have at least one character and
|
|
542
|
+
f"Found some invalid class names: {log_a_list(class_names_list, 'and')} "
|
|
543
|
+
f"in the {col} column. Class names must have at least one character and "
|
|
544
|
+
f"less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
|
|
544
545
|
)
|
|
545
546
|
return err_msg
|
|
546
547
|
|
|
@@ -565,9 +566,11 @@ class InvalidMultiClassPredScoreValue(ValidationError):
|
|
|
565
566
|
err_msg = ""
|
|
566
567
|
for col, scores in self.invalid_col_class_scores.items():
|
|
567
568
|
# limit to 10
|
|
568
|
-
|
|
569
|
+
scores_list = (
|
|
570
|
+
list(scores)[:10] if len(scores) > 10 else list(scores)
|
|
571
|
+
)
|
|
569
572
|
err_msg += (
|
|
570
|
-
f"Found some invalid scores: {log_a_list(
|
|
573
|
+
f"Found some invalid scores: {log_a_list(scores_list, 'and')} in the {col} column that was "
|
|
571
574
|
"invalid. All scores (values in dictionary) must be between 0 and 1, inclusive. \n"
|
|
572
575
|
)
|
|
573
576
|
return err_msg
|
arize/experiments/__init__.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
"""Experiment tracking and evaluation functionality for the Arize SDK."""
|
|
2
2
|
|
|
3
|
+
from arize.experiments.evaluators.base import (
|
|
4
|
+
Evaluator,
|
|
5
|
+
)
|
|
3
6
|
from arize.experiments.evaluators.types import (
|
|
4
7
|
EvaluationResult,
|
|
5
8
|
EvaluationResultFieldNames,
|
|
@@ -9,5 +12,6 @@ from arize.experiments.types import ExperimentTaskFieldNames
|
|
|
9
12
|
__all__ = [
|
|
10
13
|
"EvaluationResult",
|
|
11
14
|
"EvaluationResultFieldNames",
|
|
15
|
+
"Evaluator",
|
|
12
16
|
"ExperimentTaskFieldNames",
|
|
13
17
|
]
|
arize/experiments/client.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING, cast
|
|
7
7
|
|
|
8
8
|
import opentelemetry.sdk.trace as trace_sdk
|
|
9
9
|
import pandas as pd
|
|
@@ -36,8 +36,13 @@ from arize.utils.openinference_conversion import (
|
|
|
36
36
|
from arize.utils.size import get_payload_size_mb
|
|
37
37
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
|
+
# builtins is needed to use builtins.list in type annotations because
|
|
40
|
+
# the class has a list() method that shadows the built-in list type
|
|
41
|
+
import builtins
|
|
42
|
+
|
|
39
43
|
from opentelemetry.trace import Tracer
|
|
40
44
|
|
|
45
|
+
from arize._generated.api_client.api_client import ApiClient
|
|
41
46
|
from arize.config import SDKConfiguration
|
|
42
47
|
from arize.experiments.evaluators.base import Evaluators
|
|
43
48
|
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
@@ -61,20 +66,22 @@ class ExperimentsClient:
|
|
|
61
66
|
:class:`arize.config.SDKConfiguration`.
|
|
62
67
|
"""
|
|
63
68
|
|
|
64
|
-
def __init__(
|
|
69
|
+
def __init__(
|
|
70
|
+
self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
|
|
71
|
+
) -> None:
|
|
65
72
|
"""
|
|
66
73
|
Args:
|
|
67
74
|
sdk_config: Resolved SDK configuration.
|
|
75
|
+
generated_client: Shared generated API client instance.
|
|
68
76
|
""" # noqa: D205, D212
|
|
69
77
|
self._sdk_config = sdk_config
|
|
70
78
|
from arize._generated import api_client as gen
|
|
71
79
|
|
|
72
|
-
|
|
80
|
+
# Use the provided client directly for both APIs
|
|
81
|
+
self._api = gen.ExperimentsApi(generated_client)
|
|
73
82
|
# TODO(Kiko): Space ID should not be needed,
|
|
74
83
|
# should work on server tech debt to remove this
|
|
75
|
-
self._datasets_api = gen.DatasetsApi(
|
|
76
|
-
self._sdk_config.get_generated_client()
|
|
77
|
-
)
|
|
84
|
+
self._datasets_api = gen.DatasetsApi(generated_client)
|
|
78
85
|
|
|
79
86
|
@prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
|
|
80
87
|
def list(
|
|
@@ -113,7 +120,7 @@ class ExperimentsClient:
|
|
|
113
120
|
*,
|
|
114
121
|
name: str,
|
|
115
122
|
dataset_id: str,
|
|
116
|
-
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
123
|
+
experiment_runs: builtins.list[dict[str, object]] | pd.DataFrame,
|
|
117
124
|
task_fields: ExperimentTaskFieldNames,
|
|
118
125
|
evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
|
|
119
126
|
force_http: bool = False,
|
|
@@ -141,7 +148,7 @@ class ExperimentsClient:
|
|
|
141
148
|
dataset_id: Dataset ID to attach the experiment to.
|
|
142
149
|
experiment_runs: Experiment runs either as:
|
|
143
150
|
- a list of JSON-like dicts, or
|
|
144
|
-
- a pandas
|
|
151
|
+
- a :class:`pandas.DataFrame`.
|
|
145
152
|
task_fields: Mapping that identifies the columns/fields containing the
|
|
146
153
|
task results (e.g. `example_id`, output fields).
|
|
147
154
|
evaluator_columns: Optional mapping describing evaluator result columns.
|
|
@@ -178,7 +185,7 @@ class ExperimentsClient:
|
|
|
178
185
|
body = gen.ExperimentsCreateRequest(
|
|
179
186
|
name=name,
|
|
180
187
|
dataset_id=dataset_id,
|
|
181
|
-
experiment_runs=data,
|
|
188
|
+
experiment_runs=cast("list[gen.ExperimentRunCreate]", data),
|
|
182
189
|
)
|
|
183
190
|
return self._api.experiments_create(experiments_create_request=body)
|
|
184
191
|
|
|
@@ -229,7 +236,8 @@ class ExperimentsClient:
|
|
|
229
236
|
Args:
|
|
230
237
|
experiment_id: Experiment ID to delete.
|
|
231
238
|
|
|
232
|
-
Returns:
|
|
239
|
+
Returns:
|
|
240
|
+
This method returns None on success (common empty 204 response).
|
|
233
241
|
|
|
234
242
|
Raises:
|
|
235
243
|
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
@@ -299,7 +307,10 @@ class ExperimentsClient:
|
|
|
299
307
|
)
|
|
300
308
|
if experiment_df is not None:
|
|
301
309
|
return models.ExperimentsRunsList200Response(
|
|
302
|
-
|
|
310
|
+
experiment_runs=cast(
|
|
311
|
+
"list[models.ExperimentRun]",
|
|
312
|
+
experiment_df.to_dict(orient="records"),
|
|
313
|
+
),
|
|
303
314
|
pagination=models.PaginationMetadata(
|
|
304
315
|
has_more=False, # Note that all=True
|
|
305
316
|
),
|
|
@@ -339,7 +350,10 @@ class ExperimentsClient:
|
|
|
339
350
|
)
|
|
340
351
|
|
|
341
352
|
return models.ExperimentsRunsList200Response(
|
|
342
|
-
|
|
353
|
+
experiment_runs=cast(
|
|
354
|
+
"list[models.ExperimentRun]",
|
|
355
|
+
experiment_df.to_dict(orient="records"),
|
|
356
|
+
),
|
|
343
357
|
pagination=models.PaginationMetadata(
|
|
344
358
|
has_more=False, # Note that all=True
|
|
345
359
|
),
|
|
@@ -357,6 +371,7 @@ class ExperimentsClient:
|
|
|
357
371
|
concurrency: int = 3,
|
|
358
372
|
set_global_tracer_provider: bool = False,
|
|
359
373
|
exit_on_error: bool = False,
|
|
374
|
+
timeout: int = 120,
|
|
360
375
|
) -> tuple[models.Experiment | None, pd.DataFrame]:
|
|
361
376
|
"""Run an experiment on a dataset and optionally upload results.
|
|
362
377
|
|
|
@@ -387,6 +402,7 @@ class ExperimentsClient:
|
|
|
387
402
|
provider for the experiment run.
|
|
388
403
|
exit_on_error: If True, stop on the first error encountered during
|
|
389
404
|
execution.
|
|
405
|
+
timeout: The timeout in seconds for each task execution. Defaults to 120.
|
|
390
406
|
|
|
391
407
|
Returns:
|
|
392
408
|
If `dry_run=True`, returns `(None, results_df)`.
|
|
@@ -505,6 +521,7 @@ class ExperimentsClient:
|
|
|
505
521
|
evaluators=evaluators,
|
|
506
522
|
concurrency=concurrency,
|
|
507
523
|
exit_on_error=exit_on_error,
|
|
524
|
+
timeout=timeout,
|
|
508
525
|
)
|
|
509
526
|
output_df = convert_default_columns_to_json_str(output_df)
|
|
510
527
|
output_df = convert_boolean_columns_to_str(output_df)
|
|
@@ -546,9 +563,7 @@ class ExperimentsClient:
|
|
|
546
563
|
logger.error(msg)
|
|
547
564
|
raise RuntimeError(msg)
|
|
548
565
|
|
|
549
|
-
experiment = self.get(
|
|
550
|
-
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
551
|
-
)
|
|
566
|
+
experiment = self.get(experiment_id=str(post_resp.experiment_id))
|
|
552
567
|
return experiment, output_df
|
|
553
568
|
|
|
554
569
|
def _create_experiment_via_flight(
|
|
@@ -629,9 +644,7 @@ class ExperimentsClient:
|
|
|
629
644
|
logger.error(msg)
|
|
630
645
|
raise RuntimeError(msg)
|
|
631
646
|
|
|
632
|
-
return self.get(
|
|
633
|
-
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
634
|
-
)
|
|
647
|
+
return self.get(experiment_id=str(post_resp.experiment_id))
|
|
635
648
|
|
|
636
649
|
|
|
637
650
|
def _get_tracer_resource(
|
|
@@ -7,7 +7,7 @@ import inspect
|
|
|
7
7
|
from abc import ABC
|
|
8
8
|
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
9
9
|
from types import MappingProxyType
|
|
10
|
-
from typing import TYPE_CHECKING
|
|
10
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
11
11
|
|
|
12
12
|
from arize.experiments.evaluators.types import (
|
|
13
13
|
AnnotatorKind,
|
|
@@ -79,10 +79,10 @@ class Evaluator(ABC):
|
|
|
79
79
|
method and the asynchronous `async_evaluate` method, but it is not required.
|
|
80
80
|
|
|
81
81
|
Args:
|
|
82
|
-
dataset_row (
|
|
82
|
+
dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
|
|
83
83
|
input (ExampleInput): The input provided for evaluation.
|
|
84
|
-
output (
|
|
85
|
-
experiment_output (
|
|
84
|
+
output (TaskOutput | :obj:`None`): The output produced by the task.
|
|
85
|
+
experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
|
|
86
86
|
dataset_output (ExampleOutput): The expected output from the dataset.
|
|
87
87
|
metadata (ExampleMetadata): Metadata associated with the example.
|
|
88
88
|
**kwargs (Any): Additional keyword arguments.
|
|
@@ -112,10 +112,10 @@ class Evaluator(ABC):
|
|
|
112
112
|
method and the synchronous `evaluate` method, but it is not required.
|
|
113
113
|
|
|
114
114
|
Args:
|
|
115
|
-
dataset_row (
|
|
115
|
+
dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
|
|
116
116
|
input (ExampleInput): The input provided for evaluation.
|
|
117
|
-
output (
|
|
118
|
-
experiment_output (
|
|
117
|
+
output (TaskOutput | :obj:`None`): The output produced by the task.
|
|
118
|
+
experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
|
|
119
119
|
dataset_output (ExampleOutput): The expected output from the dataset.
|
|
120
120
|
metadata (ExampleMetadata): Metadata associated with the example.
|
|
121
121
|
**kwargs (Any): Additional keyword arguments.
|
|
@@ -162,7 +162,9 @@ class Evaluator(ABC):
|
|
|
162
162
|
f"`evaluate()` method should be callable, got {type(evaluate)}"
|
|
163
163
|
)
|
|
164
164
|
# need to remove the first param, i.e. `self`
|
|
165
|
-
_validate_sig(
|
|
165
|
+
_validate_sig(
|
|
166
|
+
functools.partial(evaluate, cast("Any", None)), "evaluate"
|
|
167
|
+
)
|
|
166
168
|
return
|
|
167
169
|
if async_evaluate := super_cls.__dict__.get(
|
|
168
170
|
Evaluator.async_evaluate.__name__
|
|
@@ -175,7 +177,8 @@ class Evaluator(ABC):
|
|
|
175
177
|
)
|
|
176
178
|
# need to remove the first param, i.e. `self`
|
|
177
179
|
_validate_sig(
|
|
178
|
-
functools.partial(async_evaluate, None),
|
|
180
|
+
functools.partial(async_evaluate, cast("Any", None)),
|
|
181
|
+
"async_evaluate",
|
|
179
182
|
)
|
|
180
183
|
return
|
|
181
184
|
raise ValueError(
|