arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +208 -77
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +269 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +390 -286
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a21.dist-info/RECORD +0 -146
- arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
arize/experiments/client.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
"""Client implementation for managing experiments in the Arize platform."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
|
-
import hashlib
|
|
4
5
|
import logging
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
6
7
|
|
|
7
8
|
import opentelemetry.sdk.trace as trace_sdk
|
|
9
|
+
import pandas as pd
|
|
8
10
|
import pyarrow as pa
|
|
9
11
|
from openinference.semconv.resource import ResourceAttributes
|
|
10
12
|
from opentelemetry import trace
|
|
@@ -16,23 +18,16 @@ from opentelemetry.sdk.trace.export import (
|
|
|
16
18
|
ConsoleSpanExporter,
|
|
17
19
|
SimpleSpanProcessor,
|
|
18
20
|
)
|
|
19
|
-
from opentelemetry.trace import Tracer
|
|
20
21
|
|
|
21
22
|
from arize._flight.client import ArizeFlightClient
|
|
22
23
|
from arize._flight.types import FlightRequestType
|
|
23
24
|
from arize._generated.api_client import models
|
|
24
|
-
from arize.config import SDKConfiguration
|
|
25
25
|
from arize.exceptions.base import INVALID_ARROW_CONVERSION_MSG
|
|
26
|
-
from arize.experiments.evaluators.base import Evaluators
|
|
27
|
-
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
28
26
|
from arize.experiments.functions import (
|
|
29
27
|
run_experiment,
|
|
30
28
|
transform_to_experiment_format,
|
|
31
29
|
)
|
|
32
|
-
from arize.
|
|
33
|
-
ExperimentTask,
|
|
34
|
-
ExperimentTaskResultFieldNames,
|
|
35
|
-
)
|
|
30
|
+
from arize.pre_releases import ReleaseStage, prerelease_endpoint
|
|
36
31
|
from arize.utils.cache import cache_resource, load_cached_resource
|
|
37
32
|
from arize.utils.openinference_conversion import (
|
|
38
33
|
convert_boolean_columns_to_str,
|
|
@@ -41,16 +36,31 @@ from arize.utils.openinference_conversion import (
|
|
|
41
36
|
from arize.utils.size import get_payload_size_mb
|
|
42
37
|
|
|
43
38
|
if TYPE_CHECKING:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
from arize.
|
|
47
|
-
|
|
39
|
+
from opentelemetry.trace import Tracer
|
|
40
|
+
|
|
41
|
+
from arize.config import SDKConfiguration
|
|
42
|
+
from arize.experiments.evaluators.base import Evaluators
|
|
43
|
+
from arize.experiments.evaluators.types import EvaluationResultFieldNames
|
|
44
|
+
from arize.experiments.types import (
|
|
45
|
+
ExperimentTask,
|
|
46
|
+
ExperimentTaskResultFieldNames,
|
|
47
|
+
)
|
|
48
48
|
|
|
49
49
|
logger = logging.getLogger(__name__)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class ExperimentsClient:
|
|
53
|
-
|
|
53
|
+
"""Client for managing experiments including creation, execution, and result tracking."""
|
|
54
|
+
|
|
55
|
+
def __init__(self, *, sdk_config: SDKConfiguration) -> None:
|
|
56
|
+
"""Create an experiments sub-client.
|
|
57
|
+
|
|
58
|
+
The experiments client is a thin wrapper around the generated REST API client,
|
|
59
|
+
using the shared generated API client owned by `SDKConfiguration`.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
sdk_config: Resolved SDK configuration.
|
|
63
|
+
"""
|
|
54
64
|
self._sdk_config = sdk_config
|
|
55
65
|
from arize._generated import api_client as gen
|
|
56
66
|
|
|
@@ -61,16 +71,277 @@ class ExperimentsClient:
|
|
|
61
71
|
self._sdk_config.get_generated_client()
|
|
62
72
|
)
|
|
63
73
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
self
|
|
74
|
+
@prerelease_endpoint(key="experiments.list", stage=ReleaseStage.BETA)
|
|
75
|
+
def list(
|
|
76
|
+
self,
|
|
77
|
+
*,
|
|
78
|
+
dataset_id: str | None = None,
|
|
79
|
+
limit: int = 100,
|
|
80
|
+
cursor: str | None = None,
|
|
81
|
+
) -> models.ExperimentsList200Response:
|
|
82
|
+
"""List experiments the user has access to.
|
|
83
|
+
|
|
84
|
+
To filter experiments by the dataset they were run on, provide `dataset_id`.
|
|
67
85
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
86
|
+
Args:
|
|
87
|
+
dataset_id: Optional dataset ID to filter experiments.
|
|
88
|
+
limit: Maximum number of experiments to return. The server enforces an
|
|
89
|
+
upper bound.
|
|
90
|
+
cursor: Opaque pagination cursor returned from a previous response.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A response object with the experiments and pagination information.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
97
|
+
returns an error response (e.g. 401/403/429).
|
|
98
|
+
"""
|
|
99
|
+
return self._api.experiments_list(
|
|
100
|
+
dataset_id=dataset_id,
|
|
101
|
+
limit=limit,
|
|
102
|
+
cursor=cursor,
|
|
103
|
+
)
|
|
72
104
|
|
|
73
|
-
|
|
105
|
+
@prerelease_endpoint(key="experiments.create", stage=ReleaseStage.BETA)
|
|
106
|
+
def create(
|
|
107
|
+
self,
|
|
108
|
+
*,
|
|
109
|
+
name: str,
|
|
110
|
+
dataset_id: str,
|
|
111
|
+
experiment_runs: list[dict[str, object]] | pd.DataFrame,
|
|
112
|
+
task_fields: ExperimentTaskResultFieldNames,
|
|
113
|
+
evaluator_columns: dict[str, EvaluationResultFieldNames] | None = None,
|
|
114
|
+
force_http: bool = False,
|
|
115
|
+
) -> models.Experiment:
|
|
116
|
+
"""Create an experiment with one or more experiment runs.
|
|
117
|
+
|
|
118
|
+
Experiments are composed of runs. Each run must include:
|
|
119
|
+
- `example_id`: ID of an existing example in the dataset/version
|
|
120
|
+
- `output`: Model/task output for the matching example
|
|
121
|
+
|
|
122
|
+
You may include any additional user-defined fields per run (e.g. `model`,
|
|
123
|
+
`latency_ms`, `temperature`, `prompt`, `tool_calls`, etc.) that can be used
|
|
124
|
+
for analysis or filtering.
|
|
125
|
+
|
|
126
|
+
This method transforms the input runs into the server's expected experiment
|
|
127
|
+
format using `task_fields` and optional `evaluator_columns`.
|
|
128
|
+
|
|
129
|
+
Transport selection:
|
|
130
|
+
- If the payload is below the configured REST payload threshold (or
|
|
131
|
+
`force_http=True`), this method uploads via REST.
|
|
132
|
+
- Otherwise, it attempts a more efficient upload path via gRPC + Flight.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
name: Experiment name. Must be unique within the target dataset.
|
|
136
|
+
dataset_id: Dataset ID to attach the experiment to.
|
|
137
|
+
experiment_runs: Experiment runs either as:
|
|
138
|
+
- a list of JSON-like dicts, or
|
|
139
|
+
- a pandas DataFrame.
|
|
140
|
+
task_fields: Mapping that identifies the columns/fields containing the
|
|
141
|
+
task results (e.g. `example_id`, output fields).
|
|
142
|
+
evaluator_columns: Optional mapping describing evaluator result columns.
|
|
143
|
+
force_http: If True, force REST upload even if the payload exceeds the
|
|
144
|
+
configured REST payload threshold.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
The created experiment object.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
TypeError: If `experiment_runs` is not a list of dicts or a DataFrame.
|
|
151
|
+
RuntimeError: If the Flight upload path is selected and the Flight request
|
|
152
|
+
fails.
|
|
153
|
+
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
154
|
+
returns an error response (e.g. 400/401/403/409/429).
|
|
155
|
+
"""
|
|
156
|
+
if not isinstance(experiment_runs, list | pd.DataFrame):
|
|
157
|
+
raise TypeError(
|
|
158
|
+
"Experiment runs must be a list of dicts or a pandas DataFrame"
|
|
159
|
+
)
|
|
160
|
+
# transform experiment data to experiment format
|
|
161
|
+
experiment_df = transform_to_experiment_format(
|
|
162
|
+
experiment_runs, task_fields, evaluator_columns
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
below_threshold = (
|
|
166
|
+
get_payload_size_mb(experiment_runs)
|
|
167
|
+
<= self._sdk_config.max_http_payload_size_mb
|
|
168
|
+
)
|
|
169
|
+
if below_threshold or force_http:
|
|
170
|
+
from arize._generated import api_client as gen
|
|
171
|
+
|
|
172
|
+
data = experiment_df.to_dict(orient="records")
|
|
173
|
+
|
|
174
|
+
body = gen.ExperimentsCreateRequest(
|
|
175
|
+
name=name,
|
|
176
|
+
dataset_id=dataset_id,
|
|
177
|
+
experiment_runs=data, # type: ignore
|
|
178
|
+
)
|
|
179
|
+
return self._api.experiments_create(experiments_create_request=body)
|
|
180
|
+
|
|
181
|
+
# If we have too many examples, try to convert to a dataframe
|
|
182
|
+
# and log via gRPC + flight
|
|
183
|
+
logger.info(
|
|
184
|
+
f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
|
|
185
|
+
"Trying for more efficient upload via gRPC + Flight."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# TODO(Kiko): Space ID should not be needed,
|
|
189
|
+
# should work on server tech debt to remove this
|
|
190
|
+
dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
|
|
191
|
+
space_id = dataset.space_id
|
|
192
|
+
|
|
193
|
+
return self._create_experiment_via_flight(
|
|
194
|
+
name=name,
|
|
195
|
+
dataset_id=dataset_id,
|
|
196
|
+
space_id=space_id,
|
|
197
|
+
experiment_df=experiment_df,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
@prerelease_endpoint(key="experiments.get", stage=ReleaseStage.BETA)
|
|
201
|
+
def get(self, *, experiment_id: str) -> models.Experiment:
|
|
202
|
+
"""Get an experiment by ID.
|
|
203
|
+
|
|
204
|
+
The response does not include the experiment's runs. Use `list_runs()` to
|
|
205
|
+
retrieve runs for an experiment.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
experiment_id: Experiment ID to retrieve.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
The experiment object.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
215
|
+
returns an error response (e.g. 401/403/404/429).
|
|
216
|
+
"""
|
|
217
|
+
return self._api.experiments_get(experiment_id=experiment_id)
|
|
218
|
+
|
|
219
|
+
@prerelease_endpoint(key="experiments.delete", stage=ReleaseStage.BETA)
|
|
220
|
+
def delete(self, *, experiment_id: str) -> None:
|
|
221
|
+
"""Delete an experiment by ID.
|
|
222
|
+
|
|
223
|
+
This operation is irreversible.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
experiment_id: Experiment ID to delete.
|
|
227
|
+
|
|
228
|
+
Returns: This method returns None on success (common empty 204 response)
|
|
229
|
+
|
|
230
|
+
Raises:
|
|
231
|
+
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
232
|
+
returns an error response (e.g. 401/403/404/429).
|
|
233
|
+
"""
|
|
234
|
+
return self._api.experiments_delete(
|
|
235
|
+
experiment_id=experiment_id,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
@prerelease_endpoint(key="experiments.list_runs", stage=ReleaseStage.BETA)
|
|
239
|
+
def list_runs(
|
|
240
|
+
self,
|
|
241
|
+
*,
|
|
242
|
+
experiment_id: str,
|
|
243
|
+
limit: int = 100,
|
|
244
|
+
all: bool = False,
|
|
245
|
+
) -> models.ExperimentsRunsList200Response:
|
|
246
|
+
"""List runs for an experiment.
|
|
247
|
+
|
|
248
|
+
Runs are returned in insertion order.
|
|
249
|
+
|
|
250
|
+
Pagination notes:
|
|
251
|
+
- The response includes `pagination` for forward compatibility.
|
|
252
|
+
- Cursor pagination may not be fully implemented by the server yet.
|
|
253
|
+
- If `all=True`, this method retrieves all runs via the Flight path and
|
|
254
|
+
returns them in a single response with `has_more=False`.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
experiment_id: Experiment ID to list runs for.
|
|
258
|
+
limit: Maximum number of runs to return when `all=False`. The server
|
|
259
|
+
enforces an upper bound.
|
|
260
|
+
all: If True, fetch all runs (ignores `limit`) via Flight and return a
|
|
261
|
+
single response.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
A response object containing `experiment_runs` and `pagination` metadata.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
RuntimeError: If the Flight request fails or returns no response when
|
|
268
|
+
`all=True`.
|
|
269
|
+
arize._generated.api_client.exceptions.ApiException: If the REST API
|
|
270
|
+
returns an error response when `all=False` (e.g. 401/403/404/429).
|
|
271
|
+
"""
|
|
272
|
+
if not all:
|
|
273
|
+
return self._api.experiments_runs_list(
|
|
274
|
+
experiment_id=experiment_id,
|
|
275
|
+
limit=limit,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
experiment = self.get(experiment_id=experiment_id)
|
|
279
|
+
experiment_updated_at = getattr(experiment, "updated_at", None)
|
|
280
|
+
# TODO(Kiko): Space ID should not be needed,
|
|
281
|
+
# should work on server tech debt to remove this
|
|
282
|
+
dataset = self._datasets_api.datasets_get(
|
|
283
|
+
dataset_id=experiment.dataset_id
|
|
284
|
+
)
|
|
285
|
+
space_id = dataset.space_id
|
|
286
|
+
|
|
287
|
+
experiment_df = None
|
|
288
|
+
# try to load dataset from cache
|
|
289
|
+
if self._sdk_config.enable_caching:
|
|
290
|
+
experiment_df = load_cached_resource(
|
|
291
|
+
cache_dir=self._sdk_config.cache_dir,
|
|
292
|
+
resource="experiment",
|
|
293
|
+
resource_id=experiment_id,
|
|
294
|
+
resource_updated_at=experiment_updated_at,
|
|
295
|
+
)
|
|
296
|
+
if experiment_df is not None:
|
|
297
|
+
return models.ExperimentsRunsList200Response(
|
|
298
|
+
experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
|
|
299
|
+
pagination=models.PaginationMetadata(
|
|
300
|
+
has_more=False, # Note that all=True
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
with ArizeFlightClient(
|
|
305
|
+
api_key=self._sdk_config.api_key,
|
|
306
|
+
host=self._sdk_config.flight_host,
|
|
307
|
+
port=self._sdk_config.flight_port,
|
|
308
|
+
scheme=self._sdk_config.flight_scheme,
|
|
309
|
+
request_verify=self._sdk_config.request_verify,
|
|
310
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
311
|
+
) as flight_client:
|
|
312
|
+
try:
|
|
313
|
+
experiment_df = flight_client.get_experiment_runs(
|
|
314
|
+
space_id=space_id,
|
|
315
|
+
experiment_id=experiment_id,
|
|
316
|
+
)
|
|
317
|
+
except Exception as e:
|
|
318
|
+
msg = f"Error during request: {e!s}"
|
|
319
|
+
logger.exception(msg)
|
|
320
|
+
raise RuntimeError(msg) from e
|
|
321
|
+
if experiment_df is None:
|
|
322
|
+
# This should not happen with proper Flight client implementation,
|
|
323
|
+
# but we handle it defensively
|
|
324
|
+
msg = "No response received from flight server during request"
|
|
325
|
+
logger.error(msg)
|
|
326
|
+
raise RuntimeError(msg)
|
|
327
|
+
|
|
328
|
+
# cache experiment for future use
|
|
329
|
+
cache_resource(
|
|
330
|
+
cache_dir=self._sdk_config.cache_dir,
|
|
331
|
+
resource="experiment",
|
|
332
|
+
resource_id=experiment_id,
|
|
333
|
+
resource_updated_at=experiment_updated_at,
|
|
334
|
+
resource_data=experiment_df,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
return models.ExperimentsRunsList200Response(
|
|
338
|
+
experimentRuns=experiment_df.to_dict(orient="records"), # type: ignore
|
|
339
|
+
pagination=models.PaginationMetadata(
|
|
340
|
+
has_more=False, # Note that all=True
|
|
341
|
+
),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def run(
|
|
74
345
|
self,
|
|
75
346
|
*,
|
|
76
347
|
name: str,
|
|
@@ -82,37 +353,46 @@ class ExperimentsClient:
|
|
|
82
353
|
concurrency: int = 3,
|
|
83
354
|
set_global_tracer_provider: bool = False,
|
|
84
355
|
exit_on_error: bool = False,
|
|
85
|
-
) ->
|
|
86
|
-
"""
|
|
87
|
-
|
|
356
|
+
) -> tuple[models.Experiment | None, pd.DataFrame]:
|
|
357
|
+
"""Run an experiment on a dataset and optionally upload results.
|
|
358
|
+
|
|
359
|
+
This method executes a task against dataset examples, optionally evaluates
|
|
360
|
+
outputs, and (when `dry_run=False`) uploads results to Arize.
|
|
361
|
+
|
|
362
|
+
High-level flow:
|
|
363
|
+
1) Resolve the dataset and `space_id`.
|
|
364
|
+
2) Download dataset examples (or load from cache if enabled).
|
|
365
|
+
3) Run the task and evaluators with configurable concurrency.
|
|
366
|
+
4) If not a dry run, upload experiment runs and return the created
|
|
367
|
+
experiment plus the results dataframe.
|
|
88
368
|
|
|
89
|
-
|
|
90
|
-
|
|
369
|
+
Notes:
|
|
370
|
+
- If `dry_run=True`, no data is uploaded and the returned experiment is
|
|
371
|
+
`None`.
|
|
372
|
+
- When `enable_caching=True`, dataset examples may be cached and reused.
|
|
91
373
|
|
|
92
374
|
Args:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
set_global_tracer_provider (bool): If True, sets the global tracer provider for the experiment.
|
|
105
|
-
Defaults to False.
|
|
106
|
-
exit_on_error (bool): If True, the experiment will stop running on first occurrence of an error.
|
|
375
|
+
name: Experiment name.
|
|
376
|
+
dataset_id: Dataset ID to run the experiment against.
|
|
377
|
+
task: The task to execute for each dataset example.
|
|
378
|
+
evaluators: Optional evaluators used to score outputs.
|
|
379
|
+
dry_run: If True, do not upload results to Arize.
|
|
380
|
+
dry_run_count: Number of dataset rows to use when `dry_run=True`.
|
|
381
|
+
concurrency: Number of concurrent tasks to run.
|
|
382
|
+
set_global_tracer_provider: If True, sets the global OpenTelemetry tracer
|
|
383
|
+
provider for the experiment run.
|
|
384
|
+
exit_on_error: If True, stop on the first error encountered during
|
|
385
|
+
execution.
|
|
107
386
|
|
|
108
387
|
Returns:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
If dry_run is True, the experiment ID will be an empty string.
|
|
388
|
+
If `dry_run=True`, returns `(None, results_df)`.
|
|
389
|
+
If `dry_run=False`, returns `(experiment, results_df)`.
|
|
112
390
|
|
|
113
391
|
Raises:
|
|
114
|
-
|
|
115
|
-
|
|
392
|
+
RuntimeError: If Flight operations (init/download/upload) fail or return
|
|
393
|
+
no response.
|
|
394
|
+
pa.ArrowInvalid: If converting results to Arrow fails.
|
|
395
|
+
Exception: For unexpected errors during Arrow conversion.
|
|
116
396
|
"""
|
|
117
397
|
# TODO(Kiko): Space ID should not be needed,
|
|
118
398
|
# should work on server tech debt to remove this
|
|
@@ -122,8 +402,8 @@ class ExperimentsClient:
|
|
|
122
402
|
|
|
123
403
|
with ArizeFlightClient(
|
|
124
404
|
api_key=self._sdk_config.api_key,
|
|
125
|
-
host=self._sdk_config.
|
|
126
|
-
port=self._sdk_config.
|
|
405
|
+
host=self._sdk_config.flight_host,
|
|
406
|
+
port=self._sdk_config.flight_port,
|
|
127
407
|
scheme=self._sdk_config.flight_scheme,
|
|
128
408
|
request_verify=self._sdk_config.request_verify,
|
|
129
409
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -141,8 +421,8 @@ class ExperimentsClient:
|
|
|
141
421
|
experiment_name=name,
|
|
142
422
|
)
|
|
143
423
|
except Exception as e:
|
|
144
|
-
msg = f"Error during request: {
|
|
145
|
-
logger.
|
|
424
|
+
msg = f"Error during request: {e!s}"
|
|
425
|
+
logger.exception(msg)
|
|
146
426
|
raise RuntimeError(msg) from e
|
|
147
427
|
|
|
148
428
|
if response is None:
|
|
@@ -173,8 +453,8 @@ class ExperimentsClient:
|
|
|
173
453
|
dataset_id=dataset_id,
|
|
174
454
|
)
|
|
175
455
|
except Exception as e:
|
|
176
|
-
msg = f"Error during request: {
|
|
177
|
-
logger.
|
|
456
|
+
msg = f"Error during request: {e!s}"
|
|
457
|
+
logger.exception(msg)
|
|
178
458
|
raise RuntimeError(msg) from e
|
|
179
459
|
if dataset_df is None:
|
|
180
460
|
# This should not happen with proper Flight client implementation,
|
|
@@ -232,12 +512,12 @@ class ExperimentsClient:
|
|
|
232
512
|
logger.debug("Converting data to Arrow format")
|
|
233
513
|
pa_table = pa.Table.from_pandas(output_df, preserve_index=False)
|
|
234
514
|
except pa.ArrowInvalid as e:
|
|
235
|
-
logger.
|
|
515
|
+
logger.exception(INVALID_ARROW_CONVERSION_MSG)
|
|
236
516
|
raise pa.ArrowInvalid(
|
|
237
|
-
f"Error converting to Arrow format: {
|
|
517
|
+
f"Error converting to Arrow format: {e!s}"
|
|
238
518
|
) from e
|
|
239
|
-
except Exception
|
|
240
|
-
logger.
|
|
519
|
+
except Exception:
|
|
520
|
+
logger.exception("Unexpected error creating Arrow table")
|
|
241
521
|
raise
|
|
242
522
|
|
|
243
523
|
request_type = FlightRequestType.LOG_EXPERIMENT_DATA
|
|
@@ -247,12 +527,12 @@ class ExperimentsClient:
|
|
|
247
527
|
space_id=space_id,
|
|
248
528
|
pa_table=pa_table,
|
|
249
529
|
dataset_id=dataset_id,
|
|
250
|
-
experiment_name=
|
|
530
|
+
experiment_name=name,
|
|
251
531
|
request_type=request_type,
|
|
252
532
|
)
|
|
253
533
|
except Exception as e:
|
|
254
|
-
msg = f"Error during update request: {
|
|
255
|
-
logger.
|
|
534
|
+
msg = f"Error during update request: {e!s}"
|
|
535
|
+
logger.exception(msg)
|
|
256
536
|
raise RuntimeError(msg) from e
|
|
257
537
|
|
|
258
538
|
if post_resp is None:
|
|
@@ -267,200 +547,32 @@ class ExperimentsClient:
|
|
|
267
547
|
)
|
|
268
548
|
return experiment, output_df
|
|
269
549
|
|
|
270
|
-
def _create_experiment(
|
|
271
|
-
self,
|
|
272
|
-
*,
|
|
273
|
-
name: str,
|
|
274
|
-
dataset_id: str,
|
|
275
|
-
experiment_runs: List[Dict[str, Any]] | pd.DataFrame,
|
|
276
|
-
task_fields: ExperimentTaskResultFieldNames,
|
|
277
|
-
evaluator_columns: Dict[str, EvaluationResultFieldNames] | None = None,
|
|
278
|
-
force_http: bool = False,
|
|
279
|
-
) -> Experiment:
|
|
280
|
-
"""
|
|
281
|
-
Log an experiment to Arize.
|
|
282
|
-
|
|
283
|
-
Args:
|
|
284
|
-
space_id (str): The ID of the space where the experiment will be logged.
|
|
285
|
-
experiment_name (str): The name of the experiment.
|
|
286
|
-
experiment_df (pd.DataFrame): The data to be logged.
|
|
287
|
-
task_columns (ExperimentTaskResultColumnNames): The column names for task results.
|
|
288
|
-
evaluator_columns (Optional[Dict[str, EvaluationResultColumnNames]]):
|
|
289
|
-
The column names for evaluator results.
|
|
290
|
-
dataset_id (str, optional): The ID of the dataset associated with the experiment.
|
|
291
|
-
Required if dataset_name is not provided. Defaults to "".
|
|
292
|
-
dataset_name (str, optional): The name of the dataset associated with the experiment.
|
|
293
|
-
Required if dataset_id is not provided. Defaults to "".
|
|
294
|
-
|
|
295
|
-
Examples:
|
|
296
|
-
>>> # Example DataFrame:
|
|
297
|
-
>>> df = pd.DataFrame({
|
|
298
|
-
... "example_id": ["1", "2"],
|
|
299
|
-
... "result": ["success", "failure"],
|
|
300
|
-
... "accuracy": [0.95, 0.85],
|
|
301
|
-
... "ground_truth": ["A", "B"],
|
|
302
|
-
... "explanation_text": ["Good match", "Poor match"],
|
|
303
|
-
... "confidence": [0.9, 0.7],
|
|
304
|
-
... "model_version": ["v1", "v2"],
|
|
305
|
-
... "custom_metric": [0.8, 0.6],
|
|
306
|
-
...})
|
|
307
|
-
...
|
|
308
|
-
>>> # Define column mappings for task
|
|
309
|
-
>>> task_cols = ExperimentTaskResultColumnNames(
|
|
310
|
-
... example_id="example_id", result="result"
|
|
311
|
-
...)
|
|
312
|
-
>>> # Define column mappings for evaluator
|
|
313
|
-
>>> evaluator_cols = EvaluationResultColumnNames(
|
|
314
|
-
... score="accuracy",
|
|
315
|
-
... label="ground_truth",
|
|
316
|
-
... explanation="explanation_text",
|
|
317
|
-
... metadata={
|
|
318
|
-
... "confidence": None, # Will use "confidence" column
|
|
319
|
-
... "version": "model_version", # Will use "model_version" column
|
|
320
|
-
... "custom_metric": None, # Will use "custom_metric" column
|
|
321
|
-
... },
|
|
322
|
-
... )
|
|
323
|
-
>>> # Use with ArizeDatasetsClient.log_experiment()
|
|
324
|
-
>>> ArizeDatasetsClient.log_experiment(
|
|
325
|
-
... space_id="my_space_id",
|
|
326
|
-
... experiment_name="my_experiment",
|
|
327
|
-
... experiment_df=df,
|
|
328
|
-
... task_columns=task_cols,
|
|
329
|
-
... evaluator_columns={"my_evaluator": evaluator_cols},
|
|
330
|
-
... dataset_name="my_dataset_name",
|
|
331
|
-
... )
|
|
332
|
-
|
|
333
|
-
Returns:
|
|
334
|
-
Optional[str]: The ID of the logged experiment, or None if the logging failed.
|
|
335
|
-
"""
|
|
336
|
-
if not isinstance(experiment_runs, (list, pd.DataFrame)):
|
|
337
|
-
raise TypeError(
|
|
338
|
-
"Examples must be a list of dicts or a pandas DataFrame"
|
|
339
|
-
)
|
|
340
|
-
# transform experiment data to experiment format
|
|
341
|
-
experiment_df = transform_to_experiment_format(
|
|
342
|
-
experiment_runs, task_fields, evaluator_columns
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
below_threshold = (
|
|
346
|
-
get_payload_size_mb(experiment_runs)
|
|
347
|
-
<= self._sdk_config.max_http_payload_size_mb
|
|
348
|
-
)
|
|
349
|
-
if below_threshold or force_http:
|
|
350
|
-
from arize._generated import api_client as gen
|
|
351
|
-
|
|
352
|
-
data = experiment_df.to_dict(orient="records")
|
|
353
|
-
|
|
354
|
-
body = gen.ExperimentsCreateRequest(
|
|
355
|
-
name=name,
|
|
356
|
-
datasetId=dataset_id,
|
|
357
|
-
experimentRuns=data,
|
|
358
|
-
)
|
|
359
|
-
return self._api.experiments_create(experiments_create_request=body)
|
|
360
|
-
|
|
361
|
-
# If we have too many examples, try to convert to a dataframe
|
|
362
|
-
# and log via gRPC + flight
|
|
363
|
-
logger.info(
|
|
364
|
-
f"Uploading {len(experiment_df)} experiment runs via REST may be slow. "
|
|
365
|
-
"Trying for more efficient upload via gRPC + Flight."
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
# TODO(Kiko): Space ID should not be needed,
|
|
369
|
-
# should work on server tech debt to remove this
|
|
370
|
-
dataset = self._datasets_api.datasets_get(dataset_id=dataset_id)
|
|
371
|
-
space_id = dataset.space_id
|
|
372
|
-
|
|
373
|
-
return self._create_experiment_via_flight(
|
|
374
|
-
name=name,
|
|
375
|
-
dataset_id=dataset_id,
|
|
376
|
-
space_id=space_id,
|
|
377
|
-
experiment_df=experiment_df,
|
|
378
|
-
)
|
|
379
|
-
|
|
380
|
-
def _list_runs(
|
|
381
|
-
self,
|
|
382
|
-
*,
|
|
383
|
-
experiment_id: str,
|
|
384
|
-
limit: int = 100,
|
|
385
|
-
all: bool = False,
|
|
386
|
-
):
|
|
387
|
-
if not all:
|
|
388
|
-
return self._api.experiments_runs_list(
|
|
389
|
-
experiment_id=experiment_id,
|
|
390
|
-
limit=limit,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
experiment = self.get(experiment_id=experiment_id)
|
|
394
|
-
experiment_updated_at = getattr(experiment, "updated_at", None)
|
|
395
|
-
# TODO(Kiko): Space ID should not be needed,
|
|
396
|
-
# should work on server tech debt to remove this
|
|
397
|
-
dataset = self._datasets_api.datasets_get(
|
|
398
|
-
dataset_id=experiment.dataset_id
|
|
399
|
-
)
|
|
400
|
-
space_id = dataset.space_id
|
|
401
|
-
|
|
402
|
-
experiment_df = None
|
|
403
|
-
# try to load dataset from cache
|
|
404
|
-
if self._sdk_config.enable_caching:
|
|
405
|
-
experiment_df = load_cached_resource(
|
|
406
|
-
cache_dir=self._sdk_config.cache_dir,
|
|
407
|
-
resource="experiment",
|
|
408
|
-
resource_id=experiment_id,
|
|
409
|
-
resource_updated_at=experiment_updated_at,
|
|
410
|
-
)
|
|
411
|
-
if experiment_df is not None:
|
|
412
|
-
return models.ExperimentsRunsList200Response(
|
|
413
|
-
experimentRuns=experiment_df.to_dict(orient="records")
|
|
414
|
-
)
|
|
415
|
-
|
|
416
|
-
with ArizeFlightClient(
|
|
417
|
-
api_key=self._sdk_config.api_key,
|
|
418
|
-
host=self._sdk_config.flight_server_host,
|
|
419
|
-
port=self._sdk_config.flight_server_port,
|
|
420
|
-
scheme=self._sdk_config.flight_scheme,
|
|
421
|
-
request_verify=self._sdk_config.request_verify,
|
|
422
|
-
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
423
|
-
) as flight_client:
|
|
424
|
-
try:
|
|
425
|
-
experiment_df = flight_client.get_experiment_runs(
|
|
426
|
-
space_id=space_id,
|
|
427
|
-
experiment_id=experiment_id,
|
|
428
|
-
)
|
|
429
|
-
except Exception as e:
|
|
430
|
-
msg = f"Error during request: {str(e)}"
|
|
431
|
-
logger.error(msg)
|
|
432
|
-
raise RuntimeError(msg) from e
|
|
433
|
-
if experiment_df is None:
|
|
434
|
-
# This should not happen with proper Flight client implementation,
|
|
435
|
-
# but we handle it defensively
|
|
436
|
-
msg = "No response received from flight server during request"
|
|
437
|
-
logger.error(msg)
|
|
438
|
-
raise RuntimeError(msg)
|
|
439
|
-
|
|
440
|
-
# cache dataset for future use
|
|
441
|
-
cache_resource(
|
|
442
|
-
cache_dir=self._sdk_config.cache_dir,
|
|
443
|
-
resource="dataset",
|
|
444
|
-
resource_id=experiment_id,
|
|
445
|
-
resource_updated_at=experiment_updated_at,
|
|
446
|
-
resource_data=experiment_df,
|
|
447
|
-
)
|
|
448
|
-
|
|
449
|
-
return models.ExperimentsRunsList200Response(
|
|
450
|
-
experimentRuns=experiment_df.to_dict(orient="records")
|
|
451
|
-
)
|
|
452
|
-
|
|
453
550
|
def _create_experiment_via_flight(
|
|
454
551
|
self,
|
|
455
552
|
name: str,
|
|
456
553
|
dataset_id: str,
|
|
457
554
|
space_id: str,
|
|
458
555
|
experiment_df: pd.DataFrame,
|
|
459
|
-
) -> Experiment:
|
|
556
|
+
) -> models.Experiment:
|
|
557
|
+
"""Internal method to create an experiment using Flight protocol for large datasets."""
|
|
558
|
+
# Convert to Arrow table
|
|
559
|
+
try:
|
|
560
|
+
logger.debug("Converting data to Arrow format")
|
|
561
|
+
pa_table = pa.Table.from_pandas(experiment_df, preserve_index=False)
|
|
562
|
+
except pa.ArrowInvalid as e:
|
|
563
|
+
logger.exception(INVALID_ARROW_CONVERSION_MSG)
|
|
564
|
+
raise pa.ArrowInvalid(
|
|
565
|
+
f"Error converting to Arrow format: {e!s}"
|
|
566
|
+
) from e
|
|
567
|
+
except Exception:
|
|
568
|
+
logger.exception("Unexpected error creating Arrow table")
|
|
569
|
+
raise
|
|
570
|
+
|
|
571
|
+
experiment_id = ""
|
|
460
572
|
with ArizeFlightClient(
|
|
461
573
|
api_key=self._sdk_config.api_key,
|
|
462
|
-
host=self._sdk_config.
|
|
463
|
-
port=self._sdk_config.
|
|
574
|
+
host=self._sdk_config.flight_host,
|
|
575
|
+
port=self._sdk_config.flight_port,
|
|
464
576
|
scheme=self._sdk_config.flight_scheme,
|
|
465
577
|
request_verify=self._sdk_config.request_verify,
|
|
466
578
|
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
@@ -474,8 +586,8 @@ class ExperimentsClient:
|
|
|
474
586
|
experiment_name=name,
|
|
475
587
|
)
|
|
476
588
|
except Exception as e:
|
|
477
|
-
msg = f"Error during request: {
|
|
478
|
-
logger.
|
|
589
|
+
msg = f"Error during request: {e!s}"
|
|
590
|
+
logger.exception(msg)
|
|
479
591
|
raise RuntimeError(msg) from e
|
|
480
592
|
|
|
481
593
|
if response is None:
|
|
@@ -484,49 +596,39 @@ class ExperimentsClient:
|
|
|
484
596
|
msg = "No response received from flight server during request"
|
|
485
597
|
logger.error(msg)
|
|
486
598
|
raise RuntimeError(msg)
|
|
487
|
-
experiment_id, _ = response
|
|
488
599
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
|
|
495
|
-
raise pa.ArrowInvalid(
|
|
496
|
-
f"Error converting to Arrow format: {str(e)}"
|
|
497
|
-
) from e
|
|
498
|
-
except Exception as e:
|
|
499
|
-
logger.error(f"Unexpected error creating Arrow table: {str(e)}")
|
|
500
|
-
raise
|
|
600
|
+
experiment_id, _ = response
|
|
601
|
+
if not experiment_id:
|
|
602
|
+
msg = "No experiment ID received from flight server during request"
|
|
603
|
+
logger.error(msg)
|
|
604
|
+
raise RuntimeError(msg)
|
|
501
605
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
606
|
+
request_type = FlightRequestType.LOG_EXPERIMENT_DATA
|
|
607
|
+
post_resp = None
|
|
608
|
+
try:
|
|
609
|
+
post_resp = flight_client.log_arrow_table(
|
|
610
|
+
space_id=space_id,
|
|
611
|
+
pa_table=pa_table,
|
|
612
|
+
dataset_id=dataset_id,
|
|
613
|
+
experiment_name=name,
|
|
614
|
+
request_type=request_type,
|
|
615
|
+
)
|
|
616
|
+
except Exception as e:
|
|
617
|
+
msg = f"Error during update request: {e!s}"
|
|
618
|
+
logger.exception(msg)
|
|
619
|
+
raise RuntimeError(msg) from e
|
|
516
620
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
621
|
+
if post_resp is None:
|
|
622
|
+
# This should not happen with proper Flight client implementation,
|
|
623
|
+
# but we handle it defensively
|
|
624
|
+
msg = "No response received from flight server during request"
|
|
625
|
+
logger.error(msg)
|
|
626
|
+
raise RuntimeError(msg)
|
|
523
627
|
|
|
524
|
-
|
|
628
|
+
return self.get(
|
|
525
629
|
experiment_id=str(post_resp.experiment_id) # type: ignore
|
|
526
630
|
)
|
|
527
631
|
|
|
528
|
-
return experiment
|
|
529
|
-
|
|
530
632
|
|
|
531
633
|
def _get_tracer_resource(
|
|
532
634
|
project_name: str,
|
|
@@ -535,7 +637,8 @@ def _get_tracer_resource(
|
|
|
535
637
|
endpoint: str,
|
|
536
638
|
dry_run: bool = False,
|
|
537
639
|
set_global_tracer_provider: bool = False,
|
|
538
|
-
) ->
|
|
640
|
+
) -> tuple[Tracer, Resource]:
|
|
641
|
+
"""Initialize and return an OpenTelemetry tracer and resource for experiment tracing."""
|
|
539
642
|
resource = Resource(
|
|
540
643
|
{
|
|
541
644
|
ResourceAttributes.PROJECT_NAME: project_name,
|
|
@@ -547,7 +650,8 @@ def _get_tracer_resource(
|
|
|
547
650
|
"arize-space-id": space_id,
|
|
548
651
|
"arize-interface": "otel",
|
|
549
652
|
}
|
|
550
|
-
|
|
653
|
+
use_tls = any(endpoint.startswith(v) for v in ["https://", "grpc+tls://"])
|
|
654
|
+
insecure = not use_tls
|
|
551
655
|
exporter = (
|
|
552
656
|
ConsoleSpanExporter()
|
|
553
657
|
if dry_run
|