arize-phoenix 4.4.4rc4__py3-none-any.whl → 4.4.4rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (31) hide show
  1. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/METADATA +2 -2
  2. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/RECORD +30 -28
  3. phoenix/datasets/evaluators/code_evaluators.py +25 -53
  4. phoenix/datasets/evaluators/llm_evaluators.py +63 -32
  5. phoenix/datasets/evaluators/utils.py +292 -0
  6. phoenix/datasets/experiments.py +147 -82
  7. phoenix/datasets/tracing.py +19 -0
  8. phoenix/datasets/types.py +18 -52
  9. phoenix/db/insertion/dataset.py +19 -16
  10. phoenix/db/migrations/versions/10460e46d750_datasets.py +2 -2
  11. phoenix/db/models.py +8 -3
  12. phoenix/server/api/context.py +2 -0
  13. phoenix/server/api/dataloaders/__init__.py +2 -0
  14. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  15. phoenix/server/api/helpers/dataset_helpers.py +8 -7
  16. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  17. phoenix/server/api/mutations/project_mutations.py +9 -4
  18. phoenix/server/api/routers/v1/datasets.py +146 -42
  19. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -0
  20. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  21. phoenix/server/api/types/Experiment.py +5 -0
  22. phoenix/server/api/types/ExperimentRun.py +1 -1
  23. phoenix/server/api/types/ExperimentRunAnnotation.py +1 -1
  24. phoenix/server/app.py +2 -0
  25. phoenix/server/static/index.js +610 -564
  26. phoenix/session/client.py +124 -2
  27. phoenix/version.py +1 -1
  28. phoenix/datasets/evaluators/_utils.py +0 -13
  29. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/WHEEL +0 -0
  30. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/licenses/IP_NOTICE +0 -0
  31. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc5.dist-info}/licenses/LICENSE +0 -0
@@ -8,11 +8,11 @@ from itertools import product
8
8
  from typing import (
9
9
  Any,
10
10
  Awaitable,
11
- Callable,
12
- Coroutine,
11
+ Dict,
13
12
  Iterable,
14
13
  Mapping,
15
14
  Optional,
15
+ Sequence,
16
16
  Tuple,
17
17
  Type,
18
18
  Union,
@@ -42,20 +42,23 @@ from phoenix.config import (
42
42
  get_env_host,
43
43
  get_env_port,
44
44
  )
45
+ from phoenix.datasets.evaluators.utils import (
46
+ Evaluator,
47
+ EvaluatorName,
48
+ ExperimentEvaluator,
49
+ create_evaluator,
50
+ )
45
51
  from phoenix.datasets.tracing import capture_spans
46
52
  from phoenix.datasets.types import (
47
- CanAsyncEvaluate,
48
- CanEvaluate,
49
53
  Dataset,
50
54
  EvaluationResult,
51
55
  Example,
52
56
  Experiment,
53
57
  ExperimentEvaluationRun,
54
- ExperimentEvaluator,
55
58
  ExperimentResult,
56
59
  ExperimentRun,
57
60
  ExperimentRunId,
58
- JSONSerializable,
61
+ ExperimentTask,
59
62
  TestCase,
60
63
  )
61
64
  from phoenix.evals.executors import get_executor_on_sync_context
@@ -65,11 +68,6 @@ from phoenix.session.session import active_session
65
68
  from phoenix.trace.attributes import flatten
66
69
  from phoenix.utilities.json import jsonify
67
70
 
68
- ExperimentTask: TypeAlias = Union[
69
- Callable[[Example], JSONSerializable],
70
- Callable[[Example], Coroutine[None, None, JSONSerializable]],
71
- ]
72
-
73
71
 
74
72
  def _get_base_url() -> str:
75
73
  host = get_env_host()
@@ -98,10 +96,22 @@ def _get_dataset_experiments_url(*, dataset_id: str) -> str:
98
96
  return f"{_get_web_base_url()}datasets/{dataset_id}/experiments"
99
97
 
100
98
 
101
- def _phoenix_client() -> httpx.Client:
99
+ def _phoenix_clients() -> Tuple[httpx.Client, httpx.AsyncClient]:
102
100
  headers = get_env_client_headers()
103
- client = httpx.Client(base_url=_get_base_url(), headers=headers)
104
- return client
101
+ return httpx.Client(
102
+ base_url=_get_base_url(),
103
+ headers=headers,
104
+ ), httpx.AsyncClient(
105
+ base_url=_get_base_url(),
106
+ headers=headers,
107
+ )
108
+
109
+
110
+ Evaluators: TypeAlias = Union[
111
+ ExperimentEvaluator,
112
+ Sequence[ExperimentEvaluator],
113
+ Mapping[EvaluatorName, ExperimentEvaluator],
114
+ ]
105
115
 
106
116
 
107
117
  def run_experiment(
@@ -111,16 +121,17 @@ def run_experiment(
111
121
  experiment_name: Optional[str] = None,
112
122
  experiment_description: Optional[str] = None,
113
123
  experiment_metadata: Optional[Mapping[str, Any]] = None,
114
- evaluators: Optional[Union[ExperimentEvaluator, Iterable[ExperimentEvaluator]]] = None,
124
+ evaluators: Optional[Evaluators] = None,
115
125
  rate_limit_errors: Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]] = None,
116
126
  ) -> Experiment:
117
127
  # Add this to the params once supported in the UI
118
128
  repetitions = 1
119
129
  assert repetitions > 0, "Must run the experiment at least once."
130
+ evaluators_by_name = _evaluators_by_name(evaluators)
120
131
 
121
- client = _phoenix_client()
132
+ sync_client, async_client = _phoenix_clients()
122
133
 
123
- experiment_response = client.post(
134
+ experiment_response = sync_client.post(
124
135
  f"/v1/datasets/{dataset.id}/experiments",
125
136
  json={
126
137
  "version-id": dataset.version_id,
@@ -141,8 +152,8 @@ def run_experiment(
141
152
  SimpleSpanProcessor(OTLPSpanExporter(urljoin(f"{_get_base_url()}", "v1/traces")))
142
153
  )
143
154
  tracer = tracer_provider.get_tracer(__name__)
144
- root_span_name = f"Task: {task.__qualname__}"
145
- root_span_kind = CHAIN.value
155
+ root_span_name = f"Task: {_get_task_name(task)}"
156
+ root_span_kind = CHAIN
146
157
 
147
158
  dataset_experiments_url = _get_dataset_experiments_url(dataset_id=dataset.id)
148
159
  experiment_compare_url = _get_experiment_url(dataset_id=dataset.id, experiment_id=experiment_id)
@@ -207,6 +218,10 @@ def run_experiment(
207
218
  error=repr(error) if error else None,
208
219
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore[no-untyped-call]
209
220
  )
221
+ resp = sync_client.post(
222
+ f"/v1/experiments/{experiment_id}/runs", json=jsonify(experiment_run)
223
+ )
224
+ resp.raise_for_status()
210
225
  return experiment_run
211
226
 
212
227
  async def async_run_experiment(test_case: TestCase) -> ExperimentRun:
@@ -257,6 +272,10 @@ def run_experiment(
257
272
  error=repr(error) if error else None,
258
273
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore[no-untyped-call]
259
274
  )
275
+ resp = await async_client.post(
276
+ f"/v1/experiments/{experiment_id}/runs", json=jsonify(experiment_run)
277
+ )
278
+ resp.raise_for_status()
260
279
  return experiment_run
261
280
 
262
281
  rate_limited_sync_run_experiment = functools.reduce(
@@ -279,12 +298,7 @@ def run_experiment(
279
298
  TestCase(example=ex, repetition_number=rep)
280
299
  for ex, rep in product(dataset.examples, range(1, repetitions + 1))
281
300
  ]
282
- experiment_payloads, _execution_details = executor.run(test_cases)
283
- for payload in experiment_payloads:
284
- if payload is not None:
285
- resp = client.post(f"/v1/experiments/{experiment_id}/runs", json=jsonify(payload))
286
- resp.raise_for_status()
287
-
301
+ _, _execution_details = executor.run(test_cases)
288
302
  experiment = Experiment(
289
303
  id=experiment_id,
290
304
  dataset_id=dataset.id,
@@ -293,26 +307,34 @@ def run_experiment(
293
307
  )
294
308
 
295
309
  print("✅ Task runs completed.")
296
- print("🧠 Evaluation started.")
297
310
 
298
- if evaluators is not None:
299
- _evaluate_experiment(experiment, evaluators, dataset.examples, client)
311
+ if evaluators_by_name:
312
+ _evaluate_experiment(
313
+ experiment,
314
+ evaluators=evaluators_by_name,
315
+ dataset_examples=dataset.examples,
316
+ clients=(sync_client, async_client),
317
+ )
300
318
 
301
319
  return experiment
302
320
 
303
321
 
304
322
  def evaluate_experiment(
305
323
  experiment: Experiment,
306
- evaluators: Union[ExperimentEvaluator, Iterable[ExperimentEvaluator]],
324
+ evaluators: Union[
325
+ ExperimentEvaluator,
326
+ Sequence[ExperimentEvaluator],
327
+ Mapping[EvaluatorName, ExperimentEvaluator],
328
+ ],
307
329
  ) -> None:
308
- client = _phoenix_client()
330
+ sync_client, async_client = _phoenix_clients()
309
331
  dataset_id = experiment.dataset_id
310
332
  dataset_version_id = experiment.dataset_version_id
311
333
 
312
334
  dataset_examples = [
313
335
  Example.from_dict(ex)
314
336
  for ex in (
315
- client.get(
337
+ sync_client.get(
316
338
  f"/v1/datasets/{dataset_id}/examples",
317
339
  params={"version-id": str(dataset_version_id)},
318
340
  )
@@ -321,26 +343,29 @@ def evaluate_experiment(
321
343
  .get("examples", [])
322
344
  )
323
345
  ]
324
- _evaluate_experiment(experiment, evaluators, dataset_examples, client)
325
-
326
-
327
- ExperimentEvaluatorName: TypeAlias = str
346
+ _evaluate_experiment(
347
+ experiment,
348
+ evaluators=evaluators,
349
+ dataset_examples=dataset_examples,
350
+ clients=(sync_client, async_client),
351
+ )
328
352
 
329
353
 
330
354
  def _evaluate_experiment(
331
355
  experiment: Experiment,
332
- evaluators: Union[ExperimentEvaluator, Iterable[ExperimentEvaluator]],
356
+ *,
357
+ evaluators: Evaluators,
333
358
  dataset_examples: Iterable[Example],
334
- client: httpx.Client,
359
+ clients: Tuple[httpx.Client, httpx.AsyncClient],
335
360
  ) -> None:
336
- if isinstance(evaluators, (CanEvaluate, CanAsyncEvaluate)):
337
- evaluators = [evaluators]
338
-
361
+ evaluators_by_name = _evaluators_by_name(evaluators)
362
+ if not evaluators_by_name:
363
+ raise ValueError("Must specify at least one Evaluator")
339
364
  experiment_id = experiment.id
340
-
365
+ sync_client, async_client = clients
341
366
  experiment_runs = [
342
367
  ExperimentRun.from_dict(exp_run)
343
- for exp_run in client.get(f"/v1/experiments/{experiment_id}/runs").json()
368
+ for exp_run in sync_client.get(f"/v1/experiments/{experiment_id}/runs").json()
344
369
  ]
345
370
 
346
371
  # not all dataset examples have associated experiment runs, so we need to pair them up
@@ -350,9 +375,9 @@ def _evaluate_experiment(
350
375
  example = examples_by_id.get(exp_run.dataset_example_id)
351
376
  if example:
352
377
  example_run_pairs.append((deepcopy(example), exp_run))
353
- evaluation_inputs = [
354
- (example, run, evaluator.name, evaluator)
355
- for (example, run), evaluator in product(example_run_pairs, evaluators)
378
+ evaluation_input = [
379
+ (example, run, evaluator)
380
+ for (example, run), evaluator in product(example_run_pairs, evaluators_by_name.values())
356
381
  ]
357
382
 
358
383
  project_name = "evaluators"
@@ -362,36 +387,34 @@ def _evaluate_experiment(
362
387
  SimpleSpanProcessor(OTLPSpanExporter(urljoin(f"{_get_base_url()}", "v1/traces")))
363
388
  )
364
389
  tracer = tracer_provider.get_tracer(__name__)
365
- root_span_kind = "EVALUATOR"
390
+ root_span_kind = EVALUATOR
366
391
 
367
392
  def sync_evaluate_run(
368
- obj: Tuple[Example, ExperimentRun, ExperimentEvaluatorName, ExperimentEvaluator],
393
+ obj: Tuple[Example, ExperimentRun, Evaluator],
369
394
  ) -> ExperimentEvaluationRun:
370
- example, experiment_run, name, evaluator = obj
395
+ example, experiment_run, evaluator = obj
371
396
  result: Optional[EvaluationResult] = None
372
397
  error: Optional[BaseException] = None
373
398
  status = Status(StatusCode.OK)
374
- root_span_name = f"Evaluation: {name}"
399
+ root_span_name = f"Evaluation: {evaluator.name}"
375
400
  with ExitStack() as stack:
376
401
  span: Span = stack.enter_context(
377
402
  tracer.start_as_current_span(root_span_name, context=Context())
378
403
  )
379
404
  stack.enter_context(capture_spans(resource))
380
405
  try:
381
- # Do not use keyword arguments, which can fail at runtime
382
- # even when function obeys protocol, because keyword arguments
383
- # are implementation details.
384
- if not isinstance(evaluator, CanEvaluate):
385
- raise RuntimeError("Task is async but running in sync context")
386
- _output = evaluator.evaluate(example, experiment_run)
387
- if isinstance(_output, Awaitable):
388
- raise RuntimeError("Task is async but running in sync context")
389
- result = _output
406
+ result = evaluator.evaluate(
407
+ output=None if experiment_run.output is None else experiment_run.output.result,
408
+ expected=example.output,
409
+ input=example.input,
410
+ metadata=example.metadata,
411
+ )
390
412
  except BaseException as exc:
391
413
  span.record_exception(exc)
392
414
  status = Status(StatusCode.ERROR, f"{type(exc).__name__}: {exc}")
393
415
  error = exc
394
- span.set_attributes(dict(flatten(jsonify(result), recurse_on_sequence=True)))
416
+ if result:
417
+ span.set_attributes(dict(flatten(jsonify(result), recurse_on_sequence=True)))
395
418
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
396
419
  span.set_status(status)
397
420
 
@@ -400,43 +423,41 @@ def _evaluate_experiment(
400
423
  start_time=_decode_unix_nano(cast(int, span.start_time)),
401
424
  end_time=_decode_unix_nano(cast(int, span.end_time)),
402
425
  name=evaluator.name,
403
- annotator_kind=evaluator.annotator_kind,
426
+ annotator_kind=evaluator.kind,
404
427
  error=repr(error) if error else None,
405
428
  result=result,
406
429
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore[no-untyped-call]
407
430
  )
431
+ resp = sync_client.post("/v1/experiment_evaluations", json=jsonify(evaluator_payload))
432
+ resp.raise_for_status()
408
433
  return evaluator_payload
409
434
 
410
435
  async def async_evaluate_run(
411
- obj: Tuple[Example, ExperimentRun, ExperimentEvaluatorName, ExperimentEvaluator],
436
+ obj: Tuple[Example, ExperimentRun, Evaluator],
412
437
  ) -> ExperimentEvaluationRun:
413
- example, experiment_run, name, evaluator = obj
438
+ example, experiment_run, evaluator = obj
414
439
  result: Optional[EvaluationResult] = None
415
440
  error: Optional[BaseException] = None
416
441
  status = Status(StatusCode.OK)
417
- root_span_name = f"Evaluation: {name}"
442
+ root_span_name = f"Evaluation: {evaluator.name}"
418
443
  with ExitStack() as stack:
419
444
  span: Span = stack.enter_context(
420
445
  tracer.start_as_current_span(root_span_name, context=Context())
421
446
  )
422
447
  stack.enter_context(capture_spans(resource))
423
448
  try:
424
- # Do not use keyword arguments, which can fail at runtime
425
- # even when function obeys protocol, because keyword arguments
426
- # are implementation details.
427
- if isinstance(evaluator, CanAsyncEvaluate):
428
- result = await evaluator.async_evaluate(example, experiment_run)
429
- else:
430
- _output = evaluator.evaluate(example, experiment_run)
431
- if isinstance(_output, Awaitable):
432
- result = await _output
433
- else:
434
- result = _output
449
+ result = await evaluator.async_evaluate(
450
+ output=None if experiment_run.output is None else experiment_run.output.result,
451
+ expected=example.output,
452
+ input=example.input,
453
+ metadata=example.metadata,
454
+ )
435
455
  except BaseException as exc:
436
456
  span.record_exception(exc)
437
457
  status = Status(StatusCode.ERROR, f"{type(exc).__name__}: {exc}")
438
458
  error = exc
439
- span.set_attributes(dict(flatten(jsonify(result), recurse_on_sequence=True)))
459
+ if result:
460
+ span.set_attributes(dict(flatten(jsonify(result), recurse_on_sequence=True)))
440
461
  span.set_attribute(OPENINFERENCE_SPAN_KIND, root_span_kind)
441
462
  span.set_status(status)
442
463
 
@@ -445,11 +466,15 @@ def _evaluate_experiment(
445
466
  start_time=_decode_unix_nano(cast(int, span.start_time)),
446
467
  end_time=_decode_unix_nano(cast(int, span.end_time)),
447
468
  name=evaluator.name,
448
- annotator_kind=evaluator.annotator_kind,
469
+ annotator_kind=evaluator.kind,
449
470
  error=repr(error) if error else None,
450
471
  result=result,
451
472
  trace_id=_str_trace_id(span.get_span_context().trace_id), # type: ignore[no-untyped-call]
452
473
  )
474
+ resp = await async_client.post(
475
+ "/v1/experiment_evaluations", json=jsonify(evaluator_payload)
476
+ )
477
+ resp.raise_for_status()
453
478
  return evaluator_payload
454
479
 
455
480
  executor = get_executor_on_sync_context(
@@ -460,11 +485,38 @@ def _evaluate_experiment(
460
485
  fallback_return_value=None,
461
486
  tqdm_bar_format=get_tqdm_progress_bar_formatter("running experiment evaluations"),
462
487
  )
463
- evaluation_payloads, _execution_details = executor.run(evaluation_inputs)
464
- for payload in evaluation_payloads:
465
- if payload is not None:
466
- resp = client.post("/v1/experiment_evaluations", json=jsonify(payload))
467
- resp.raise_for_status()
488
+ print("🧠 Evaluation started.")
489
+ _, _execution_details = executor.run(evaluation_input)
490
+
491
+
492
+ def _evaluators_by_name(obj: Optional[Evaluators]) -> Mapping[EvaluatorName, Evaluator]:
493
+ evaluators_by_name: Dict[EvaluatorName, Evaluator] = {}
494
+ if obj is None:
495
+ return evaluators_by_name
496
+ if isinstance(mapping := obj, Mapping):
497
+ for name, value in mapping.items():
498
+ evaluator = (
499
+ create_evaluator(name=name)(value) if not isinstance(value, Evaluator) else value
500
+ )
501
+ name = evaluator.name
502
+ if name in evaluators_by_name:
503
+ raise ValueError(f"Two evaluators have the same name: {name}")
504
+ evaluators_by_name[name] = evaluator
505
+ elif isinstance(seq := obj, Sequence):
506
+ for value in seq:
507
+ evaluator = create_evaluator()(value) if not isinstance(value, Evaluator) else value
508
+ name = evaluator.name
509
+ if name in evaluators_by_name:
510
+ raise ValueError(f"Two evaluators have the same name: {name}")
511
+ evaluators_by_name[name] = evaluator
512
+ else:
513
+ assert not isinstance(obj, Mapping) and not isinstance(obj, Sequence)
514
+ evaluator = create_evaluator()(obj) if not isinstance(obj, Evaluator) else obj
515
+ name = evaluator.name
516
+ if name in evaluators_by_name:
517
+ raise ValueError(f"Two evaluators have the same name: {name}")
518
+ evaluators_by_name[name] = evaluator
519
+ return evaluators_by_name
468
520
 
469
521
 
470
522
  def _str_trace_id(id_: int) -> str:
@@ -475,11 +527,24 @@ def _decode_unix_nano(time_unix_nano: int) -> datetime:
475
527
  return datetime.fromtimestamp(time_unix_nano / 1e9, tz=timezone.utc)
476
528
 
477
529
 
530
+ def _get_task_name(task: ExperimentTask) -> str:
531
+ """
532
+ Makes a best-effort attempt to get the name of the task.
533
+ """
534
+
535
+ if isinstance(task, functools.partial):
536
+ return task.func.__qualname__
537
+ if hasattr(task, "__qualname__"):
538
+ return task.__qualname__
539
+ return str(task)
540
+
541
+
478
542
  INPUT_VALUE = SpanAttributes.INPUT_VALUE
479
543
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
480
544
  INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
481
545
  OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
482
546
  OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
483
547
 
484
- CHAIN = OpenInferenceSpanKindValues.CHAIN
548
+ CHAIN = OpenInferenceSpanKindValues.CHAIN.value
549
+ EVALUATOR = OpenInferenceSpanKindValues.EVALUATOR.value
485
550
  JSON = OpenInferenceMimeTypeValues.JSON
@@ -12,12 +12,22 @@ from wrapt import apply_patch, resolve_path, wrap_function_wrapper
12
12
 
13
13
 
14
14
  class SpanModifier:
15
+ """
16
+ A class that modifies spans with the specified resource attributes.
17
+ """
18
+
15
19
  __slots__ = ("_resource",)
16
20
 
17
21
  def __init__(self, resource: Resource) -> None:
18
22
  self._resource = resource
19
23
 
20
24
  def modify_resource(self, span: ReadableSpan) -> None:
25
+ """
26
+ Takes a span and merges in the resource attributes specified in the constructor.
27
+
28
+ Args:
29
+ span: ReadableSpan: the span to modify
30
+ """
21
31
  if (ctx := span._context) is None or ctx.span_id == INVALID_TRACE_ID:
22
32
  return
23
33
  span._resource = span._resource.merge(self._resource)
@@ -59,6 +69,15 @@ def _monkey_patch_span_init() -> Iterator[None]:
59
69
 
60
70
  @contextmanager
61
71
  def capture_spans(resource: Resource) -> Iterator[SpanModifier]:
72
+ """
73
+ A context manager that captures spans and modifies them with the specified resources.
74
+
75
+ Args:
76
+ resource: Resource: The resource to merge into the spans created within the context.
77
+
78
+ Returns:
79
+ modifier: Iterator[SpanModifier]: The span modifier that is active within the context.
80
+ """
62
81
  modifier = SpanModifier(resource)
63
82
  with _monkey_patch_span_init():
64
83
  token = _ACTIVE_MODIFIER.set(modifier)
phoenix/datasets/types.py CHANGED
@@ -2,22 +2,27 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
- from types import MappingProxyType
5
+ from enum import Enum
6
6
  from typing import (
7
- TYPE_CHECKING,
8
7
  Any,
8
+ Awaitable,
9
+ Callable,
9
10
  Dict,
10
11
  List,
11
12
  Mapping,
12
13
  Optional,
13
- Protocol,
14
14
  Sequence,
15
15
  Union,
16
- runtime_checkable,
17
16
  )
18
17
 
19
18
  from typing_extensions import TypeAlias
20
19
 
20
+
21
+ class AnnotatorKind(Enum):
22
+ CODE = "CODE"
23
+ LLM = "LLM"
24
+
25
+
21
26
  JSONSerializable: TypeAlias = Optional[Union[Dict[str, Any], List[Any], str, int, float, bool]]
22
27
 
23
28
  ExperimentId: TypeAlias = str
@@ -28,6 +33,8 @@ RepetitionNumber: TypeAlias = int
28
33
  ExperimentRunId: TypeAlias = str
29
34
  TraceId: TypeAlias = str
30
35
 
36
+ TaskOutput: TypeAlias = JSONSerializable
37
+
31
38
 
32
39
  @dataclass(frozen=True)
33
40
  class Example:
@@ -35,7 +42,7 @@ class Example:
35
42
  updated_at: datetime
36
43
  input: Mapping[str, JSONSerializable]
37
44
  output: Mapping[str, JSONSerializable]
38
- metadata: Mapping[str, JSONSerializable] = field(default_factory=lambda: MappingProxyType({}))
45
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
39
46
 
40
47
  @classmethod
41
48
  def from_dict(cls, obj: Mapping[str, Any]) -> Example:
@@ -71,7 +78,7 @@ class Experiment:
71
78
 
72
79
  @dataclass(frozen=True)
73
80
  class ExperimentResult:
74
- result: JSONSerializable
81
+ result: TaskOutput
75
82
 
76
83
  @classmethod
77
84
  def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> Optional[ExperimentResult]:
@@ -116,7 +123,7 @@ class EvaluationResult:
116
123
  score: Optional[float] = None
117
124
  label: Optional[str] = None
118
125
  explanation: Optional[str] = None
119
- metadata: Mapping[str, JSONSerializable] = field(default_factory=lambda: MappingProxyType({}))
126
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
120
127
 
121
128
  @classmethod
122
129
  def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> Optional[EvaluationResult]:
@@ -165,48 +172,7 @@ class ExperimentEvaluationRun:
165
172
  ValueError("Must specify either result or error")
166
173
 
167
174
 
168
- class _HasName(Protocol):
169
- name: str
170
-
171
-
172
- class _HasKind(Protocol):
173
- @property
174
- def annotator_kind(self) -> str: ...
175
-
176
-
177
- @runtime_checkable
178
- class CanEvaluate(_HasName, _HasKind, Protocol):
179
- def evaluate(
180
- self,
181
- example: Example,
182
- experiment_run: ExperimentRun,
183
- ) -> EvaluationResult: ...
184
-
185
-
186
- @runtime_checkable
187
- class CanAsyncEvaluate(_HasName, _HasKind, Protocol):
188
- async def async_evaluate(
189
- self,
190
- example: Example,
191
- experiment_run: ExperimentRun,
192
- ) -> EvaluationResult: ...
193
-
194
-
195
- ExperimentEvaluator: TypeAlias = Union[CanEvaluate, CanAsyncEvaluate]
196
-
197
-
198
- # Someday we'll do type checking in unit tests.
199
- if TYPE_CHECKING:
200
-
201
- class _EvaluatorDummy:
202
- annotator_kind: str
203
- name: str
204
-
205
- def evaluate(self, _: Example, __: ExperimentRun) -> EvaluationResult:
206
- raise NotImplementedError
207
-
208
- async def async_evaluate(self, _: Example, __: ExperimentRun) -> EvaluationResult:
209
- raise NotImplementedError
210
-
211
- _: ExperimentEvaluator
212
- _ = _EvaluatorDummy()
175
+ ExperimentTask: TypeAlias = Union[
176
+ Callable[[Example], TaskOutput],
177
+ Callable[[Example], Awaitable[TaskOutput]],
178
+ ]
@@ -1,17 +1,17 @@
1
1
  import logging
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from datetime import datetime, timezone
4
4
  from enum import Enum
5
5
  from itertools import chain
6
6
  from typing import (
7
7
  Any,
8
8
  Awaitable,
9
+ Dict,
9
10
  FrozenSet,
10
11
  Iterable,
11
12
  Iterator,
12
13
  Mapping,
13
14
  Optional,
14
- Sequence,
15
15
  Union,
16
16
  cast,
17
17
  )
@@ -30,7 +30,16 @@ DatasetVersionId: TypeAlias = int
30
30
  DatasetExampleId: TypeAlias = int
31
31
  DatasetExampleRevisionId: TypeAlias = int
32
32
  SpanRowId: TypeAlias = int
33
- Examples: TypeAlias = Iterable[Mapping[str, Any]]
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class ExampleContent:
37
+ input: Dict[str, Any] = field(default_factory=dict)
38
+ output: Dict[str, Any] = field(default_factory=dict)
39
+ metadata: Dict[str, Any] = field(default_factory=dict)
40
+
41
+
42
+ Examples: TypeAlias = Iterable[ExampleContent]
34
43
 
35
44
 
36
45
  @dataclass(frozen=True)
@@ -149,14 +158,10 @@ async def add_dataset_examples(
149
158
  session: AsyncSession,
150
159
  name: str,
151
160
  examples: Union[Examples, Awaitable[Examples]],
152
- input_keys: Sequence[str],
153
- output_keys: Sequence[str],
154
- metadata_keys: Sequence[str] = (),
155
161
  description: Optional[str] = None,
156
162
  metadata: Optional[Mapping[str, Any]] = None,
157
163
  action: DatasetAction = DatasetAction.CREATE,
158
164
  ) -> Optional[DatasetExampleAdditionEvent]:
159
- keys = DatasetKeys(frozenset(input_keys), frozenset(output_keys), frozenset(metadata_keys))
160
165
  created_at = datetime.now(timezone.utc)
161
166
  dataset_id: Optional[DatasetId] = None
162
167
  if action is DatasetAction.APPEND and name:
@@ -173,9 +178,7 @@ async def add_dataset_examples(
173
178
  created_at=created_at,
174
179
  )
175
180
  except Exception:
176
- logger.exception(
177
- f"Fail to insert dataset: {input_keys=}, {output_keys=}, {metadata_keys=}"
178
- )
181
+ logger.exception(f"Failed to insert dataset: {name=}")
179
182
  raise
180
183
  try:
181
184
  dataset_version_id = await insert_dataset_version(
@@ -184,7 +187,7 @@ async def add_dataset_examples(
184
187
  created_at=created_at,
185
188
  )
186
189
  except Exception:
187
- logger.exception(f"Fail to insert dataset version for {dataset_id=}")
190
+ logger.exception(f"Failed to insert dataset version for {dataset_id=}")
188
191
  raise
189
192
  for example in (await examples) if isinstance(examples, Awaitable) else examples:
190
193
  try:
@@ -194,21 +197,21 @@ async def add_dataset_examples(
194
197
  created_at=created_at,
195
198
  )
196
199
  except Exception:
197
- logger.exception(f"Fail to insert dataset example for {dataset_id=}")
200
+ logger.exception(f"Failed to insert dataset example for {dataset_id=}")
198
201
  raise
199
202
  try:
200
203
  await insert_dataset_example_revision(
201
204
  session=session,
202
205
  dataset_version_id=dataset_version_id,
203
206
  dataset_example_id=dataset_example_id,
204
- input={key: example.get(key) for key in keys.input},
205
- output={key: example.get(key) for key in keys.output},
206
- metadata={key: example.get(key) for key in keys.metadata},
207
+ input=example.input,
208
+ output=example.output,
209
+ metadata=example.metadata,
207
210
  created_at=created_at,
208
211
  )
209
212
  except Exception:
210
213
  logger.exception(
211
- f"Fail to insert dataset example revision for {dataset_version_id=}, "
214
+ f"Failed to insert dataset example revision for {dataset_version_id=}, "
212
215
  f"{dataset_example_id=}"
213
216
  )
214
217
  raise