arize 8.0.0b0__py3-none-any.whl → 8.0.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +1 -1
- arize/_client_factory.py +50 -0
- arize/_flight/client.py +4 -4
- arize/_generated/api_client/__init__.py +0 -2
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +6 -6
- arize/_generated/api_client/api/projects_api.py +3 -3
- arize/_generated/api_client/models/__init__.py +0 -1
- arize/_generated/api_client/models/datasets_create_request.py +2 -10
- arize/_generated/api_client/models/datasets_examples_insert_request.py +2 -10
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_insert_request.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +2 -6
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +2 -6
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +2 -6
- arize/_generated/api_client_README.md +0 -1
- arize/_lazy.py +25 -9
- arize/client.py +16 -52
- arize/config.py +9 -36
- arize/constants/ml.py +9 -16
- arize/constants/spans.py +5 -10
- arize/datasets/client.py +13 -9
- arize/datasets/errors.py +1 -1
- arize/datasets/validation.py +2 -2
- arize/embeddings/auto_generator.py +2 -2
- arize/embeddings/errors.py +2 -2
- arize/embeddings/tabular_generators.py +1 -1
- arize/exceptions/base.py +0 -52
- arize/exceptions/parameters.py +0 -329
- arize/experiments/__init__.py +2 -2
- arize/experiments/client.py +16 -10
- arize/experiments/evaluators/base.py +6 -6
- arize/experiments/evaluators/executors.py +10 -3
- arize/experiments/evaluators/types.py +2 -2
- arize/experiments/functions.py +24 -17
- arize/experiments/types.py +6 -8
- arize/logging.py +1 -1
- arize/ml/batch_validation/errors.py +10 -1004
- arize/ml/batch_validation/validator.py +273 -225
- arize/ml/casting.py +7 -7
- arize/ml/client.py +12 -11
- arize/ml/proto.py +6 -6
- arize/ml/stream_validation.py +2 -3
- arize/ml/surrogate_explainer/mimic.py +3 -3
- arize/ml/types.py +1 -55
- arize/pre_releases.py +6 -3
- arize/projects/client.py +9 -4
- arize/regions.py +2 -2
- arize/spans/client.py +14 -12
- arize/spans/columns.py +32 -36
- arize/spans/conversion.py +5 -6
- arize/spans/validation/common/argument_validation.py +3 -3
- arize/spans/validation/common/dataframe_form_validation.py +6 -6
- arize/spans/validation/common/value_validation.py +1 -1
- arize/spans/validation/evals/dataframe_form_validation.py +4 -4
- arize/spans/validation/evals/evals_validation.py +6 -6
- arize/spans/validation/metadata/dataframe_form_validation.py +1 -1
- arize/spans/validation/spans/dataframe_form_validation.py +2 -2
- arize/spans/validation/spans/spans_validation.py +6 -6
- arize/utils/arrow.py +2 -2
- arize/utils/cache.py +2 -2
- arize/utils/dataframe.py +4 -4
- arize/utils/online_tasks/dataframe_preprocessor.py +7 -7
- arize/utils/openinference_conversion.py +10 -10
- arize/utils/proto.py +1 -1
- arize/version.py +1 -1
- {arize-8.0.0b0.dist-info → arize-8.0.0b2.dist-info}/METADATA +71 -63
- {arize-8.0.0b0.dist-info → arize-8.0.0b2.dist-info}/RECORD +72 -73
- arize/_generated/api_client/models/primitive_value.py +0 -172
- arize/_generated/api_client/test/test_primitive_value.py +0 -50
- {arize-8.0.0b0.dist-info → arize-8.0.0b2.dist-info}/WHEEL +0 -0
- {arize-8.0.0b0.dist-info → arize-8.0.0b2.dist-info}/licenses/LICENSE +0 -0
- {arize-8.0.0b0.dist-info → arize-8.0.0b2.dist-info}/licenses/NOTICE +0 -0
arize/exceptions/parameters.py
CHANGED
|
@@ -3,159 +3,6 @@
|
|
|
3
3
|
from arize.constants.ml import MAX_NUMBER_OF_EMBEDDINGS
|
|
4
4
|
from arize.exceptions.base import ValidationError
|
|
5
5
|
|
|
6
|
-
# class MissingPredictionIdColumnForDelayedRecords(ValidationError):
|
|
7
|
-
# def __repr__(self) -> str:
|
|
8
|
-
# return "Missing_Prediction_Id_Column_For_Delayed_Records"
|
|
9
|
-
#
|
|
10
|
-
# def __init__(self, has_actual_info, has_feature_importance_info) -> None:
|
|
11
|
-
# self.has_actual_info = has_actual_info
|
|
12
|
-
# self.has_feature_importance_info = has_feature_importance_info
|
|
13
|
-
#
|
|
14
|
-
# def error_message(self) -> str:
|
|
15
|
-
# actual = "actual" if self.has_actual_info else ""
|
|
16
|
-
# feat_imp = (
|
|
17
|
-
# "feature importance" if self.has_feature_importance_info else ""
|
|
18
|
-
# )
|
|
19
|
-
# if self.has_actual_info and self.has_feature_importance_info:
|
|
20
|
-
# msg = " and ".join([actual, feat_imp])
|
|
21
|
-
# else:
|
|
22
|
-
# msg = "".join([actual, feat_imp])
|
|
23
|
-
#
|
|
24
|
-
# return (
|
|
25
|
-
# "Missing 'prediction_id_column_name'. While prediction id is optional for most cases, "
|
|
26
|
-
# "it is required when sending delayed actuals, i.e. when sending actual or feature importances "
|
|
27
|
-
# f"without predictions. In this case, {msg} information was found (without predictions). "
|
|
28
|
-
# "To learn more about delayed joins, please see the docs at "
|
|
29
|
-
# "https://docs.arize.com/arize/sending-data-guides/how-to-send-delayed-actuals"
|
|
30
|
-
# )
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
# class MissingColumns(ValidationError):
|
|
34
|
-
# def __repr__(self) -> str:
|
|
35
|
-
# return "Missing_Columns"
|
|
36
|
-
#
|
|
37
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
38
|
-
# self.missing_cols = set(cols)
|
|
39
|
-
#
|
|
40
|
-
# def error_message(self) -> str:
|
|
41
|
-
# return (
|
|
42
|
-
# "The following columns are declared in the schema "
|
|
43
|
-
# "but are not found in the dataframe: "
|
|
44
|
-
# f"{', '.join(map(str, self.missing_cols))}."
|
|
45
|
-
# )
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# class MissingRequiredColumnsMetricsValidation(ValidationError):
|
|
49
|
-
# """
|
|
50
|
-
# This error is used only for model mapping validations.
|
|
51
|
-
# """
|
|
52
|
-
#
|
|
53
|
-
# def __repr__(self) -> str:
|
|
54
|
-
# return "Missing_Columns_Required_By_Metrics_Validation"
|
|
55
|
-
#
|
|
56
|
-
# def __init__(
|
|
57
|
-
# self, model_type: ModelTypes, metrics: List[Metrics], cols: Iterable
|
|
58
|
-
# ) -> None:
|
|
59
|
-
# self.model_type = model_type
|
|
60
|
-
# self.metrics = metrics
|
|
61
|
-
# self.missing_cols = cols
|
|
62
|
-
#
|
|
63
|
-
# def error_message(self) -> str:
|
|
64
|
-
# return (
|
|
65
|
-
# f"For logging data for a {self.model_type.name} model with support for metrics "
|
|
66
|
-
# f"{', '.join(m.name for m in self.metrics)}, "
|
|
67
|
-
# f"schema must include: {', '.join(map(str, self.missing_cols))}."
|
|
68
|
-
# )
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# class ReservedColumns(ValidationError):
|
|
72
|
-
# def __repr__(self) -> str:
|
|
73
|
-
# return "Reserved_Columns"
|
|
74
|
-
#
|
|
75
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
76
|
-
# self.reserved_columns = cols
|
|
77
|
-
#
|
|
78
|
-
# def error_message(self) -> str:
|
|
79
|
-
# return (
|
|
80
|
-
# "The following columns are reserved and can only be specified "
|
|
81
|
-
# "in the proper fields of the schema: "
|
|
82
|
-
# f"{', '.join(map(str, self.reserved_columns))}."
|
|
83
|
-
# )
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
# class InvalidModelTypeAndMetricsCombination(ValidationError):
|
|
87
|
-
# """
|
|
88
|
-
# This error is used only for model mapping validations.
|
|
89
|
-
# """
|
|
90
|
-
#
|
|
91
|
-
# def __repr__(self) -> str:
|
|
92
|
-
# return "Invalid_ModelType_And_Metrics_Combination"
|
|
93
|
-
#
|
|
94
|
-
# def __init__(
|
|
95
|
-
# self,
|
|
96
|
-
# model_type: ModelTypes,
|
|
97
|
-
# metrics: List[Metrics],
|
|
98
|
-
# suggested_model_metric_combinations: List[List[str]],
|
|
99
|
-
# ) -> None:
|
|
100
|
-
# self.model_type = model_type
|
|
101
|
-
# self.metrics = metrics
|
|
102
|
-
# self.suggested_combinations = suggested_model_metric_combinations
|
|
103
|
-
#
|
|
104
|
-
# def error_message(self) -> str:
|
|
105
|
-
# valid_combos = ", or \n".join(
|
|
106
|
-
# "[" + ", ".join(combo) + "]"
|
|
107
|
-
# for combo in self.suggested_combinations
|
|
108
|
-
# )
|
|
109
|
-
# return (
|
|
110
|
-
# f"Invalid combination of model type {self.model_type.name} and metrics: "
|
|
111
|
-
# f"{', '.join(m.name for m in self.metrics)}. "
|
|
112
|
-
# f"Valid Metric combinations for this model type:\n{valid_combos}.\n"
|
|
113
|
-
# )
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# class InvalidShapSuffix(ValidationError):
|
|
117
|
-
# def __repr__(self) -> str:
|
|
118
|
-
# return "Invalid_SHAP_Suffix"
|
|
119
|
-
#
|
|
120
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
121
|
-
# self.invalid_column_names = cols
|
|
122
|
-
#
|
|
123
|
-
# def error_message(self) -> str:
|
|
124
|
-
# return (
|
|
125
|
-
# "The following features or tags must not be named with a `_shap` suffix: "
|
|
126
|
-
# f"{', '.join(map(str, self.invalid_column_names))}."
|
|
127
|
-
# )
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# class InvalidModelType(ValidationError):
|
|
131
|
-
# def __repr__(self) -> str:
|
|
132
|
-
# return "Invalid_Model_Type"
|
|
133
|
-
#
|
|
134
|
-
# def error_message(self) -> str:
|
|
135
|
-
# return (
|
|
136
|
-
# "Model type not valid. Choose one of the following: "
|
|
137
|
-
# f"{', '.join('ModelTypes.' + mt.name for mt in ModelTypes)}. "
|
|
138
|
-
# )
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
# class InvalidEnvironment(ValidationError):
|
|
142
|
-
# def __repr__(self) -> str:
|
|
143
|
-
# return "Invalid_Environment"
|
|
144
|
-
#
|
|
145
|
-
# def error_message(self) -> str:
|
|
146
|
-
# return (
|
|
147
|
-
# "Environment not valid. Choose one of the following: "
|
|
148
|
-
# f"{', '.join('Environments.' + env.name for env in Environments)}. "
|
|
149
|
-
# )
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
# class InvalidBatchId(ValidationError):
|
|
153
|
-
# def __repr__(self) -> str:
|
|
154
|
-
# return "Invalid_Batch_ID"
|
|
155
|
-
#
|
|
156
|
-
# def error_message(self) -> str:
|
|
157
|
-
# return "Batch ID must be a nonempty string if logging to validation environment."
|
|
158
|
-
|
|
159
6
|
|
|
160
7
|
class InvalidModelVersion(ValidationError):
|
|
161
8
|
"""Raised when model version is empty or invalid."""
|
|
@@ -169,14 +16,6 @@ class InvalidModelVersion(ValidationError):
|
|
|
169
16
|
return "Model version must be a nonempty string."
|
|
170
17
|
|
|
171
18
|
|
|
172
|
-
# class InvalidModelId(ValidationError):
|
|
173
|
-
# def __repr__(self) -> str:
|
|
174
|
-
# return "Invalid_Model_ID"
|
|
175
|
-
#
|
|
176
|
-
# def error_message(self) -> str:
|
|
177
|
-
# return "Model ID must be a nonempty string."
|
|
178
|
-
|
|
179
|
-
|
|
180
19
|
class InvalidProjectName(ValidationError):
|
|
181
20
|
"""Raised when project name is empty or invalid."""
|
|
182
21
|
|
|
@@ -193,174 +32,6 @@ class InvalidProjectName(ValidationError):
|
|
|
193
32
|
)
|
|
194
33
|
|
|
195
34
|
|
|
196
|
-
# class MissingPredActShap(ValidationError):
|
|
197
|
-
# def __repr__(self) -> str:
|
|
198
|
-
# return "Missing_Pred_or_Act_or_SHAP"
|
|
199
|
-
#
|
|
200
|
-
# def error_message(self) -> str:
|
|
201
|
-
# return (
|
|
202
|
-
# "The schema must specify at least one of the following: "
|
|
203
|
-
# "prediction label, actual label, or SHAP value column names"
|
|
204
|
-
# )
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
# class MissingPreprodPredAct(ValidationError):
|
|
208
|
-
# def __repr__(self) -> str:
|
|
209
|
-
# return "Missing_Preproduction_Pred_and_Act"
|
|
210
|
-
#
|
|
211
|
-
# def error_message(self) -> str:
|
|
212
|
-
# return "For logging pre-production data, the schema must specify both "
|
|
213
|
-
# "prediction and actual label columns."
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
# class MissingPreprodAct(ValidationError):
|
|
217
|
-
# def __repr__(self) -> str:
|
|
218
|
-
# return "Missing_Preproduction_Act"
|
|
219
|
-
#
|
|
220
|
-
# def error_message(self) -> str:
|
|
221
|
-
# return "For logging pre-production data, the schema must specify actual label column."
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
# class MissingPreprodPredActNumericAndCategorical(ValidationError):
|
|
225
|
-
# def __repr__(self) -> str:
|
|
226
|
-
# return "Missing_Preproduction_Pred_and_Act_Numeric_and_Categorical"
|
|
227
|
-
#
|
|
228
|
-
# def error_message(self) -> str:
|
|
229
|
-
# return (
|
|
230
|
-
# "For logging pre-production data for a numeric or a categorical model, "
|
|
231
|
-
# "the schema must specify both prediction and actual label or score columns."
|
|
232
|
-
# )
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
# class MissingRequiredColumnsForRankingModel(ValidationError):
|
|
236
|
-
# def __repr__(self) -> str:
|
|
237
|
-
# return "Missing_Required_Columns_For_Ranking_Model"
|
|
238
|
-
#
|
|
239
|
-
# def error_message(self) -> str:
|
|
240
|
-
# return (
|
|
241
|
-
# "For logging data for a ranking model, schema must specify: "
|
|
242
|
-
# "prediction_group_id_column_name and rank_column_name"
|
|
243
|
-
# )
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
# class MissingCVPredAct(ValidationError):
|
|
247
|
-
# def __repr__(self) -> str:
|
|
248
|
-
# return "Missing_CV_Prediction_or_Actual"
|
|
249
|
-
#
|
|
250
|
-
# def __init__(self, environment: Environments):
|
|
251
|
-
# self.environment = environment
|
|
252
|
-
#
|
|
253
|
-
# def error_message(self) -> str:
|
|
254
|
-
# if self.environment in (Environments.TRAINING, Environments.VALIDATION):
|
|
255
|
-
# env = "pre-production"
|
|
256
|
-
# opt = "and"
|
|
257
|
-
# elif self.environment == Environments.PRODUCTION:
|
|
258
|
-
# env = "production"
|
|
259
|
-
# opt = "or"
|
|
260
|
-
# else:
|
|
261
|
-
# raise TypeError("Invalid environment")
|
|
262
|
-
# return (
|
|
263
|
-
# f"For logging {env} data for an Object Detection model,"
|
|
264
|
-
# "the schema must specify one of: "
|
|
265
|
-
# f"('object_detection_prediction_column_names' {opt} "
|
|
266
|
-
# f"'object_detection_actual_column_names') "
|
|
267
|
-
# f"or ('semantic_segmentation_prediction_column_names' {opt} "
|
|
268
|
-
# f"'semantic_segmentation_actual_column_names') "
|
|
269
|
-
# f"or ('instance_segmentation_prediction_column_names' {opt} "
|
|
270
|
-
# f"'instance_segmentation_actual_column_names')"
|
|
271
|
-
# )
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
# class MultipleCVPredAct(ValidationError):
|
|
275
|
-
# def __repr__(self) -> str:
|
|
276
|
-
# return "Multiple_CV_Prediction_or_Actual"
|
|
277
|
-
#
|
|
278
|
-
# def __init__(self, environment: Environments):
|
|
279
|
-
# self.environment = environment
|
|
280
|
-
#
|
|
281
|
-
# def error_message(self) -> str:
|
|
282
|
-
# return (
|
|
283
|
-
# "The schema must only specify one of the following: "
|
|
284
|
-
# "'object_detection_prediction_column_names'/'object_detection_actual_column_names'"
|
|
285
|
-
# "'semantic_segmentation_prediction_column_names'/'semantic_segmentation_actual_column_names'"
|
|
286
|
-
# "'instance_segmentation_prediction_column_names'/'instance_segmentation_actual_column_names'"
|
|
287
|
-
# )
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# class InvalidPredActCVColumnNamesForModelType(ValidationError):
|
|
291
|
-
# def __repr__(self) -> str:
|
|
292
|
-
# return "Invalid_CV_Prediction_or_Actual_Column_Names_for_Model_Type"
|
|
293
|
-
#
|
|
294
|
-
# def __init__(
|
|
295
|
-
# self,
|
|
296
|
-
# invalid_model_type: ModelTypes,
|
|
297
|
-
# ) -> None:
|
|
298
|
-
# self.invalid_model_type = invalid_model_type
|
|
299
|
-
#
|
|
300
|
-
# def error_message(self) -> str:
|
|
301
|
-
# return (
|
|
302
|
-
# f"Cannot use 'object_detection_prediction_column_names' or "
|
|
303
|
-
# f"'object_detection_actual_column_names' or "
|
|
304
|
-
# f"'semantic_segmentation_prediction_column_names' or "
|
|
305
|
-
# f"'semantic_segmentation_actual_column_names' or "
|
|
306
|
-
# f"'instance_segmentation_prediction_column_names' or "
|
|
307
|
-
# f"'instance_segmentation_actual_column_names' for {self.invalid_model_type} model "
|
|
308
|
-
# f"type. They are only allowed for ModelTypes.OBJECT_DETECTION models"
|
|
309
|
-
# )
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
# class MissingReqPredActColumnNamesForMultiClass(ValidationError):
|
|
313
|
-
# def __repr__(self) -> str:
|
|
314
|
-
# return "Missing_Required_Prediction_or_Actual_Column_Names_for_Multi_Class_Model_Type"
|
|
315
|
-
#
|
|
316
|
-
# def error_message(self) -> str:
|
|
317
|
-
# return (
|
|
318
|
-
# "For logging data for a multi class model, schema must specify: "
|
|
319
|
-
# "prediction_scores_column_name and/or actual_score_column_name. "
|
|
320
|
-
# "Optionally, you may include multi_class_threshold_scores_column_name"
|
|
321
|
-
# " (must include prediction_scores_column_name)"
|
|
322
|
-
# )
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
# class InvalidPredActColumnNamesForModelType(ValidationError):
|
|
326
|
-
# def __repr__(self) -> str:
|
|
327
|
-
# return "Invalid_Prediction_or_Actual_Column_Names_for_Model_Type"
|
|
328
|
-
#
|
|
329
|
-
# def __init__(
|
|
330
|
-
# self,
|
|
331
|
-
# invalid_model_type: ModelTypes,
|
|
332
|
-
# allowed_fields: List[str],
|
|
333
|
-
# wrong_columns: List[str],
|
|
334
|
-
# ) -> None:
|
|
335
|
-
# self.invalid_model_type = invalid_model_type
|
|
336
|
-
# self.allowed_fields = allowed_fields
|
|
337
|
-
# self.wrong_columns = wrong_columns
|
|
338
|
-
#
|
|
339
|
-
# def error_message(self) -> str:
|
|
340
|
-
# allowed_col_msg = ""
|
|
341
|
-
# if self.allowed_fields is not None:
|
|
342
|
-
# allowed_col_msg = f" Allowed Schema fields are {log_a_list(self.allowed_fields, 'and')}"
|
|
343
|
-
# return (
|
|
344
|
-
# f"Invalid Schema fields for {self.invalid_model_type} model type. {allowed_col_msg}"
|
|
345
|
-
# "The following columns of your dataframe are sent as an invalid schema field: "
|
|
346
|
-
# f"{log_a_list(self.wrong_columns, 'and')}"
|
|
347
|
-
# )
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
# class DuplicateColumnsInDataframe(ValidationError):
|
|
351
|
-
# def __repr__(self) -> str:
|
|
352
|
-
# return "Duplicate_Columns_In_Dataframe"
|
|
353
|
-
#
|
|
354
|
-
# def __init__(self, cols: Iterable) -> None:
|
|
355
|
-
# self.duplicate_cols = cols
|
|
356
|
-
#
|
|
357
|
-
# def error_message(self) -> str:
|
|
358
|
-
# return (
|
|
359
|
-
# "The following columns are present in the schema and have duplicates in the dataframe: "
|
|
360
|
-
# f"{self.duplicate_cols}. "
|
|
361
|
-
# )
|
|
362
|
-
|
|
363
|
-
|
|
364
35
|
class InvalidNumberOfEmbeddings(ValidationError):
|
|
365
36
|
"""Raised when number of embedding features exceeds the maximum allowed."""
|
|
366
37
|
|
arize/experiments/__init__.py
CHANGED
|
@@ -4,10 +4,10 @@ from arize.experiments.evaluators.types import (
|
|
|
4
4
|
EvaluationResult,
|
|
5
5
|
EvaluationResultFieldNames,
|
|
6
6
|
)
|
|
7
|
-
from arize.experiments.types import
|
|
7
|
+
from arize.experiments.types import ExperimentTaskFieldNames
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"EvaluationResult",
|
|
11
11
|
"EvaluationResultFieldNames",
|
|
12
|
-
"
|
|
12
|
+
"ExperimentTaskFieldNames",
|
|
13
13
|
]
|
arize/experiments/client.py
CHANGED
|
@@ -38,12 +38,13 @@ from arize.utils.size import get_payload_size_mb
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
39
|
from opentelemetry.trace import Tracer
|
|
40
40
|
|
|
41
|
+
from arize._generated.api_client.api_client import ApiClient
|
|
41
42
|
from arize.config import SDKConfiguration
|
|
42
43
|
from arize.experiments.evaluators.base import Evaluators
|
|
43
44
|
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
44
45
|
from arize.experiments.types import (
|
|
45
46
|
ExperimentTask,
|
|
46
|
-
|
|
47
|
+
ExperimentTaskFieldNames,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
50
|
logger = logging.getLogger(__name__)
|
|
@@ -61,20 +62,22 @@ class ExperimentsClient:
|
|
|
61
62
|
:class:`arize.config.SDKConfiguration`.
|
|
62
63
|
"""
|
|
63
64
|
|
|
64
|
-
def __init__(
|
|
65
|
+
def __init__(
|
|
66
|
+
self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
|
|
67
|
+
) -> None:
|
|
65
68
|
"""
|
|
66
69
|
Args:
|
|
67
70
|
sdk_config: Resolved SDK configuration.
|
|
71
|
+
generated_client: Shared generated API client instance.
|
|
68
72
|
""" # noqa: D205, D212
|
|
69
73
|
self._sdk_config = sdk_config
|
|
70
74
|
from arize._generated import api_client as gen
|
|
71
75
|
|
|
72
|
-
|
|
76
|
+
# Use the provided client directly for both APIs
|
|
77
|
+
self._api = gen.ExperimentsApi(generated_client)
|
|
73
78
|
# TODO(Kiko): Space ID should not be needed,
|
|
74
79
|
# should work on server tech debt to remove this
|
|
75
|
-
self._datasets_api = gen.DatasetsApi(
|
|
76
|
-
self._sdk_config.get_generated_client()
|
|
77
|
-
)
|
|
80
|
+
self._datasets_api = gen.DatasetsApi(generated_client)
|
|
78
81
|
|
|
79
82
|
@prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
|
|
80
83
|
def list(
|
|
@@ -114,7 +117,7 @@ class ExperimentsClient:
|
|
|
114
117
|
name: str,
|
|
115
118
|
dataset_id: str,
|
|
116
119
|
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
117
|
-
task_fields:
|
|
120
|
+
task_fields: ExperimentTaskFieldNames,
|
|
118
121
|
evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
|
|
119
122
|
force_http: bool = False,
|
|
120
123
|
) -> models.Experiment:
|
|
@@ -141,7 +144,7 @@ class ExperimentsClient:
|
|
|
141
144
|
dataset_id: Dataset ID to attach the experiment to.
|
|
142
145
|
experiment_runs: Experiment runs either as:
|
|
143
146
|
- a list of JSON-like dicts, or
|
|
144
|
-
- a pandas
|
|
147
|
+
- a :class:`pandas.DataFrame`.
|
|
145
148
|
task_fields: Mapping that identifies the columns/fields containing the
|
|
146
149
|
task results (e.g. `example_id`, output fields).
|
|
147
150
|
evaluator_columns: Optional mapping describing evaluator result columns.
|
|
@@ -175,7 +178,6 @@ class ExperimentsClient:
|
|
|
175
178
|
from arize._generated import api_client as gen
|
|
176
179
|
|
|
177
180
|
data = experiment_df.to_dict(orient="records")
|
|
178
|
-
|
|
179
181
|
body = gen.ExperimentsCreateRequest(
|
|
180
182
|
name=name,
|
|
181
183
|
dataset_id=dataset_id,
|
|
@@ -230,7 +232,8 @@ class ExperimentsClient:
|
|
|
230
232
|
Args:
|
|
231
233
|
experiment_id: Experiment ID to delete.
|
|
232
234
|
|
|
233
|
-
Returns:
|
|
235
|
+
Returns:
|
|
236
|
+
This method returns None on success (common empty 204 response).
|
|
234
237
|
|
|
235
238
|
Raises:
|
|
236
239
|
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
@@ -358,6 +361,7 @@ class ExperimentsClient:
|
|
|
358
361
|
concurrency: int = 3,
|
|
359
362
|
set_global_tracer_provider: bool = False,
|
|
360
363
|
exit_on_error: bool = False,
|
|
364
|
+
timeout: int = 120,
|
|
361
365
|
) -> tuple[models.Experiment | None, pd.DataFrame]:
|
|
362
366
|
"""Run an experiment on a dataset and optionally upload results.
|
|
363
367
|
|
|
@@ -388,6 +392,7 @@ class ExperimentsClient:
|
|
|
388
392
|
provider for the experiment run.
|
|
389
393
|
exit_on_error: If True, stop on the first error encountered during
|
|
390
394
|
execution.
|
|
395
|
+
timeout: The timeout in seconds for each task execution. Defaults to 120.
|
|
391
396
|
|
|
392
397
|
Returns:
|
|
393
398
|
If `dry_run=True`, returns `(None, results_df)`.
|
|
@@ -506,6 +511,7 @@ class ExperimentsClient:
|
|
|
506
511
|
evaluators=evaluators,
|
|
507
512
|
concurrency=concurrency,
|
|
508
513
|
exit_on_error=exit_on_error,
|
|
514
|
+
timeout=timeout,
|
|
509
515
|
)
|
|
510
516
|
output_df = convert_default_columns_to_json_str(output_df)
|
|
511
517
|
output_df = convert_boolean_columns_to_str(output_df)
|
|
@@ -79,10 +79,10 @@ class Evaluator(ABC):
|
|
|
79
79
|
method and the asynchronous `async_evaluate` method, but it is not required.
|
|
80
80
|
|
|
81
81
|
Args:
|
|
82
|
-
dataset_row (
|
|
82
|
+
dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
|
|
83
83
|
input (ExampleInput): The input provided for evaluation.
|
|
84
|
-
output (
|
|
85
|
-
experiment_output (
|
|
84
|
+
output (TaskOutput | :obj:`None`): The output produced by the task.
|
|
85
|
+
experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
|
|
86
86
|
dataset_output (ExampleOutput): The expected output from the dataset.
|
|
87
87
|
metadata (ExampleMetadata): Metadata associated with the example.
|
|
88
88
|
**kwargs (Any): Additional keyword arguments.
|
|
@@ -112,10 +112,10 @@ class Evaluator(ABC):
|
|
|
112
112
|
method and the synchronous `evaluate` method, but it is not required.
|
|
113
113
|
|
|
114
114
|
Args:
|
|
115
|
-
dataset_row (
|
|
115
|
+
dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
|
|
116
116
|
input (ExampleInput): The input provided for evaluation.
|
|
117
|
-
output (
|
|
118
|
-
experiment_output (
|
|
117
|
+
output (TaskOutput | :obj:`None`): The output produced by the task.
|
|
118
|
+
experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
|
|
119
119
|
dataset_output (ExampleOutput): The expected output from the dataset.
|
|
120
120
|
metadata (ExampleMetadata): Metadata associated with the example.
|
|
121
121
|
**kwargs (Any): Additional keyword arguments.
|
|
@@ -94,7 +94,7 @@ class AsyncExecutor(Executor):
|
|
|
94
94
|
|
|
95
95
|
concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
|
|
96
96
|
|
|
97
|
-
tqdm_bar_format (
|
|
97
|
+
tqdm_bar_format (str | :obj:`None`, optional): The format string for the progress bar.
|
|
98
98
|
Defaults to None.
|
|
99
99
|
|
|
100
100
|
max_retries (int, optional): The maximum number of times to retry on exceptions.
|
|
@@ -119,6 +119,7 @@ class AsyncExecutor(Executor):
|
|
|
119
119
|
exit_on_error: bool = True,
|
|
120
120
|
fallback_return_value: Unset | object = _unset,
|
|
121
121
|
termination_signal: signal.Signals = signal.SIGINT,
|
|
122
|
+
timeout: int = 120,
|
|
122
123
|
) -> None:
|
|
123
124
|
"""Initialize the async executor with configuration parameters.
|
|
124
125
|
|
|
@@ -130,6 +131,7 @@ class AsyncExecutor(Executor):
|
|
|
130
131
|
exit_on_error: Whether to exit on first error.
|
|
131
132
|
fallback_return_value: Value to return when execution fails.
|
|
132
133
|
termination_signal: Signal to handle for graceful termination.
|
|
134
|
+
timeout: Timeout for each task in seconds.
|
|
133
135
|
"""
|
|
134
136
|
self.generate = generation_fn
|
|
135
137
|
self.fallback_return_value = fallback_return_value
|
|
@@ -139,6 +141,7 @@ class AsyncExecutor(Executor):
|
|
|
139
141
|
self.exit_on_error = exit_on_error
|
|
140
142
|
self.base_priority = 0
|
|
141
143
|
self.termination_signal = termination_signal
|
|
144
|
+
self.timeout = timeout
|
|
142
145
|
|
|
143
146
|
async def producer(
|
|
144
147
|
self,
|
|
@@ -195,7 +198,7 @@ class AsyncExecutor(Executor):
|
|
|
195
198
|
)
|
|
196
199
|
done, _pending = await asyncio.wait(
|
|
197
200
|
[generate_task, termination_event_watcher],
|
|
198
|
-
timeout=
|
|
201
|
+
timeout=self.timeout,
|
|
199
202
|
return_when=asyncio.FIRST_COMPLETED,
|
|
200
203
|
)
|
|
201
204
|
|
|
@@ -341,7 +344,7 @@ class SyncExecutor(Executor):
|
|
|
341
344
|
generation_fn (Callable[[object], Any]): The generation function that takes an input and
|
|
342
345
|
returns an output.
|
|
343
346
|
|
|
344
|
-
tqdm_bar_format (
|
|
347
|
+
tqdm_bar_format (str | :obj:`None`, optional): The format string for the progress bar. Defaults
|
|
345
348
|
to None.
|
|
346
349
|
|
|
347
350
|
max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
|
|
@@ -460,6 +463,7 @@ def get_executor_on_sync_context(
|
|
|
460
463
|
max_retries: int = 10,
|
|
461
464
|
exit_on_error: bool = True,
|
|
462
465
|
fallback_return_value: Unset | object = _unset,
|
|
466
|
+
timeout: int = 120,
|
|
463
467
|
) -> Executor:
|
|
464
468
|
"""Get an appropriate executor based on the current threading context.
|
|
465
469
|
|
|
@@ -475,6 +479,7 @@ def get_executor_on_sync_context(
|
|
|
475
479
|
max_retries: Maximum number of retry attempts. Defaults to 10.
|
|
476
480
|
exit_on_error: Whether to exit on first error. Defaults to True.
|
|
477
481
|
fallback_return_value: Value to return on failure. Defaults to unset.
|
|
482
|
+
timeout: Timeout for each task in seconds. Defaults to 120.
|
|
478
483
|
|
|
479
484
|
Returns:
|
|
480
485
|
An Executor instance configured for the current context.
|
|
@@ -513,6 +518,7 @@ def get_executor_on_sync_context(
|
|
|
513
518
|
max_retries=max_retries,
|
|
514
519
|
exit_on_error=exit_on_error,
|
|
515
520
|
fallback_return_value=fallback_return_value,
|
|
521
|
+
timeout=timeout,
|
|
516
522
|
)
|
|
517
523
|
logger.warning(
|
|
518
524
|
"🐌!! If running inside a notebook, patching the event loop with "
|
|
@@ -533,6 +539,7 @@ def get_executor_on_sync_context(
|
|
|
533
539
|
max_retries=max_retries,
|
|
534
540
|
exit_on_error=exit_on_error,
|
|
535
541
|
fallback_return_value=fallback_return_value,
|
|
542
|
+
timeout=timeout,
|
|
536
543
|
)
|
|
537
544
|
|
|
538
545
|
|
|
@@ -94,14 +94,14 @@ EvaluatorOutput = (
|
|
|
94
94
|
|
|
95
95
|
@dataclass
|
|
96
96
|
class EvaluationResultFieldNames:
|
|
97
|
-
"""Column names for mapping evaluation results in a DataFrame
|
|
97
|
+
"""Column names for mapping evaluation results in a :class:`pandas.DataFrame`.
|
|
98
98
|
|
|
99
99
|
Args:
|
|
100
100
|
score: Optional name of column containing evaluation scores
|
|
101
101
|
label: Optional name of column containing evaluation labels
|
|
102
102
|
explanation: Optional name of column containing evaluation explanations
|
|
103
103
|
metadata: Optional mapping of metadata keys to column names. If a column name
|
|
104
|
-
is None or empty string, the metadata key will be used as the column name.
|
|
104
|
+
is :obj:`None` or empty string, the metadata key will be used as the column name.
|
|
105
105
|
|
|
106
106
|
Examples:
|
|
107
107
|
>>> # Basic usage with score and label columns
|