arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,753 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import textwrap
5
+ from collections import Counter
6
+ from copy import copy, deepcopy
7
+ from dataclasses import dataclass, field, fields
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from functools import cached_property
11
+ from importlib.metadata import version
12
+ from random import getrandbits
13
+ from typing import (
14
+ Any,
15
+ Awaitable,
16
+ Callable,
17
+ Dict,
18
+ FrozenSet,
19
+ Iterable,
20
+ Iterator,
21
+ List,
22
+ Mapping,
23
+ Optional,
24
+ Tuple,
25
+ TypeVar,
26
+ Union,
27
+ cast,
28
+ overload,
29
+ )
30
+
31
+ import pandas as pd
32
+ from typing_extensions import TypeAlias
33
+ from wrapt import ObjectProxy
34
+
35
+ from phoenix.datetime_utils import local_now
36
+ from phoenix.experiments.utils import get_experiment_url
37
+
38
+
39
+ class AnnotatorKind(Enum):
40
+ CODE = "CODE"
41
+ LLM = "LLM"
42
+
43
+
44
+ JSONSerializable: TypeAlias = Optional[Union[Dict[str, Any], List[Any], str, int, float, bool]]
45
+ ExperimentId: TypeAlias = str
46
+ DatasetId: TypeAlias = str
47
+ DatasetVersionId: TypeAlias = str
48
+ ExampleId: TypeAlias = str
49
+ RepetitionNumber: TypeAlias = int
50
+ ExperimentRunId: TypeAlias = str
51
+ TraceId: TypeAlias = str
52
+
53
+ TaskOutput: TypeAlias = JSONSerializable
54
+
55
+ ExampleOutput: TypeAlias = Mapping[str, JSONSerializable]
56
+ ExampleMetadata: TypeAlias = Mapping[str, JSONSerializable]
57
+ ExampleInput: TypeAlias = Mapping[str, JSONSerializable]
58
+
59
+ Score: TypeAlias = Optional[Union[bool, int, float]]
60
+ Label: TypeAlias = Optional[str]
61
+ Explanation: TypeAlias = Optional[str]
62
+
63
+ EvaluatorName: TypeAlias = str
64
+ EvaluatorKind: TypeAlias = str
65
+ EvaluatorOutput: TypeAlias = Union[
66
+ "EvaluationResult", bool, int, float, str, Tuple[Score, Label, Explanation]
67
+ ]
68
+
69
+ DRY_RUN: ExperimentId = "DRY_RUN"
70
+
71
+
72
+ def _dry_run_id() -> str:
73
+ suffix = getrandbits(24).to_bytes(3, "big").hex()
74
+ return f"{DRY_RUN}_{suffix}"
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class Example:
79
+ id: ExampleId
80
+ updated_at: datetime
81
+ input: Mapping[str, JSONSerializable] = field(default_factory=dict)
82
+ output: Mapping[str, JSONSerializable] = field(default_factory=dict)
83
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
84
+
85
+ def __post_init__(self) -> None:
86
+ object.__setattr__(self, "input", _make_read_only(self.input))
87
+ object.__setattr__(self, "output", _make_read_only(self.output))
88
+ object.__setattr__(self, "metadata", _make_read_only(self.metadata))
89
+
90
+ @classmethod
91
+ def from_dict(cls, obj: Mapping[str, Any]) -> Example:
92
+ return cls(
93
+ input=obj["input"],
94
+ output=obj["output"],
95
+ metadata=obj.get("metadata") or {},
96
+ id=obj["id"],
97
+ updated_at=obj["updated_at"],
98
+ )
99
+
100
+ def __repr__(self) -> str:
101
+ spaces = " " * 4
102
+ name = self.__class__.__name__
103
+ identifiers = [f'{spaces}id="{self.id}",']
104
+ contents = [
105
+ spaces
106
+ + f"{_blue(key)}="
107
+ + json.dumps(
108
+ _shorten(value),
109
+ ensure_ascii=False,
110
+ sort_keys=True,
111
+ indent=len(spaces),
112
+ )
113
+ .replace("\n", f"\n{spaces}")
114
+ .replace(' "..."\n', " ...\n")
115
+ + ","
116
+ for key in ("input", "output", "metadata")
117
+ if (value := getattr(self, key, None))
118
+ ]
119
+ return "\n".join([f"{name}(", *identifiers, *contents, ")"])
120
+
121
+
122
+ @dataclass(frozen=True)
123
+ class Dataset:
124
+ id: DatasetId
125
+ version_id: DatasetVersionId
126
+ examples: Mapping[ExampleId, Example] = field(repr=False, default_factory=dict)
127
+
128
+ def __post_init__(self) -> None:
129
+ object.__setattr__(self, "examples", _ReadOnly(self.examples))
130
+
131
+ def __len__(self) -> int:
132
+ return len(self.examples)
133
+
134
+ def __iter__(self) -> Iterator[Example]:
135
+ return iter(self.examples.values())
136
+
137
+ @cached_property
138
+ def _keys(self) -> Tuple[str, ...]:
139
+ return tuple(self.examples.keys())
140
+
141
+ @overload
142
+ def __getitem__(self, key: int) -> Example: ...
143
+ @overload
144
+ def __getitem__(self, key: slice) -> List[Example]: ...
145
+ def __getitem__(self, key: Union[int, slice]) -> Union[Example, List[Example]]:
146
+ if isinstance(key, int):
147
+ return self.examples[self._keys[key]]
148
+ return [self.examples[k] for k in self._keys[key]]
149
+
150
+ def as_dataframe(self, drop_empty_columns: bool = True) -> pd.DataFrame:
151
+ df = pd.DataFrame.from_records(
152
+ [
153
+ {
154
+ "example_id": example.id,
155
+ "input": deepcopy(example.input),
156
+ "output": deepcopy(example.output),
157
+ "metadata": deepcopy(example.metadata),
158
+ }
159
+ for example in self.examples.values()
160
+ ]
161
+ ).set_index("example_id")
162
+ if drop_empty_columns:
163
+ return df.reindex([k for k, v in df.items() if v.astype(bool).any()], axis=1)
164
+ return df
165
+
166
+ @classmethod
167
+ def from_dict(cls, obj: Mapping[str, Any]) -> Dataset:
168
+ examples = tuple(map(Example.from_dict, obj.get("examples") or ()))
169
+ return cls(
170
+ id=obj["id"],
171
+ version_id=obj["version_id"],
172
+ examples={ex.id: ex for ex in examples},
173
+ )
174
+
175
+
176
+ @dataclass(frozen=True)
177
+ class TestCase:
178
+ example: Example
179
+ repetition_number: RepetitionNumber
180
+
181
+
182
+ @dataclass(frozen=True)
183
+ class Experiment:
184
+ id: ExperimentId
185
+ dataset_id: DatasetId
186
+ dataset_version_id: DatasetVersionId
187
+ repetitions: int
188
+ project_name: str = field(repr=False)
189
+
190
+ @classmethod
191
+ def from_dict(cls, obj: Mapping[str, Any]) -> Experiment:
192
+ return cls(
193
+ id=obj["id"],
194
+ dataset_id=obj["dataset_id"],
195
+ dataset_version_id=obj["dataset_version_id"],
196
+ repetitions=obj.get("repetitions") or 1,
197
+ project_name=obj.get("project_name") or "",
198
+ )
199
+
200
+
201
+ @dataclass(frozen=True)
202
+ class ExperimentRunOutput:
203
+ task_output: TaskOutput
204
+
205
+ def __post_init__(self) -> None:
206
+ object.__setattr__(self, "task_output", _make_read_only(self.task_output))
207
+
208
+ @classmethod
209
+ def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> ExperimentRunOutput:
210
+ if not obj:
211
+ return cls(task_output=None)
212
+ return cls(task_output=obj["task_output"])
213
+
214
+
215
+ @dataclass(frozen=True)
216
+ class ExperimentRun:
217
+ start_time: datetime
218
+ end_time: datetime
219
+ experiment_id: ExperimentId
220
+ dataset_example_id: ExampleId
221
+ repetition_number: RepetitionNumber
222
+ experiment_run_output: ExperimentRunOutput
223
+ error: Optional[str] = None
224
+ id: ExperimentRunId = field(default_factory=_dry_run_id)
225
+ trace_id: Optional[TraceId] = None
226
+
227
+ @property
228
+ def output(self) -> Optional[TaskOutput]:
229
+ return deepcopy(self.experiment_run_output.task_output)
230
+
231
+ @classmethod
232
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentRun:
233
+ return cls(
234
+ start_time=obj["start_time"],
235
+ end_time=obj["end_time"],
236
+ experiment_id=obj["experiment_id"],
237
+ dataset_example_id=obj["dataset_example_id"],
238
+ repetition_number=obj.get("repetition_number") or 1,
239
+ experiment_run_output=ExperimentRunOutput.from_dict(obj["experiment_run_output"]),
240
+ error=obj.get("error"),
241
+ id=obj["id"],
242
+ trace_id=obj.get("trace_id"),
243
+ )
244
+
245
+ def __post_init__(self) -> None:
246
+ if bool(self.experiment_run_output) == bool(self.error):
247
+ ValueError("Must specify exactly one of experiment_run_output or error")
248
+
249
+
250
+ @dataclass(frozen=True)
251
+ class EvaluationResult:
252
+ score: Optional[float] = None
253
+ label: Optional[str] = None
254
+ explanation: Optional[str] = None
255
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
256
+
257
+ @classmethod
258
+ def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> Optional[EvaluationResult]:
259
+ if not obj:
260
+ return None
261
+ return cls(
262
+ score=obj.get("score"),
263
+ label=obj.get("label"),
264
+ explanation=obj.get("explanation"),
265
+ metadata=obj.get("metadata") or {},
266
+ )
267
+
268
+ def __post_init__(self) -> None:
269
+ if self.score is None and not self.label:
270
+ ValueError("Must specify score or label, or both")
271
+ if self.score is None and not self.label:
272
+ object.__setattr__(self, "score", 0)
273
+ for k in ("label", "explanation"):
274
+ if (v := getattr(self, k, None)) is not None:
275
+ object.__setattr__(self, k, str(v) or None)
276
+
277
+
278
+ @dataclass(frozen=True)
279
+ class ExperimentEvaluationRun:
280
+ experiment_run_id: ExperimentRunId
281
+ start_time: datetime
282
+ end_time: datetime
283
+ name: str
284
+ annotator_kind: str
285
+ error: Optional[str] = None
286
+ result: Optional[EvaluationResult] = None
287
+ id: str = field(default_factory=_dry_run_id)
288
+ trace_id: Optional[TraceId] = None
289
+
290
+ @classmethod
291
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentEvaluationRun:
292
+ return cls(
293
+ experiment_run_id=obj["experiment_run_id"],
294
+ start_time=obj["start_time"],
295
+ end_time=obj["end_time"],
296
+ name=obj["name"],
297
+ annotator_kind=obj["annotator_kind"],
298
+ error=obj.get("error"),
299
+ result=EvaluationResult.from_dict(obj.get("result")),
300
+ id=obj["id"],
301
+ trace_id=obj.get("trace_id"),
302
+ )
303
+
304
+ def __post_init__(self) -> None:
305
+ if bool(self.result) == bool(self.error):
306
+ ValueError("Must specify either result or error")
307
+
308
+
309
+ ExperimentTask: TypeAlias = Union[
310
+ Callable[[Example], TaskOutput],
311
+ Callable[[Example], Awaitable[TaskOutput]],
312
+ ]
313
+
314
+
315
+ @dataclass(frozen=True)
316
+ class ExperimentParameters:
317
+ n_examples: int
318
+ n_repetitions: int = 1
319
+
320
+ @property
321
+ def count(self) -> int:
322
+ return self.n_examples * self.n_repetitions
323
+
324
+
325
+ @dataclass(frozen=True)
326
+ class EvaluationParameters:
327
+ eval_names: FrozenSet[str]
328
+ exp_params: ExperimentParameters
329
+
330
+
331
+ @dataclass(frozen=True)
332
+ class _HasStats:
333
+ _title: str = field(repr=False, default="")
334
+ _timestamp: datetime = field(repr=False, default_factory=local_now)
335
+ stats: pd.DataFrame = field(repr=False, default_factory=pd.DataFrame)
336
+
337
+ @property
338
+ def title(self) -> str:
339
+ return f"{self._title} ({self._timestamp:%x %I:%M %p %z})"
340
+
341
+ def __str__(self) -> str:
342
+ try:
343
+ assert int(version("pandas").split(".")[0]) >= 1
344
+ # `tabulate` is used by pandas >= 1.0 in DataFrame.to_markdown()
345
+ import tabulate # noqa: F401
346
+ except (AssertionError, ImportError):
347
+ text = self.stats.__str__()
348
+ else:
349
+ text = self.stats.to_markdown(index=False)
350
+ return f"{self.title}\n{'-'*len(self.title)}\n" + text
351
+
352
+
353
+ @dataclass(frozen=True)
354
+ class EvaluationSummary(_HasStats):
355
+ """
356
+ Summary statistics of experiment evaluations.
357
+
358
+ Users should not instantiate this directly.
359
+ """
360
+
361
+ _title: str = "Experiment Summary"
362
+
363
+ @classmethod
364
+ def from_eval_runs(
365
+ cls,
366
+ params: EvaluationParameters,
367
+ *eval_runs: Optional[ExperimentEvaluationRun],
368
+ ) -> EvaluationSummary:
369
+ df = pd.DataFrame.from_records(
370
+ [
371
+ {
372
+ "evaluator": run.name,
373
+ "error": run.error,
374
+ "score": run.result.score if run.result else None,
375
+ "label": run.result.label if run.result else None,
376
+ }
377
+ for run in eval_runs
378
+ if run is not None
379
+ ]
380
+ )
381
+ if df.empty:
382
+ df = pd.DataFrame.from_records(
383
+ [
384
+ {"evaluator": name, "error": True, "score": None, "label": None}
385
+ for name in params.eval_names
386
+ ]
387
+ )
388
+ has_error = bool(df.loc[:, "error"].astype(bool).sum())
389
+ has_score = bool(df.loc[:, "score"].dropna().count())
390
+ has_label = bool(df.loc[:, "label"].astype(bool).sum())
391
+ agg = {
392
+ **(
393
+ dict(n_errors=("error", "count"), top_error=("error", _top_string))
394
+ if has_error
395
+ else {}
396
+ ),
397
+ **(dict(n_scores=("score", "count"), avg_score=("score", "mean")) if has_score else {}),
398
+ **(
399
+ dict(
400
+ n_labels=("label", "count"),
401
+ top_2_labels=(
402
+ "label",
403
+ lambda s: (dict(Counter(s.dropna()).most_common(2)) or None),
404
+ ),
405
+ )
406
+ if has_label
407
+ else {}
408
+ ),
409
+ }
410
+ stats = (
411
+ df.groupby("evaluator").agg(**agg) # type: ignore[call-overload]
412
+ if agg
413
+ else pd.DataFrame()
414
+ )
415
+ sorted_eval_names = sorted(params.eval_names)
416
+ eval_names = pd.DataFrame(
417
+ {
418
+ "evaluator": sorted_eval_names,
419
+ "n": [params.exp_params.count] * len(sorted_eval_names),
420
+ }
421
+ ).set_index("evaluator")
422
+ stats = pd.concat([eval_names, stats], axis=1).reset_index()
423
+ summary: EvaluationSummary = object.__new__(cls)
424
+ summary.__init__(stats=stats) # type: ignore[misc]
425
+ return summary
426
+
427
+ @classmethod
428
+ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
429
+ # Direct instantiation by users is discouraged.
430
+ raise NotImplementedError
431
+
432
+ @classmethod
433
+ def __init_subclass__(cls, **kwargs: Any) -> None:
434
+ # Direct sub-classing by users is discouraged.
435
+ raise NotImplementedError
436
+
437
+
438
+ @dataclass(frozen=True)
439
+ class TaskSummary(_HasStats):
440
+ """
441
+ Summary statistics of experiment task executions.
442
+
443
+ **Users should not instantiate this object directly.**
444
+ """
445
+
446
+ _title: str = "Tasks Summary"
447
+
448
+ @classmethod
449
+ def from_task_runs(
450
+ cls,
451
+ params: ExperimentParameters,
452
+ task_runs: Iterable[Optional[ExperimentRun]],
453
+ ) -> "TaskSummary":
454
+ df = pd.DataFrame.from_records(
455
+ [
456
+ {
457
+ "example_id": run.dataset_example_id,
458
+ "error": run.error,
459
+ }
460
+ for run in task_runs
461
+ if run is not None
462
+ ]
463
+ )
464
+ n_runs = len(df)
465
+ n_errors = 0 if df.empty else df.loc[:, "error"].astype(bool).sum()
466
+ record = {
467
+ "n_examples": params.count,
468
+ "n_runs": n_runs,
469
+ "n_errors": n_errors,
470
+ **(dict(top_error=_top_string(df.loc[:, "error"])) if n_errors else {}),
471
+ }
472
+ stats = pd.DataFrame.from_records([record])
473
+ summary: TaskSummary = object.__new__(cls)
474
+ summary.__init__(stats=stats) # type: ignore[misc]
475
+ return summary
476
+
477
+ @classmethod
478
+ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
479
+ # Direct instantiation by users is discouraged.
480
+ raise NotImplementedError
481
+
482
+ @classmethod
483
+ def __init_subclass__(cls, **kwargs: Any) -> None:
484
+ # Direct sub-classing by users is discouraged.
485
+ raise NotImplementedError
486
+
487
+
488
+ def _top_string(s: "pd.Series[Any]", length: int = 100) -> Optional[str]:
489
+ if (cnt := s.dropna().str.slice(0, length).value_counts()).empty:
490
+ return None
491
+ return cast(str, cnt.sort_values(ascending=False).index[0])
492
+
493
+
494
+ @dataclass(frozen=True)
495
+ class RanExperiment(Experiment):
496
+ """
497
+ An experiment that has been run.
498
+
499
+ **Users should not instantiate this object directly.**
500
+ """
501
+
502
+ params: ExperimentParameters = field(repr=False)
503
+ dataset: Dataset = field(repr=False)
504
+ runs: Mapping[ExperimentRunId, ExperimentRun] = field(repr=False)
505
+ task_summary: TaskSummary = field(repr=False)
506
+ eval_runs: Tuple[ExperimentEvaluationRun, ...] = field(repr=False, default=())
507
+ eval_summaries: Tuple[EvaluationSummary, ...] = field(repr=False, default=())
508
+
509
+ @property
510
+ def url(self) -> str:
511
+ return get_experiment_url(dataset_id=self.dataset.id, experiment_id=self.id)
512
+
513
+ @property
514
+ def info(self) -> str:
515
+ return f"🔗 View this experiment: {self.url}"
516
+
517
+ def __post_init__(self) -> None:
518
+ runs = {
519
+ id_: (
520
+ _ExperimentRunWithExample(run, example)
521
+ if (example := self.dataset.examples.get(run.dataset_example_id))
522
+ else run
523
+ )
524
+ for id_, run in self.runs.items()
525
+ }
526
+ object.__setattr__(self, "runs", runs)
527
+
528
+ def __len__(self) -> int:
529
+ return len(self.runs)
530
+
531
+ def __iter__(self) -> Iterator[ExperimentRun]:
532
+ return iter(self.runs.values())
533
+
534
+ @cached_property
535
+ def _keys(self) -> Tuple[str, ...]:
536
+ return tuple(self.runs.keys())
537
+
538
+ @overload
539
+ def __getitem__(self, key: int) -> ExperimentRun: ...
540
+ @overload
541
+ def __getitem__(self, key: slice) -> List[ExperimentRun]: ...
542
+ def __getitem__(self, key: Union[int, slice]) -> Union[ExperimentRun, List[ExperimentRun]]:
543
+ if isinstance(key, int):
544
+ return self.runs[self._keys[key]]
545
+ return [self.runs[k] for k in self._keys[key]]
546
+
547
+ def get_evaluations(
548
+ self,
549
+ drop_empty_columns: bool = True,
550
+ ) -> pd.DataFrame:
551
+ df = pd.DataFrame.from_records(
552
+ [
553
+ {
554
+ "run_id": run.experiment_run_id,
555
+ "name": run.name,
556
+ "error": run.error,
557
+ "score": run.result.score if run.result else None,
558
+ "label": run.result.label if run.result else None,
559
+ "explanation": run.result.explanation if run.result else None,
560
+ }
561
+ for run in self.eval_runs
562
+ ]
563
+ ).set_index("run_id")
564
+ if drop_empty_columns:
565
+ df = df.reindex([k for k, v in df.items() if v.astype(bool).any()], axis=1)
566
+ return df.join(self.as_dataframe())
567
+
568
+ def as_dataframe(self, drop_empty_columns: bool = True) -> pd.DataFrame:
569
+ df = pd.DataFrame.from_records(
570
+ [
571
+ {
572
+ "run_id": run.id,
573
+ "error": run.error,
574
+ "output": deepcopy(run.experiment_run_output.task_output),
575
+ "input": deepcopy((ex := self.dataset.examples[run.dataset_example_id]).input),
576
+ "expected": deepcopy(ex.output),
577
+ "metadata": deepcopy(ex.metadata),
578
+ "example_id": run.dataset_example_id,
579
+ }
580
+ for run in self.runs.values()
581
+ ]
582
+ ).set_index("run_id")
583
+ if drop_empty_columns:
584
+ return df.reindex([k for k, v in df.items() if v.astype(bool).any()], axis=1)
585
+ return df
586
+
587
+ def add(
588
+ self,
589
+ eval_summary: EvaluationSummary,
590
+ *eval_runs: Optional[ExperimentEvaluationRun],
591
+ ) -> "RanExperiment":
592
+ return _replace(
593
+ self,
594
+ eval_runs=(*self.eval_runs, *filter(bool, eval_runs)),
595
+ eval_summaries=(*self.eval_summaries, eval_summary),
596
+ )
597
+
598
+ def __str__(self) -> str:
599
+ summaries = (*reversed(self.eval_summaries), self.task_summary)
600
+ return (
601
+ "\n"
602
+ + ("" if self.id.startswith(DRY_RUN) else f"{self.info}\n\n")
603
+ + "\n\n".join(map(str, summaries))
604
+ )
605
+
606
+ @classmethod
607
+ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
608
+ # Direct instantiation by users is discouraged.
609
+ raise NotImplementedError
610
+
611
+ @classmethod
612
+ def __init_subclass__(cls, **kwargs: Any) -> None:
613
+ # Direct sub-classing by users is discouraged.
614
+ raise NotImplementedError
615
+
616
+
617
+ def _asdict(dc: Any) -> Dict[str, Any]:
618
+ # non-recursive version of `dataclasses.asdict()`
619
+ return {field.name: getattr(dc, field.name) for field in fields(dc)}
620
+
621
+
622
+ T = TypeVar("T")
623
+
624
+
625
+ def _replace(obj: T, **kwargs: Any) -> T:
626
+ new_obj = object.__new__(obj.__class__)
627
+ new_obj.__init__(**{**_asdict(obj), **kwargs}) # type: ignore[misc]
628
+ return new_obj
629
+
630
+
631
+ def _shorten(obj: Any, width: int = 50) -> Any:
632
+ if isinstance(obj, str):
633
+ return textwrap.shorten(obj, width=width, placeholder="...")
634
+ if isinstance(obj, dict):
635
+ return {k: _shorten(v) for k, v in obj.items()}
636
+ if isinstance(obj, list):
637
+ if len(obj) > 2:
638
+ return [_shorten(v) for v in obj[:2]] + ["..."]
639
+ return [_shorten(v) for v in obj]
640
+ return obj
641
+
642
+
643
+ def _make_read_only(obj: Any) -> Any:
644
+ if isinstance(obj, dict):
645
+ return _ReadOnly({k: _make_read_only(v) for k, v in obj.items()})
646
+ if isinstance(obj, str):
647
+ return obj
648
+ if isinstance(obj, list):
649
+ return _ReadOnly(list(map(_make_read_only, obj)))
650
+ return obj
651
+
652
+
653
+ class _ReadOnly(ObjectProxy): # type: ignore[misc]
654
+ def __setitem__(self, *args: Any, **kwargs: Any) -> Any:
655
+ raise NotImplementedError
656
+
657
+ def __delitem__(self, *args: Any, **kwargs: Any) -> Any:
658
+ raise NotImplementedError
659
+
660
+ def __iadd__(self, *args: Any, **kwargs: Any) -> Any:
661
+ raise NotImplementedError
662
+
663
+ def pop(self, *args: Any, **kwargs: Any) -> Any:
664
+ raise NotImplementedError
665
+
666
+ def append(self, *args: Any, **kwargs: Any) -> Any:
667
+ raise NotImplementedError
668
+
669
+ def __copy__(self, *args: Any, **kwargs: Any) -> Any:
670
+ return copy(self.__wrapped__)
671
+
672
+ def __deepcopy__(self, *args: Any, **kwargs: Any) -> Any:
673
+ return deepcopy(self.__wrapped__)
674
+
675
+ def __repr__(self) -> str:
676
+ return repr(self.__wrapped__)
677
+
678
+ def __str__(self) -> str:
679
+ return str(self.__wrapped__)
680
+
681
+
682
+ class _ExperimentRunWithExample(ObjectProxy): # type: ignore[misc]
683
+ def __init__(self, wrapped: ExperimentRun, example: Example) -> None:
684
+ super().__init__(wrapped)
685
+ self._self_example = example
686
+
687
+ @property
688
+ def expected(self) -> ExampleOutput:
689
+ return deepcopy(self._self_example.output)
690
+
691
+ @property
692
+ def reference(self) -> ExampleOutput:
693
+ return deepcopy(self._self_example.output)
694
+
695
+ @property
696
+ def input(self) -> ExampleInput:
697
+ return deepcopy(self._self_example.input)
698
+
699
+ @property
700
+ def metadata(self) -> ExampleMetadata:
701
+ return deepcopy(self._self_example.metadata)
702
+
703
+ def __repr__(self) -> str:
704
+ spaces = " " * 4
705
+ name = self.__class__.__name__
706
+ identifiers = [
707
+ f'{spaces}id="{self.id}",',
708
+ f'{spaces}example_id="{self.dataset_example_id}",',
709
+ ]
710
+ outputs = [
711
+ *([f'{spaces}error="{self.error}",'] if self.error else []),
712
+ *(
713
+ [
714
+ f"{spaces}{_blue('output')}="
715
+ + json.dumps(
716
+ _shorten(self.output),
717
+ ensure_ascii=False,
718
+ sort_keys=True,
719
+ indent=len(spaces),
720
+ )
721
+ .replace("\n", f"\n{spaces}")
722
+ .replace(' "..."\n', " ...\n")
723
+ ]
724
+ if not self.error
725
+ else []
726
+ ),
727
+ ]
728
+ dicts = [
729
+ spaces
730
+ + f"{_blue(alias)}={{"
731
+ + (f" # {comment}" if comment else "")
732
+ + json.dumps(
733
+ _shorten(value),
734
+ ensure_ascii=False,
735
+ sort_keys=True,
736
+ indent=len(spaces),
737
+ )[1:]
738
+ .replace("\n", f"\n{spaces}")
739
+ .replace(' "..."\n', " ...\n")
740
+ + ","
741
+ for alias, value, comment in (
742
+ ("expected", self.expected, f"alias for the example.{_blue('output')} dict"),
743
+ ("reference", self.reference, f"alias for the example.{_blue('output')} dict"),
744
+ ("input", self.input, f"alias for the example.{_blue('input')} dict"),
745
+ ("metadata", self.metadata, f"alias for the example.{_blue('metadata')} dict"),
746
+ )
747
+ if value
748
+ ]
749
+ return "\n".join([f"{name}(", *identifiers, *outputs, *dicts, ")"])
750
+
751
+
752
+ def _blue(text: str) -> str:
753
+ return f"\033[1m\033[94m{text}\033[0m"