arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. arize/__init__.py +70 -1
  2. arize/_flight/client.py +163 -43
  3. arize/_flight/types.py +1 -0
  4. arize/_generated/api_client/__init__.py +5 -1
  5. arize/_generated/api_client/api/datasets_api.py +6 -6
  6. arize/_generated/api_client/api/experiments_api.py +924 -61
  7. arize/_generated/api_client/api_client.py +1 -1
  8. arize/_generated/api_client/configuration.py +1 -1
  9. arize/_generated/api_client/exceptions.py +1 -1
  10. arize/_generated/api_client/models/__init__.py +3 -1
  11. arize/_generated/api_client/models/dataset.py +2 -2
  12. arize/_generated/api_client/models/dataset_version.py +1 -1
  13. arize/_generated/api_client/models/datasets_create_request.py +3 -3
  14. arize/_generated/api_client/models/datasets_list200_response.py +1 -1
  15. arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
  16. arize/_generated/api_client/models/error.py +1 -1
  17. arize/_generated/api_client/models/experiment.py +6 -6
  18. arize/_generated/api_client/models/experiments_create_request.py +98 -0
  19. arize/_generated/api_client/models/experiments_list200_response.py +1 -1
  20. arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
  21. arize/_generated/api_client/rest.py +1 -1
  22. arize/_generated/api_client/test/test_dataset.py +2 -1
  23. arize/_generated/api_client/test/test_dataset_version.py +1 -1
  24. arize/_generated/api_client/test/test_datasets_api.py +1 -1
  25. arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
  26. arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
  27. arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
  28. arize/_generated/api_client/test/test_error.py +1 -1
  29. arize/_generated/api_client/test/test_experiment.py +6 -1
  30. arize/_generated/api_client/test/test_experiments_api.py +23 -2
  31. arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
  32. arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
  33. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
  34. arize/_generated/api_client_README.md +13 -8
  35. arize/client.py +19 -2
  36. arize/config.py +50 -3
  37. arize/constants/config.py +8 -2
  38. arize/constants/openinference.py +14 -0
  39. arize/constants/pyarrow.py +1 -0
  40. arize/datasets/__init__.py +0 -70
  41. arize/datasets/client.py +106 -19
  42. arize/datasets/errors.py +61 -0
  43. arize/datasets/validation.py +46 -0
  44. arize/experiments/client.py +455 -0
  45. arize/experiments/evaluators/__init__.py +0 -0
  46. arize/experiments/evaluators/base.py +255 -0
  47. arize/experiments/evaluators/exceptions.py +10 -0
  48. arize/experiments/evaluators/executors.py +502 -0
  49. arize/experiments/evaluators/rate_limiters.py +277 -0
  50. arize/experiments/evaluators/types.py +122 -0
  51. arize/experiments/evaluators/utils.py +198 -0
  52. arize/experiments/functions.py +920 -0
  53. arize/experiments/tracing.py +276 -0
  54. arize/experiments/types.py +394 -0
  55. arize/models/client.py +4 -1
  56. arize/spans/client.py +16 -20
  57. arize/utils/arrow.py +4 -3
  58. arize/utils/openinference_conversion.py +56 -0
  59. arize/utils/proto.py +13 -0
  60. arize/utils/size.py +22 -0
  61. arize/version.py +1 -1
  62. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
  63. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
  64. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
  65. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,276 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import json
5
+ from contextlib import contextmanager
6
+ from contextvars import ContextVar
7
+ from threading import Lock
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ Iterable,
12
+ Iterator,
13
+ List,
14
+ Mapping,
15
+ Sequence,
16
+ cast,
17
+ )
18
+
19
+ import numpy as np
20
+ from openinference.semconv import trace
21
+ from openinference.semconv.trace import DocumentAttributes, SpanAttributes
22
+ from opentelemetry.sdk.resources import Resource
23
+ from opentelemetry.sdk.trace import ReadableSpan
24
+ from opentelemetry.trace import INVALID_TRACE_ID
25
+ from typing_extensions import assert_never
26
+ from wrapt import apply_patch, resolve_path, wrap_function_wrapper
27
+
28
+
29
+ class SpanModifier:
30
+ """
31
+ A class that modifies spans with the specified resource attributes.
32
+ """
33
+
34
+ __slots__ = ("_resource",)
35
+
36
+ def __init__(self, resource: Resource) -> None:
37
+ self._resource = resource
38
+
39
+ def modify_resource(self, span: ReadableSpan) -> None:
40
+ """
41
+ Takes a span and merges in the resource attributes specified in the constructor.
42
+
43
+ Args:
44
+ span: ReadableSpan: the span to modify
45
+
46
+ """
47
+ if (ctx := span._context) is None or ctx.span_id == INVALID_TRACE_ID:
48
+ return
49
+ span._resource = span._resource.merge(self._resource)
50
+
51
+
52
+ _ACTIVE_MODIFIER: ContextVar[SpanModifier | None] = ContextVar(
53
+ "active_modifier"
54
+ )
55
+
56
+
57
+ def override_span(
58
+ init: Callable[..., None], span: ReadableSpan, args: Any, kwargs: Any
59
+ ) -> None:
60
+ init(*args, **kwargs)
61
+ if isinstance(span_modifier := _ACTIVE_MODIFIER.get(None), SpanModifier):
62
+ span_modifier.modify_resource(span)
63
+
64
+
65
+ _SPAN_INIT_MONKEY_PATCH_LOCK = Lock()
66
+ _SPAN_INIT_MONKEY_PATCH_COUNT = 0
67
+ _SPAN_INIT_MODULE = ReadableSpan.__init__.__module__
68
+ _SPAN_INIT_NAME = ReadableSpan.__init__.__qualname__
69
+ _SPAN_INIT_PARENT, _SPAN_INIT_ATTR, _SPAN_INIT_ORIGINAL = resolve_path(
70
+ _SPAN_INIT_MODULE, _SPAN_INIT_NAME
71
+ )
72
+
73
+
74
+ @contextmanager
75
+ def _monkey_patch_span_init() -> Iterator[None]:
76
+ global _SPAN_INIT_MONKEY_PATCH_COUNT
77
+ with _SPAN_INIT_MONKEY_PATCH_LOCK:
78
+ _SPAN_INIT_MONKEY_PATCH_COUNT += 1
79
+ if _SPAN_INIT_MONKEY_PATCH_COUNT == 1:
80
+ wrap_function_wrapper(
81
+ module=_SPAN_INIT_MODULE,
82
+ name=_SPAN_INIT_NAME,
83
+ wrapper=override_span,
84
+ )
85
+ yield
86
+ with _SPAN_INIT_MONKEY_PATCH_LOCK:
87
+ _SPAN_INIT_MONKEY_PATCH_COUNT -= 1
88
+ if _SPAN_INIT_MONKEY_PATCH_COUNT == 0:
89
+ apply_patch(_SPAN_INIT_PARENT, _SPAN_INIT_ATTR, _SPAN_INIT_ORIGINAL)
90
+
91
+
92
+ @contextmanager
93
+ def capture_spans(resource: Resource) -> Iterator[SpanModifier]:
94
+ """
95
+ A context manager that captures spans and modifies them with the specified resources.
96
+
97
+ Args:
98
+ resource: Resource: The resource to merge into the spans created within the context.
99
+
100
+ Returns:
101
+ modifier: Iterator[SpanModifier]: The span modifier that is active within the context.
102
+
103
+ """
104
+ modifier = SpanModifier(resource)
105
+ with _monkey_patch_span_init():
106
+ token = _ACTIVE_MODIFIER.set(modifier)
107
+ yield modifier
108
+ _ACTIVE_MODIFIER.reset(token)
109
+
110
+
111
+ """
112
+ Span attribute keys have a special relationship with the `.` separator. When
113
+ a span attribute is ingested from protobuf, it's in the form of a key value
114
+ pair such as `("llm.token_count.completion", 123)`. What we need to do is to split
115
+ the key by the `.` separator and turn it into part of a nested dictionary such
116
+ as {"llm": {"token_count": {"completion": 123}}}. We also need to reverse this
117
+ process, which is to flatten the nested dictionary into a list of key value
118
+ pairs. This module provides functions to do both of these operations.
119
+
120
+ Note that digit keys are treated as indices of a nested array. For example,
121
+ the digits inside `("retrieval.documents.0.document.content", 'A')` and
122
+ `("retrieval.documents.1.document.content": 'B')` turn the sub-keys following
123
+ them into a nested list of dictionaries i.e.
124
+ {`retrieval: {"documents": [{"document": {"content": "A"}}, {"document":
125
+ {"content": "B"}}]}`.
126
+ """
127
+
128
+
129
+ DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA
130
+ LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
131
+ METADATA = SpanAttributes.METADATA
132
+ TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
133
+
134
+ # attributes interpreted as JSON strings during ingestion
135
+ JSON_STRING_ATTRIBUTES = (
136
+ DOCUMENT_METADATA,
137
+ LLM_PROMPT_TEMPLATE_VARIABLES,
138
+ METADATA,
139
+ TOOL_PARAMETERS,
140
+ )
141
+
142
+ SEMANTIC_CONVENTIONS: List[str] = sorted(
143
+ # e.g. "input.value", "llm.token_count.total", etc.
144
+ (
145
+ cast(str, getattr(klass, attr))
146
+ for name in dir(trace)
147
+ if name.endswith("Attributes")
148
+ and inspect.isclass(klass := getattr(trace, name))
149
+ for attr in dir(klass)
150
+ if attr.isupper()
151
+ ),
152
+ key=len,
153
+ reverse=True,
154
+ ) # sorted so the longer strings go first
155
+
156
+
157
+ def flatten(
158
+ obj: Mapping[str, Any] | Iterable[Any],
159
+ *,
160
+ prefix: str = "",
161
+ separator: str = ".",
162
+ recurse_on_sequence: bool = False,
163
+ json_string_attributes: Sequence[str] | None = None,
164
+ ) -> Iterator[tuple[str, Any]]:
165
+ """
166
+ Flatten a nested dictionary or a sequence of dictionaries into a list of
167
+ key value pairs. If `recurse_on_sequence` is True, then the function will
168
+ also recursively flatten nested sequences of dictionaries. If
169
+ `json_string_attributes` is provided, then the function will interpret the
170
+ attributes in the list as JSON strings and convert them into dictionaries.
171
+ The `prefix` argument is used to prefix the keys in the output list, but
172
+ it's mostly used internally to facilitate recursion.
173
+ """
174
+ if isinstance(obj, Mapping):
175
+ yield from _flatten_mapping(
176
+ obj,
177
+ prefix=prefix,
178
+ recurse_on_sequence=recurse_on_sequence,
179
+ json_string_attributes=json_string_attributes,
180
+ separator=separator,
181
+ )
182
+ elif isinstance(obj, Iterable):
183
+ yield from _flatten_sequence(
184
+ obj,
185
+ prefix=prefix,
186
+ recurse_on_sequence=recurse_on_sequence,
187
+ json_string_attributes=json_string_attributes,
188
+ separator=separator,
189
+ )
190
+ else:
191
+ assert_never(obj)
192
+
193
+
194
+ def has_mapping(sequence: Iterable[Any]) -> bool:
195
+ """
196
+ Check if a sequence contains a dictionary. We don't flatten sequences that
197
+ only contain primitive types, such as strings, integers, etc. Conversely,
198
+ we'll only un-flatten digit sub-keys if it can be interpreted the index of
199
+ an array of dictionaries.
200
+ """
201
+ return any(isinstance(item, Mapping) for item in sequence)
202
+
203
+
204
+ def _flatten_mapping(
205
+ mapping: Mapping[str, Any],
206
+ *,
207
+ prefix: str = "",
208
+ recurse_on_sequence: bool = False,
209
+ json_string_attributes: Sequence[str] | None = None,
210
+ separator: str = ".",
211
+ ) -> Iterator[tuple[str, Any]]:
212
+ """
213
+ Flatten a nested dictionary into a list of key value pairs. If `recurse_on_sequence`
214
+ is True, then the function will also recursively flatten nested sequences of dictionaries.
215
+ If `json_string_attributes` is provided, then the function will interpret the attributes
216
+ in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
217
+ used to prefix the keys in the output list, but it's mostly used internally to facilitate
218
+ recursion.
219
+ """
220
+ for key, value in mapping.items():
221
+ prefixed_key = f"{prefix}{separator}{key}" if prefix else key
222
+ if isinstance(value, Mapping):
223
+ if json_string_attributes and prefixed_key.endswith(
224
+ JSON_STRING_ATTRIBUTES
225
+ ):
226
+ yield prefixed_key, json.dumps(value)
227
+ else:
228
+ yield from _flatten_mapping(
229
+ value,
230
+ prefix=prefixed_key,
231
+ recurse_on_sequence=recurse_on_sequence,
232
+ json_string_attributes=json_string_attributes,
233
+ separator=separator,
234
+ )
235
+ elif (
236
+ isinstance(value, (Sequence, np.ndarray))
237
+ ) and recurse_on_sequence:
238
+ yield from _flatten_sequence(
239
+ value,
240
+ prefix=prefixed_key,
241
+ recurse_on_sequence=recurse_on_sequence,
242
+ json_string_attributes=json_string_attributes,
243
+ separator=separator,
244
+ )
245
+ elif value is not None:
246
+ yield prefixed_key, value
247
+
248
+
249
+ def _flatten_sequence(
250
+ sequence: Iterable[Any],
251
+ *,
252
+ prefix: str = "",
253
+ recurse_on_sequence: bool = False,
254
+ json_string_attributes: Sequence[str] | None = None,
255
+ separator: str = ".",
256
+ ) -> Iterator[tuple[str, Any]]:
257
+ """
258
+ Flatten a sequence of dictionaries into a list of key value pairs. If `recurse_on_sequence`
259
+ is True, then the function will also recursively flatten nested sequences of dictionaries.
260
+ If `json_string_attributes` is provided, then the function will interpret the attributes
261
+ in the list as JSON strings and convert them into dictionaries. The `prefix` argument is
262
+ used to prefix the keys in the output list, but it's mostly used internally to facilitate
263
+ recursion.
264
+ """
265
+ if isinstance(sequence, str) or not has_mapping(sequence):
266
+ yield prefix, sequence
267
+ for idx, obj in enumerate(sequence):
268
+ if not isinstance(obj, Mapping):
269
+ continue
270
+ yield from _flatten_mapping(
271
+ obj,
272
+ prefix=f"{prefix}{separator}{idx}" if prefix else f"{idx}",
273
+ recurse_on_sequence=recurse_on_sequence,
274
+ json_string_attributes=json_string_attributes,
275
+ separator=separator,
276
+ )
@@ -0,0 +1,394 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import textwrap
5
+ from copy import copy, deepcopy
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone
8
+ from importlib.metadata import version
9
+ from random import getrandbits
10
+ from typing import (
11
+ Any,
12
+ Awaitable,
13
+ Callable,
14
+ Iterable,
15
+ Mapping,
16
+ cast,
17
+ )
18
+
19
+ import pandas as pd
20
+ from wrapt import ObjectProxy
21
+
22
+ from arize.experiments.evaluators.types import (
23
+ EvaluationResult,
24
+ JSONSerializable,
25
+ )
26
+
27
+ ExperimentId = str
28
+ # DatasetId= str
29
+ # DatasetVersionId= str
30
+ ExampleId = str
31
+ RepetitionNumber = int
32
+ ExperimentRunId = str
33
+ TraceId = str
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class Example:
38
+ """
39
+ Represents an example in an experiment dataset.
40
+ Args:
41
+ id: The unique identifier for the example.
42
+ updated_at: The timestamp when the example was last updated.
43
+ input: The input data for the example.
44
+ output: The output data for the example.
45
+ metadata: Additional metadata for the example.
46
+ dataset_row: The original dataset row containing the example data.
47
+ """
48
+
49
+ id: ExampleId = field(default_factory=str)
50
+ updated_at: datetime = field(default_factory=datetime.now)
51
+ input: Mapping[str, JSONSerializable] = field(default_factory=dict)
52
+ output: Mapping[str, JSONSerializable] = field(default_factory=dict)
53
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
54
+ dataset_row: Mapping[str, JSONSerializable] = field(default_factory=dict)
55
+
56
+ def __post_init__(self) -> None:
57
+ if self.dataset_row is not None:
58
+ object.__setattr__(
59
+ self, "dataset_row", _make_read_only(self.dataset_row)
60
+ )
61
+ if "attributes.input.value" in self.dataset_row:
62
+ object.__setattr__(
63
+ self,
64
+ "input",
65
+ _make_read_only(self.dataset_row["attributes.input.value"]),
66
+ )
67
+ if "attributes.output.value" in self.dataset_row:
68
+ object.__setattr__(
69
+ self,
70
+ "output",
71
+ _make_read_only(
72
+ self.dataset_row["attributes.output.value"]
73
+ ),
74
+ )
75
+ if "attributes.metadata" in self.dataset_row:
76
+ object.__setattr__(
77
+ self,
78
+ "metadata",
79
+ _make_read_only(self.dataset_row["attributes.metadata"]),
80
+ )
81
+ if "id" in self.dataset_row:
82
+ object.__setattr__(self, "id", self.dataset_row["id"])
83
+ if "updated_at" in self.dataset_row:
84
+ object.__setattr__(
85
+ self, "updated_at", self.dataset_row["updated_at"]
86
+ )
87
+ else:
88
+ object.__setattr__(self, "input", self.input)
89
+ object.__setattr__(self, "output", self.output)
90
+ object.__setattr__(self, "metadata", self.metadata)
91
+
92
+ @classmethod
93
+ def from_dict(cls, obj: Mapping[str, Any]) -> Example:
94
+ return cls(
95
+ id=obj["id"],
96
+ input=obj["input"],
97
+ output=obj["output"],
98
+ metadata=obj.get("metadata") or {},
99
+ updated_at=obj["updated_at"],
100
+ )
101
+
102
+ def __repr__(self) -> str:
103
+ spaces = " " * 4
104
+ name = self.__class__.__name__
105
+ identifiers = [f'{spaces}id="{self.id}",']
106
+ contents = []
107
+ for key in ("input", "output", "metadata", "dataset_row"):
108
+ value = getattr(self, key, None)
109
+ if value:
110
+ contents.append(
111
+ spaces
112
+ + f"{_blue(key)}="
113
+ + json.dumps(
114
+ _shorten(value),
115
+ ensure_ascii=False,
116
+ sort_keys=True,
117
+ indent=len(spaces),
118
+ )
119
+ .replace("\n", f"\n{spaces}")
120
+ .replace(' "..."\n', " ...\n")
121
+ + ","
122
+ )
123
+ return "\n".join([f"{name}(", *identifiers, *contents, ")"])
124
+
125
+
126
+ def _shorten(obj: Any, width: int = 50) -> Any:
127
+ if isinstance(obj, str):
128
+ return textwrap.shorten(obj, width=width, placeholder="...")
129
+ if isinstance(obj, dict):
130
+ return {k: _shorten(v) for k, v in obj.items()}
131
+ if isinstance(obj, list):
132
+ if len(obj) > 2:
133
+ return [_shorten(v) for v in obj[:2]] + ["..."]
134
+ return [_shorten(v) for v in obj]
135
+ return obj
136
+
137
+
138
+ def _make_read_only(obj: Any) -> Any:
139
+ if isinstance(obj, dict):
140
+ return _ReadOnly({k: _make_read_only(v) for k, v in obj.items()})
141
+ if isinstance(obj, str):
142
+ return obj
143
+ if isinstance(obj, list):
144
+ return _ReadOnly(list(map(_make_read_only, obj)))
145
+ return obj
146
+
147
+
148
+ class _ReadOnly(ObjectProxy): # type: ignore[misc]
149
+ def __setitem__(self, *args: Any, **kwargs: Any) -> Any:
150
+ raise NotImplementedError
151
+
152
+ def __delitem__(self, *args: Any, **kwargs: Any) -> Any:
153
+ raise NotImplementedError
154
+
155
+ def __iadd__(self, *args: Any, **kwargs: Any) -> Any:
156
+ raise NotImplementedError
157
+
158
+ def pop(self, *args: Any, **kwargs: Any) -> Any:
159
+ raise NotImplementedError
160
+
161
+ def append(self, *args: Any, **kwargs: Any) -> Any:
162
+ raise NotImplementedError
163
+
164
+ def __copy__(self, *args: Any, **kwargs: Any) -> Any:
165
+ return copy(self.__wrapped__)
166
+
167
+ def __deepcopy__(self, *args: Any, **kwargs: Any) -> Any:
168
+ return deepcopy(self.__wrapped__)
169
+
170
+ def __repr__(self) -> str:
171
+ return repr(self.__wrapped__)
172
+
173
+ def __str__(self) -> str:
174
+ return str(self.__wrapped__)
175
+
176
+
177
+ def _blue(text: str) -> str:
178
+ return f"\033[1m\033[94m{text}\033[0m"
179
+
180
+
181
+ @dataclass(frozen=True)
182
+ class TestCase:
183
+ example: Example
184
+ repetition_number: RepetitionNumber
185
+
186
+
187
+ EXP_ID: ExperimentId = "EXP_ID"
188
+
189
+
190
+ def _exp_id() -> str:
191
+ suffix = getrandbits(24).to_bytes(3, "big").hex()
192
+ return f"{EXP_ID}_{suffix}"
193
+
194
+
195
+ @dataclass(frozen=True)
196
+ class ExperimentRun:
197
+ """
198
+ Represents a single run of an experiment.
199
+ Args:
200
+ start_time: The start time of the experiment run.
201
+ end_time: The end time of the experiment run.
202
+ experiment_id: The unique identifier for the experiment.
203
+ dataset_example_id: The unique identifier for the dataset example.
204
+ repetition_number: The repetition number of the experiment run.
205
+ output: The output of the experiment run.
206
+ error: The error message if the experiment run failed.
207
+ id: The unique identifier for the experiment run.
208
+ trace_id: The trace identifier for the experiment run.
209
+ """
210
+
211
+ start_time: datetime
212
+ end_time: datetime
213
+ experiment_id: ExperimentId
214
+ dataset_example_id: ExampleId
215
+ repetition_number: RepetitionNumber
216
+ output: JSONSerializable
217
+ error: str | None = None
218
+ id: ExperimentRunId = field(default_factory=_exp_id)
219
+ trace_id: TraceId | None = None
220
+
221
+ @classmethod
222
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentRun:
223
+ return cls(
224
+ start_time=obj["start_time"],
225
+ end_time=obj["end_time"],
226
+ experiment_id=obj["experiment_id"],
227
+ dataset_example_id=obj["dataset_example_id"],
228
+ repetition_number=obj.get("repetition_number") or 1,
229
+ output=_make_read_only(obj.get("output")),
230
+ error=obj.get("error"),
231
+ id=obj["id"],
232
+ trace_id=obj.get("trace_id"),
233
+ )
234
+
235
+ def __post_init__(self) -> None:
236
+ if (self.output is None) == (self.error is None):
237
+ raise ValueError(
238
+ "Must specify exactly one of experiment_run_output or error"
239
+ )
240
+
241
+
242
+ @dataclass(frozen=True)
243
+ class ExperimentEvaluationRun:
244
+ """
245
+ Represents a single evaluation run of an experiment.
246
+ Args:
247
+ experiment_run_id: The unique identifier for the experiment run.
248
+ start_time: The start time of the evaluation run.
249
+ end_time: The end time of the evaluation run.
250
+ name: The name of the evaluation run.
251
+ annotator_kind: The kind of annotator used in the evaluation run.
252
+ error: The error message if the evaluation run failed.
253
+ result (Optional[EvaluationResult]): The result of the evaluation run.
254
+ id (str): The unique identifier for the evaluation run.
255
+ trace_id (Optional[TraceId]): The trace identifier for the evaluation run.
256
+ """
257
+
258
+ experiment_run_id: ExperimentRunId
259
+ start_time: datetime
260
+ end_time: datetime
261
+ name: str
262
+ annotator_kind: str
263
+ error: str | None = None
264
+ result: EvaluationResult | None = None
265
+ id: str = field(default_factory=_exp_id)
266
+ trace_id: TraceId | None = None
267
+
268
+ @classmethod
269
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentEvaluationRun:
270
+ return cls(
271
+ experiment_run_id=obj["experiment_run_id"],
272
+ start_time=obj["start_time"],
273
+ end_time=obj["end_time"],
274
+ name=obj["name"],
275
+ annotator_kind=obj["annotator_kind"],
276
+ error=obj.get("error"),
277
+ result=EvaluationResult.from_dict(obj.get("result")),
278
+ id=obj["id"],
279
+ trace_id=obj.get("trace_id"),
280
+ )
281
+
282
+ def __post_init__(self) -> None:
283
+ if bool(self.result) == bool(self.error):
284
+ raise ValueError("Must specify either result or error")
285
+
286
+
287
+ _LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo
288
+
289
+
290
+ def local_now() -> datetime:
291
+ return datetime.now(timezone.utc).astimezone(tz=_LOCAL_TIMEZONE)
292
+
293
+
294
+ @dataclass(frozen=True)
295
+ class _HasStats:
296
+ _title: str = field(repr=False, default="")
297
+ _timestamp: datetime = field(repr=False, default_factory=local_now)
298
+ stats: pd.DataFrame = field(repr=False, default_factory=pd.DataFrame)
299
+
300
+ @property
301
+ def title(self) -> str:
302
+ return f"{self._title} ({self._timestamp:%x %I:%M %p %z})"
303
+
304
+ def __str__(self) -> str:
305
+ try:
306
+ assert int(version("pandas").split(".")[0]) >= 1
307
+ # `tabulate` is used by pandas >= 1.0 in DataFrame.to_markdown()
308
+ import tabulate # noqa: F401
309
+ except (AssertionError, ImportError):
310
+ text = self.stats.__str__()
311
+ else:
312
+ text = self.stats.to_markdown(index=False)
313
+ return f"{self.title}\n{'-' * len(self.title)}\n" + text # type: ignore
314
+
315
+
316
+ @dataclass(frozen=True)
317
+ class _TaskSummary(_HasStats):
318
+ """
319
+ Summary statistics of experiment task executions.
320
+
321
+ **Users should not instantiate this object directly.**
322
+ """
323
+
324
+ _title: str = "Tasks Summary"
325
+
326
+ @classmethod
327
+ def from_task_runs(
328
+ cls, n_examples: int, task_runs: Iterable[ExperimentRun | None]
329
+ ) -> _TaskSummary:
330
+ df = pd.DataFrame.from_records(
331
+ [
332
+ {
333
+ "example_id": run.dataset_example_id,
334
+ "error": run.error,
335
+ }
336
+ for run in task_runs
337
+ if run is not None
338
+ ]
339
+ )
340
+ n_runs = len(df)
341
+ n_errors = 0 if df.empty else df.loc[:, "error"].astype(bool).sum()
342
+ record = {
343
+ "n_examples": n_examples,
344
+ "n_runs": n_runs,
345
+ "n_errors": n_errors,
346
+ **(
347
+ dict(top_error=_top_string(df.loc[:, "error"]))
348
+ if n_errors
349
+ else {}
350
+ ),
351
+ }
352
+ stats = pd.DataFrame.from_records([record])
353
+ summary: _TaskSummary = object.__new__(cls)
354
+ summary.__init__(stats=stats) # type: ignore[misc]
355
+ return summary
356
+
357
+ @classmethod
358
+ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
359
+ # Direct instantiation by users is discouraged.
360
+ raise NotImplementedError
361
+
362
+ @classmethod
363
+ def __init_subclass__(cls, **kwargs: Any) -> None:
364
+ # Direct sub-classing by users is discouraged.
365
+ raise NotImplementedError
366
+
367
+
368
+ def _top_string(s: pd.Series, length: int = 100) -> str | None:
369
+ if (cnt := s.dropna().str.slice(0, length).value_counts()).empty:
370
+ return None
371
+ return cast(str, cnt.sort_values(ascending=False).index[0])
372
+
373
+
374
+ @dataclass
375
+ class ExperimentTaskResultFieldNames:
376
+ """Column names for mapping experiment task results in a DataFrame.
377
+
378
+ Args:
379
+ example_id: Name of column containing example IDs.
380
+ The ID values must match the id of the dataset rows.
381
+ result: Name of column containing task results
382
+ """
383
+
384
+ example_id: str
385
+ result: str
386
+
387
+
388
+ TaskOutput = JSONSerializable
389
+ ExampleOutput = Mapping[str, JSONSerializable]
390
+ ExampleMetadata = Mapping[str, JSONSerializable]
391
+ ExampleInput = Mapping[str, JSONSerializable]
392
+ ExperimentTask = (
393
+ Callable[[Example], TaskOutput] | Callable[[Example], Awaitable[TaskOutput]]
394
+ )
arize/models/client.py CHANGED
@@ -584,7 +584,7 @@ class MLModelsClient:
584
584
  # pyarrow will err if a mixed type column exist in the dataset even if
585
585
  # the column is not specified in schema. Caveat: There may be other
586
586
  # error conditions that we're currently not aware of.
587
- pa_table = pa.Table.from_pandas(dataframe)
587
+ pa_table = pa.Table.from_pandas(dataframe, preserve_index=False)
588
588
  except pa.ArrowInvalid as e:
589
589
  logger.error(f"{INVALID_ARROW_CONVERSION_MSG}: {str(e)}")
590
590
  raise pa.ArrowInvalid(
@@ -660,6 +660,7 @@ class MLModelsClient:
660
660
  headers=headers,
661
661
  timeout=timeout,
662
662
  verify=self._sdk_config.request_verify,
663
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
663
664
  tmp_dir=tmp_dir,
664
665
  )
665
666
 
@@ -688,6 +689,7 @@ class MLModelsClient:
688
689
  port=self._sdk_config.flight_server_port,
689
690
  scheme=self._sdk_config.flight_scheme,
690
691
  request_verify=self._sdk_config.request_verify,
692
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
691
693
  ) as flight_client:
692
694
  exporter = ArizeExportClient(
693
695
  flight_client=flight_client,
@@ -732,6 +734,7 @@ class MLModelsClient:
732
734
  port=self._sdk_config.flight_server_port,
733
735
  scheme=self._sdk_config.flight_scheme,
734
736
  request_verify=self._sdk_config.request_verify,
737
+ max_chunksize=self._sdk_config.pyarrow_max_chunksize,
735
738
  ) as flight_client:
736
739
  exporter = ArizeExportClient(
737
740
  flight_client=flight_client,