arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,10 +1,12 @@
1
+ """Client implementation for managing experiments in the Arize platform."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- import hashlib
4
5
  import logging
5
- from typing import TYPE_CHECKING, Any, Dict, List, Tuple
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import opentelemetry.sdk.trace as trace_sdk
9
+ import pandas as pd
8
10
  import pyarrow as pa
9
11
  from openinference.semconv.resource import ResourceAttributes
10
12
  from opentelemetry import trace
@@ -16,23 +18,16 @@ from opentelemetry.sdk.trace.export import (
16
18
  ConsoleSpanExporter,
17
19
  SimpleSpanProcessor,
18
20
  )
19
- from opentelemetry.trace import Tracer
20
21
 
21
22
  from arize._flight.client import ArizeFlightClient
22
23
  from arize._flight.types import FlightRequestType
23
24
  from arize._generated.api_client import models
24
- from arize.config import SDKConfiguration
25
25
  from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
26
- from arize.experiments.evaluators.base import Evaluators
27
- from arize.experiments.evaluators.types import EvaluationResultFieldNames
28
26
  from arize.experiments.functions import (
29
27
  run_experiment,
30
28
  transform_to_experiment_format,
31
29
  )
32
- from arize.experiments.types import (
33
- ExperimentTask,
34
- ExperimentTaskResultFieldNames,
35
- )
30
+ from arize.pre_releases import ReleaseStage, prerelease_endpoint
36
31
  from arize.utils.cache import cache_resource, load_cached_resource
37
32
  from arize.utils.openinference_conversion import (
38
33
  convert_boolean_columns_to_str,
@@ -41,16 +36,36 @@ from arize.utils.openinference_conversion import (
41
36
  from arize.utils.size import get_payload_size_mb
42
37
 
43
38
  if TYPE_CHECKING:
44
- import pandas as pd
45
-
46
- from arize._generated.api_client.models.experiment import Experiment
47
-
39
+ from opentelemetry.trace import Tracer
40
+
41
+ from arize.config import SDKConfiguration
42
+ from arize.experiments.evaluators.base import Evaluators
43
+ from arize.experiments.evaluators.types import EvaluationResultFieldNames
44
+ from arize.experiments.types import (
45
+ ExperimentTask,
46
+ ExperimentTaskResultFieldNames,
47
+ )
48
48
 
49
49
  logger = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class ExperimentsClient:
53
- def __init__(self, *, sdk_config: SDKConfiguration):
53
+ """Client for managing experiments including creation, execution, and result tracking.
54
+
55
+ This class is primarily intended for internal use within the SDK. Users are
56
+ highly encouraged to access resource-specific functionality via
57
+ :class:`arize.ArizeClient`.
58
+
59
+ The experiments client is a thin wrapper around the generated REST API client,
60
+ using the shared generated API client owned by
61
+ :class:`arize.config.SDKConfiguration`.
62
+ """
63
+
64
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
65
+ """
66
+ Args:
67
+ sdk_config: Resolved SDK configuration.
68
+ """ # noqa: D205, D212
54
69
  self._sdk_config = sdk_config
55
70
  from arize._generated import api_client as gen
56
71
 
@@ -61,16 +76,277 @@ class ExperimentsClient:
61
76
  self._sdk_config.get_generated_client()
62
77
  )
63
78
 
64
- self.list = self._api.experiments_list
65
- self.get = self._api.experiments_get
66
- self.delete = self._api.experiments_delete
79
+ @prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
80
+ def list(
81
+ self,
82
+ *,
83
+ dataset_id: str | None = None,
84
+ limit: int = 100,
85
+ cursor: str | None = None,
86
+ ) -> models.ExperimentsList200Response:
87
+ """List experiments the user has access to.
88
+
89
+ To filter experiments by the dataset they were run on, provide `dataset_id`.
90
+
91
+ Args:
92
+ dataset_id: Optional dataset ID to filter experiments.
93
+ limit: Maximum number of experiments to return. The server enforces an
94
+ upper bound.
95
+ cursor: Opaque pagination cursor returned from a previous response.
96
+
97
+ Returns:
98
+ A response object with the experiments and pagination information.
99
+
100
+ Raises:
101
+ arize._generated.api_client.exceptions.ApiException: If the REST API
102
+ returns an error response (e.g. 401/403/429).
103
+ """
104
+ return self._api.experiments_list(
105
+ dataset_id=dataset_id,
106
+ limit=limit,
107
+ cursor=cursor,
108
+ )
109
+
110
+ @prerelease_endpoint(key="experiments.create", stage=ReleaseStage.BETA)
111
+ def create(
112
+ self,
113
+ *,
114
+ name: str,
115
+ dataset_id: str,
116
+ experiment_runs: list[dict[str, object]] | pd.DataFrame,
117
+ task_fields: ExperimentTaskResultFieldNames,
118
+ evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
119
+ force_http: bool = False,
120
+ ) -> models.Experiment:
121
+ """Create an experiment with one or more experiment runs.
122
+
123
+ Experiments are composed of runs. Each run must include:
124
+ - `example_id`: ID of an existing example in the dataset/version
125
+ - `output`: Model/task output for the matching example
126
+
127
+ You may include any additional user-defined fields per run (e.g. `model`,
128
+ `latency_ms`, `temperature`, `prompt`, `tool_calls`, etc.) that can be used
129
+ for analysis or filtering.
130
+
131
+ This method transforms the input runs into the server's expected experiment
132
+ format using `task_fields` and optional `evaluator_columns`.
133
+
134
+ Transport selection:
135
+ - If the payload is below the configured REST payload threshold (or
136
+ `force_http=True`), this method uploads via REST.
137
+ - Otherwise, it attempts a more efficient upload path via gRPC + Flight.
138
+
139
+ Args:
140
+ name: Experiment name. Must be unique within the target dataset.
141
+ dataset_id: Dataset ID to attach the experiment to.
142
+ experiment_runs: Experiment runs either as:
143
+ - a list of JSON-like dicts, or
144
+ - a pandas DataFrame.
145
+ task_fields: Mapping that identifies the columns/fields containing the
146
+ task results (e.g. `example_id`, output fields).
147
+ evaluator_columns: Optional mapping describing evaluator result columns.
148
+ force_http: If True, force REST upload even if the payload exceeds the
149
+ configured REST payload threshold.
150
+
151
+ Returns:
152
+ The created experiment object.
153
+
154
+ Raises:
155
+ TypeError: If `experiment_runs` is not a list of dicts or a DataFrame.
156
+ RuntimeError: If the Flight upload path is selected and the Flight request
157
+ fails.
158
+ arize._generated.api_client.exceptions.ApiException: If the REST API
159
+ returns an error response (e.g. 400/401/403/409/429).
160
+ """
161
+ if not isinstance(experiment_runs, list | pd.DataFrame):
162
+ raise TypeError(
163
+ "Experiment runs must be a list of dicts or a pandas DataFrame"
164
+ )
165
+ # transform experiment data to experiment format
166
+ experiment_df = transform_to_experiment_format(
167
+ experiment_runs, task_fields, evaluator_columns
168
+ )
169
+
170
+ below_threshold = (
171
+ get_payload_size_mb(experiment_runs)
172
+ <= self._sdk_config.max_http_payload_size_mb
173
+ )
174
+ if below_threshold or force_http:
175
+ from arize._generated import api_client as gen
176
+
177
+ data = experiment_df.to_dict(orient="records")
178
+
179
+ body = gen.ExperimentsCreateRequest(
180
+ name=name,
181
+ dataset_id=dataset_id,
182
+ experiment_runs=data, # type: ignore
183
+ )
184
+ return self._api.experiments_create(experiments_create_request=body)
185
+
186
+ # If we have too many examples, try to convert to a dataframe
187
+ # and log via gRPC + flight
188
+ logger.info(
189
+ f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
190
+ "Trying for more efficient upload via gRPC + Flight."
191
+ )
192
+
193
+ # TODO(Kiko): Space ID should not be needed,
194
+ # should work on server tech debt to remove this
195
+ dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
196
+ space_id = dataset.space_id
197
+
198
+ return self._create_experiment_via_flight(
199
+ name=name,
200
+ dataset_id=dataset_id,
201
+ space_id=space_id,
202
+ experiment_df=experiment_df,
203
+ )
204
+
205
+ @prerelease_endpoint(key="experiments.get", stage=ReleaseStage.BETA)
206
+ def get(self, *, experiment_id: str) -> models.Experiment:
207
+ """Get an experiment by ID.
208
+
209
+ The response does not include the experiment's runs. Use `list_runs()` to
210
+ retrieve runs for an experiment.
211
+
212
+ Args:
213
+ experiment_id: Experiment ID to retrieve.
214
+
215
+ Returns:
216
+ The experiment object.
217
+
218
+ Raises:
219
+ arize._generated.api_client.exceptions.ApiException: If the REST API
220
+ returns an error response (e.g. 401/403/404/429).
221
+ """
222
+ return self._api.experiments_get(experiment_id=experiment_id)
223
+
224
+ @prerelease_endpoint(key="experiments.delete", stage=ReleaseStage.BETA)
225
+ def delete(self, *, experiment_id: str) -> None:
226
+ """Delete an experiment by ID.
227
+
228
+ This operation is irreversible.
229
+
230
+ Args:
231
+ experiment_id: Experiment ID to delete.
232
+
233
+ Returns: This method returns None on success (common empty 204 response)
234
+
235
+ Raises:
236
+ arize._generated.api_client.exceptions.ApiException: If the REST API
237
+ returns an error response (e.g. 401/403/404/429).
238
+ """
239
+ return self._api.experiments_delete(
240
+ experiment_id=experiment_id,
241
+ )
242
+
243
+ @prerelease_endpoint(key="experiments.list_runs", stage=ReleaseStage.BETA)
244
+ def list_runs(
245
+ self,
246
+ *,
247
+ experiment_id: str,
248
+ limit: int = 100,
249
+ all: bool = False,
250
+ ) -> models.ExperimentsRunsList200Response:
251
+ """List runs for an experiment.
252
+
253
+ Runs are returned in insertion order.
254
+
255
+ Pagination notes:
256
+ - The response includes `pagination` for forward compatibility.
257
+ - Cursor pagination may not be fully implemented by the server yet.
258
+ - If `all=True`, this method retrieves all runs via the Flight path and
259
+ returns them in a single response with `has_more=False`.
260
+
261
+ Args:
262
+ experiment_id: Experiment ID to list runs for.
263
+ limit: Maximum number of runs to return when `all=False`. The server
264
+ enforces an upper bound.
265
+ all: If True, fetch all runs (ignores `limit`) via Flight and return a
266
+ single response.
267
+
268
+ Returns:
269
+ A response object containing `experiment_runs` and `pagination` metadata.
270
+
271
+ Raises:
272
+ RuntimeError: If the Flight request fails or returns no response when
273
+ `all=True`.
274
+ arize._generated.api_client.exceptions.ApiException: If the REST API
275
+ returns an error response when `all=False` (e.g. 401/403/404/429).
276
+ """
277
+ if not all:
278
+ return self._api.experiments_runs_list(
279
+ experiment_id=experiment_id,
280
+ limit=limit,
281
+ )
282
+
283
+ experiment = self.get(experiment_id=experiment_id)
284
+ experiment_updated_at = getattr(experiment, "updated_at", None)
285
+ # TODO(Kiko): Space ID should not be needed,
286
+ # should work on server tech debt to remove this
287
+ dataset = self._datasets_api.datasets_get(
288
+ dataset_id=experiment.dataset_id
289
+ )
290
+ space_id = dataset.space_id
291
+
292
+ experiment_df = None
293
+ # try to load dataset from cache
294
+ if self._sdk_config.enable_caching:
295
+ experiment_df = load_cached_resource(
296
+ cache_dir=self._sdk_config.cache_dir,
297
+ resource="experiment",
298
+ resource_id=experiment_id,
299
+ resource_updated_at=experiment_updated_at,
300
+ )
301
+ if experiment_df is not None:
302
+ return models.ExperimentsRunsList200Response(
303
+ experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
304
+ pagination=models.PaginationMetadata(
305
+ has_more=False, # Note that all=True
306
+ ),
307
+ )
67
308
 
68
- # Custom methods
69
- self.run = self._run_experiment
70
- self.create = self._create_experiment
71
- self.list_runs = self._api.experiments_runs_list
309
+ with ArizeFlightClient(
310
+ api_key=self._sdk_config.api_key,
311
+ host=self._sdk_config.flight_host,
312
+ port=self._sdk_config.flight_port,
313
+ scheme=self._sdk_config.flight_scheme,
314
+ request_verify=self._sdk_config.request_verify,
315
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
316
+ ) as flight_client:
317
+ try:
318
+ experiment_df = flight_client.get_experiment_runs(
319
+ space_id=space_id,
320
+ experiment_id=experiment_id,
321
+ )
322
+ except Exception as e:
323
+ msg = f"Error during request: {e!s}"
324
+ logger.exception(msg)
325
+ raise RuntimeError(msg) from e
326
+ if experiment_df is None:
327
+ # This should not happen with proper Flight client implementation,
328
+ # but we handle it defensively
329
+ msg = "No response received from flight server during request"
330
+ logger.error(msg)
331
+ raise RuntimeError(msg)
72
332
 
73
- def _run_experiment(
333
+ # cache experiment for future use
334
+ cache_resource(
335
+ cache_dir=self._sdk_config.cache_dir,
336
+ resource="experiment",
337
+ resource_id=experiment_id,
338
+ resource_updated_at=experiment_updated_at,
339
+ resource_data=experiment_df,
340
+ )
341
+
342
+ return models.ExperimentsRunsList200Response(
343
+ experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
344
+ pagination=models.PaginationMetadata(
345
+ has_more=False, # Note that all=True
346
+ ),
347
+ )
348
+
349
+ def run(
74
350
  self,
75
351
  *,
76
352
  name: str,
@@ -82,37 +358,46 @@ class ExperimentsClient:
82
358
  concurrency: int = 3,
83
359
  set_global_tracer_provider: bool = False,
84
360
  exit_on_error: bool = False,
85
- ) -> Tuple[Experiment | None, pd.DataFrame] | None:
86
- """
87
- Run an experiment on a dataset and upload the results.
361
+ ) -> tuple[models.Experiment | None, pd.DataFrame]:
362
+ """Run an experiment on a dataset and optionally upload results.
363
+
364
+ This method executes a task against dataset examples, optionally evaluates
365
+ outputs, and (when `dry_run=False`) uploads results to Arize.
88
366
 
89
- This function initializes an experiment, retrieves or uses a provided dataset,
90
- runs the experiment with specified tasks and evaluators, and uploads the results.
367
+ High-level flow:
368
+ 1) Resolve the dataset and `space_id`.
369
+ 2) Download dataset examples (or load from cache if enabled).
370
+ 3) Run the task and evaluators with configurable concurrency.
371
+ 4) If not a dry run, upload experiment runs and return the created
372
+ experiment plus the results dataframe.
373
+
374
+ Notes:
375
+ - If `dry_run=True`, no data is uploaded and the returned experiment is
376
+ `None`.
377
+ - When `enable_caching=True`, dataset examples may be cached and reused.
91
378
 
92
379
  Args:
93
- experiment_name (str): The name of the experiment.
94
- task (ExperimentTask): The task to be performed in the experiment.
95
- dataset_id (Optional[str], optional): The ID of the dataset to use.
96
- Required if dataset_df and dataset_name are not provided. Defaults to None.
97
- dataset_name (Optional[str], optional): The name of the dataset to use.
98
- Used if dataset_df and dataset_id are not provided. Defaults to None.
99
- evaluators (Optional[Evaluators], optional): The evaluators to use in the experiment.
100
- Defaults to None.
101
- dry_run (bool): If True, the experiment result will not be uploaded to Arize.
102
- Defaults to False.
103
- concurrency (int): The number of concurrent tasks to run. Defaults to 3.
104
- set_global_tracer_provider (bool): If True, sets the global tracer provider for the experiment.
105
- Defaults to False.
106
- exit_on_error (bool): If True, the experiment will stop running on first occurrence of an error.
380
+ name: Experiment name.
381
+ dataset_id: Dataset ID to run the experiment against.
382
+ task: The task to execute for each dataset example.
383
+ evaluators: Optional evaluators used to score outputs.
384
+ dry_run: If True, do not upload results to Arize.
385
+ dry_run_count: Number of dataset rows to use when `dry_run=True`.
386
+ concurrency: Number of concurrent tasks to run.
387
+ set_global_tracer_provider: If True, sets the global OpenTelemetry tracer
388
+ provider for the experiment run.
389
+ exit_on_error: If True, stop on the first error encountered during
390
+ execution.
107
391
 
108
392
  Returns:
109
- Tuple[str, pd.DataFrame]:
110
- A tuple of experiment ID and experiment result DataFrame.
111
- If dry_run is True, the experiment ID will be an empty string.
393
+ If `dry_run=True`, returns `(None, results_df)`.
394
+ If `dry_run=False`, returns `(experiment, results_df)`.
112
395
 
113
396
  Raises:
114
- ValueError: If dataset_id and dataset_name are both not provided, or if the dataset is empty.
115
- RuntimeError: If experiment initialization, dataset download, or result upload fails.
397
+ RuntimeError: If Flight operations (init/download/upload) fail or return
398
+ no response.
399
+ pa.ArrowInvalid: If converting results to Arrow fails.
400
+ Exception: For unexpected errors during Arrow conversion.
116
401
  """
117
402
  # TODO(Kiko): Space ID should not be needed,
118
403
  # should work on server tech debt to remove this
@@ -122,8 +407,8 @@ class ExperimentsClient:
122
407
 
123
408
  with ArizeFlightClient(
124
409
  api_key=self._sdk_config.api_key,
125
- host=self._sdk_config.flight_server_host,
126
- port=self._sdk_config.flight_server_port,
410
+ host=self._sdk_config.flight_host,
411
+ port=self._sdk_config.flight_port,
127
412
  scheme=self._sdk_config.flight_scheme,
128
413
  request_verify=self._sdk_config.request_verify,
129
414
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -141,8 +426,8 @@ class ExperimentsClient:
141
426
  experiment_name=name,
142
427
  )
143
428
  except Exception as e:
144
- msg = f"Error during request: {str(e)}"
145
- logger.error(msg)
429
+ msg = f"Error during request: {e!s}"
430
+ logger.exception(msg)
146
431
  raise RuntimeError(msg) from e
147
432
 
148
433
  if response is None:
@@ -173,8 +458,8 @@ class ExperimentsClient:
173
458
  dataset_id=dataset_id,
174
459
  )
175
460
  except Exception as e:
176
- msg = f"Error during request: {str(e)}"
177
- logger.error(msg)
461
+ msg = f"Error during request: {e!s}"
462
+ logger.exception(msg)
178
463
  raise RuntimeError(msg) from e
179
464
  if dataset_df is None:
180
465
  # This should not happen with proper Flight client implementation,
@@ -232,12 +517,12 @@ class ExperimentsClient:
232
517
  logger.debug("Converting data to Arrow format")
233
518
  pa_table = pa.Table.from_pandas(output_df, preserve_index=False)
234
519
  except pa.ArrowInvalid as e:
235
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
520
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
236
521
  raise pa.ArrowInvalid(
237
- f"Error converting to Arrow format: {str(e)}"
522
+ f"Error converting to Arrow format: {e!s}"
238
523
  ) from e
239
- except Exception as e:
240
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
524
+ except Exception:
525
+ logger.exception("Unexpected error creating Arrow table")
241
526
  raise
242
527
 
243
528
  request_type = FlightRequestType.LOG_EXPERIMENT_DATA
@@ -251,8 +536,8 @@ class ExperimentsClient:
251
536
  request_type=request_type,
252
537
  )
253
538
  except Exception as e:
254
- msg = f"Error during update request: {str(e)}"
255
- logger.error(msg)
539
+ msg = f"Error during update request: {e!s}"
540
+ logger.exception(msg)
256
541
  raise RuntimeError(msg) from e
257
542
 
258
543
  if post_resp is None:
@@ -267,200 +552,32 @@ class ExperimentsClient:
267
552
  )
268
553
  return experiment, output_df
269
554
 
270
- def _create_experiment(
271
- self,
272
- *,
273
- name: str,
274
- dataset_id: str,
275
- experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
276
- task_fields: ExperimentTaskResultFieldNames,
277
- evaluator_columns: Dict[str, EvaluationResultFieldNames] | None = None,
278
- force_http: bool = False,
279
- ) -> Experiment:
280
- """
281
- Log an experiment to Arize.
282
-
283
- Args:
284
- space_id (str): The ID of the space where the experiment will be logged.
285
- experiment_name (str): The name of the experiment.
286
- experiment_df (pd.DataFrame): The data to be logged.
287
- task_columns (ExperimentTaskResultColumnNames): The column names for task results.
288
- evaluator_columns (Optional[Dict[str, EvaluationResultColumnNames]]):
289
- The column names for evaluator results.
290
- dataset_id (str, optional): The ID of the dataset associated with the experiment.
291
- Required if dataset_name is not provided. Defaults to "".
292
- dataset_name (str, optional): The name of the dataset associated with the experiment.
293
- Required if dataset_id is not provided. Defaults to "".
294
-
295
- Examples:
296
- >>> # Example DataFrame:
297
- >>> df = pd.DataFrame({
298
- ... "example_id": ["1", "2"],
299
- ... "result": ["success", "failure"],
300
- ... "accuracy": [0.95, 0.85],
301
- ... "ground_truth": ["A", "B"],
302
- ... "explanation_text": ["Good match", "Poor match"],
303
- ... "confidence": [0.9, 0.7],
304
- ... "model_version": ["v1", "v2"],
305
- ... "custom_metric": [0.8, 0.6],
306
- ...})
307
- ...
308
- >>> # Define column mappings for task
309
- >>> task_cols = ExperimentTaskResultColumnNames(
310
- ... example_id="example_id", result="result"
311
- ...)
312
- >>> # Define column mappings for evaluator
313
- >>> evaluator_cols = EvaluationResultColumnNames(
314
- ... score="accuracy",
315
- ... label="ground_truth",
316
- ... explanation="explanation_text",
317
- ... metadata={
318
- ... "confidence": None, # Will use "confidence" column
319
- ... "version": "model_version", # Will use "model_version" column
320
- ... "custom_metric": None, # Will use "custom_metric" column
321
- ... },
322
- ... )
323
- >>> # Use with ArizeDatasetsClient.log_experiment()
324
- >>> ArizeDatasetsClient.log_experiment(
325
- ... space_id="my_space_id",
326
- ... experiment_name="my_experiment",
327
- ... experiment_df=df,
328
- ... task_columns=task_cols,
329
- ... evaluator_columns={"my_evaluator": evaluator_cols},
330
- ... dataset_name="my_dataset_name",
331
- ... )
332
-
333
- Returns:
334
- Optional[str]: The ID of the logged experiment, or None if the logging failed.
335
- """
336
- if not isinstance(experiment_runs, (list, pd.DataFrame)):
337
- raise TypeError(
338
- "Examples must be a list of dicts or a pandas DataFrame"
339
- )
340
- # transform experiment data to experiment format
341
- experiment_df = transform_to_experiment_format(
342
- experiment_runs, task_fields, evaluator_columns
343
- )
344
-
345
- below_threshold = (
346
- get_payload_size_mb(experiment_runs)
347
- <= self._sdk_config.max_http_payload_size_mb
348
- )
349
- if below_threshold or force_http:
350
- from arize._generated import api_client as gen
351
-
352
- data = experiment_df.to_dict(orient="records")
353
-
354
- body = gen.ExperimentsCreateRequest(
355
- name=name,
356
- datasetId=dataset_id,
357
- experimentRuns=data,
358
- )
359
- return self._api.experiments_create(experiments_create_request=body)
360
-
361
- # If we have too many examples, try to convert to a dataframe
362
- # and log via gRPC + flight
363
- logger.info(
364
- f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
365
- "Trying for more efficient upload via gRPC + Flight."
366
- )
367
-
368
- # TODO(Kiko): Space ID should not be needed,
369
- # should work on server tech debt to remove this
370
- dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
371
- space_id = dataset.space_id
372
-
373
- return self._create_experiment_via_flight(
374
- name=name,
375
- dataset_id=dataset_id,
376
- space_id=space_id,
377
- experiment_df=experiment_df,
378
- )
379
-
380
- def _list_runs(
381
- self,
382
- *,
383
- experiment_id: str,
384
- limit: int = 100,
385
- all: bool = False,
386
- ):
387
- if not all:
388
- return self._api.experiments_runs_list(
389
- experiment_id=experiment_id,
390
- limit=limit,
391
- )
392
-
393
- experiment = self.get(experiment_id=experiment_id)
394
- experiment_updated_at = getattr(experiment, "updated_at", None)
395
- # TODO(Kiko): Space ID should not be needed,
396
- # should work on server tech debt to remove this
397
- dataset = self._datasets_api.datasets_get(
398
- dataset_id=experiment.dataset_id
399
- )
400
- space_id = dataset.space_id
401
-
402
- experiment_df = None
403
- # try to load dataset from cache
404
- if self._sdk_config.enable_caching:
405
- experiment_df = load_cached_resource(
406
- cache_dir=self._sdk_config.cache_dir,
407
- resource="experiment",
408
- resource_id=experiment_id,
409
- resource_updated_at=experiment_updated_at,
410
- )
411
- if experiment_df is not None:
412
- return models.ExperimentsRunsList200Response(
413
- experimentRuns=experiment_df.to_dict(orient="records")
414
- )
415
-
416
- with ArizeFlightClient(
417
- api_key=self._sdk_config.api_key,
418
- host=self._sdk_config.flight_server_host,
419
- port=self._sdk_config.flight_server_port,
420
- scheme=self._sdk_config.flight_scheme,
421
- request_verify=self._sdk_config.request_verify,
422
- max_chunksize=self._sdk_config.pyarrow_max_chunksize,
423
- ) as flight_client:
424
- try:
425
- experiment_df = flight_client.get_experiment_runs(
426
- space_id=space_id,
427
- experiment_id=experiment_id,
428
- )
429
- except Exception as e:
430
- msg = f"Error during request: {str(e)}"
431
- logger.error(msg)
432
- raise RuntimeError(msg) from e
433
- if experiment_df is None:
434
- # This should not happen with proper Flight client implementation,
435
- # but we handle it defensively
436
- msg = "No response received from flight server during request"
437
- logger.error(msg)
438
- raise RuntimeError(msg)
439
-
440
- # cache dataset for future use
441
- cache_resource(
442
- cache_dir=self._sdk_config.cache_dir,
443
- resource="dataset",
444
- resource_id=experiment_id,
445
- resource_updated_at=experiment_updated_at,
446
- resource_data=experiment_df,
447
- )
448
-
449
- return models.ExperimentsRunsList200Response(
450
- experimentRuns=experiment_df.to_dict(orient="records")
451
- )
452
-
453
555
  def _create_experiment_via_flight(
454
556
  self,
455
557
  name: str,
456
558
  dataset_id: str,
457
559
  space_id: str,
458
560
  experiment_df: pd.DataFrame,
459
- ) -> Experiment:
561
+ ) -> models.Experiment:
562
+ """Internal method to create an experiment using Flight protocol for large datasets."""
563
+ # Convert to Arrow table
564
+ try:
565
+ logger.debug("Converting data to Arrow format")
566
+ pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
567
+ except pa.ArrowInvalid as e:
568
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
569
+ raise pa.ArrowInvalid(
570
+ f"Error converting to Arrow format: {e!s}"
571
+ ) from e
572
+ except Exception:
573
+ logger.exception("Unexpected error creating Arrow table")
574
+ raise
575
+
576
+ experiment_id = ""
460
577
  with ArizeFlightClient(
461
578
  api_key=self._sdk_config.api_key,
462
- host=self._sdk_config.flight_server_host,
463
- port=self._sdk_config.flight_server_port,
579
+ host=self._sdk_config.flight_host,
580
+ port=self._sdk_config.flight_port,
464
581
  scheme=self._sdk_config.flight_scheme,
465
582
  request_verify=self._sdk_config.request_verify,
466
583
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -474,8 +591,8 @@ class ExperimentsClient:
474
591
  experiment_name=name,
475
592
  )
476
593
  except Exception as e:
477
- msg = f"Error during request: {str(e)}"
478
- logger.error(msg)
594
+ msg = f"Error during request: {e!s}"
595
+ logger.exception(msg)
479
596
  raise RuntimeError(msg) from e
480
597
 
481
598
  if response is None:
@@ -484,49 +601,39 @@ class ExperimentsClient:
484
601
  msg = "No response received from flight server during request"
485
602
  logger.error(msg)
486
603
  raise RuntimeError(msg)
487
- experiment_id, _ = response
488
604
 
489
- # Convert to Arrow table
490
- try:
491
- logger.debug("Converting data to Arrow format")
492
- pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
493
- except pa.ArrowInvalid as e:
494
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
495
- raise pa.ArrowInvalid(
496
- f"Error converting to Arrow format: {str(e)}"
497
- ) from e
498
- except Exception as e:
499
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
500
- raise
605
+ experiment_id, _ = response
606
+ if not experiment_id:
607
+ msg = "No experiment ID received from flight server during request"
608
+ logger.error(msg)
609
+ raise RuntimeError(msg)
501
610
 
502
- request_type = FlightRequestType.LOG_EXPERIMENT_DATA
503
- post_resp = None
504
- try:
505
- post_resp = flight_client.log_arrow_table(
506
- space_id=space_id,
507
- pa_table=pa_table,
508
- dataset_id=dataset_id,
509
- experiment_name=experiment_id,
510
- request_type=request_type,
511
- )
512
- except Exception as e:
513
- msg = f"Error during update request: {str(e)}"
514
- logger.error(msg)
515
- raise RuntimeError(msg) from e
611
+ request_type = FlightRequestType.LOG_EXPERIMENT_DATA
612
+ post_resp = None
613
+ try:
614
+ post_resp = flight_client.log_arrow_table(
615
+ space_id=space_id,
616
+ pa_table=pa_table,
617
+ dataset_id=dataset_id,
618
+ experiment_name=name,
619
+ request_type=request_type,
620
+ )
621
+ except Exception as e:
622
+ msg = f"Error during update request: {e!s}"
623
+ logger.exception(msg)
624
+ raise RuntimeError(msg) from e
516
625
 
517
- if post_resp is None:
518
- # This should not happen with proper Flight client implementation,
519
- # but we handle it defensively
520
- msg = "No response received from flight server during request"
521
- logger.error(msg)
522
- raise RuntimeError(msg)
626
+ if post_resp is None:
627
+ # This should not happen with proper Flight client implementation,
628
+ # but we handle it defensively
629
+ msg = "No response received from flight server during request"
630
+ logger.error(msg)
631
+ raise RuntimeError(msg)
523
632
 
524
- experiment = self.get(
633
+ return self.get(
525
634
  experiment_id=str(post_resp.experiment_id) # type: ignore
526
635
  )
527
636
 
528
- return experiment
529
-
530
637
 
531
638
  def _get_tracer_resource(
532
639
  project_name: str,
@@ -535,7 +642,8 @@ def _get_tracer_resource(
535
642
  endpoint: str,
536
643
  dry_run: bool = False,
537
644
  set_global_tracer_provider: bool = False,
538
- ) -> Tuple[Tracer, Resource]:
645
+ ) -> tuple[Tracer, Resource]:
646
+ """Initialize and return an OpenTelemetry tracer and resource for experiment tracing."""
539
647
  resource = Resource(
540
648
  {
541
649
  ResourceAttributes.PROJECT_NAME: project_name,
@@ -547,7 +655,8 @@ def _get_tracer_resource(
547
655
  "arize-space-id": space_id,
548
656
  "arize-interface": "otel",
549
657
  }
550
- insecure = endpoint.startswith("http://")
658
+ use_tls = any(endpoint.startswith(v) for v in ["https://", "grpc+tls://"])
659
+ insecure = not use_tls
551
660
  exporter = (
552
661
  ConsoleSpanExporter()
553
662
  if dry_run