arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. arize/__init__.py +70 -1
  2. arize/_flight/client.py +163 -43
  3. arize/_flight/types.py +1 -0
  4. arize/_generated/api_client/__init__.py +5 -1
  5. arize/_generated/api_client/api/datasets_api.py +6 -6
  6. arize/_generated/api_client/api/experiments_api.py +924 -61
  7. arize/_generated/api_client/api_client.py +1 -1
  8. arize/_generated/api_client/configuration.py +1 -1
  9. arize/_generated/api_client/exceptions.py +1 -1
  10. arize/_generated/api_client/models/__init__.py +3 -1
  11. arize/_generated/api_client/models/dataset.py +2 -2
  12. arize/_generated/api_client/models/dataset_version.py +1 -1
  13. arize/_generated/api_client/models/datasets_create_request.py +3 -3
  14. arize/_generated/api_client/models/datasets_list200_response.py +1 -1
  15. arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
  16. arize/_generated/api_client/models/error.py +1 -1
  17. arize/_generated/api_client/models/experiment.py +6 -6
  18. arize/_generated/api_client/models/experiments_create_request.py +98 -0
  19. arize/_generated/api_client/models/experiments_list200_response.py +1 -1
  20. arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
  21. arize/_generated/api_client/rest.py +1 -1
  22. arize/_generated/api_client/test/test_dataset.py +2 -1
  23. arize/_generated/api_client/test/test_dataset_version.py +1 -1
  24. arize/_generated/api_client/test/test_datasets_api.py +1 -1
  25. arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
  26. arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
  27. arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
  28. arize/_generated/api_client/test/test_error.py +1 -1
  29. arize/_generated/api_client/test/test_experiment.py +6 -1
  30. arize/_generated/api_client/test/test_experiments_api.py +23 -2
  31. arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
  32. arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
  33. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
  34. arize/_generated/api_client_README.md +13 -8
  35. arize/client.py +19 -2
  36. arize/config.py +50 -3
  37. arize/constants/config.py +8 -2
  38. arize/constants/openinference.py +14 -0
  39. arize/constants/pyarrow.py +1 -0
  40. arize/datasets/__init__.py +0 -70
  41. arize/datasets/client.py +106 -19
  42. arize/datasets/errors.py +61 -0
  43. arize/datasets/validation.py +46 -0
  44. arize/experiments/client.py +455 -0
  45. arize/experiments/evaluators/__init__.py +0 -0
  46. arize/experiments/evaluators/base.py +255 -0
  47. arize/experiments/evaluators/exceptions.py +10 -0
  48. arize/experiments/evaluators/executors.py +502 -0
  49. arize/experiments/evaluators/rate_limiters.py +277 -0
  50. arize/experiments/evaluators/types.py +122 -0
  51. arize/experiments/evaluators/utils.py +198 -0
  52. arize/experiments/functions.py +920 -0
  53. arize/experiments/tracing.py +276 -0
  54. arize/experiments/types.py +394 -0
  55. arize/models/client.py +4 -1
  56. arize/spans/client.py +16 -20
  57. arize/utils/arrow.py +4 -3
  58. arize/utils/openinference_conversion.py +56 -0
  59. arize/utils/proto.py +13 -0
  60. arize/utils/size.py +22 -0
  61. arize/version.py +1 -1
  62. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
  63. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
  64. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
  65. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,4 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
5
+
6
+ import opentelemetry.sdk.trace as trace_sdk
7
+ import pyarrow as pa
8
+ from openinference.semconv.resource import ResourceAttributes
9
+ from opentelemetry import trace
10
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
11
+ OTLPSpanExporter as GrpcSpanExporter,
12
+ )
13
+ from opentelemetry.sdk.resources import Resource
14
+ from opentelemetry.sdk.trace.export import (
15
+ ConsoleSpanExporter,
16
+ SimpleSpanProcessor,
17
+ )
18
+ from opentelemetry.trace import Tracer
19
+
20
+ from arize._flight.client import ArizeFlightClient
21
+ from arize._flight.types import FlightRequestType
1
22
  from arize.config import SDKConfiguration
23
+ from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
24
+ from arize.experiments.evaluators.base import Evaluators
25
+ from arize.experiments.evaluators.types import EvaluationResultFieldNames
26
+ from arize.experiments.functions import (
27
+ run_experiment,
28
+ transform_to_experiment_format,
29
+ )
30
+ from arize.experiments.types import (
31
+ ExperimentTask,
32
+ ExperimentTaskResultFieldNames,
33
+ )
34
+ from arize.utils.openinference_conversion import (
35
+ convert_boolean_columns_to_str,
36
+ convert_default_columns_to_json_str,
37
+ )
38
+ from arize.utils.size import get_payload_size_mb
39
+
40
+ if TYPE_CHECKING:
41
+ import pandas as pd
42
+
43
+ from arize._generated.api_client.models.experiment import Experiment
44
+
45
+
46
+ logger = logging.getLogger(__name__)
2
47
 
3
48
 
4
49
  class ExperimentsClient:
@@ -7,4 +52,414 @@ class ExperimentsClient:
7
52
  from arize._generated import api_client as gen
8
53
 
9
54
  self._api = gen.ExperimentsApi(self._sdk_config.get_generated_client())
55
+ # TODO(Kiko): Space ID should not be needed,
56
+ # should work on server tech debt to remove this
57
+ self._datasets_api = gen.DatasetsApi(
58
+ self._sdk_config.get_generated_client()
59
+ )
10
60
  self.list = self._api.experiments_list
61
+ self.get = self._api.experiments_get
62
+ self.delete = self._api.experiments_delete
63
+ self.list_runs = self._api.experiments_runs_list # REST ?
64
+
65
+ # Custom methods
66
+ self.create = self._create_experiment
67
+ self.run = self._run_experiment
68
+
69
+ def _run_experiment(
70
+ self,
71
+ name: str,
72
+ dataset_id: str,
73
+ task: ExperimentTask,
74
+ dataset_df: pd.DataFrame | None = None,
75
+ evaluators: Evaluators | None = None,
76
+ dry_run: bool = False,
77
+ concurrency: int = 3,
78
+ set_global_tracer_provider: bool = False,
79
+ exit_on_error: bool = False,
80
+ ) -> Tuple[str, pd.DataFrame] | None:
81
+ """
82
+ Run an experiment on a dataset and upload the results.
83
+
84
+ This function initializes an experiment, retrieves or uses a provided dataset,
85
+ runs the experiment with specified tasks and evaluators, and uploads the results.
86
+
87
+ Args:
88
+ experiment_name (str): The name of the experiment.
89
+ task (ExperimentTask): The task to be performed in the experiment.
90
+ dataset_df (Optional[pd.DataFrame], optional): The dataset as a pandas DataFrame.
91
+ If not provided, the dataset will be downloaded using dataset_id or dataset_name.
92
+ Defaults to None.
93
+ dataset_id (Optional[str], optional): The ID of the dataset to use.
94
+ Required if dataset_df and dataset_name are not provided. Defaults to None.
95
+ dataset_name (Optional[str], optional): The name of the dataset to use.
96
+ Used if dataset_df and dataset_id are not provided. Defaults to None.
97
+ evaluators (Optional[Evaluators], optional): The evaluators to use in the experiment.
98
+ Defaults to None.
99
+ dry_run (bool): If True, the experiment result will not be uploaded to Arize.
100
+ Defaults to False.
101
+ concurrency (int): The number of concurrent tasks to run. Defaults to 3.
102
+ set_global_tracer_provider (bool): If True, sets the global tracer provider for the experiment.
103
+ Defaults to False.
104
+ exit_on_error (bool): If True, the experiment will stop running on first occurrence of an error.
105
+
106
+ Returns:
107
+ Tuple[str, pd.DataFrame]:
108
+ A tuple of experiment ID and experiment result DataFrame.
109
+ If dry_run is True, the experiment ID will be an empty string.
110
+
111
+ Raises:
112
+ ValueError: If dataset_id and dataset_name are both not provided, or if the dataset is empty.
113
+ RuntimeError: If experiment initialization, dataset download, or result upload fails.
114
+ """
115
+ # TODO(Kiko): Space ID should not be needed,
116
+ # should work on server tech debt to remove this
117
+ dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
118
+ space_id = dataset.space_id
119
+
120
+ with ArizeFlightClient(
121
+ api_key=self._sdk_config.api_key,
122
+ host=self._sdk_config.flight_server_host,
123
+ port=self._sdk_config.flight_server_port,
124
+ scheme=self._sdk_config.flight_scheme,
125
+ request_verify=self._sdk_config.request_verify,
126
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
127
+ ) as flight_client:
128
+ # set up initial experiment and trace model
129
+ if dry_run:
130
+ trace_model_name = "traces_for_dry_run"
131
+ experiment_id = "experiment_id_for_dry_run"
132
+ else:
133
+ response = None
134
+ try:
135
+ response = flight_client.init_experiment(
136
+ space_id=space_id,
137
+ dataset_id=dataset_id,
138
+ experiment_name=name,
139
+ )
140
+ except Exception as e:
141
+ msg = f"Error during request: {str(e)}"
142
+ logger.error(msg)
143
+ raise RuntimeError(msg) from e
144
+
145
+ if response is None:
146
+ # This should not happen with proper Flight client implementation,
147
+ # but we handle it defensively
148
+ msg = (
149
+ "No response received from flight server during request"
150
+ )
151
+ logger.error(msg)
152
+ raise RuntimeError(msg)
153
+ experiment_id, trace_model_name = response
154
+
155
+ # download dataset if not provided
156
+ if dataset_df is None:
157
+ try:
158
+ response = flight_client.get_dataset_examples(
159
+ space_id=space_id,
160
+ dataset_id=dataset_id,
161
+ )
162
+ except Exception as e:
163
+ msg = f"Error during request: {str(e)}"
164
+ logger.error(msg)
165
+ raise RuntimeError(msg) from e
166
+ if response is None:
167
+ # This should not happen with proper Flight client implementation,
168
+ # but we handle it defensively
169
+ msg = (
170
+ "No response received from flight server during request"
171
+ )
172
+ logger.error(msg)
173
+ raise RuntimeError(msg)
174
+
175
+ if dataset_df is None or dataset_df.empty:
176
+ raise ValueError(f"Dataset {dataset_id} is empty")
177
+
178
+ input_df = dataset_df.copy()
179
+ if dry_run:
180
+ # only dry_run experiment on a subset (first 10 rows) of the dataset
181
+ input_df = input_df.head(10)
182
+
183
+ # trace model and resource for the experiment
184
+ tracer, resource = _get_tracer_resource(
185
+ project_name=trace_model_name,
186
+ space_id=space_id,
187
+ api_key=self._sdk_config.api_key,
188
+ endpoint=self._sdk_config.otlp_url,
189
+ dry_run=dry_run,
190
+ set_global_tracer_provider=set_global_tracer_provider,
191
+ )
192
+
193
+ output_df = run_experiment(
194
+ experiment_name=name,
195
+ experiment_id=experiment_id,
196
+ dataset=input_df,
197
+ task=task,
198
+ tracer=tracer,
199
+ resource=resource,
200
+ evaluators=evaluators,
201
+ concurrency=concurrency,
202
+ exit_on_error=exit_on_error,
203
+ )
204
+ output_df = convert_default_columns_to_json_str(output_df)
205
+ output_df = convert_boolean_columns_to_str(output_df)
206
+ if dry_run:
207
+ return "", output_df
208
+
209
+ # Convert to Arrow table
210
+ try:
211
+ logger.debug("Converting data to Arrow format")
212
+ pa_table = pa.Table.from_pandas(output_df, preserve_index=False)
213
+ except pa.ArrowInvalid as e:
214
+ logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
215
+ raise pa.ArrowInvalid(
216
+ f"Error converting to Arrow format: {str(e)}"
217
+ ) from e
218
+ except Exception as e:
219
+ logger.error(f"Unexpected error creating Arrow table: {str(e)}")
220
+ raise
221
+
222
+ request_type = FlightRequestType.LOG_EXPERIMENT_DATA
223
+ post_resp = None
224
+ try:
225
+ post_resp = flight_client.log_arrow_table(
226
+ space_id=space_id,
227
+ pa_table=pa_table,
228
+ dataset_id=dataset_id,
229
+ experiment_name=experiment_id,
230
+ request_type=request_type,
231
+ )
232
+ except Exception as e:
233
+ msg = f"Error during update request: {str(e)}"
234
+ logger.error(msg)
235
+ raise RuntimeError(msg) from e
236
+
237
+ if post_resp is None:
238
+ # This should not happen with proper Flight client implementation,
239
+ # but we handle it defensively
240
+ msg = "No response received from flight server during request"
241
+ logger.error(msg)
242
+ raise RuntimeError(msg)
243
+
244
+ return str(post_resp.experiment_id), output_df # type: ignore
245
+
246
+ def _create_experiment(
247
+ self,
248
+ name: str,
249
+ dataset_id: str,
250
+ experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
251
+ task_fields: ExperimentTaskResultFieldNames,
252
+ evaluator_columns: Dict[str, EvaluationResultFieldNames] | None = None,
253
+ force_http: bool = False,
254
+ ) -> Experiment:
255
+ """
256
+ Log an experiment to Arize.
257
+
258
+ Args:
259
+ space_id (str): The ID of the space where the experiment will be logged.
260
+ experiment_name (str): The name of the experiment.
261
+ experiment_df (pd.DataFrame): The data to be logged.
262
+ task_columns (ExperimentTaskResultColumnNames): The column names for task results.
263
+ evaluator_columns (Optional[Dict[str, EvaluationResultColumnNames]]):
264
+ The column names for evaluator results.
265
+ dataset_id (str, optional): The ID of the dataset associated with the experiment.
266
+ Required if dataset_name is not provided. Defaults to "".
267
+ dataset_name (str, optional): The name of the dataset associated with the experiment.
268
+ Required if dataset_id is not provided. Defaults to "".
269
+
270
+ Examples:
271
+ >>> # Example DataFrame:
272
+ >>> df = pd.DataFrame({
273
+ ... "example_id": ["1", "2"],
274
+ ... "result": ["success", "failure"],
275
+ ... "accuracy": [0.95, 0.85],
276
+ ... "ground_truth": ["A", "B"],
277
+ ... "explanation_text": ["Good match", "Poor match"],
278
+ ... "confidence": [0.9, 0.7],
279
+ ... "model_version": ["v1", "v2"],
280
+ ... "custom_metric": [0.8, 0.6],
281
+ ...})
282
+ ...
283
+ >>> # Define column mappings for task
284
+ >>> task_cols = ExperimentTaskResultColumnNames(
285
+ ... example_id="example_id", result="result"
286
+ ...)
287
+ >>> # Define column mappings for evaluator
288
+ >>> evaluator_cols = EvaluationResultColumnNames(
289
+ ... score="accuracy",
290
+ ... label="ground_truth",
291
+ ... explanation="explanation_text",
292
+ ... metadata={
293
+ ... "confidence": None, # Will use "confidence" column
294
+ ... "version": "model_version", # Will use "model_version" column
295
+ ... "custom_metric": None, # Will use "custom_metric" column
296
+ ... },
297
+ ... )
298
+ >>> # Use with ArizeDatasetsClient.log_experiment()
299
+ >>> ArizeDatasetsClient.log_experiment(
300
+ ... space_id="my_space_id",
301
+ ... experiment_name="my_experiment",
302
+ ... experiment_df=df,
303
+ ... task_columns=task_cols,
304
+ ... evaluator_columns={"my_evaluator": evaluator_cols},
305
+ ... dataset_name="my_dataset_name",
306
+ ... )
307
+
308
+ Returns:
309
+ Optional[str]: The ID of the logged experiment, or None if the logging failed.
310
+ """
311
+ if not isinstance(experiment_runs, (list, pd.DataFrame)):
312
+ raise TypeError(
313
+ "Examples must be a list of dicts or a pandas DataFrame"
314
+ )
315
+ # transform experiment data to experiment format
316
+ experiment_df = transform_to_experiment_format(
317
+ experiment_runs, task_fields, evaluator_columns
318
+ )
319
+
320
+ below_threshold = (
321
+ get_payload_size_mb(experiment_runs)
322
+ <= self._sdk_config.max_http_payload_size_mb
323
+ )
324
+ if below_threshold or force_http:
325
+ from arize._generated import api_client as gen
326
+
327
+ data = experiment_df.to_dict(orient="records")
328
+
329
+ body = gen.ExperimentsCreateRequest(
330
+ name=name,
331
+ datasetId=dataset_id,
332
+ experimentRuns=data,
333
+ )
334
+ return self._api.experiments_create(experiments_create_request=body)
335
+
336
+ # If we have too many examples, try to convert to a dataframe
337
+ # and log via gRPC + flight
338
+ logger.info(
339
+ f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
340
+ "Trying for more efficient upload via gRPC + Flight."
341
+ )
342
+
343
+ # TODO(Kiko): Space ID should not be needed,
344
+ # should work on server tech debt to remove this
345
+ dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
346
+ space_id = dataset.space_id
347
+
348
+ return self._create_experiment_via_flight(
349
+ name=name,
350
+ dataset_id=dataset_id,
351
+ space_id=space_id,
352
+ experiment_df=experiment_df,
353
+ )
354
+
355
+ def _create_experiment_via_flight(
356
+ self,
357
+ name: str,
358
+ dataset_id: str,
359
+ space_id: str,
360
+ experiment_df: pd.DataFrame,
361
+ ) -> Experiment:
362
+ with ArizeFlightClient(
363
+ api_key=self._sdk_config.api_key,
364
+ host=self._sdk_config.flight_server_host,
365
+ port=self._sdk_config.flight_server_port,
366
+ scheme=self._sdk_config.flight_scheme,
367
+ request_verify=self._sdk_config.request_verify,
368
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
369
+ ) as flight_client:
370
+ # set up initial experiment and trace model
371
+ response = None
372
+ try:
373
+ response = flight_client.init_experiment(
374
+ space_id=space_id,
375
+ dataset_id=dataset_id,
376
+ experiment_name=name,
377
+ )
378
+ except Exception as e:
379
+ msg = f"Error during request: {str(e)}"
380
+ logger.error(msg)
381
+ raise RuntimeError(msg) from e
382
+
383
+ if response is None:
384
+ # This should not happen with proper Flight client implementation,
385
+ # but we handle it defensively
386
+ msg = "No response received from flight server during request"
387
+ logger.error(msg)
388
+ raise RuntimeError(msg)
389
+ experiment_id, _ = response
390
+
391
+ # Convert to Arrow table
392
+ try:
393
+ logger.debug("Converting data to Arrow format")
394
+ pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
395
+ except pa.ArrowInvalid as e:
396
+ logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
397
+ raise pa.ArrowInvalid(
398
+ f"Error converting to Arrow format: {str(e)}"
399
+ ) from e
400
+ except Exception as e:
401
+ logger.error(f"Unexpected error creating Arrow table: {str(e)}")
402
+ raise
403
+
404
+ request_type = FlightRequestType.LOG_EXPERIMENT_DATA
405
+ post_resp = None
406
+ try:
407
+ post_resp = flight_client.log_arrow_table(
408
+ space_id=space_id,
409
+ pa_table=pa_table,
410
+ dataset_id=dataset_id,
411
+ experiment_name=experiment_id,
412
+ request_type=request_type,
413
+ )
414
+ except Exception as e:
415
+ msg = f"Error during update request: {str(e)}"
416
+ logger.error(msg)
417
+ raise RuntimeError(msg) from e
418
+
419
+ if post_resp is None:
420
+ # This should not happen with proper Flight client implementation,
421
+ # but we handle it defensively
422
+ msg = "No response received from flight server during request"
423
+ logger.error(msg)
424
+ raise RuntimeError(msg)
425
+
426
+ experiment = self.get(
427
+ experiment_id=str(post_resp.experiment_id) # type: ignore
428
+ )
429
+
430
+ return experiment
431
+
432
+
433
+ def _get_tracer_resource(
434
+ project_name: str,
435
+ space_id: str,
436
+ api_key: str,
437
+ endpoint: str,
438
+ dry_run: bool = False,
439
+ set_global_tracer_provider: bool = False,
440
+ ) -> Tuple[Tracer, Resource]:
441
+ resource = Resource(
442
+ {
443
+ ResourceAttributes.PROJECT_NAME: project_name,
444
+ }
445
+ )
446
+ tracer_provider = trace_sdk.TracerProvider(resource=resource)
447
+ headers = {
448
+ "authorization": api_key,
449
+ "arize-space-id": space_id,
450
+ "arize-interface": "otel",
451
+ }
452
+ insecure = endpoint.startswith("http://")
453
+ exporter = (
454
+ ConsoleSpanExporter()
455
+ if dry_run
456
+ else GrpcSpanExporter(
457
+ endpoint=endpoint, insecure=insecure, headers=headers
458
+ )
459
+ )
460
+ tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
461
+
462
+ if set_global_tracer_provider:
463
+ trace.set_tracer_provider(tracer_provider)
464
+
465
+ return tracer_provider.get_tracer(__name__), resource
File without changes