arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +70 -1
- arize/_flight/client.py +163 -43
- arize/_flight/types.py +1 -0
- arize/_generated/api_client/__init__.py +5 -1
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +924 -61
- arize/_generated/api_client/api_client.py +1 -1
- arize/_generated/api_client/configuration.py +1 -1
- arize/_generated/api_client/exceptions.py +1 -1
- arize/_generated/api_client/models/__init__.py +3 -1
- arize/_generated/api_client/models/dataset.py +2 -2
- arize/_generated/api_client/models/dataset_version.py +1 -1
- arize/_generated/api_client/models/datasets_create_request.py +3 -3
- arize/_generated/api_client/models/datasets_list200_response.py +1 -1
- arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/models/error.py +1 -1
- arize/_generated/api_client/models/experiment.py +6 -6
- arize/_generated/api_client/models/experiments_create_request.py +98 -0
- arize/_generated/api_client/models/experiments_list200_response.py +1 -1
- arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
- arize/_generated/api_client/rest.py +1 -1
- arize/_generated/api_client/test/test_dataset.py +2 -1
- arize/_generated/api_client/test/test_dataset_version.py +1 -1
- arize/_generated/api_client/test/test_datasets_api.py +1 -1
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
- arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
- arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/test/test_error.py +1 -1
- arize/_generated/api_client/test/test_experiment.py +6 -1
- arize/_generated/api_client/test/test_experiments_api.py +23 -2
- arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
- arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
- arize/_generated/api_client_README.md +13 -8
- arize/client.py +19 -2
- arize/config.py +50 -3
- arize/constants/config.py +8 -2
- arize/constants/openinference.py +14 -0
- arize/constants/pyarrow.py +1 -0
- arize/datasets/__init__.py +0 -70
- arize/datasets/client.py +106 -19
- arize/datasets/errors.py +61 -0
- arize/datasets/validation.py +46 -0
- arize/experiments/client.py +455 -0
- arize/experiments/evaluators/__init__.py +0 -0
- arize/experiments/evaluators/base.py +255 -0
- arize/experiments/evaluators/exceptions.py +10 -0
- arize/experiments/evaluators/executors.py +502 -0
- arize/experiments/evaluators/rate_limiters.py +277 -0
- arize/experiments/evaluators/types.py +122 -0
- arize/experiments/evaluators/utils.py +198 -0
- arize/experiments/functions.py +920 -0
- arize/experiments/tracing.py +276 -0
- arize/experiments/types.py +394 -0
- arize/models/client.py +4 -1
- arize/spans/client.py +16 -20
- arize/utils/arrow.py +4 -3
- arize/utils/openinference_conversion.py +56 -0
- arize/utils/proto.py +13 -0
- arize/utils/size.py +22 -0
- arize/version.py +1 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
arize/experiments/client.py
CHANGED
|
@@ -1,4 +1,49 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import opentelemetry.sdk.trace as trace_sdk
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
from openinference.semconv.resource import ResourceAttributes
|
|
9
|
+
from opentelemetry import trace
|
|
10
|
+
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
|
11
|
+
OTLPSpanExporter as GrpcSpanExporter,
|
|
12
|
+
)
|
|
13
|
+
from opentelemetry.sdk.resources import Resource
|
|
14
|
+
from opentelemetry.sdk.trace.export import (
|
|
15
|
+
ConsoleSpanExporter,
|
|
16
|
+
SimpleSpanProcessor,
|
|
17
|
+
)
|
|
18
|
+
from opentelemetry.trace import Tracer
|
|
19
|
+
|
|
20
|
+
from arize._flight.client import ArizeFlightClient
|
|
21
|
+
from arize._flight.types import FlightRequestType
|
|
1
22
|
from arize.config import SDKConfiguration
|
|
23
|
+
from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
|
|
24
|
+
from arize.experiments.evaluators.base import Evaluators
|
|
25
|
+
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
26
|
+
from arize.experiments.functions import (
|
|
27
|
+
run_experiment,
|
|
28
|
+
transform_to_experiment_format,
|
|
29
|
+
)
|
|
30
|
+
from arize.experiments.types import (
|
|
31
|
+
ExperimentTask,
|
|
32
|
+
ExperimentTaskResultFieldNames,
|
|
33
|
+
)
|
|
34
|
+
from arize.utils.openinference_conversion import (
|
|
35
|
+
convert_boolean_columns_to_str,
|
|
36
|
+
convert_default_columns_to_json_str,
|
|
37
|
+
)
|
|
38
|
+
from arize.utils.size import get_payload_size_mb
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
import pandas as pd
|
|
42
|
+
|
|
43
|
+
from arize._generated.api_client.models.experiment import Experiment
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
2
47
|
|
|
3
48
|
|
|
4
49
|
class ExperimentsClient:
|
|
@@ -7,4 +52,414 @@ class ExperimentsClient:
|
|
|
7
52
|
from arize._generated import api_client as gen
|
|
8
53
|
|
|
9
54
|
self._api = gen.ExperimentsApi(self._sdk_config.get_generated_client())
|
|
55
|
+
# TODO(Kiko): Space ID should not be needed,
|
|
56
|
+
# should work on server tech debt to remove this
|
|
57
|
+
self._datasets_api = gen.DatasetsApi(
|
|
58
|
+
self._sdk_config.get_generated_client()
|
|
59
|
+
)
|
|
10
60
|
self.list = self._api.experiments_list
|
|
61
|
+
self.get = self._api.experiments_get
|
|
62
|
+
self.delete = self._api.experiments_delete
|
|
63
|
+
self.list_runs = self._api.experiments_runs_list # REST ?
|
|
64
|
+
|
|
65
|
+
# Custom methods
|
|
66
|
+
self.create = self._create_experiment
|
|
67
|
+
self.run = self._run_experiment
|
|
68
|
+
|
|
69
|
+
def _run_experiment(
|
|
70
|
+
self,
|
|
71
|
+
name: str,
|
|
72
|
+
dataset_id: str,
|
|
73
|
+
task: ExperimentTask,
|
|
74
|
+
dataset_df: pd.DataFrame | None = None,
|
|
75
|
+
evaluators: Evaluators | None = None,
|
|
76
|
+
dry_run: bool = False,
|
|
77
|
+
concurrency: int = 3,
|
|
78
|
+
set_global_tracer_provider: bool = False,
|
|
79
|
+
exit_on_error: bool = False,
|
|
80
|
+
) -> Tuple[str, pd.DataFrame] | None:
|
|
81
|
+
"""
|
|
82
|
+
Run an experiment on a dataset and upload the results.
|
|
83
|
+
|
|
84
|
+
This function initializes an experiment, retrieves or uses a provided dataset,
|
|
85
|
+
runs the experiment with specified tasks and evaluators, and uploads the results.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
experiment_name (str): The name of the experiment.
|
|
89
|
+
task (ExperimentTask): The task to be performed in the experiment.
|
|
90
|
+
dataset_df (Optional[pd.DataFrame], optional): The dataset as a pandas DataFrame.
|
|
91
|
+
If not provided, the dataset will be downloaded using dataset_id or dataset_name.
|
|
92
|
+
Defaults to None.
|
|
93
|
+
dataset_id (Optional[str], optional): The ID of the dataset to use.
|
|
94
|
+
Required if dataset_df and dataset_name are not provided. Defaults to None.
|
|
95
|
+
dataset_name (Optional[str], optional): The name of the dataset to use.
|
|
96
|
+
Used if dataset_df and dataset_id are not provided. Defaults to None.
|
|
97
|
+
evaluators (Optional[Evaluators], optional): The evaluators to use in the experiment.
|
|
98
|
+
Defaults to None.
|
|
99
|
+
dry_run (bool): If True, the experiment result will not be uploaded to Arize.
|
|
100
|
+
Defaults to False.
|
|
101
|
+
concurrency (int): The number of concurrent tasks to run. Defaults to 3.
|
|
102
|
+
set_global_tracer_provider (bool): If True, sets the global tracer provider for the experiment.
|
|
103
|
+
Defaults to False.
|
|
104
|
+
exit_on_error (bool): If True, the experiment will stop running on first occurrence of an error.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple[str, pd.DataFrame]:
|
|
108
|
+
A tuple of experiment ID and experiment result DataFrame.
|
|
109
|
+
If dry_run is True, the experiment ID will be an empty string.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If dataset_id and dataset_name are both not provided, or if the dataset is empty.
|
|
113
|
+
RuntimeError: If experiment initialization, dataset download, or result upload fails.
|
|
114
|
+
"""
|
|
115
|
+
# TODO(Kiko): Space ID should not be needed,
|
|
116
|
+
# should work on server tech debt to remove this
|
|
117
|
+
dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
|
|
118
|
+
space_id = dataset.space_id
|
|
119
|
+
|
|
120
|
+
with ArizeFlightClient(
|
|
121
|
+
api_key=self._sdk_config.api_key,
|
|
122
|
+
host=self._sdk_config.flight_server_host,
|
|
123
|
+
port=self._sdk_config.flight_server_port,
|
|
124
|
+
scheme=self._sdk_config.flight_scheme,
|
|
125
|
+
request_verify=self._sdk_config.request_verify,
|
|
126
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
127
|
+
) as flight_client:
|
|
128
|
+
# set up initial experiment and trace model
|
|
129
|
+
if dry_run:
|
|
130
|
+
trace_model_name = "traces_for_dry_run"
|
|
131
|
+
experiment_id = "experiment_id_for_dry_run"
|
|
132
|
+
else:
|
|
133
|
+
response = None
|
|
134
|
+
try:
|
|
135
|
+
response = flight_client.init_experiment(
|
|
136
|
+
space_id=space_id,
|
|
137
|
+
dataset_id=dataset_id,
|
|
138
|
+
experiment_name=name,
|
|
139
|
+
)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
msg = f"Error during request: {str(e)}"
|
|
142
|
+
logger.error(msg)
|
|
143
|
+
raise RuntimeError(msg) from e
|
|
144
|
+
|
|
145
|
+
if response is None:
|
|
146
|
+
# This should not happen with proper Flight client implementation,
|
|
147
|
+
# but we handle it defensively
|
|
148
|
+
msg = (
|
|
149
|
+
"No response received from flight server during request"
|
|
150
|
+
)
|
|
151
|
+
logger.error(msg)
|
|
152
|
+
raise RuntimeError(msg)
|
|
153
|
+
experiment_id, trace_model_name = response
|
|
154
|
+
|
|
155
|
+
# download dataset if not provided
|
|
156
|
+
if dataset_df is None:
|
|
157
|
+
try:
|
|
158
|
+
response = flight_client.get_dataset_examples(
|
|
159
|
+
space_id=space_id,
|
|
160
|
+
dataset_id=dataset_id,
|
|
161
|
+
)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
msg = f"Error during request: {str(e)}"
|
|
164
|
+
logger.error(msg)
|
|
165
|
+
raise RuntimeError(msg) from e
|
|
166
|
+
if response is None:
|
|
167
|
+
# This should not happen with proper Flight client implementation,
|
|
168
|
+
# but we handle it defensively
|
|
169
|
+
msg = (
|
|
170
|
+
"No response received from flight server during request"
|
|
171
|
+
)
|
|
172
|
+
logger.error(msg)
|
|
173
|
+
raise RuntimeError(msg)
|
|
174
|
+
|
|
175
|
+
if dataset_df is None or dataset_df.empty:
|
|
176
|
+
raise ValueError(f"Dataset {dataset_id} is empty")
|
|
177
|
+
|
|
178
|
+
input_df = dataset_df.copy()
|
|
179
|
+
if dry_run:
|
|
180
|
+
# only dry_run experiment on a subset (first 10 rows) of the dataset
|
|
181
|
+
input_df = input_df.head(10)
|
|
182
|
+
|
|
183
|
+
# trace model and resource for the experiment
|
|
184
|
+
tracer, resource = _get_tracer_resource(
|
|
185
|
+
project_name=trace_model_name,
|
|
186
|
+
space_id=space_id,
|
|
187
|
+
api_key=self._sdk_config.api_key,
|
|
188
|
+
endpoint=self._sdk_config.otlp_url,
|
|
189
|
+
dry_run=dry_run,
|
|
190
|
+
set_global_tracer_provider=set_global_tracer_provider,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
output_df = run_experiment(
|
|
194
|
+
experiment_name=name,
|
|
195
|
+
experiment_id=experiment_id,
|
|
196
|
+
dataset=input_df,
|
|
197
|
+
task=task,
|
|
198
|
+
tracer=tracer,
|
|
199
|
+
resource=resource,
|
|
200
|
+
evaluators=evaluators,
|
|
201
|
+
concurrency=concurrency,
|
|
202
|
+
exit_on_error=exit_on_error,
|
|
203
|
+
)
|
|
204
|
+
output_df = convert_default_columns_to_json_str(output_df)
|
|
205
|
+
output_df = convert_boolean_columns_to_str(output_df)
|
|
206
|
+
if dry_run:
|
|
207
|
+
return "", output_df
|
|
208
|
+
|
|
209
|
+
# Convert to Arrow table
|
|
210
|
+
try:
|
|
211
|
+
logger.debug("Converting data to Arrow format")
|
|
212
|
+
pa_table = pa.Table.from_pandas(output_df, preserve_index=False)
|
|
213
|
+
except pa.ArrowInvalid as e:
|
|
214
|
+
logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
|
|
215
|
+
raise pa.ArrowInvalid(
|
|
216
|
+
f"Error converting to Arrow format: {str(e)}"
|
|
217
|
+
) from e
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Unexpected error creating Arrow table: {str(e)}")
|
|
220
|
+
raise
|
|
221
|
+
|
|
222
|
+
request_type = FlightRequestType.LOG_EXPERIMENT_DATA
|
|
223
|
+
post_resp = None
|
|
224
|
+
try:
|
|
225
|
+
post_resp = flight_client.log_arrow_table(
|
|
226
|
+
space_id=space_id,
|
|
227
|
+
pa_table=pa_table,
|
|
228
|
+
dataset_id=dataset_id,
|
|
229
|
+
experiment_name=experiment_id,
|
|
230
|
+
request_type=request_type,
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
msg = f"Error during update request: {str(e)}"
|
|
234
|
+
logger.error(msg)
|
|
235
|
+
raise RuntimeError(msg) from e
|
|
236
|
+
|
|
237
|
+
if post_resp is None:
|
|
238
|
+
# This should not happen with proper Flight client implementation,
|
|
239
|
+
# but we handle it defensively
|
|
240
|
+
msg = "No response received from flight server during request"
|
|
241
|
+
logger.error(msg)
|
|
242
|
+
raise RuntimeError(msg)
|
|
243
|
+
|
|
244
|
+
return str(post_resp.experiment_id), output_df # type: ignore
|
|
245
|
+
|
|
246
|
+
def _create_experiment(
|
|
247
|
+
self,
|
|
248
|
+
name: str,
|
|
249
|
+
dataset_id: str,
|
|
250
|
+
experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
|
|
251
|
+
task_fields: ExperimentTaskResultFieldNames,
|
|
252
|
+
evaluator_columns: Dict[str, EvaluationResultFieldNames] | None = None,
|
|
253
|
+
force_http: bool = False,
|
|
254
|
+
) -> Experiment:
|
|
255
|
+
"""
|
|
256
|
+
Log an experiment to Arize.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
space_id (str): The ID of the space where the experiment will be logged.
|
|
260
|
+
experiment_name (str): The name of the experiment.
|
|
261
|
+
experiment_df (pd.DataFrame): The data to be logged.
|
|
262
|
+
task_columns (ExperimentTaskResultColumnNames): The column names for task results.
|
|
263
|
+
evaluator_columns (Optional[Dict[str, EvaluationResultColumnNames]]):
|
|
264
|
+
The column names for evaluator results.
|
|
265
|
+
dataset_id (str, optional): The ID of the dataset associated with the experiment.
|
|
266
|
+
Required if dataset_name is not provided. Defaults to "".
|
|
267
|
+
dataset_name (str, optional): The name of the dataset associated with the experiment.
|
|
268
|
+
Required if dataset_id is not provided. Defaults to "".
|
|
269
|
+
|
|
270
|
+
Examples:
|
|
271
|
+
>>> # Example DataFrame:
|
|
272
|
+
>>> df = pd.DataFrame({
|
|
273
|
+
... "example_id": ["1", "2"],
|
|
274
|
+
... "result": ["success", "failure"],
|
|
275
|
+
... "accuracy": [0.95, 0.85],
|
|
276
|
+
... "ground_truth": ["A", "B"],
|
|
277
|
+
... "explanation_text": ["Good match", "Poor match"],
|
|
278
|
+
... "confidence": [0.9, 0.7],
|
|
279
|
+
... "model_version": ["v1", "v2"],
|
|
280
|
+
... "custom_metric": [0.8, 0.6],
|
|
281
|
+
...})
|
|
282
|
+
...
|
|
283
|
+
>>> # Define column mappings for task
|
|
284
|
+
>>> task_cols = ExperimentTaskResultColumnNames(
|
|
285
|
+
... example_id="example_id", result="result"
|
|
286
|
+
...)
|
|
287
|
+
>>> # Define column mappings for evaluator
|
|
288
|
+
>>> evaluator_cols = EvaluationResultColumnNames(
|
|
289
|
+
... score="accuracy",
|
|
290
|
+
... label="ground_truth",
|
|
291
|
+
... explanation="explanation_text",
|
|
292
|
+
... metadata={
|
|
293
|
+
... "confidence": None, # Will use "confidence" column
|
|
294
|
+
... "version": "model_version", # Will use "model_version" column
|
|
295
|
+
... "custom_metric": None, # Will use "custom_metric" column
|
|
296
|
+
... },
|
|
297
|
+
... )
|
|
298
|
+
>>> # Use with ArizeDatasetsClient.log_experiment()
|
|
299
|
+
>>> ArizeDatasetsClient.log_experiment(
|
|
300
|
+
... space_id="my_space_id",
|
|
301
|
+
... experiment_name="my_experiment",
|
|
302
|
+
... experiment_df=df,
|
|
303
|
+
... task_columns=task_cols,
|
|
304
|
+
... evaluator_columns={"my_evaluator": evaluator_cols},
|
|
305
|
+
... dataset_name="my_dataset_name",
|
|
306
|
+
... )
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Optional[str]: The ID of the logged experiment, or None if the logging failed.
|
|
310
|
+
"""
|
|
311
|
+
if not isinstance(experiment_runs, (list, pd.DataFrame)):
|
|
312
|
+
raise TypeError(
|
|
313
|
+
"Examples must be a list of dicts or a pandas DataFrame"
|
|
314
|
+
)
|
|
315
|
+
# transform experiment data to experiment format
|
|
316
|
+
experiment_df = transform_to_experiment_format(
|
|
317
|
+
experiment_runs, task_fields, evaluator_columns
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
below_threshold = (
|
|
321
|
+
get_payload_size_mb(experiment_runs)
|
|
322
|
+
<= self._sdk_config.max_http_payload_size_mb
|
|
323
|
+
)
|
|
324
|
+
if below_threshold or force_http:
|
|
325
|
+
from arize._generated import api_client as gen
|
|
326
|
+
|
|
327
|
+
data = experiment_df.to_dict(orient="records")
|
|
328
|
+
|
|
329
|
+
body = gen.ExperimentsCreateRequest(
|
|
330
|
+
name=name,
|
|
331
|
+
datasetId=dataset_id,
|
|
332
|
+
experimentRuns=data,
|
|
333
|
+
)
|
|
334
|
+
return self._api.experiments_create(experiments_create_request=body)
|
|
335
|
+
|
|
336
|
+
# If we have too many examples, try to convert to a dataframe
|
|
337
|
+
# and log via gRPC + flight
|
|
338
|
+
logger.info(
|
|
339
|
+
f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
|
|
340
|
+
"Trying for more efficient upload via gRPC + Flight."
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# TODO(Kiko): Space ID should not be needed,
|
|
344
|
+
# should work on server tech debt to remove this
|
|
345
|
+
dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
|
|
346
|
+
space_id = dataset.space_id
|
|
347
|
+
|
|
348
|
+
return self._create_experiment_via_flight(
|
|
349
|
+
name=name,
|
|
350
|
+
dataset_id=dataset_id,
|
|
351
|
+
space_id=space_id,
|
|
352
|
+
experiment_df=experiment_df,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def _create_experiment_via_flight(
|
|
356
|
+
self,
|
|
357
|
+
name: str,
|
|
358
|
+
dataset_id: str,
|
|
359
|
+
space_id: str,
|
|
360
|
+
experiment_df: pd.DataFrame,
|
|
361
|
+
) -> Experiment:
|
|
362
|
+
with ArizeFlightClient(
|
|
363
|
+
api_key=self._sdk_config.api_key,
|
|
364
|
+
host=self._sdk_config.flight_server_host,
|
|
365
|
+
port=self._sdk_config.flight_server_port,
|
|
366
|
+
scheme=self._sdk_config.flight_scheme,
|
|
367
|
+
request_verify=self._sdk_config.request_verify,
|
|
368
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
369
|
+
) as flight_client:
|
|
370
|
+
# set up initial experiment and trace model
|
|
371
|
+
response = None
|
|
372
|
+
try:
|
|
373
|
+
response = flight_client.init_experiment(
|
|
374
|
+
space_id=space_id,
|
|
375
|
+
dataset_id=dataset_id,
|
|
376
|
+
experiment_name=name,
|
|
377
|
+
)
|
|
378
|
+
except Exception as e:
|
|
379
|
+
msg = f"Error during request: {str(e)}"
|
|
380
|
+
logger.error(msg)
|
|
381
|
+
raise RuntimeError(msg) from e
|
|
382
|
+
|
|
383
|
+
if response is None:
|
|
384
|
+
# This should not happen with proper Flight client implementation,
|
|
385
|
+
# but we handle it defensively
|
|
386
|
+
msg = "No response received from flight server during request"
|
|
387
|
+
logger.error(msg)
|
|
388
|
+
raise RuntimeError(msg)
|
|
389
|
+
experiment_id, _ = response
|
|
390
|
+
|
|
391
|
+
# Convert to Arrow table
|
|
392
|
+
try:
|
|
393
|
+
logger.debug("Converting data to Arrow format")
|
|
394
|
+
pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
|
|
395
|
+
except pa.ArrowInvalid as e:
|
|
396
|
+
logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
|
|
397
|
+
raise pa.ArrowInvalid(
|
|
398
|
+
f"Error converting to Arrow format: {str(e)}"
|
|
399
|
+
) from e
|
|
400
|
+
except Exception as e:
|
|
401
|
+
logger.error(f"Unexpected error creating Arrow table: {str(e)}")
|
|
402
|
+
raise
|
|
403
|
+
|
|
404
|
+
request_type = FlightRequestType.LOG_EXPERIMENT_DATA
|
|
405
|
+
post_resp = None
|
|
406
|
+
try:
|
|
407
|
+
post_resp = flight_client.log_arrow_table(
|
|
408
|
+
space_id=space_id,
|
|
409
|
+
pa_table=pa_table,
|
|
410
|
+
dataset_id=dataset_id,
|
|
411
|
+
experiment_name=experiment_id,
|
|
412
|
+
request_type=request_type,
|
|
413
|
+
)
|
|
414
|
+
except Exception as e:
|
|
415
|
+
msg = f"Error during update request: {str(e)}"
|
|
416
|
+
logger.error(msg)
|
|
417
|
+
raise RuntimeError(msg) from e
|
|
418
|
+
|
|
419
|
+
if post_resp is None:
|
|
420
|
+
# This should not happen with proper Flight client implementation,
|
|
421
|
+
# but we handle it defensively
|
|
422
|
+
msg = "No response received from flight server during request"
|
|
423
|
+
logger.error(msg)
|
|
424
|
+
raise RuntimeError(msg)
|
|
425
|
+
|
|
426
|
+
experiment = self.get(
|
|
427
|
+
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
return experiment
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _get_tracer_resource(
|
|
434
|
+
project_name: str,
|
|
435
|
+
space_id: str,
|
|
436
|
+
api_key: str,
|
|
437
|
+
endpoint: str,
|
|
438
|
+
dry_run: bool = False,
|
|
439
|
+
set_global_tracer_provider: bool = False,
|
|
440
|
+
) -> Tuple[Tracer, Resource]:
|
|
441
|
+
resource = Resource(
|
|
442
|
+
{
|
|
443
|
+
ResourceAttributes.PROJECT_NAME: project_name,
|
|
444
|
+
}
|
|
445
|
+
)
|
|
446
|
+
tracer_provider = trace_sdk.TracerProvider(resource=resource)
|
|
447
|
+
headers = {
|
|
448
|
+
"authorization": api_key,
|
|
449
|
+
"arize-space-id": space_id,
|
|
450
|
+
"arize-interface": "otel",
|
|
451
|
+
}
|
|
452
|
+
insecure = endpoint.startswith("http://")
|
|
453
|
+
exporter = (
|
|
454
|
+
ConsoleSpanExporter()
|
|
455
|
+
if dry_run
|
|
456
|
+
else GrpcSpanExporter(
|
|
457
|
+
endpoint=endpoint, insecure=insecure, headers=headers
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
|
|
461
|
+
|
|
462
|
+
if set_global_tracer_provider:
|
|
463
|
+
trace.set_tracer_provider(tracer_provider)
|
|
464
|
+
|
|
465
|
+
return tracer_provider.get_tracer(__name__), resource
|
|
File without changes
|