arize 8.0.0b1__py3-none-any.whl → 8.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. arize/__init__.py +9 -2
  2. arize/_client_factory.py +50 -0
  3. arize/_exporter/client.py +18 -17
  4. arize/_exporter/parsers/tracing_data_parser.py +9 -4
  5. arize/_exporter/validation.py +1 -1
  6. arize/_flight/client.py +37 -17
  7. arize/_generated/api_client/api/datasets_api.py +6 -6
  8. arize/_generated/api_client/api/experiments_api.py +6 -6
  9. arize/_generated/api_client/api/projects_api.py +3 -3
  10. arize/_lazy.py +61 -10
  11. arize/client.py +66 -50
  12. arize/config.py +175 -48
  13. arize/constants/config.py +1 -0
  14. arize/constants/ml.py +9 -16
  15. arize/constants/spans.py +5 -10
  16. arize/datasets/client.py +45 -28
  17. arize/datasets/errors.py +1 -1
  18. arize/datasets/validation.py +2 -2
  19. arize/embeddings/auto_generator.py +16 -9
  20. arize/embeddings/base_generators.py +15 -9
  21. arize/embeddings/cv_generators.py +2 -2
  22. arize/embeddings/errors.py +2 -2
  23. arize/embeddings/nlp_generators.py +8 -8
  24. arize/embeddings/tabular_generators.py +6 -6
  25. arize/exceptions/base.py +0 -52
  26. arize/exceptions/config.py +22 -0
  27. arize/exceptions/parameters.py +1 -330
  28. arize/exceptions/values.py +8 -5
  29. arize/experiments/__init__.py +4 -0
  30. arize/experiments/client.py +31 -18
  31. arize/experiments/evaluators/base.py +12 -9
  32. arize/experiments/evaluators/executors.py +16 -7
  33. arize/experiments/evaluators/rate_limiters.py +3 -1
  34. arize/experiments/evaluators/types.py +9 -7
  35. arize/experiments/evaluators/utils.py +7 -5
  36. arize/experiments/functions.py +128 -58
  37. arize/experiments/tracing.py +4 -1
  38. arize/experiments/types.py +34 -31
  39. arize/logging.py +54 -33
  40. arize/ml/batch_validation/errors.py +10 -1004
  41. arize/ml/batch_validation/validator.py +351 -291
  42. arize/ml/bounded_executor.py +25 -6
  43. arize/ml/casting.py +51 -33
  44. arize/ml/client.py +43 -35
  45. arize/ml/proto.py +21 -22
  46. arize/ml/stream_validation.py +64 -27
  47. arize/ml/surrogate_explainer/mimic.py +18 -10
  48. arize/ml/types.py +27 -67
  49. arize/pre_releases.py +10 -6
  50. arize/projects/client.py +9 -4
  51. arize/py.typed +0 -0
  52. arize/regions.py +11 -11
  53. arize/spans/client.py +125 -31
  54. arize/spans/columns.py +32 -36
  55. arize/spans/conversion.py +12 -11
  56. arize/spans/validation/annotations/dataframe_form_validation.py +1 -1
  57. arize/spans/validation/annotations/value_validation.py +11 -14
  58. arize/spans/validation/common/argument_validation.py +3 -3
  59. arize/spans/validation/common/dataframe_form_validation.py +7 -7
  60. arize/spans/validation/common/value_validation.py +11 -14
  61. arize/spans/validation/evals/dataframe_form_validation.py +4 -4
  62. arize/spans/validation/evals/evals_validation.py +6 -6
  63. arize/spans/validation/evals/value_validation.py +1 -1
  64. arize/spans/validation/metadata/argument_validation.py +1 -1
  65. arize/spans/validation/metadata/dataframe_form_validation.py +2 -2
  66. arize/spans/validation/metadata/value_validation.py +23 -1
  67. arize/spans/validation/spans/dataframe_form_validation.py +2 -2
  68. arize/spans/validation/spans/spans_validation.py +6 -6
  69. arize/utils/arrow.py +38 -2
  70. arize/utils/cache.py +2 -2
  71. arize/utils/dataframe.py +4 -4
  72. arize/utils/online_tasks/dataframe_preprocessor.py +15 -11
  73. arize/utils/openinference_conversion.py +10 -10
  74. arize/utils/proto.py +0 -1
  75. arize/utils/types.py +6 -6
  76. arize/version.py +1 -1
  77. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/METADATA +32 -7
  78. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/RECORD +81 -78
  79. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/WHEEL +0 -0
  80. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/LICENSE +0 -0
  81. {arize-8.0.0b1.dist-info → arize-8.0.0b4.dist-info}/licenses/NOTICE +0 -0
@@ -3,159 +3,6 @@
3
3
  from arize.constants.ml import MAX_NUMBER_OF_EMBEDDINGS
4
4
  from arize.exceptions.base import ValidationError
5
5
 
6
- # class MissingPredictionIdColumnForDelayedRecords(ValidationError):
7
- # def __repr__(self) -> str:
8
- # return "Missing_Prediction_Id_Column_For_Delayed_Records"
9
- #
10
- # def __init__(self, has_actual_info, has_feature_importance_info) -> None:
11
- # self.has_actual_info = has_actual_info
12
- # self.has_feature_importance_info = has_feature_importance_info
13
- #
14
- # def error_message(self) -> str:
15
- # actual = "actual" if self.has_actual_info else ""
16
- # feat_imp = (
17
- # "feature importance" if self.has_feature_importance_info else ""
18
- # )
19
- # if self.has_actual_info and self.has_feature_importance_info:
20
- # msg = " and ".join([actual, feat_imp])
21
- # else:
22
- # msg = "".join([actual, feat_imp])
23
- #
24
- # return (
25
- # "Missing 'prediction_id_column_name'. While prediction id is optional for most cases, "
26
- # "it is required when sending delayed actuals, i.e. when sending actual or feature importances "
27
- # f"without predictions. In this case, {msg} information was found (without predictions). "
28
- # "To learn more about delayed joins, please see the docs at "
29
- # "https://docs.arize.com/arize/sending-data-guides/how-to-send-delayed-actuals"
30
- # )
31
-
32
-
33
- # class MissingColumns(ValidationError):
34
- # def __repr__(self) -> str:
35
- # return "Missing_Columns"
36
- #
37
- # def __init__(self, cols: Iterable) -> None:
38
- # self.missing_cols = set(cols)
39
- #
40
- # def error_message(self) -> str:
41
- # return (
42
- # "The following columns are declared in the schema "
43
- # "but are not found in the dataframe: "
44
- # f"{', '.join(map(str, self.missing_cols))}."
45
- # )
46
-
47
-
48
- # class MissingRequiredColumnsMetricsValidation(ValidationError):
49
- # """
50
- # This error is used only for model mapping validations.
51
- # """
52
- #
53
- # def __repr__(self) -> str:
54
- # return "Missing_Columns_Required_By_Metrics_Validation"
55
- #
56
- # def __init__(
57
- # self, model_type: ModelTypes, metrics: List[Metrics], cols: Iterable
58
- # ) -> None:
59
- # self.model_type = model_type
60
- # self.metrics = metrics
61
- # self.missing_cols = cols
62
- #
63
- # def error_message(self) -> str:
64
- # return (
65
- # f"For logging data for a {self.model_type.name} model with support for metrics "
66
- # f"{', '.join(m.name for m in self.metrics)}, "
67
- # f"schema must include: {', '.join(map(str, self.missing_cols))}."
68
- # )
69
-
70
-
71
- # class ReservedColumns(ValidationError):
72
- # def __repr__(self) -> str:
73
- # return "Reserved_Columns"
74
- #
75
- # def __init__(self, cols: Iterable) -> None:
76
- # self.reserved_columns = cols
77
- #
78
- # def error_message(self) -> str:
79
- # return (
80
- # "The following columns are reserved and can only be specified "
81
- # "in the proper fields of the schema: "
82
- # f"{', '.join(map(str, self.reserved_columns))}."
83
- # )
84
-
85
-
86
- # class InvalidModelTypeAndMetricsCombination(ValidationError):
87
- # """
88
- # This error is used only for model mapping validations.
89
- # """
90
- #
91
- # def __repr__(self) -> str:
92
- # return "Invalid_ModelType_And_Metrics_Combination"
93
- #
94
- # def __init__(
95
- # self,
96
- # model_type: ModelTypes,
97
- # metrics: List[Metrics],
98
- # suggested_model_metric_combinations: List[List[str]],
99
- # ) -> None:
100
- # self.model_type = model_type
101
- # self.metrics = metrics
102
- # self.suggested_combinations = suggested_model_metric_combinations
103
- #
104
- # def error_message(self) -> str:
105
- # valid_combos = ", or \n".join(
106
- # "[" + ", ".join(combo) + "]"
107
- # for combo in self.suggested_combinations
108
- # )
109
- # return (
110
- # f"Invalid combination of model type {self.model_type.name} and metrics: "
111
- # f"{', '.join(m.name for m in self.metrics)}. "
112
- # f"Valid Metric combinations for this model type:\n{valid_combos}.\n"
113
- # )
114
-
115
-
116
- # class InvalidShapSuffix(ValidationError):
117
- # def __repr__(self) -> str:
118
- # return "Invalid_SHAP_Suffix"
119
- #
120
- # def __init__(self, cols: Iterable) -> None:
121
- # self.invalid_column_names = cols
122
- #
123
- # def error_message(self) -> str:
124
- # return (
125
- # "The following features or tags must not be named with a `_shap` suffix: "
126
- # f"{', '.join(map(str, self.invalid_column_names))}."
127
- # )
128
-
129
-
130
- # class InvalidModelType(ValidationError):
131
- # def __repr__(self) -> str:
132
- # return "Invalid_Model_Type"
133
- #
134
- # def error_message(self) -> str:
135
- # return (
136
- # "Model type not valid. Choose one of the following: "
137
- # f"{', '.join('ModelTypes.' + mt.name for mt in ModelTypes)}. "
138
- # )
139
-
140
-
141
- # class InvalidEnvironment(ValidationError):
142
- # def __repr__(self) -> str:
143
- # return "Invalid_Environment"
144
- #
145
- # def error_message(self) -> str:
146
- # return (
147
- # "Environment not valid. Choose one of the following: "
148
- # f"{', '.join('Environments.' + env.name for env in Environments)}. "
149
- # )
150
-
151
-
152
- # class InvalidBatchId(ValidationError):
153
- # def __repr__(self) -> str:
154
- # return "Invalid_Batch_ID"
155
- #
156
- # def error_message(self) -> str:
157
- # return "Batch ID must be a nonempty string if logging to validation environment."
158
-
159
6
 
160
7
  class InvalidModelVersion(ValidationError):
161
8
  """Raised when model version is empty or invalid."""
@@ -169,14 +16,6 @@ class InvalidModelVersion(ValidationError):
169
16
  return "Model version must be a nonempty string."
170
17
 
171
18
 
172
- # class InvalidModelId(ValidationError):
173
- # def __repr__(self) -> str:
174
- # return "Invalid_Model_ID"
175
- #
176
- # def error_message(self) -> str:
177
- # return "Model ID must be a nonempty string."
178
-
179
-
180
19
  class InvalidProjectName(ValidationError):
181
20
  """Raised when project name is empty or invalid."""
182
21
 
@@ -193,174 +32,6 @@ class InvalidProjectName(ValidationError):
193
32
  )
194
33
 
195
34
 
196
- # class MissingPredActShap(ValidationError):
197
- # def __repr__(self) -> str:
198
- # return "Missing_Pred_or_Act_or_SHAP"
199
- #
200
- # def error_message(self) -> str:
201
- # return (
202
- # "The schema must specify at least one of the following: "
203
- # "prediction label, actual label, or SHAP value column names"
204
- # )
205
-
206
-
207
- # class MissingPreprodPredAct(ValidationError):
208
- # def __repr__(self) -> str:
209
- # return "Missing_Preproduction_Pred_and_Act"
210
- #
211
- # def error_message(self) -> str:
212
- # return "For logging pre-production data, the schema must specify both "
213
- # "prediction and actual label columns."
214
-
215
-
216
- # class MissingPreprodAct(ValidationError):
217
- # def __repr__(self) -> str:
218
- # return "Missing_Preproduction_Act"
219
- #
220
- # def error_message(self) -> str:
221
- # return "For logging pre-production data, the schema must specify actual label column."
222
-
223
-
224
- # class MissingPreprodPredActNumericAndCategorical(ValidationError):
225
- # def __repr__(self) -> str:
226
- # return "Missing_Preproduction_Pred_and_Act_Numeric_and_Categorical"
227
- #
228
- # def error_message(self) -> str:
229
- # return (
230
- # "For logging pre-production data for a numeric or a categorical model, "
231
- # "the schema must specify both prediction and actual label or score columns."
232
- # )
233
-
234
-
235
- # class MissingRequiredColumnsForRankingModel(ValidationError):
236
- # def __repr__(self) -> str:
237
- # return "Missing_Required_Columns_For_Ranking_Model"
238
- #
239
- # def error_message(self) -> str:
240
- # return (
241
- # "For logging data for a ranking model, schema must specify: "
242
- # "prediction_group_id_column_name and rank_column_name"
243
- # )
244
-
245
-
246
- # class MissingCVPredAct(ValidationError):
247
- # def __repr__(self) -> str:
248
- # return "Missing_CV_Prediction_or_Actual"
249
- #
250
- # def __init__(self, environment: Environments):
251
- # self.environment = environment
252
- #
253
- # def error_message(self) -> str:
254
- # if self.environment in (Environments.TRAINING, Environments.VALIDATION):
255
- # env = "pre-production"
256
- # opt = "and"
257
- # elif self.environment == Environments.PRODUCTION:
258
- # env = "production"
259
- # opt = "or"
260
- # else:
261
- # raise TypeError("Invalid environment")
262
- # return (
263
- # f"For logging {env} data for an Object Detection model,"
264
- # "the schema must specify one of: "
265
- # f"('object_detection_prediction_column_names' {opt} "
266
- # f"'object_detection_actual_column_names') "
267
- # f"or ('semantic_segmentation_prediction_column_names' {opt} "
268
- # f"'semantic_segmentation_actual_column_names') "
269
- # f"or ('instance_segmentation_prediction_column_names' {opt} "
270
- # f"'instance_segmentation_actual_column_names')"
271
- # )
272
-
273
-
274
- # class MultipleCVPredAct(ValidationError):
275
- # def __repr__(self) -> str:
276
- # return "Multiple_CV_Prediction_or_Actual"
277
- #
278
- # def __init__(self, environment: Environments):
279
- # self.environment = environment
280
- #
281
- # def error_message(self) -> str:
282
- # return (
283
- # "The schema must only specify one of the following: "
284
- # "'object_detection_prediction_column_names'/'object_detection_actual_column_names'"
285
- # "'semantic_segmentation_prediction_column_names'/'semantic_segmentation_actual_column_names'"
286
- # "'instance_segmentation_prediction_column_names'/'instance_segmentation_actual_column_names'"
287
- # )
288
-
289
-
290
- # class InvalidPredActCVColumnNamesForModelType(ValidationError):
291
- # def __repr__(self) -> str:
292
- # return "Invalid_CV_Prediction_or_Actual_Column_Names_for_Model_Type"
293
- #
294
- # def __init__(
295
- # self,
296
- # invalid_model_type: ModelTypes,
297
- # ) -> None:
298
- # self.invalid_model_type = invalid_model_type
299
- #
300
- # def error_message(self) -> str:
301
- # return (
302
- # f"Cannot use 'object_detection_prediction_column_names' or "
303
- # f"'object_detection_actual_column_names' or "
304
- # f"'semantic_segmentation_prediction_column_names' or "
305
- # f"'semantic_segmentation_actual_column_names' or "
306
- # f"'instance_segmentation_prediction_column_names' or "
307
- # f"'instance_segmentation_actual_column_names' for {self.invalid_model_type} model "
308
- # f"type. They are only allowed for ModelTypes.OBJECT_DETECTION models"
309
- # )
310
-
311
-
312
- # class MissingReqPredActColumnNamesForMultiClass(ValidationError):
313
- # def __repr__(self) -> str:
314
- # return "Missing_Required_Prediction_or_Actual_Column_Names_for_Multi_Class_Model_Type"
315
- #
316
- # def error_message(self) -> str:
317
- # return (
318
- # "For logging data for a multi class model, schema must specify: "
319
- # "prediction_scores_column_name and/or actual_score_column_name. "
320
- # "Optionally, you may include multi_class_threshold_scores_column_name"
321
- # " (must include prediction_scores_column_name)"
322
- # )
323
-
324
-
325
- # class InvalidPredActColumnNamesForModelType(ValidationError):
326
- # def __repr__(self) -> str:
327
- # return "Invalid_Prediction_or_Actual_Column_Names_for_Model_Type"
328
- #
329
- # def __init__(
330
- # self,
331
- # invalid_model_type: ModelTypes,
332
- # allowed_fields: List[str],
333
- # wrong_columns: List[str],
334
- # ) -> None:
335
- # self.invalid_model_type = invalid_model_type
336
- # self.allowed_fields = allowed_fields
337
- # self.wrong_columns = wrong_columns
338
- #
339
- # def error_message(self) -> str:
340
- # allowed_col_msg = ""
341
- # if self.allowed_fields is not None:
342
- # allowed_col_msg = f" Allowed Schema fields are {log_a_list(self.allowed_fields, 'and')}"
343
- # return (
344
- # f"Invalid Schema fields for {self.invalid_model_type} model type. {allowed_col_msg}"
345
- # "The following columns of your dataframe are sent as an invalid schema field: "
346
- # f"{log_a_list(self.wrong_columns, 'and')}"
347
- # )
348
-
349
-
350
- # class DuplicateColumnsInDataframe(ValidationError):
351
- # def __repr__(self) -> str:
352
- # return "Duplicate_Columns_In_Dataframe"
353
- #
354
- # def __init__(self, cols: Iterable) -> None:
355
- # self.duplicate_cols = cols
356
- #
357
- # def error_message(self) -> str:
358
- # return (
359
- # "The following columns are present in the schema and have duplicates in the dataframe: "
360
- # f"{self.duplicate_cols}. "
361
- # )
362
-
363
-
364
35
  class InvalidNumberOfEmbeddings(ValidationError):
365
36
  """Raised when number of embedding features exceeds the maximum allowed."""
366
37
 
@@ -390,7 +61,7 @@ class InvalidValueType(Exception):
390
61
  def __init__(
391
62
  self,
392
63
  value_name: str,
393
- value: bool | int | float | str,
64
+ value: object,
394
65
  correct_type: str,
395
66
  ) -> None:
396
67
  """Initialize the exception with value type validation context.
@@ -533,14 +533,15 @@ class InvalidMultiClassClassNameLength(ValidationError):
533
533
  err_msg = ""
534
534
  for col, class_names in self.invalid_col_class_name.items():
535
535
  # limit to 10
536
- class_names = (
536
+ class_names_list = (
537
537
  list(class_names)[:10]
538
538
  if len(class_names) > 10
539
539
  else list(class_names)
540
540
  )
541
541
  err_msg += (
542
- f"Found some invalid class names: {log_a_list(class_names, 'and')} in the {col} column. Class"
543
- f" names must have at least one character and less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
542
+ f"Found some invalid class names: {log_a_list(class_names_list, 'and')} "
543
+ f"in the {col} column. Class names must have at least one character and "
544
+ f"less than {MAX_MULTI_CLASS_NAME_LENGTH}.\n"
544
545
  )
545
546
  return err_msg
546
547
 
@@ -565,9 +566,11 @@ class InvalidMultiClassPredScoreValue(ValidationError):
565
566
  err_msg = ""
566
567
  for col, scores in self.invalid_col_class_scores.items():
567
568
  # limit to 10
568
- scores = list(scores)[:10] if len(scores) > 10 else list(scores)
569
+ scores_list = (
570
+ list(scores)[:10] if len(scores) > 10 else list(scores)
571
+ )
569
572
  err_msg += (
570
- f"Found some invalid scores: {log_a_list(scores, 'and')} in the {col} column that was "
573
+ f"Found some invalid scores: {log_a_list(scores_list, 'and')} in the {col} column that was "
571
574
  "invalid. All scores (values in dictionary) must be between 0 and 1, inclusive. \n"
572
575
  )
573
576
  return err_msg
@@ -1,5 +1,8 @@
1
1
  """Experiment tracking and evaluation functionality for the Arize SDK."""
2
2
 
3
+ from arize.experiments.evaluators.base import (
4
+ Evaluator,
5
+ )
3
6
  from arize.experiments.evaluators.types import (
4
7
  EvaluationResult,
5
8
  EvaluationResultFieldNames,
@@ -9,5 +12,6 @@ from arize.experiments.types import ExperimentTaskFieldNames
9
12
  __all__ = [
10
13
  "EvaluationResult",
11
14
  "EvaluationResultFieldNames",
15
+ "Evaluator",
12
16
  "ExperimentTaskFieldNames",
13
17
  ]
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, cast
7
7
 
8
8
  import opentelemetry.sdk.trace as trace_sdk
9
9
  import pandas as pd
@@ -36,8 +36,13 @@ from arize.utils.openinference_conversion import (
36
36
  from arize.utils.size import get_payload_size_mb
37
37
 
38
38
  if TYPE_CHECKING:
39
+ # builtins is needed to use builtins.list in type annotations because
40
+ # the class has a list() method that shadows the built-in list type
41
+ import builtins
42
+
39
43
  from opentelemetry.trace import Tracer
40
44
 
45
+ from arize._generated.api_client.api_client import ApiClient
41
46
  from arize.config import SDKConfiguration
42
47
  from arize.experiments.evaluators.base import Evaluators
43
48
  from arize.experiments.evaluators.types import EvaluationResultFieldNames
@@ -61,20 +66,22 @@ class ExperimentsClient:
61
66
  :class:`arize.config.SDKConfiguration`.
62
67
  """
63
68
 
64
- def __init__(self, *, sdk_config: SDKConfiguration) -> None:
69
+ def __init__(
70
+ self, *, sdk_config: SDKConfiguration, generated_client: ApiClient
71
+ ) -> None:
65
72
  """
66
73
  Args:
67
74
  sdk_config: Resolved SDK configuration.
75
+ generated_client: Shared generated API client instance.
68
76
  """ # noqa: D205, D212
69
77
  self._sdk_config = sdk_config
70
78
  from arize._generated import api_client as gen
71
79
 
72
- self._api = gen.ExperimentsApi(self._sdk_config.get_generated_client())
80
+ # Use the provided client directly for both APIs
81
+ self._api = gen.ExperimentsApi(generated_client)
73
82
  # TODO(Kiko): Space ID should not be needed,
74
83
  # should work on server tech debt to remove this
75
- self._datasets_api = gen.DatasetsApi(
76
- self._sdk_config.get_generated_client()
77
- )
84
+ self._datasets_api = gen.DatasetsApi(generated_client)
78
85
 
79
86
  @prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
80
87
  def list(
@@ -113,7 +120,7 @@ class ExperimentsClient:
113
120
  *,
114
121
  name: str,
115
122
  dataset_id: str,
116
- experiment_runs: list[dict[str, object]] | pd.DataFrame,
123
+ experiment_runs: builtins.list[dict[str, object]] | pd.DataFrame,
117
124
  task_fields: ExperimentTaskFieldNames,
118
125
  evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
119
126
  force_http: bool = False,
@@ -141,7 +148,7 @@ class ExperimentsClient:
141
148
  dataset_id: Dataset ID to attach the experiment to.
142
149
  experiment_runs: Experiment runs either as:
143
150
  - a list of JSON-like dicts, or
144
- - a pandas DataFrame.
151
+ - a :class:`pandas.DataFrame`.
145
152
  task_fields: Mapping that identifies the columns/fields containing the
146
153
  task results (e.g. `example_id`, output fields).
147
154
  evaluator_columns: Optional mapping describing evaluator result columns.
@@ -178,7 +185,7 @@ class ExperimentsClient:
178
185
  body = gen.ExperimentsCreateRequest(
179
186
  name=name,
180
187
  dataset_id=dataset_id,
181
- experiment_runs=data, # type: ignore
188
+ experiment_runs=cast("list[gen.ExperimentRunCreate]", data),
182
189
  )
183
190
  return self._api.experiments_create(experiments_create_request=body)
184
191
 
@@ -229,7 +236,8 @@ class ExperimentsClient:
229
236
  Args:
230
237
  experiment_id: Experiment ID to delete.
231
238
 
232
- Returns: This method returns None on success (common empty 204 response)
239
+ Returns:
240
+ This method returns None on success (common empty 204 response).
233
241
 
234
242
  Raises:
235
243
  arize._generated.api_client.exceptions.ApiException: If the REST API
@@ -299,7 +307,10 @@ class ExperimentsClient:
299
307
  )
300
308
  if experiment_df is not None:
301
309
  return models.ExperimentsRunsList200Response(
302
- experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
310
+ experiment_runs=cast(
311
+ "list[models.ExperimentRun]",
312
+ experiment_df.to_dict(orient="records"),
313
+ ),
303
314
  pagination=models.PaginationMetadata(
304
315
  has_more=False, # Note that all=True
305
316
  ),
@@ -339,7 +350,10 @@ class ExperimentsClient:
339
350
  )
340
351
 
341
352
  return models.ExperimentsRunsList200Response(
342
- experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
353
+ experiment_runs=cast(
354
+ "list[models.ExperimentRun]",
355
+ experiment_df.to_dict(orient="records"),
356
+ ),
343
357
  pagination=models.PaginationMetadata(
344
358
  has_more=False, # Note that all=True
345
359
  ),
@@ -357,6 +371,7 @@ class ExperimentsClient:
357
371
  concurrency: int = 3,
358
372
  set_global_tracer_provider: bool = False,
359
373
  exit_on_error: bool = False,
374
+ timeout: int = 120,
360
375
  ) -> tuple[models.Experiment | None, pd.DataFrame]:
361
376
  """Run an experiment on a dataset and optionally upload results.
362
377
 
@@ -387,6 +402,7 @@ class ExperimentsClient:
387
402
  provider for the experiment run.
388
403
  exit_on_error: If True, stop on the first error encountered during
389
404
  execution.
405
+ timeout: The timeout in seconds for each task execution. Defaults to 120.
390
406
 
391
407
  Returns:
392
408
  If `dry_run=True`, returns `(None, results_df)`.
@@ -505,6 +521,7 @@ class ExperimentsClient:
505
521
  evaluators=evaluators,
506
522
  concurrency=concurrency,
507
523
  exit_on_error=exit_on_error,
524
+ timeout=timeout,
508
525
  )
509
526
  output_df = convert_default_columns_to_json_str(output_df)
510
527
  output_df = convert_boolean_columns_to_str(output_df)
@@ -546,9 +563,7 @@ class ExperimentsClient:
546
563
  logger.error(msg)
547
564
  raise RuntimeError(msg)
548
565
 
549
- experiment = self.get(
550
- experiment_id=str(post_resp.experiment_id) # type: ignore
551
- )
566
+ experiment = self.get(experiment_id=str(post_resp.experiment_id))
552
567
  return experiment, output_df
553
568
 
554
569
  def _create_experiment_via_flight(
@@ -629,9 +644,7 @@ class ExperimentsClient:
629
644
  logger.error(msg)
630
645
  raise RuntimeError(msg)
631
646
 
632
- return self.get(
633
- experiment_id=str(post_resp.experiment_id) # type: ignore
634
- )
647
+ return self.get(experiment_id=str(post_resp.experiment_id))
635
648
 
636
649
 
637
650
  def _get_tracer_resource(
@@ -7,7 +7,7 @@ import inspect
7
7
  from abc import ABC
8
8
  from collections.abc import Awaitable, Callable, Mapping, Sequence
9
9
  from types import MappingProxyType
10
- from typing import TYPE_CHECKING
10
+ from typing import TYPE_CHECKING, Any, cast
11
11
 
12
12
  from arize.experiments.evaluators.types import (
13
13
  AnnotatorKind,
@@ -79,10 +79,10 @@ class Evaluator(ABC):
79
79
  method and the asynchronous `async_evaluate` method, but it is not required.
80
80
 
81
81
  Args:
82
- dataset_row (Optional[Mapping[str, JSONSerializable]]): A row from the dataset.
82
+ dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
83
83
  input (ExampleInput): The input provided for evaluation.
84
- output (Optional[TaskOutput]): The output produced by the task.
85
- experiment_output (Optional[TaskOutput]): The experiment output for comparison.
84
+ output (TaskOutput | :obj:`None`): The output produced by the task.
85
+ experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
86
86
  dataset_output (ExampleOutput): The expected output from the dataset.
87
87
  metadata (ExampleMetadata): Metadata associated with the example.
88
88
  **kwargs (Any): Additional keyword arguments.
@@ -112,10 +112,10 @@ class Evaluator(ABC):
112
112
  method and the synchronous `evaluate` method, but it is not required.
113
113
 
114
114
  Args:
115
- dataset_row (Optional[Mapping[str, JSONSerializable]]): A row from the dataset.
115
+ dataset_row (Mapping[str, JSONSerializable] | :obj:`None`): A row from the dataset.
116
116
  input (ExampleInput): The input provided for evaluation.
117
- output (Optional[TaskOutput]): The output produced by the task.
118
- experiment_output (Optional[TaskOutput]): The experiment output for comparison.
117
+ output (TaskOutput | :obj:`None`): The output produced by the task.
118
+ experiment_output (TaskOutput | :obj:`None`): The experiment output for comparison.
119
119
  dataset_output (ExampleOutput): The expected output from the dataset.
120
120
  metadata (ExampleMetadata): Metadata associated with the example.
121
121
  **kwargs (Any): Additional keyword arguments.
@@ -162,7 +162,9 @@ class Evaluator(ABC):
162
162
  f"`evaluate()` method should be callable, got {type(evaluate)}"
163
163
  )
164
164
  # need to remove the first param, i.e. `self`
165
- _validate_sig(functools.partial(evaluate, None), "evaluate")
165
+ _validate_sig(
166
+ functools.partial(evaluate, cast("Any", None)), "evaluate"
167
+ )
166
168
  return
167
169
  if async_evaluate := super_cls.__dict__.get(
168
170
  Evaluator.async_evaluate.__name__
@@ -175,7 +177,8 @@ class Evaluator(ABC):
175
177
  )
176
178
  # need to remove the first param, i.e. `self`
177
179
  _validate_sig(
178
- functools.partial(async_evaluate, None), "async_evaluate"
180
+ functools.partial(async_evaluate, cast("Any", None)),
181
+ "async_evaluate",
179
182
  )
180
183
  return
181
184
  raise ValueError(