arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +70 -1
- arize/_flight/client.py +163 -43
- arize/_flight/types.py +1 -0
- arize/_generated/api_client/__init__.py +5 -1
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +924 -61
- arize/_generated/api_client/api_client.py +1 -1
- arize/_generated/api_client/configuration.py +1 -1
- arize/_generated/api_client/exceptions.py +1 -1
- arize/_generated/api_client/models/__init__.py +3 -1
- arize/_generated/api_client/models/dataset.py +2 -2
- arize/_generated/api_client/models/dataset_version.py +1 -1
- arize/_generated/api_client/models/datasets_create_request.py +3 -3
- arize/_generated/api_client/models/datasets_list200_response.py +1 -1
- arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/models/error.py +1 -1
- arize/_generated/api_client/models/experiment.py +6 -6
- arize/_generated/api_client/models/experiments_create_request.py +98 -0
- arize/_generated/api_client/models/experiments_list200_response.py +1 -1
- arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
- arize/_generated/api_client/rest.py +1 -1
- arize/_generated/api_client/test/test_dataset.py +2 -1
- arize/_generated/api_client/test/test_dataset_version.py +1 -1
- arize/_generated/api_client/test/test_datasets_api.py +1 -1
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
- arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
- arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/test/test_error.py +1 -1
- arize/_generated/api_client/test/test_experiment.py +6 -1
- arize/_generated/api_client/test/test_experiments_api.py +23 -2
- arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
- arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
- arize/_generated/api_client_README.md +13 -8
- arize/client.py +19 -2
- arize/config.py +50 -3
- arize/constants/config.py +8 -2
- arize/constants/openinference.py +14 -0
- arize/constants/pyarrow.py +1 -0
- arize/datasets/__init__.py +0 -70
- arize/datasets/client.py +106 -19
- arize/datasets/errors.py +61 -0
- arize/datasets/validation.py +46 -0
- arize/experiments/client.py +455 -0
- arize/experiments/evaluators/__init__.py +0 -0
- arize/experiments/evaluators/base.py +255 -0
- arize/experiments/evaluators/exceptions.py +10 -0
- arize/experiments/evaluators/executors.py +502 -0
- arize/experiments/evaluators/rate_limiters.py +277 -0
- arize/experiments/evaluators/types.py +122 -0
- arize/experiments/evaluators/utils.py +198 -0
- arize/experiments/functions.py +920 -0
- arize/experiments/tracing.py +276 -0
- arize/experiments/types.py +394 -0
- arize/models/client.py +4 -1
- arize/spans/client.py +16 -20
- arize/utils/arrow.py +4 -3
- arize/utils/openinference_conversion.py +56 -0
- arize/utils/proto.py +13 -0
- arize/utils/size.py +22 -0
- arize/version.py +1 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from contextvars import ContextVar
|
|
7
|
+
from threading import Lock
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Callable,
|
|
11
|
+
Iterable,
|
|
12
|
+
Iterator,
|
|
13
|
+
List,
|
|
14
|
+
Mapping,
|
|
15
|
+
Sequence,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from openinference.semconv import trace
|
|
21
|
+
from openinference.semconv.trace import DocumentAttributes, SpanAttributes
|
|
22
|
+
from opentelemetry.sdk.resources import Resource
|
|
23
|
+
from opentelemetry.sdk.trace import ReadableSpan
|
|
24
|
+
from opentelemetry.trace import INVALID_TRACE_ID
|
|
25
|
+
from typing_extensions import assert_never
|
|
26
|
+
from wrapt import apply_patch, resolve_path, wrap_function_wrapper
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SpanModifier:
|
|
30
|
+
"""
|
|
31
|
+
A class that modifies spans with the specified resource attributes.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
__slots__ = ("_resource",)
|
|
35
|
+
|
|
36
|
+
def __init__(self, resource: Resource) -> None:
|
|
37
|
+
self._resource = resource
|
|
38
|
+
|
|
39
|
+
def modify_resource(self, span: ReadableSpan) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Takes a span and merges in the resource attributes specified in the constructor.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
span: ReadableSpan: the span to modify
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
if (ctx := span._context) is None or ctx.span_id == INVALID_TRACE_ID:
|
|
48
|
+
return
|
|
49
|
+
span._resource = span._resource.merge(self._resource)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_ACTIVE_MODIFIER: ContextVar[SpanModifier | None] = ContextVar(
|
|
53
|
+
"active_modifier"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def override_span(
|
|
58
|
+
init: Callable[..., None], span: ReadableSpan, args: Any, kwargs: Any
|
|
59
|
+
) -> None:
|
|
60
|
+
init(*args, **kwargs)
|
|
61
|
+
if isinstance(span_modifier := _ACTIVE_MODIFIER.get(None), SpanModifier):
|
|
62
|
+
span_modifier.modify_resource(span)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
_SPAN_INIT_MONKEY_PATCH_LOCK = Lock()
|
|
66
|
+
_SPAN_INIT_MONKEY_PATCH_COUNT = 0
|
|
67
|
+
_SPAN_INIT_MODULE = ReadableSpan.__init__.__module__
|
|
68
|
+
_SPAN_INIT_NAME = ReadableSpan.__init__.__qualname__
|
|
69
|
+
_SPAN_INIT_PARENT, _SPAN_INIT_ATTR, _SPAN_INIT_ORIGINAL = resolve_path(
|
|
70
|
+
_SPAN_INIT_MODULE, _SPAN_INIT_NAME
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@contextmanager
|
|
75
|
+
def _monkey_patch_span_init() -> Iterator[None]:
|
|
76
|
+
global _SPAN_INIT_MONKEY_PATCH_COUNT
|
|
77
|
+
with _SPAN_INIT_MONKEY_PATCH_LOCK:
|
|
78
|
+
_SPAN_INIT_MONKEY_PATCH_COUNT += 1
|
|
79
|
+
if _SPAN_INIT_MONKEY_PATCH_COUNT == 1:
|
|
80
|
+
wrap_function_wrapper(
|
|
81
|
+
module=_SPAN_INIT_MODULE,
|
|
82
|
+
name=_SPAN_INIT_NAME,
|
|
83
|
+
wrapper=override_span,
|
|
84
|
+
)
|
|
85
|
+
yield
|
|
86
|
+
with _SPAN_INIT_MONKEY_PATCH_LOCK:
|
|
87
|
+
_SPAN_INIT_MONKEY_PATCH_COUNT -= 1
|
|
88
|
+
if _SPAN_INIT_MONKEY_PATCH_COUNT == 0:
|
|
89
|
+
apply_patch(_SPAN_INIT_PARENT, _SPAN_INIT_ATTR, _SPAN_INIT_ORIGINAL)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@contextmanager
|
|
93
|
+
def capture_spans(resource: Resource) -> Iterator[SpanModifier]:
|
|
94
|
+
"""
|
|
95
|
+
A context manager that captures spans and modifies them with the specified resources.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
resource: Resource: The resource to merge into the spans created within the context.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
modifier: Iterator[SpanModifier]: The span modifier that is active within the context.
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
modifier = SpanModifier(resource)
|
|
105
|
+
with _monkey_patch_span_init():
|
|
106
|
+
token = _ACTIVE_MODIFIER.set(modifier)
|
|
107
|
+
yield modifier
|
|
108
|
+
_ACTIVE_MODIFIER.reset(token)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
Span attribute keys have a special relationship with the `.` separator. When
|
|
113
|
+
a span attribute is ingested from protobuf, it's in the form of a key value
|
|
114
|
+
pair such as `("llm.token_count.completion", 123)`. What we need to do is to split
|
|
115
|
+
the key by the `.` separator and turn it into part of a nested dictionary such
|
|
116
|
+
as {"llm": {"token_count": {"completion": 123}}}. We also need to reverse this
|
|
117
|
+
process, which is to flatten the nested dictionary into a list of key value
|
|
118
|
+
pairs. This module provides functions to do both of these operations.
|
|
119
|
+
|
|
120
|
+
Note that digit keys are treated as indices of a nested array. For example,
|
|
121
|
+
the digits inside `("retrieval.documents.0.document.content", 'A')` and
|
|
122
|
+
`("retrieval.documents.1.document.content": 'B')` turn the sub-keys following
|
|
123
|
+
them into a nested list of dictionaries i.e.
|
|
124
|
+
{`retrieval: {"documents": [{"document": {"content": "A"}}, {"document":
|
|
125
|
+
{"content": "B"}}]}`.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA
|
|
130
|
+
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
131
|
+
METADATA = SpanAttributes.METADATA
|
|
132
|
+
TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
|
|
133
|
+
|
|
134
|
+
# attributes interpreted as JSON strings during ingestion
|
|
135
|
+
JSON_STRING_ATTRIBUTES = (
|
|
136
|
+
DOCUMENT_METADATA,
|
|
137
|
+
LLM_PROMPT_TEMPLATE_VARIABLES,
|
|
138
|
+
METADATA,
|
|
139
|
+
TOOL_PARAMETERS,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
SEMANTIC_CONVENTIONS: List[str] = sorted(
|
|
143
|
+
# e.g. "input.value", "llm.token_count.total", etc.
|
|
144
|
+
(
|
|
145
|
+
cast(str, getattr(klass, attr))
|
|
146
|
+
for name in dir(trace)
|
|
147
|
+
if name.endswith("Attributes")
|
|
148
|
+
and inspect.isclass(klass := getattr(trace, name))
|
|
149
|
+
for attr in dir(klass)
|
|
150
|
+
if attr.isupper()
|
|
151
|
+
),
|
|
152
|
+
key=len,
|
|
153
|
+
reverse=True,
|
|
154
|
+
) # sorted so the longer strings go first
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def flatten(
|
|
158
|
+
obj: Mapping[str, Any] | Iterable[Any],
|
|
159
|
+
*,
|
|
160
|
+
prefix: str = "",
|
|
161
|
+
separator: str = ".",
|
|
162
|
+
recurse_on_sequence: bool = False,
|
|
163
|
+
json_string_attributes: Sequence[str] | None = None,
|
|
164
|
+
) -> Iterator[tuple[str, Any]]:
|
|
165
|
+
"""
|
|
166
|
+
Flatten a nested dictionary or a sequence of dictionaries into a list of
|
|
167
|
+
key value pairs. If `recurse_on_sequence` is True, then the function will
|
|
168
|
+
also recursively flatten nested sequences of dictionaries. If
|
|
169
|
+
`json_string_attributes` is provided, then the function will interpret the
|
|
170
|
+
attributes in the list as JSON strings and convert them into dictionaries.
|
|
171
|
+
The `prefix` argument is used to prefix the keys in the output list, but
|
|
172
|
+
it's mostly used internally to facilitate recursion.
|
|
173
|
+
"""
|
|
174
|
+
if isinstance(obj, Mapping):
|
|
175
|
+
yield from _flatten_mapping(
|
|
176
|
+
obj,
|
|
177
|
+
prefix=prefix,
|
|
178
|
+
recurse_on_sequence=recurse_on_sequence,
|
|
179
|
+
json_string_attributes=json_string_attributes,
|
|
180
|
+
separator=separator,
|
|
181
|
+
)
|
|
182
|
+
elif isinstance(obj, Iterable):
|
|
183
|
+
yield from _flatten_sequence(
|
|
184
|
+
obj,
|
|
185
|
+
prefix=prefix,
|
|
186
|
+
recurse_on_sequence=recurse_on_sequence,
|
|
187
|
+
json_string_attributes=json_string_attributes,
|
|
188
|
+
separator=separator,
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
assert_never(obj)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def has_mapping(sequence: Iterable[Any]) -> bool:
|
|
195
|
+
"""
|
|
196
|
+
Check if a sequence contains a dictionary. We don't flatten sequences that
|
|
197
|
+
only contain primitive types, such as strings, integers, etc. Conversely,
|
|
198
|
+
we'll only un-flatten digit sub-keys if it can be interpreted the index of
|
|
199
|
+
an array of dictionaries.
|
|
200
|
+
"""
|
|
201
|
+
return any(isinstance(item, Mapping) for item in sequence)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _flatten_mapping(
|
|
205
|
+
mapping: Mapping[str, Any],
|
|
206
|
+
*,
|
|
207
|
+
prefix: str = "",
|
|
208
|
+
recurse_on_sequence: bool = False,
|
|
209
|
+
json_string_attributes: Sequence[str] | None = None,
|
|
210
|
+
separator: str = ".",
|
|
211
|
+
) -> Iterator[tuple[str, Any]]:
|
|
212
|
+
"""
|
|
213
|
+
Flatten a nested dictionary into a list of key value pairs. If `recurse_on_sequence`
|
|
214
|
+
is True, then the function will also recursively flatten nested sequences of dictionaries.
|
|
215
|
+
If `json_string_attributes` is provided, then the function will interpret the attributes
|
|
216
|
+
in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
|
|
217
|
+
used to prefix the keys in the output list, but it's mostly used internally to facilitate
|
|
218
|
+
recursion.
|
|
219
|
+
"""
|
|
220
|
+
for key, value in mapping.items():
|
|
221
|
+
prefixed_key = f"{prefix}{separator}{key}" if prefix else key
|
|
222
|
+
if isinstance(value, Mapping):
|
|
223
|
+
if json_string_attributes and prefixed_key.endswith(
|
|
224
|
+
JSON_STRING_ATTRIBUTES
|
|
225
|
+
):
|
|
226
|
+
yield prefixed_key, json.dumps(value)
|
|
227
|
+
else:
|
|
228
|
+
yield from _flatten_mapping(
|
|
229
|
+
value,
|
|
230
|
+
prefix=prefixed_key,
|
|
231
|
+
recurse_on_sequence=recurse_on_sequence,
|
|
232
|
+
json_string_attributes=json_string_attributes,
|
|
233
|
+
separator=separator,
|
|
234
|
+
)
|
|
235
|
+
elif (
|
|
236
|
+
isinstance(value, (Sequence, np.ndarray))
|
|
237
|
+
) and recurse_on_sequence:
|
|
238
|
+
yield from _flatten_sequence(
|
|
239
|
+
value,
|
|
240
|
+
prefix=prefixed_key,
|
|
241
|
+
recurse_on_sequence=recurse_on_sequence,
|
|
242
|
+
json_string_attributes=json_string_attributes,
|
|
243
|
+
separator=separator,
|
|
244
|
+
)
|
|
245
|
+
elif value is not None:
|
|
246
|
+
yield prefixed_key, value
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _flatten_sequence(
|
|
250
|
+
sequence: Iterable[Any],
|
|
251
|
+
*,
|
|
252
|
+
prefix: str = "",
|
|
253
|
+
recurse_on_sequence: bool = False,
|
|
254
|
+
json_string_attributes: Sequence[str] | None = None,
|
|
255
|
+
separator: str = ".",
|
|
256
|
+
) -> Iterator[tuple[str, Any]]:
|
|
257
|
+
"""
|
|
258
|
+
Flatten a sequence of dictionaries into a list of key value pairs. If `recurse_on_sequence`
|
|
259
|
+
is True, then the function will also recursively flatten nested sequences of dictionaries.
|
|
260
|
+
If `json_string_attributes` is provided, then the function will interpret the attributes
|
|
261
|
+
in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
|
|
262
|
+
used to prefix the keys in the output list, but it's mostly used internally to facilitate
|
|
263
|
+
recursion.
|
|
264
|
+
"""
|
|
265
|
+
if isinstance(sequence, str) or not has_mapping(sequence):
|
|
266
|
+
yield prefix, sequence
|
|
267
|
+
for idx, obj in enumerate(sequence):
|
|
268
|
+
if not isinstance(obj, Mapping):
|
|
269
|
+
continue
|
|
270
|
+
yield from _flatten_mapping(
|
|
271
|
+
obj,
|
|
272
|
+
prefix=f"{prefix}{separator}{idx}" if prefix else f"{idx}",
|
|
273
|
+
recurse_on_sequence=recurse_on_sequence,
|
|
274
|
+
json_string_attributes=json_string_attributes,
|
|
275
|
+
separator=separator,
|
|
276
|
+
)
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import textwrap
|
|
5
|
+
from copy import copy, deepcopy
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from importlib.metadata import version
|
|
9
|
+
from random import getrandbits
|
|
10
|
+
from typing import (
|
|
11
|
+
Any,
|
|
12
|
+
Awaitable,
|
|
13
|
+
Callable,
|
|
14
|
+
Iterable,
|
|
15
|
+
Mapping,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import pandas as pd
|
|
20
|
+
from wrapt import ObjectProxy
|
|
21
|
+
|
|
22
|
+
from arize.experiments.evaluators.types import (
|
|
23
|
+
EvaluationResult,
|
|
24
|
+
JSONSerializable,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
ExperimentId = str
|
|
28
|
+
# DatasetId= str
|
|
29
|
+
# DatasetVersionId= str
|
|
30
|
+
ExampleId = str
|
|
31
|
+
RepetitionNumber = int
|
|
32
|
+
ExperimentRunId = str
|
|
33
|
+
TraceId = str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class Example:
|
|
38
|
+
"""
|
|
39
|
+
Represents an example in an experiment dataset.
|
|
40
|
+
Args:
|
|
41
|
+
id: The unique identifier for the example.
|
|
42
|
+
updated_at: The timestamp when the example was last updated.
|
|
43
|
+
input: The input data for the example.
|
|
44
|
+
output: The output data for the example.
|
|
45
|
+
metadata: Additional metadata for the example.
|
|
46
|
+
dataset_row: The original dataset row containing the example data.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
id: ExampleId = field(default_factory=str)
|
|
50
|
+
updated_at: datetime = field(default_factory=datetime.now)
|
|
51
|
+
input: Mapping[str, JSONSerializable] = field(default_factory=dict)
|
|
52
|
+
output: Mapping[str, JSONSerializable] = field(default_factory=dict)
|
|
53
|
+
metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
|
|
54
|
+
dataset_row: Mapping[str, JSONSerializable] = field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
def __post_init__(self) -> None:
|
|
57
|
+
if self.dataset_row is not None:
|
|
58
|
+
object.__setattr__(
|
|
59
|
+
self, "dataset_row", _make_read_only(self.dataset_row)
|
|
60
|
+
)
|
|
61
|
+
if "attributes.input.value" in self.dataset_row:
|
|
62
|
+
object.__setattr__(
|
|
63
|
+
self,
|
|
64
|
+
"input",
|
|
65
|
+
_make_read_only(self.dataset_row["attributes.input.value"]),
|
|
66
|
+
)
|
|
67
|
+
if "attributes.output.value" in self.dataset_row:
|
|
68
|
+
object.__setattr__(
|
|
69
|
+
self,
|
|
70
|
+
"output",
|
|
71
|
+
_make_read_only(
|
|
72
|
+
self.dataset_row["attributes.output.value"]
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
if "attributes.metadata" in self.dataset_row:
|
|
76
|
+
object.__setattr__(
|
|
77
|
+
self,
|
|
78
|
+
"metadata",
|
|
79
|
+
_make_read_only(self.dataset_row["attributes.metadata"]),
|
|
80
|
+
)
|
|
81
|
+
if "id" in self.dataset_row:
|
|
82
|
+
object.__setattr__(self, "id", self.dataset_row["id"])
|
|
83
|
+
if "updated_at" in self.dataset_row:
|
|
84
|
+
object.__setattr__(
|
|
85
|
+
self, "updated_at", self.dataset_row["updated_at"]
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
object.__setattr__(self, "input", self.input)
|
|
89
|
+
object.__setattr__(self, "output", self.output)
|
|
90
|
+
object.__setattr__(self, "metadata", self.metadata)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_dict(cls, obj: Mapping[str, Any]) -> Example:
|
|
94
|
+
return cls(
|
|
95
|
+
id=obj["id"],
|
|
96
|
+
input=obj["input"],
|
|
97
|
+
output=obj["output"],
|
|
98
|
+
metadata=obj.get("metadata") or {},
|
|
99
|
+
updated_at=obj["updated_at"],
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def __repr__(self) -> str:
|
|
103
|
+
spaces = " " * 4
|
|
104
|
+
name = self.__class__.__name__
|
|
105
|
+
identifiers = [f'{spaces}id="{self.id}",']
|
|
106
|
+
contents = []
|
|
107
|
+
for key in ("input", "output", "metadata", "dataset_row"):
|
|
108
|
+
value = getattr(self, key, None)
|
|
109
|
+
if value:
|
|
110
|
+
contents.append(
|
|
111
|
+
spaces
|
|
112
|
+
+ f"{_blue(key)}="
|
|
113
|
+
+ json.dumps(
|
|
114
|
+
_shorten(value),
|
|
115
|
+
ensure_ascii=False,
|
|
116
|
+
sort_keys=True,
|
|
117
|
+
indent=len(spaces),
|
|
118
|
+
)
|
|
119
|
+
.replace("\n", f"\n{spaces}")
|
|
120
|
+
.replace(' "..."\n', " ...\n")
|
|
121
|
+
+ ","
|
|
122
|
+
)
|
|
123
|
+
return "\n".join([f"{name}(", *identifiers, *contents, ")"])
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _shorten(obj: Any, width: int = 50) -> Any:
|
|
127
|
+
if isinstance(obj, str):
|
|
128
|
+
return textwrap.shorten(obj, width=width, placeholder="...")
|
|
129
|
+
if isinstance(obj, dict):
|
|
130
|
+
return {k: _shorten(v) for k, v in obj.items()}
|
|
131
|
+
if isinstance(obj, list):
|
|
132
|
+
if len(obj) > 2:
|
|
133
|
+
return [_shorten(v) for v in obj[:2]] + ["..."]
|
|
134
|
+
return [_shorten(v) for v in obj]
|
|
135
|
+
return obj
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _make_read_only(obj: Any) -> Any:
|
|
139
|
+
if isinstance(obj, dict):
|
|
140
|
+
return _ReadOnly({k: _make_read_only(v) for k, v in obj.items()})
|
|
141
|
+
if isinstance(obj, str):
|
|
142
|
+
return obj
|
|
143
|
+
if isinstance(obj, list):
|
|
144
|
+
return _ReadOnly(list(map(_make_read_only, obj)))
|
|
145
|
+
return obj
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _ReadOnly(ObjectProxy): # type: ignore[misc]
|
|
149
|
+
def __setitem__(self, *args: Any, **kwargs: Any) -> Any:
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def __delitem__(self, *args: Any, **kwargs: Any) -> Any:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
def __iadd__(self, *args: Any, **kwargs: Any) -> Any:
|
|
156
|
+
raise NotImplementedError
|
|
157
|
+
|
|
158
|
+
def pop(self, *args: Any, **kwargs: Any) -> Any:
|
|
159
|
+
raise NotImplementedError
|
|
160
|
+
|
|
161
|
+
def append(self, *args: Any, **kwargs: Any) -> Any:
|
|
162
|
+
raise NotImplementedError
|
|
163
|
+
|
|
164
|
+
def __copy__(self, *args: Any, **kwargs: Any) -> Any:
|
|
165
|
+
return copy(self.__wrapped__)
|
|
166
|
+
|
|
167
|
+
def __deepcopy__(self, *args: Any, **kwargs: Any) -> Any:
|
|
168
|
+
return deepcopy(self.__wrapped__)
|
|
169
|
+
|
|
170
|
+
def __repr__(self) -> str:
|
|
171
|
+
return repr(self.__wrapped__)
|
|
172
|
+
|
|
173
|
+
def __str__(self) -> str:
|
|
174
|
+
return str(self.__wrapped__)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _blue(text: str) -> str:
|
|
178
|
+
return f"\033[1m\033[94m{text}\033[0m"
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dataclass(frozen=True)
|
|
182
|
+
class TestCase:
|
|
183
|
+
example: Example
|
|
184
|
+
repetition_number: RepetitionNumber
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
EXP_ID: ExperimentId = "EXP_ID"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _exp_id() -> str:
|
|
191
|
+
suffix = getrandbits(24).to_bytes(3, "big").hex()
|
|
192
|
+
return f"{EXP_ID}_{suffix}"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass(frozen=True)
|
|
196
|
+
class ExperimentRun:
|
|
197
|
+
"""
|
|
198
|
+
Represents a single run of an experiment.
|
|
199
|
+
Args:
|
|
200
|
+
start_time: The start time of the experiment run.
|
|
201
|
+
end_time: The end time of the experiment run.
|
|
202
|
+
experiment_id: The unique identifier for the experiment.
|
|
203
|
+
dataset_example_id: The unique identifier for the dataset example.
|
|
204
|
+
repetition_number: The repetition number of the experiment run.
|
|
205
|
+
output: The output of the experiment run.
|
|
206
|
+
error: The error message if the experiment run failed.
|
|
207
|
+
id: The unique identifier for the experiment run.
|
|
208
|
+
trace_id: The trace identifier for the experiment run.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
start_time: datetime
|
|
212
|
+
end_time: datetime
|
|
213
|
+
experiment_id: ExperimentId
|
|
214
|
+
dataset_example_id: ExampleId
|
|
215
|
+
repetition_number: RepetitionNumber
|
|
216
|
+
output: JSONSerializable
|
|
217
|
+
error: str | None = None
|
|
218
|
+
id: ExperimentRunId = field(default_factory=_exp_id)
|
|
219
|
+
trace_id: TraceId | None = None
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentRun:
|
|
223
|
+
return cls(
|
|
224
|
+
start_time=obj["start_time"],
|
|
225
|
+
end_time=obj["end_time"],
|
|
226
|
+
experiment_id=obj["experiment_id"],
|
|
227
|
+
dataset_example_id=obj["dataset_example_id"],
|
|
228
|
+
repetition_number=obj.get("repetition_number") or 1,
|
|
229
|
+
output=_make_read_only(obj.get("output")),
|
|
230
|
+
error=obj.get("error"),
|
|
231
|
+
id=obj["id"],
|
|
232
|
+
trace_id=obj.get("trace_id"),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def __post_init__(self) -> None:
|
|
236
|
+
if (self.output is None) == (self.error is None):
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Must specify exactly one of experiment_run_output or error"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@dataclass(frozen=True)
|
|
243
|
+
class ExperimentEvaluationRun:
|
|
244
|
+
"""
|
|
245
|
+
Represents a single evaluation run of an experiment.
|
|
246
|
+
Args:
|
|
247
|
+
experiment_run_id: The unique identifier for the experiment run.
|
|
248
|
+
start_time: The start time of the evaluation run.
|
|
249
|
+
end_time: The end time of the evaluation run.
|
|
250
|
+
name: The name of the evaluation run.
|
|
251
|
+
annotator_kind: The kind of annotator used in the evaluation run.
|
|
252
|
+
error: The error message if the evaluation run failed.
|
|
253
|
+
result (Optional[EvaluationResult]): The result of the evaluation run.
|
|
254
|
+
id (str): The unique identifier for the evaluation run.
|
|
255
|
+
trace_id (Optional[TraceId]): The trace identifier for the evaluation run.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
experiment_run_id: ExperimentRunId
|
|
259
|
+
start_time: datetime
|
|
260
|
+
end_time: datetime
|
|
261
|
+
name: str
|
|
262
|
+
annotator_kind: str
|
|
263
|
+
error: str | None = None
|
|
264
|
+
result: EvaluationResult | None = None
|
|
265
|
+
id: str = field(default_factory=_exp_id)
|
|
266
|
+
trace_id: TraceId | None = None
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentEvaluationRun:
|
|
270
|
+
return cls(
|
|
271
|
+
experiment_run_id=obj["experiment_run_id"],
|
|
272
|
+
start_time=obj["start_time"],
|
|
273
|
+
end_time=obj["end_time"],
|
|
274
|
+
name=obj["name"],
|
|
275
|
+
annotator_kind=obj["annotator_kind"],
|
|
276
|
+
error=obj.get("error"),
|
|
277
|
+
result=EvaluationResult.from_dict(obj.get("result")),
|
|
278
|
+
id=obj["id"],
|
|
279
|
+
trace_id=obj.get("trace_id"),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def __post_init__(self) -> None:
|
|
283
|
+
if bool(self.result) == bool(self.error):
|
|
284
|
+
raise ValueError("Must specify either result or error")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
_LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def local_now() -> datetime:
|
|
291
|
+
return datetime.now(timezone.utc).astimezone(tz=_LOCAL_TIMEZONE)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@dataclass(frozen=True)
|
|
295
|
+
class _HasStats:
|
|
296
|
+
_title: str = field(repr=False, default="")
|
|
297
|
+
_timestamp: datetime = field(repr=False, default_factory=local_now)
|
|
298
|
+
stats: pd.DataFrame = field(repr=False, default_factory=pd.DataFrame)
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def title(self) -> str:
|
|
302
|
+
return f"{self._title} ({self._timestamp:%x %I:%M %p %z})"
|
|
303
|
+
|
|
304
|
+
def __str__(self) -> str:
|
|
305
|
+
try:
|
|
306
|
+
assert int(version("pandas").split(".")[0]) >= 1
|
|
307
|
+
# `tabulate` is used by pandas >= 1.0 in DataFrame.to_markdown()
|
|
308
|
+
import tabulate # noqa: F401
|
|
309
|
+
except (AssertionError, ImportError):
|
|
310
|
+
text = self.stats.__str__()
|
|
311
|
+
else:
|
|
312
|
+
text = self.stats.to_markdown(index=False)
|
|
313
|
+
return f"{self.title}\n{'-' * len(self.title)}\n" + text # type: ignore
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@dataclass(frozen=True)
|
|
317
|
+
class _TaskSummary(_HasStats):
|
|
318
|
+
"""
|
|
319
|
+
Summary statistics of experiment task executions.
|
|
320
|
+
|
|
321
|
+
**Users should not instantiate this object directly.**
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
_title: str = "Tasks Summary"
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def from_task_runs(
|
|
328
|
+
cls, n_examples: int, task_runs: Iterable[ExperimentRun | None]
|
|
329
|
+
) -> _TaskSummary:
|
|
330
|
+
df = pd.DataFrame.from_records(
|
|
331
|
+
[
|
|
332
|
+
{
|
|
333
|
+
"example_id": run.dataset_example_id,
|
|
334
|
+
"error": run.error,
|
|
335
|
+
}
|
|
336
|
+
for run in task_runs
|
|
337
|
+
if run is not None
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
n_runs = len(df)
|
|
341
|
+
n_errors = 0 if df.empty else df.loc[:, "error"].astype(bool).sum()
|
|
342
|
+
record = {
|
|
343
|
+
"n_examples": n_examples,
|
|
344
|
+
"n_runs": n_runs,
|
|
345
|
+
"n_errors": n_errors,
|
|
346
|
+
**(
|
|
347
|
+
dict(top_error=_top_string(df.loc[:, "error"]))
|
|
348
|
+
if n_errors
|
|
349
|
+
else {}
|
|
350
|
+
),
|
|
351
|
+
}
|
|
352
|
+
stats = pd.DataFrame.from_records([record])
|
|
353
|
+
summary: _TaskSummary = object.__new__(cls)
|
|
354
|
+
summary.__init__(stats=stats) # type: ignore[misc]
|
|
355
|
+
return summary
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
|
359
|
+
# Direct instantiation by users is discouraged.
|
|
360
|
+
raise NotImplementedError
|
|
361
|
+
|
|
362
|
+
@classmethod
|
|
363
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
364
|
+
# Direct sub-classing by users is discouraged.
|
|
365
|
+
raise NotImplementedError
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _top_string(s: pd.Series, length: int = 100) -> str | None:
|
|
369
|
+
if (cnt := s.dropna().str.slice(0, length).value_counts()).empty:
|
|
370
|
+
return None
|
|
371
|
+
return cast(str, cnt.sort_values(ascending=False).index[0])
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@dataclass
|
|
375
|
+
class ExperimentTaskResultFieldNames:
|
|
376
|
+
"""Column names for mapping experiment task results in a DataFrame.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
example_id: Name of column containing example IDs.
|
|
380
|
+
The ID values must match the id of the dataset rows.
|
|
381
|
+
result: Name of column containing task results
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
example_id: str
|
|
385
|
+
result: str
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
TaskOutput = JSONSerializable
|
|
389
|
+
ExampleOutput = Mapping[str, JSONSerializable]
|
|
390
|
+
ExampleMetadata = Mapping[str, JSONSerializable]
|
|
391
|
+
ExampleInput = Mapping[str, JSONSerializable]
|
|
392
|
+
ExperimentTask = (
|
|
393
|
+
Callable[[Example], TaskOutput] | Callable[[Example], Awaitable[TaskOutput]]
|
|
394
|
+
)
|
arize/models/client.py
CHANGED
|
@@ -584,7 +584,7 @@ class MLModelsClient:
|
|
|
584
584
|
# pyarrow will err if a mixed type column exist in the dataset even if
|
|
585
585
|
# the column is not specified in schema. Caveat: There may be other
|
|
586
586
|
# error conditions that we're currently not aware of.
|
|
587
|
-
pa_table = pa.Table.from_pandas(dataframe)
|
|
587
|
+
pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
|
|
588
588
|
except pa.ArrowInvalid as e:
|
|
589
589
|
logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
|
|
590
590
|
raise pa.ArrowInvalid(
|
|
@@ -660,6 +660,7 @@ class MLModelsClient:
|
|
|
660
660
|
headers=headers,
|
|
661
661
|
timeout=timeout,
|
|
662
662
|
verify=self._sdk_config.request_verify,
|
|
663
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
663
664
|
tmp_dir=tmp_dir,
|
|
664
665
|
)
|
|
665
666
|
|
|
@@ -688,6 +689,7 @@ class MLModelsClient:
|
|
|
688
689
|
port=self._sdk_config.flight_server_port,
|
|
689
690
|
scheme=self._sdk_config.flight_scheme,
|
|
690
691
|
request_verify=self._sdk_config.request_verify,
|
|
692
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
691
693
|
) as flight_client:
|
|
692
694
|
exporter = ArizeExportClient(
|
|
693
695
|
flight_client=flight_client,
|
|
@@ -732,6 +734,7 @@ class MLModelsClient:
|
|
|
732
734
|
port=self._sdk_config.flight_server_port,
|
|
733
735
|
scheme=self._sdk_config.flight_scheme,
|
|
734
736
|
request_verify=self._sdk_config.request_verify,
|
|
737
|
+
max_chunksize=self._sdk_config.pyarrow_max_chunksize,
|
|
735
738
|
) as flight_client:
|
|
736
739
|
exporter = ArizeExportClient(
|
|
737
740
|
flight_client=flight_client,
|