arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +268 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +389 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a22.dist-info/RECORD +0 -146
  166. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,10 +1,12 @@
1
+ """Client implementation for managing experiments in the Arize platform."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- import hashlib
4
5
  import logging
5
- from typing import TYPE_CHECKING, Any, Dict, List, Tuple
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import opentelemetry.sdk.trace as trace_sdk
9
+ import pandas as pd
8
10
  import pyarrow as pa
9
11
  from openinference.semconv.resource import ResourceAttributes
10
12
  from opentelemetry import trace
@@ -16,23 +18,16 @@ from opentelemetry.sdk.trace.export import (
16
18
  ConsoleSpanExporter,
17
19
  SimpleSpanProcessor,
18
20
  )
19
- from opentelemetry.trace import Tracer
20
21
 
21
22
  from arize._flight.client import ArizeFlightClient
22
23
  from arize._flight.types import FlightRequestType
23
24
  from arize._generated.api_client import models
24
- from arize.config import SDKConfiguration
25
25
  from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
26
- from arize.experiments.evaluators.base import Evaluators
27
- from arize.experiments.evaluators.types import EvaluationResultFieldNames
28
26
  from arize.experiments.functions import (
29
27
  run_experiment,
30
28
  transform_to_experiment_format,
31
29
  )
32
- from arize.experiments.types import (
33
- ExperimentTask,
34
- ExperimentTaskResultFieldNames,
35
- )
30
+ from arize.pre_releases import ReleaseStage, prerelease_endpoint
36
31
  from arize.utils.cache import cache_resource, load_cached_resource
37
32
  from arize.utils.openinference_conversion import (
38
33
  convert_boolean_columns_to_str,
@@ -41,16 +36,31 @@ from arize.utils.openinference_conversion import (
41
36
  from arize.utils.size import get_payload_size_mb
42
37
 
43
38
  if TYPE_CHECKING:
44
- import pandas as pd
45
-
46
- from arize._generated.api_client.models.experiment import Experiment
47
-
39
+ from opentelemetry.trace import Tracer
40
+
41
+ from arize.config import SDKConfiguration
42
+ from arize.experiments.evaluators.base import Evaluators
43
+ from arize.experiments.evaluators.types import EvaluationResultFieldNames
44
+ from arize.experiments.types import (
45
+ ExperimentTask,
46
+ ExperimentTaskResultFieldNames,
47
+ )
48
48
 
49
49
  logger = logging.getLogger(__name__)
50
50
 
51
51
 
52
52
  class ExperimentsClient:
53
- def __init__(self, *, sdk_config: SDKConfiguration):
53
+ """Client for managing experiments including creation, execution, and result tracking."""
54
+
55
+ def __init__(self, *, sdk_config: SDKConfiguration) -> None:
56
+ """Create an experiments sub-client.
57
+
58
+ The experiments client is a thin wrapper around the generated REST API client,
59
+ using the shared generated API client owned by `SDKConfiguration`.
60
+
61
+ Args:
62
+ sdk_config: Resolved SDK configuration.
63
+ """
54
64
  self._sdk_config = sdk_config
55
65
  from arize._generated import api_client as gen
56
66
 
@@ -61,16 +71,277 @@ class ExperimentsClient:
61
71
  self._sdk_config.get_generated_client()
62
72
  )
63
73
 
64
- self.list = self._api.experiments_list
65
- self.get = self._api.experiments_get
66
- self.delete = self._api.experiments_delete
74
+ @prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
75
+ def list(
76
+ self,
77
+ *,
78
+ dataset_id: str | None = None,
79
+ limit: int = 100,
80
+ cursor: str | None = None,
81
+ ) -> models.ExperimentsList200Response:
82
+ """List experiments the user has access to.
83
+
84
+ To filter experiments by the dataset they were run on, provide `dataset_id`.
85
+
86
+ Args:
87
+ dataset_id: Optional dataset ID to filter experiments.
88
+ limit: Maximum number of experiments to return. The server enforces an
89
+ upper bound.
90
+ cursor: Opaque pagination cursor returned from a previous response.
91
+
92
+ Returns:
93
+ A response object with the experiments and pagination information.
94
+
95
+ Raises:
96
+ arize._generated.api_client.exceptions.ApiException: If the REST API
97
+ returns an error response (e.g. 401/403/429).
98
+ """
99
+ return self._api.experiments_list(
100
+ dataset_id=dataset_id,
101
+ limit=limit,
102
+ cursor=cursor,
103
+ )
104
+
105
+ @prerelease_endpoint(key="experiments.create", stage=ReleaseStage.BETA)
106
+ def create(
107
+ self,
108
+ *,
109
+ name: str,
110
+ dataset_id: str,
111
+ experiment_runs: list[dict[str, object]] | pd.DataFrame,
112
+ task_fields: ExperimentTaskResultFieldNames,
113
+ evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
114
+ force_http: bool = False,
115
+ ) -> models.Experiment:
116
+ """Create an experiment with one or more experiment runs.
117
+
118
+ Experiments are composed of runs. Each run must include:
119
+ - `example_id`: ID of an existing example in the dataset/version
120
+ - `output`: Model/task output for the matching example
121
+
122
+ You may include any additional user-defined fields per run (e.g. `model`,
123
+ `latency_ms`, `temperature`, `prompt`, `tool_calls`, etc.) that can be used
124
+ for analysis or filtering.
125
+
126
+ This method transforms the input runs into the server's expected experiment
127
+ format using `task_fields` and optional `evaluator_columns`.
128
+
129
+ Transport selection:
130
+ - If the payload is below the configured REST payload threshold (or
131
+ `force_http=True`), this method uploads via REST.
132
+ - Otherwise, it attempts a more efficient upload path via gRPC + Flight.
133
+
134
+ Args:
135
+ name: Experiment name. Must be unique within the target dataset.
136
+ dataset_id: Dataset ID to attach the experiment to.
137
+ experiment_runs: Experiment runs either as:
138
+ - a list of JSON-like dicts, or
139
+ - a pandas DataFrame.
140
+ task_fields: Mapping that identifies the columns/fields containing the
141
+ task results (e.g. `example_id`, output fields).
142
+ evaluator_columns: Optional mapping describing evaluator result columns.
143
+ force_http: If True, force REST upload even if the payload exceeds the
144
+ configured REST payload threshold.
145
+
146
+ Returns:
147
+ The created experiment object.
148
+
149
+ Raises:
150
+ TypeError: If `experiment_runs` is not a list of dicts or a DataFrame.
151
+ RuntimeError: If the Flight upload path is selected and the Flight request
152
+ fails.
153
+ arize._generated.api_client.exceptions.ApiException: If the REST API
154
+ returns an error response (e.g. 400/401/403/409/429).
155
+ """
156
+ if not isinstance(experiment_runs, list | pd.DataFrame):
157
+ raise TypeError(
158
+ "Experiment runs must be a list of dicts or a pandas DataFrame"
159
+ )
160
+ # transform experiment data to experiment format
161
+ experiment_df = transform_to_experiment_format(
162
+ experiment_runs, task_fields, evaluator_columns
163
+ )
164
+
165
+ below_threshold = (
166
+ get_payload_size_mb(experiment_runs)
167
+ <= self._sdk_config.max_http_payload_size_mb
168
+ )
169
+ if below_threshold or force_http:
170
+ from arize._generated import api_client as gen
171
+
172
+ data = experiment_df.to_dict(orient="records")
173
+
174
+ body = gen.ExperimentsCreateRequest(
175
+ name=name,
176
+ dataset_id=dataset_id,
177
+ experiment_runs=data, # type: ignore
178
+ )
179
+ return self._api.experiments_create(experiments_create_request=body)
180
+
181
+ # If we have too many examples, try to convert to a dataframe
182
+ # and log via gRPC + flight
183
+ logger.info(
184
+ f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
185
+ "Trying for more efficient upload via gRPC + Flight."
186
+ )
187
+
188
+ # TODO(Kiko): Space ID should not be needed,
189
+ # should work on server tech debt to remove this
190
+ dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
191
+ space_id = dataset.space_id
192
+
193
+ return self._create_experiment_via_flight(
194
+ name=name,
195
+ dataset_id=dataset_id,
196
+ space_id=space_id,
197
+ experiment_df=experiment_df,
198
+ )
199
+
200
+ @prerelease_endpoint(key="experiments.get", stage=ReleaseStage.BETA)
201
+ def get(self, *, experiment_id: str) -> models.Experiment:
202
+ """Get an experiment by ID.
203
+
204
+ The response does not include the experiment's runs. Use `list_runs()` to
205
+ retrieve runs for an experiment.
206
+
207
+ Args:
208
+ experiment_id: Experiment ID to retrieve.
209
+
210
+ Returns:
211
+ The experiment object.
212
+
213
+ Raises:
214
+ arize._generated.api_client.exceptions.ApiException: If the REST API
215
+ returns an error response (e.g. 401/403/404/429).
216
+ """
217
+ return self._api.experiments_get(experiment_id=experiment_id)
218
+
219
+ @prerelease_endpoint(key="experiments.delete", stage=ReleaseStage.BETA)
220
+ def delete(self, *, experiment_id: str) -> None:
221
+ """Delete an experiment by ID.
222
+
223
+ This operation is irreversible.
224
+
225
+ Args:
226
+ experiment_id: Experiment ID to delete.
227
+
228
+ Returns: This method returns None on success (common empty 204 response)
229
+
230
+ Raises:
231
+ arize._generated.api_client.exceptions.ApiException: If the REST API
232
+ returns an error response (e.g. 401/403/404/429).
233
+ """
234
+ return self._api.experiments_delete(
235
+ experiment_id=experiment_id,
236
+ )
237
+
238
+ @prerelease_endpoint(key="experiments.list_runs", stage=ReleaseStage.BETA)
239
+ def list_runs(
240
+ self,
241
+ *,
242
+ experiment_id: str,
243
+ limit: int = 100,
244
+ all: bool = False,
245
+ ) -> models.ExperimentsRunsList200Response:
246
+ """List runs for an experiment.
247
+
248
+ Runs are returned in insertion order.
249
+
250
+ Pagination notes:
251
+ - The response includes `pagination` for forward compatibility.
252
+ - Cursor pagination may not be fully implemented by the server yet.
253
+ - If `all=True`, this method retrieves all runs via the Flight path and
254
+ returns them in a single response with `has_more=False`.
255
+
256
+ Args:
257
+ experiment_id: Experiment ID to list runs for.
258
+ limit: Maximum number of runs to return when `all=False`. The server
259
+ enforces an upper bound.
260
+ all: If True, fetch all runs (ignores `limit`) via Flight and return a
261
+ single response.
262
+
263
+ Returns:
264
+ A response object containing `experiment_runs` and `pagination` metadata.
265
+
266
+ Raises:
267
+ RuntimeError: If the Flight request fails or returns no response when
268
+ `all=True`.
269
+ arize._generated.api_client.exceptions.ApiException: If the REST API
270
+ returns an error response when `all=False` (e.g. 401/403/404/429).
271
+ """
272
+ if not all:
273
+ return self._api.experiments_runs_list(
274
+ experiment_id=experiment_id,
275
+ limit=limit,
276
+ )
277
+
278
+ experiment = self.get(experiment_id=experiment_id)
279
+ experiment_updated_at = getattr(experiment, "updated_at", None)
280
+ # TODO(Kiko): Space ID should not be needed,
281
+ # should work on server tech debt to remove this
282
+ dataset = self._datasets_api.datasets_get(
283
+ dataset_id=experiment.dataset_id
284
+ )
285
+ space_id = dataset.space_id
286
+
287
+ experiment_df = None
288
+ # try to load dataset from cache
289
+ if self._sdk_config.enable_caching:
290
+ experiment_df = load_cached_resource(
291
+ cache_dir=self._sdk_config.cache_dir,
292
+ resource="experiment",
293
+ resource_id=experiment_id,
294
+ resource_updated_at=experiment_updated_at,
295
+ )
296
+ if experiment_df is not None:
297
+ return models.ExperimentsRunsList200Response(
298
+ experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
299
+ pagination=models.PaginationMetadata(
300
+ has_more=False, # Note that all=True
301
+ ),
302
+ )
67
303
 
68
- # Custom methods
69
- self.run = self._run_experiment
70
- self.create = self._create_experiment
71
- self.list_runs = self._api.experiments_runs_list
304
+ with ArizeFlightClient(
305
+ api_key=self._sdk_config.api_key,
306
+ host=self._sdk_config.flight_host,
307
+ port=self._sdk_config.flight_port,
308
+ scheme=self._sdk_config.flight_scheme,
309
+ request_verify=self._sdk_config.request_verify,
310
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
311
+ ) as flight_client:
312
+ try:
313
+ experiment_df = flight_client.get_experiment_runs(
314
+ space_id=space_id,
315
+ experiment_id=experiment_id,
316
+ )
317
+ except Exception as e:
318
+ msg = f"Error during request: {e!s}"
319
+ logger.exception(msg)
320
+ raise RuntimeError(msg) from e
321
+ if experiment_df is None:
322
+ # This should not happen with proper Flight client implementation,
323
+ # but we handle it defensively
324
+ msg = "No response received from flight server during request"
325
+ logger.error(msg)
326
+ raise RuntimeError(msg)
72
327
 
73
- def _run_experiment(
328
+ # cache experiment for future use
329
+ cache_resource(
330
+ cache_dir=self._sdk_config.cache_dir,
331
+ resource="experiment",
332
+ resource_id=experiment_id,
333
+ resource_updated_at=experiment_updated_at,
334
+ resource_data=experiment_df,
335
+ )
336
+
337
+ return models.ExperimentsRunsList200Response(
338
+ experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
339
+ pagination=models.PaginationMetadata(
340
+ has_more=False, # Note that all=True
341
+ ),
342
+ )
343
+
344
+ def run(
74
345
  self,
75
346
  *,
76
347
  name: str,
@@ -82,37 +353,46 @@ class ExperimentsClient:
82
353
  concurrency: int = 3,
83
354
  set_global_tracer_provider: bool = False,
84
355
  exit_on_error: bool = False,
85
- ) -> Tuple[Experiment | None, pd.DataFrame] | None:
86
- """
87
- Run an experiment on a dataset and upload the results.
356
+ ) -> tuple[models.Experiment | None, pd.DataFrame]:
357
+ """Run an experiment on a dataset and optionally upload results.
358
+
359
+ This method executes a task against dataset examples, optionally evaluates
360
+ outputs, and (when `dry_run=False`) uploads results to Arize.
88
361
 
89
- This function initializes an experiment, retrieves or uses a provided dataset,
90
- runs the experiment with specified tasks and evaluators, and uploads the results.
362
+ High-level flow:
363
+ 1) Resolve the dataset and `space_id`.
364
+ 2) Download dataset examples (or load from cache if enabled).
365
+ 3) Run the task and evaluators with configurable concurrency.
366
+ 4) If not a dry run, upload experiment runs and return the created
367
+ experiment plus the results dataframe.
368
+
369
+ Notes:
370
+ - If `dry_run=True`, no data is uploaded and the returned experiment is
371
+ `None`.
372
+ - When `enable_caching=True`, dataset examples may be cached and reused.
91
373
 
92
374
  Args:
93
- experiment_name (str): The name of the experiment.
94
- task (ExperimentTask): The task to be performed in the experiment.
95
- dataset_id (Optional[str], optional): The ID of the dataset to use.
96
- Required if dataset_df and dataset_name are not provided. Defaults to None.
97
- dataset_name (Optional[str], optional): The name of the dataset to use.
98
- Used if dataset_df and dataset_id are not provided. Defaults to None.
99
- evaluators (Optional[Evaluators], optional): The evaluators to use in the experiment.
100
- Defaults to None.
101
- dry_run (bool): If True, the experiment result will not be uploaded to Arize.
102
- Defaults to False.
103
- concurrency (int): The number of concurrent tasks to run. Defaults to 3.
104
- set_global_tracer_provider (bool): If True, sets the global tracer provider for the experiment.
105
- Defaults to False.
106
- exit_on_error (bool): If True, the experiment will stop running on first occurrence of an error.
375
+ name: Experiment name.
376
+ dataset_id: Dataset ID to run the experiment against.
377
+ task: The task to execute for each dataset example.
378
+ evaluators: Optional evaluators used to score outputs.
379
+ dry_run: If True, do not upload results to Arize.
380
+ dry_run_count: Number of dataset rows to use when `dry_run=True`.
381
+ concurrency: Number of concurrent tasks to run.
382
+ set_global_tracer_provider: If True, sets the global OpenTelemetry tracer
383
+ provider for the experiment run.
384
+ exit_on_error: If True, stop on the first error encountered during
385
+ execution.
107
386
 
108
387
  Returns:
109
- Tuple[str, pd.DataFrame]:
110
- A tuple of experiment ID and experiment result DataFrame.
111
- If dry_run is True, the experiment ID will be an empty string.
388
+ If `dry_run=True`, returns `(None, results_df)`.
389
+ If `dry_run=False`, returns `(experiment, results_df)`.
112
390
 
113
391
  Raises:
114
- ValueError: If dataset_id and dataset_name are both not provided, or if the dataset is empty.
115
- RuntimeError: If experiment initialization, dataset download, or result upload fails.
392
+ RuntimeError: If Flight operations (init/download/upload) fail or return
393
+ no response.
394
+ pa.ArrowInvalid: If converting results to Arrow fails.
395
+ Exception: For unexpected errors during Arrow conversion.
116
396
  """
117
397
  # TODO(Kiko): Space ID should not be needed,
118
398
  # should work on server tech debt to remove this
@@ -122,8 +402,8 @@ class ExperimentsClient:
122
402
 
123
403
  with ArizeFlightClient(
124
404
  api_key=self._sdk_config.api_key,
125
- host=self._sdk_config.flight_server_host,
126
- port=self._sdk_config.flight_server_port,
405
+ host=self._sdk_config.flight_host,
406
+ port=self._sdk_config.flight_port,
127
407
  scheme=self._sdk_config.flight_scheme,
128
408
  request_verify=self._sdk_config.request_verify,
129
409
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -141,8 +421,8 @@ class ExperimentsClient:
141
421
  experiment_name=name,
142
422
  )
143
423
  except Exception as e:
144
- msg = f"Error during request: {str(e)}"
145
- logger.error(msg)
424
+ msg = f"Error during request: {e!s}"
425
+ logger.exception(msg)
146
426
  raise RuntimeError(msg) from e
147
427
 
148
428
  if response is None:
@@ -173,8 +453,8 @@ class ExperimentsClient:
173
453
  dataset_id=dataset_id,
174
454
  )
175
455
  except Exception as e:
176
- msg = f"Error during request: {str(e)}"
177
- logger.error(msg)
456
+ msg = f"Error during request: {e!s}"
457
+ logger.exception(msg)
178
458
  raise RuntimeError(msg) from e
179
459
  if dataset_df is None:
180
460
  # This should not happen with proper Flight client implementation,
@@ -232,12 +512,12 @@ class ExperimentsClient:
232
512
  logger.debug("Converting data to Arrow format")
233
513
  pa_table = pa.Table.from_pandas(output_df, preserve_index=False)
234
514
  except pa.ArrowInvalid as e:
235
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
515
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
236
516
  raise pa.ArrowInvalid(
237
- f"Error converting to Arrow format: {str(e)}"
517
+ f"Error converting to Arrow format: {e!s}"
238
518
  ) from e
239
- except Exception as e:
240
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
519
+ except Exception:
520
+ logger.exception("Unexpected error creating Arrow table")
241
521
  raise
242
522
 
243
523
  request_type = FlightRequestType.LOG_EXPERIMENT_DATA
@@ -251,8 +531,8 @@ class ExperimentsClient:
251
531
  request_type=request_type,
252
532
  )
253
533
  except Exception as e:
254
- msg = f"Error during update request: {str(e)}"
255
- logger.error(msg)
534
+ msg = f"Error during update request: {e!s}"
535
+ logger.exception(msg)
256
536
  raise RuntimeError(msg) from e
257
537
 
258
538
  if post_resp is None:
@@ -267,200 +547,32 @@ class ExperimentsClient:
267
547
  )
268
548
  return experiment, output_df
269
549
 
270
- def _create_experiment(
271
- self,
272
- *,
273
- name: str,
274
- dataset_id: str,
275
- experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
276
- task_fields: ExperimentTaskResultFieldNames,
277
- evaluator_columns: Dict[str, EvaluationResultFieldNames] | None = None,
278
- force_http: bool = False,
279
- ) -> Experiment:
280
- """
281
- Log an experiment to Arize.
282
-
283
- Args:
284
- space_id (str): The ID of the space where the experiment will be logged.
285
- experiment_name (str): The name of the experiment.
286
- experiment_df (pd.DataFrame): The data to be logged.
287
- task_columns (ExperimentTaskResultColumnNames): The column names for task results.
288
- evaluator_columns (Optional[Dict[str, EvaluationResultColumnNames]]):
289
- The column names for evaluator results.
290
- dataset_id (str, optional): The ID of the dataset associated with the experiment.
291
- Required if dataset_name is not provided. Defaults to "".
292
- dataset_name (str, optional): The name of the dataset associated with the experiment.
293
- Required if dataset_id is not provided. Defaults to "".
294
-
295
- Examples:
296
- >>> # Example DataFrame:
297
- >>> df = pd.DataFrame({
298
- ... "example_id": ["1", "2"],
299
- ... "result": ["success", "failure"],
300
- ... "accuracy": [0.95, 0.85],
301
- ... "ground_truth": ["A", "B"],
302
- ... "explanation_text": ["Good match", "Poor match"],
303
- ... "confidence": [0.9, 0.7],
304
- ... "model_version": ["v1", "v2"],
305
- ... "custom_metric": [0.8, 0.6],
306
- ...})
307
- ...
308
- >>> # Define column mappings for task
309
- >>> task_cols = ExperimentTaskResultColumnNames(
310
- ... example_id="example_id", result="result"
311
- ...)
312
- >>> # Define column mappings for evaluator
313
- >>> evaluator_cols = EvaluationResultColumnNames(
314
- ... score="accuracy",
315
- ... label="ground_truth",
316
- ... explanation="explanation_text",
317
- ... metadata={
318
- ... "confidence": None, # Will use "confidence" column
319
- ... "version": "model_version", # Will use "model_version" column
320
- ... "custom_metric": None, # Will use "custom_metric" column
321
- ... },
322
- ... )
323
- >>> # Use with ArizeDatasetsClient.log_experiment()
324
- >>> ArizeDatasetsClient.log_experiment(
325
- ... space_id="my_space_id",
326
- ... experiment_name="my_experiment",
327
- ... experiment_df=df,
328
- ... task_columns=task_cols,
329
- ... evaluator_columns={"my_evaluator": evaluator_cols},
330
- ... dataset_name="my_dataset_name",
331
- ... )
332
-
333
- Returns:
334
- Optional[str]: The ID of the logged experiment, or None if the logging failed.
335
- """
336
- if not isinstance(experiment_runs, (list, pd.DataFrame)):
337
- raise TypeError(
338
- "Examples must be a list of dicts or a pandas DataFrame"
339
- )
340
- # transform experiment data to experiment format
341
- experiment_df = transform_to_experiment_format(
342
- experiment_runs, task_fields, evaluator_columns
343
- )
344
-
345
- below_threshold = (
346
- get_payload_size_mb(experiment_runs)
347
- <= self._sdk_config.max_http_payload_size_mb
348
- )
349
- if below_threshold or force_http:
350
- from arize._generated import api_client as gen
351
-
352
- data = experiment_df.to_dict(orient="records")
353
-
354
- body = gen.ExperimentsCreateRequest(
355
- name=name,
356
- datasetId=dataset_id,
357
- experimentRuns=data,
358
- )
359
- return self._api.experiments_create(experiments_create_request=body)
360
-
361
- # If we have too many examples, try to convert to a dataframe
362
- # and log via gRPC + flight
363
- logger.info(
364
- f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
365
- "Trying for more efficient upload via gRPC + Flight."
366
- )
367
-
368
- # TODO(Kiko): Space ID should not be needed,
369
- # should work on server tech debt to remove this
370
- dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
371
- space_id = dataset.space_id
372
-
373
- return self._create_experiment_via_flight(
374
- name=name,
375
- dataset_id=dataset_id,
376
- space_id=space_id,
377
- experiment_df=experiment_df,
378
- )
379
-
380
- def _list_runs(
381
- self,
382
- *,
383
- experiment_id: str,
384
- limit: int = 100,
385
- all: bool = False,
386
- ):
387
- if not all:
388
- return self._api.experiments_runs_list(
389
- experiment_id=experiment_id,
390
- limit=limit,
391
- )
392
-
393
- experiment = self.get(experiment_id=experiment_id)
394
- experiment_updated_at = getattr(experiment, "updated_at", None)
395
- # TODO(Kiko): Space ID should not be needed,
396
- # should work on server tech debt to remove this
397
- dataset = self._datasets_api.datasets_get(
398
- dataset_id=experiment.dataset_id
399
- )
400
- space_id = dataset.space_id
401
-
402
- experiment_df = None
403
- # try to load dataset from cache
404
- if self._sdk_config.enable_caching:
405
- experiment_df = load_cached_resource(
406
- cache_dir=self._sdk_config.cache_dir,
407
- resource="experiment",
408
- resource_id=experiment_id,
409
- resource_updated_at=experiment_updated_at,
410
- )
411
- if experiment_df is not None:
412
- return models.ExperimentsRunsList200Response(
413
- experimentRuns=experiment_df.to_dict(orient="records")
414
- )
415
-
416
- with ArizeFlightClient(
417
- api_key=self._sdk_config.api_key,
418
- host=self._sdk_config.flight_server_host,
419
- port=self._sdk_config.flight_server_port,
420
- scheme=self._sdk_config.flight_scheme,
421
- request_verify=self._sdk_config.request_verify,
422
- max_chunksize=self._sdk_config.pyarrow_max_chunksize,
423
- ) as flight_client:
424
- try:
425
- experiment_df = flight_client.get_experiment_runs(
426
- space_id=space_id,
427
- experiment_id=experiment_id,
428
- )
429
- except Exception as e:
430
- msg = f"Error during request: {str(e)}"
431
- logger.error(msg)
432
- raise RuntimeError(msg) from e
433
- if experiment_df is None:
434
- # This should not happen with proper Flight client implementation,
435
- # but we handle it defensively
436
- msg = "No response received from flight server during request"
437
- logger.error(msg)
438
- raise RuntimeError(msg)
439
-
440
- # cache dataset for future use
441
- cache_resource(
442
- cache_dir=self._sdk_config.cache_dir,
443
- resource="dataset",
444
- resource_id=experiment_id,
445
- resource_updated_at=experiment_updated_at,
446
- resource_data=experiment_df,
447
- )
448
-
449
- return models.ExperimentsRunsList200Response(
450
- experimentRuns=experiment_df.to_dict(orient="records")
451
- )
452
-
453
550
  def _create_experiment_via_flight(
454
551
  self,
455
552
  name: str,
456
553
  dataset_id: str,
457
554
  space_id: str,
458
555
  experiment_df: pd.DataFrame,
459
- ) -> Experiment:
556
+ ) -> models.Experiment:
557
+ """Internal method to create an experiment using Flight protocol for large datasets."""
558
+ # Convert to Arrow table
559
+ try:
560
+ logger.debug("Converting data to Arrow format")
561
+ pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
562
+ except pa.ArrowInvalid as e:
563
+ logger.exception(INVALID_ARROW_CONVERSION_MSG)
564
+ raise pa.ArrowInvalid(
565
+ f"Error converting to Arrow format: {e!s}"
566
+ ) from e
567
+ except Exception:
568
+ logger.exception("Unexpected error creating Arrow table")
569
+ raise
570
+
571
+ experiment_id = ""
460
572
  with ArizeFlightClient(
461
573
  api_key=self._sdk_config.api_key,
462
- host=self._sdk_config.flight_server_host,
463
- port=self._sdk_config.flight_server_port,
574
+ host=self._sdk_config.flight_host,
575
+ port=self._sdk_config.flight_port,
464
576
  scheme=self._sdk_config.flight_scheme,
465
577
  request_verify=self._sdk_config.request_verify,
466
578
  max_chunksize=self._sdk_config.pyarrow_max_chunksize,
@@ -474,8 +586,8 @@ class ExperimentsClient:
474
586
  experiment_name=name,
475
587
  )
476
588
  except Exception as e:
477
- msg = f"Error during request: {str(e)}"
478
- logger.error(msg)
589
+ msg = f"Error during request: {e!s}"
590
+ logger.exception(msg)
479
591
  raise RuntimeError(msg) from e
480
592
 
481
593
  if response is None:
@@ -484,49 +596,39 @@ class ExperimentsClient:
484
596
  msg = "No response received from flight server during request"
485
597
  logger.error(msg)
486
598
  raise RuntimeError(msg)
487
- experiment_id, _ = response
488
599
 
489
- # Convert to Arrow table
490
- try:
491
- logger.debug("Converting data to Arrow format")
492
- pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
493
- except pa.ArrowInvalid as e:
494
- logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
495
- raise pa.ArrowInvalid(
496
- f"Error converting to Arrow format: {str(e)}"
497
- ) from e
498
- except Exception as e:
499
- logger.error(f"Unexpected error creating Arrow table: {str(e)}")
500
- raise
600
+ experiment_id, _ = response
601
+ if not experiment_id:
602
+ msg = "No experiment ID received from flight server during request"
603
+ logger.error(msg)
604
+ raise RuntimeError(msg)
501
605
 
502
- request_type = FlightRequestType.LOG_EXPERIMENT_DATA
503
- post_resp = None
504
- try:
505
- post_resp = flight_client.log_arrow_table(
506
- space_id=space_id,
507
- pa_table=pa_table,
508
- dataset_id=dataset_id,
509
- experiment_name=experiment_id,
510
- request_type=request_type,
511
- )
512
- except Exception as e:
513
- msg = f"Error during update request: {str(e)}"
514
- logger.error(msg)
515
- raise RuntimeError(msg) from e
606
+ request_type = FlightRequestType.LOG_EXPERIMENT_DATA
607
+ post_resp = None
608
+ try:
609
+ post_resp = flight_client.log_arrow_table(
610
+ space_id=space_id,
611
+ pa_table=pa_table,
612
+ dataset_id=dataset_id,
613
+ experiment_name=name,
614
+ request_type=request_type,
615
+ )
616
+ except Exception as e:
617
+ msg = f"Error during update request: {e!s}"
618
+ logger.exception(msg)
619
+ raise RuntimeError(msg) from e
516
620
 
517
- if post_resp is None:
518
- # This should not happen with proper Flight client implementation,
519
- # but we handle it defensively
520
- msg = "No response received from flight server during request"
521
- logger.error(msg)
522
- raise RuntimeError(msg)
621
+ if post_resp is None:
622
+ # This should not happen with proper Flight client implementation,
623
+ # but we handle it defensively
624
+ msg = "No response received from flight server during request"
625
+ logger.error(msg)
626
+ raise RuntimeError(msg)
523
627
 
524
- experiment = self.get(
628
+ return self.get(
525
629
  experiment_id=str(post_resp.experiment_id) # type: ignore
526
630
  )
527
631
 
528
- return experiment
529
-
530
632
 
531
633
  def _get_tracer_resource(
532
634
  project_name: str,
@@ -535,7 +637,8 @@ def _get_tracer_resource(
535
637
  endpoint: str,
536
638
  dry_run: bool = False,
537
639
  set_global_tracer_provider: bool = False,
538
- ) -> Tuple[Tracer, Resource]:
640
+ ) -> tuple[Tracer, Resource]:
641
+ """Initialize and return an OpenTelemetry tracer and resource for experiment tracing."""
539
642
  resource = Resource(
540
643
  {
541
644
  ResourceAttributes.PROJECT_NAME: project_name,
@@ -547,7 +650,8 @@ def _get_tracer_resource(
547
650
  "arize-space-id": space_id,
548
651
  "arize-interface": "otel",
549
652
  }
550
- insecure = endpoint.startswith("http://")
653
+ use_tls = any(endpoint.startswith(v) for v in ["https://", "grpc+tls://"])
654
+ insecure = not use_tls
551
655
  exporter = (
552
656
  ConsoleSpanExporter()
553
657
  if dry_run