arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +64 -62
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/server/api/context.py +16 -0
  25. phoenix/server/api/dataloaders/__init__.py +16 -0
  26. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  27. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  28. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  29. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  30. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  31. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  33. phoenix/server/api/dataloaders/span_projects.py +33 -0
  34. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  35. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  36. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  37. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  38. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  39. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  40. phoenix/server/api/input_types/DatasetSort.py +17 -0
  41. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  42. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  43. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  44. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  45. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  46. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  47. phoenix/server/api/mutations/__init__.py +13 -0
  48. phoenix/server/api/mutations/auth.py +11 -0
  49. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  50. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  51. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  52. phoenix/server/api/mutations/project_mutations.py +42 -0
  53. phoenix/server/api/queries.py +503 -0
  54. phoenix/server/api/routers/v1/__init__.py +77 -2
  55. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  56. phoenix/server/api/routers/v1/datasets.py +861 -0
  57. phoenix/server/api/routers/v1/evaluations.py +4 -2
  58. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  59. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  60. phoenix/server/api/routers/v1/experiments.py +174 -0
  61. phoenix/server/api/routers/v1/spans.py +3 -1
  62. phoenix/server/api/routers/v1/traces.py +1 -4
  63. phoenix/server/api/schema.py +2 -303
  64. phoenix/server/api/types/AnnotatorKind.py +10 -0
  65. phoenix/server/api/types/Cluster.py +19 -19
  66. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  67. phoenix/server/api/types/Dataset.py +282 -63
  68. phoenix/server/api/types/DatasetExample.py +85 -0
  69. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  70. phoenix/server/api/types/DatasetVersion.py +14 -0
  71. phoenix/server/api/types/Dimension.py +30 -29
  72. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  73. phoenix/server/api/types/Event.py +16 -16
  74. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  75. phoenix/server/api/types/Experiment.py +135 -0
  76. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  77. phoenix/server/api/types/ExperimentComparison.py +19 -0
  78. phoenix/server/api/types/ExperimentRun.py +91 -0
  79. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  80. phoenix/server/api/types/Inferences.py +80 -0
  81. phoenix/server/api/types/InferencesRole.py +23 -0
  82. phoenix/server/api/types/Model.py +43 -42
  83. phoenix/server/api/types/Project.py +26 -12
  84. phoenix/server/api/types/Span.py +78 -2
  85. phoenix/server/api/types/TimeSeries.py +6 -6
  86. phoenix/server/api/types/Trace.py +15 -4
  87. phoenix/server/api/types/UMAPPoints.py +1 -1
  88. phoenix/server/api/types/node.py +5 -111
  89. phoenix/server/api/types/pagination.py +10 -52
  90. phoenix/server/app.py +99 -49
  91. phoenix/server/main.py +49 -27
  92. phoenix/server/openapi/docs.py +3 -0
  93. phoenix/server/static/index.js +2246 -1368
  94. phoenix/server/templates/index.html +1 -0
  95. phoenix/services.py +15 -15
  96. phoenix/session/client.py +316 -21
  97. phoenix/session/session.py +47 -37
  98. phoenix/trace/exporter.py +14 -9
  99. phoenix/trace/fixtures.py +133 -7
  100. phoenix/trace/span_evaluations.py +3 -3
  101. phoenix/trace/trace_dataset.py +6 -6
  102. phoenix/utilities/json.py +61 -0
  103. phoenix/utilities/re.py +50 -0
  104. phoenix/version.py +1 -1
  105. phoenix/server/api/types/DatasetRole.py +0 -23
  106. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  107. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  108. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  109. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,212 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from types import MappingProxyType
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Dict,
10
+ List,
11
+ Mapping,
12
+ Optional,
13
+ Protocol,
14
+ Sequence,
15
+ Union,
16
+ runtime_checkable,
17
+ )
18
+
19
+ from typing_extensions import TypeAlias
20
+
21
+ JSONSerializable: TypeAlias = Optional[Union[Dict[str, Any], List[Any], str, int, float, bool]]
22
+
23
+ ExperimentId: TypeAlias = str
24
+ DatasetId: TypeAlias = str
25
+ DatasetVersionId: TypeAlias = str
26
+ ExampleId: TypeAlias = str
27
+ RepetitionNumber: TypeAlias = int
28
+ ExperimentRunId: TypeAlias = str
29
+ TraceId: TypeAlias = str
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class Example:
34
+ id: ExampleId
35
+ updated_at: datetime
36
+ input: Mapping[str, JSONSerializable]
37
+ output: Mapping[str, JSONSerializable]
38
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=lambda: MappingProxyType({}))
39
+
40
+ @classmethod
41
+ def from_dict(cls, obj: Mapping[str, Any]) -> Example:
42
+ return cls(
43
+ input=obj["input"],
44
+ output=obj["output"],
45
+ metadata=obj.get("metadata") or {},
46
+ id=obj["id"],
47
+ updated_at=obj["updated_at"],
48
+ )
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class Dataset:
53
+ id: DatasetId
54
+ version_id: DatasetVersionId
55
+ examples: Sequence[Example]
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class TestCase:
60
+ example: Example
61
+ repetition_number: RepetitionNumber
62
+
63
+
64
+ @dataclass(frozen=True)
65
+ class Experiment:
66
+ id: ExperimentId
67
+ dataset_id: DatasetId
68
+ dataset_version_id: DatasetVersionId
69
+ project_name: Optional[str] = None
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class ExperimentResult:
74
+ result: JSONSerializable
75
+
76
+ @classmethod
77
+ def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> Optional[ExperimentResult]:
78
+ if not obj:
79
+ return None
80
+ return cls(result=obj["result"])
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class ExperimentRun:
85
+ start_time: datetime
86
+ end_time: datetime
87
+ experiment_id: ExperimentId
88
+ dataset_example_id: ExampleId
89
+ repetition_number: RepetitionNumber
90
+ output: Optional[ExperimentResult] = None
91
+ error: Optional[str] = None
92
+ id: Optional[ExperimentRunId] = None
93
+ trace_id: Optional[TraceId] = None
94
+
95
+ @classmethod
96
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentRun:
97
+ return cls(
98
+ start_time=obj["start_time"],
99
+ end_time=obj["end_time"],
100
+ experiment_id=obj["experiment_id"],
101
+ dataset_example_id=obj["dataset_example_id"],
102
+ repetition_number=obj.get("repetition_number") or 1,
103
+ output=ExperimentResult.from_dict(obj["output"]),
104
+ error=obj.get("error"),
105
+ id=obj.get("id"),
106
+ trace_id=obj.get("trace_id"),
107
+ )
108
+
109
+ def __post_init__(self) -> None:
110
+ if bool(self.output) == bool(self.error):
111
+ ValueError("Must specify either result or error")
112
+
113
+
114
+ @dataclass(frozen=True)
115
+ class EvaluationResult:
116
+ score: Optional[float] = None
117
+ label: Optional[str] = None
118
+ explanation: Optional[str] = None
119
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=lambda: MappingProxyType({}))
120
+
121
+ @classmethod
122
+ def from_dict(cls, obj: Optional[Mapping[str, Any]]) -> Optional[EvaluationResult]:
123
+ if not obj:
124
+ return None
125
+ return cls(
126
+ score=obj.get("score"),
127
+ label=obj.get("label"),
128
+ explanation=obj.get("explanation"),
129
+ metadata=obj.get("metadata") or {},
130
+ )
131
+
132
+ def __post_init__(self) -> None:
133
+ if self.score is None and not self.label and not self.explanation:
134
+ ValueError("Must specify one of score, label, or explanation")
135
+
136
+
137
+ @dataclass(frozen=True)
138
+ class ExperimentEvaluationRun:
139
+ experiment_run_id: ExperimentRunId
140
+ start_time: datetime
141
+ end_time: datetime
142
+ name: str
143
+ annotator_kind: str
144
+ error: Optional[str] = None
145
+ result: Optional[EvaluationResult] = None
146
+ id: Optional[str] = None
147
+ trace_id: Optional[TraceId] = None
148
+
149
+ @classmethod
150
+ def from_dict(cls, obj: Mapping[str, Any]) -> ExperimentEvaluationRun:
151
+ return cls(
152
+ experiment_run_id=obj["experiment_run_id"],
153
+ start_time=obj["start_time"],
154
+ end_time=obj["end_time"],
155
+ name=obj["name"],
156
+ annotator_kind=obj["annotator_kind"],
157
+ error=obj.get("error"),
158
+ result=EvaluationResult.from_dict(obj.get("result")),
159
+ id=obj.get("id"),
160
+ trace_id=obj.get("trace_id"),
161
+ )
162
+
163
+ def __post_init__(self) -> None:
164
+ if bool(self.result) == bool(self.error):
165
+ ValueError("Must specify either result or error")
166
+
167
+
168
+ class _HasName(Protocol):
169
+ name: str
170
+
171
+
172
+ class _HasKind(Protocol):
173
+ @property
174
+ def annotator_kind(self) -> str: ...
175
+
176
+
177
+ @runtime_checkable
178
+ class CanEvaluate(_HasName, _HasKind, Protocol):
179
+ def evaluate(
180
+ self,
181
+ example: Example,
182
+ experiment_run: ExperimentRun,
183
+ ) -> EvaluationResult: ...
184
+
185
+
186
+ @runtime_checkable
187
+ class CanAsyncEvaluate(_HasName, _HasKind, Protocol):
188
+ async def async_evaluate(
189
+ self,
190
+ example: Example,
191
+ experiment_run: ExperimentRun,
192
+ ) -> EvaluationResult: ...
193
+
194
+
195
+ ExperimentEvaluator: TypeAlias = Union[CanEvaluate, CanAsyncEvaluate]
196
+
197
+
198
+ # Someday we'll do type checking in unit tests.
199
+ if TYPE_CHECKING:
200
+
201
+ class _EvaluatorDummy:
202
+ annotator_kind: str
203
+ name: str
204
+
205
+ def evaluate(self, _: Example, __: ExperimentRun) -> EvaluationResult:
206
+ raise NotImplementedError
207
+
208
+ async def async_evaluate(self, _: Example, __: ExperimentRun) -> EvaluationResult:
209
+ raise NotImplementedError
210
+
211
+ _: ExperimentEvaluator
212
+ _ = _EvaluatorDummy()
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import logging
3
+ from asyncio import Queue
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime, timezone
5
6
  from itertools import islice
@@ -14,6 +15,7 @@ from typing import (
14
15
  Optional,
15
16
  Set,
16
17
  Tuple,
18
+ cast,
17
19
  )
18
20
 
19
21
  from cachetools import LRUCache
@@ -22,10 +24,11 @@ from typing_extensions import TypeAlias
22
24
 
23
25
  import phoenix.trace.v1 as pb
24
26
  from phoenix.db.insertion.evaluation import (
25
- EvaluationInsertionResult,
27
+ EvaluationInsertionEvent,
26
28
  InsertEvaluationError,
27
29
  insert_evaluation,
28
30
  )
31
+ from phoenix.db.insertion.helpers import DataManipulation, DataManipulationEvent
29
32
  from phoenix.db.insertion.span import SpanInsertionEvent, insert_span
30
33
  from phoenix.server.api.dataloaders import CacheForDataLoaders
31
34
  from phoenix.trace.schemas import Span
@@ -46,23 +49,29 @@ class BulkInserter:
46
49
  db: Callable[[], AsyncContextManager[AsyncSession]],
47
50
  *,
48
51
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
52
+ initial_batch_of_operations: Iterable[DataManipulation] = (),
49
53
  initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None,
50
54
  initial_batch_of_evaluations: Optional[Iterable[pb.Evaluation]] = None,
51
55
  sleep: float = 0.1,
52
- max_num_per_transaction: int = 1000,
56
+ max_ops_per_transaction: int = 1000,
57
+ max_queue_size: int = 1000,
53
58
  enable_prometheus: bool = False,
54
59
  ) -> None:
55
60
  """
56
61
  :param db: A function to initiate a new database session.
57
62
  :param initial_batch_of_spans: Initial batch of spans to insert.
58
63
  :param sleep: The time to sleep between bulk insertions
59
- :param max_num_per_transaction: The maximum number of items to insert in a single
60
- transaction. Multiple transactions will be used if there are more items in the batch.
64
+ :param max_ops_per_transaction: The maximum number of operations to dequeue from
65
+ the operations queue for each transaction.
66
+ :param max_queue_size: The maximum length of the operations queue.
67
+ :param enable_prometheus: Whether Prometheus is enabled.
61
68
  """
62
69
  self._db = db
63
70
  self._running = False
64
71
  self._sleep = sleep
65
- self._max_num_per_transaction = max_num_per_transaction
72
+ self._max_ops_per_transaction = max_ops_per_transaction
73
+ self._operations: Optional[Queue[DataManipulation]] = None
74
+ self._max_queue_size = max_queue_size
66
75
  self._spans: List[Tuple[Span, str]] = (
67
76
  [] if initial_batch_of_spans is None else list(initial_batch_of_spans)
68
77
  )
@@ -81,27 +90,58 @@ class BulkInserter:
81
90
 
82
91
  async def __aenter__(
83
92
  self,
84
- ) -> Tuple[Callable[[Span, str], Awaitable[None]], Callable[[pb.Evaluation], Awaitable[None]]]:
93
+ ) -> Tuple[
94
+ Callable[[Span, str], Awaitable[None]],
95
+ Callable[[pb.Evaluation], Awaitable[None]],
96
+ Callable[[DataManipulation], None],
97
+ ]:
85
98
  self._running = True
99
+ self._operations = Queue(maxsize=self._max_queue_size)
86
100
  self._task = asyncio.create_task(self._bulk_insert())
87
- return self._queue_span, self._queue_evaluation
101
+ return (
102
+ self._queue_span,
103
+ self._queue_evaluation,
104
+ self._enqueue_operation,
105
+ )
88
106
 
89
107
  async def __aexit__(self, *args: Any) -> None:
108
+ self._operations = None
90
109
  self._running = False
91
110
 
111
+ def _enqueue_operation(self, operation: DataManipulation) -> None:
112
+ cast("Queue[DataManipulation]", self._operations).put_nowait(operation)
113
+
92
114
  async def _queue_span(self, span: Span, project_name: str) -> None:
93
115
  self._spans.append((span, project_name))
94
116
 
95
117
  async def _queue_evaluation(self, evaluation: pb.Evaluation) -> None:
96
118
  self._evaluations.append(evaluation)
97
119
 
120
+ async def _process_events(self, events: Iterable[Optional[DataManipulationEvent]]) -> None: ...
121
+
98
122
  async def _bulk_insert(self) -> None:
123
+ assert isinstance(self._operations, Queue)
99
124
  spans_buffer, evaluations_buffer = None, None
100
125
  # start first insert immediately if the inserter has not run recently
101
- while self._spans or self._evaluations or self._running:
102
- if not (self._spans or self._evaluations):
126
+ while self._running or not self._operations.empty() or self._spans or self._evaluations:
127
+ if self._operations.empty() and not (self._spans or self._evaluations):
103
128
  await asyncio.sleep(self._sleep)
104
129
  continue
130
+ ops_remaining, events = self._max_ops_per_transaction, []
131
+ async with self._db() as session:
132
+ while ops_remaining and not self._operations.empty():
133
+ ops_remaining -= 1
134
+ op = await self._operations.get()
135
+ try:
136
+ async with session.begin_nested():
137
+ events.append(await op(session))
138
+ except Exception as e:
139
+ if self._enable_prometheus:
140
+ from phoenix.server.prometheus import BULK_LOADER_EXCEPTIONS
141
+
142
+ BULK_LOADER_EXCEPTIONS.inc()
143
+ logger.exception(str(e))
144
+ await self._process_events(events)
105
145
  # It's important to grab the buffers at the same time so there's
106
146
  # no race condition, since an eval insertion will fail if the span
107
147
  # it references doesn't exist. Grabbing the eval buffer later may
@@ -130,11 +170,11 @@ class BulkInserter:
130
170
 
131
171
  async def _insert_spans(self, spans: List[Tuple[Span, str]]) -> TransactionResult:
132
172
  transaction_result = TransactionResult()
133
- for i in range(0, len(spans), self._max_num_per_transaction):
173
+ for i in range(0, len(spans), self._max_ops_per_transaction):
134
174
  try:
135
175
  start = perf_counter()
136
176
  async with self._db() as session:
137
- for span, project_name in islice(spans, i, i + self._max_num_per_transaction):
177
+ for span, project_name in islice(spans, i, i + self._max_ops_per_transaction):
138
178
  if self._enable_prometheus:
139
179
  from phoenix.server.prometheus import BULK_LOADER_SPAN_INSERTIONS
140
180
 
@@ -169,16 +209,16 @@ class BulkInserter:
169
209
 
170
210
  async def _insert_evaluations(self, evaluations: List[pb.Evaluation]) -> TransactionResult:
171
211
  transaction_result = TransactionResult()
172
- for i in range(0, len(evaluations), self._max_num_per_transaction):
212
+ for i in range(0, len(evaluations), self._max_ops_per_transaction):
173
213
  try:
174
214
  start = perf_counter()
175
215
  async with self._db() as session:
176
- for evaluation in islice(evaluations, i, i + self._max_num_per_transaction):
216
+ for evaluation in islice(evaluations, i, i + self._max_ops_per_transaction):
177
217
  if self._enable_prometheus:
178
218
  from phoenix.server.prometheus import BULK_LOADER_EVALUATION_INSERTIONS
179
219
 
180
220
  BULK_LOADER_EVALUATION_INSERTIONS.inc()
181
- result: Optional[EvaluationInsertionResult] = None
221
+ result: Optional[EvaluationInsertionEvent] = None
182
222
  try:
183
223
  async with session.begin_nested():
184
224
  result = await insert_evaluation(session, evaluation)
@@ -0,0 +1,234 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from datetime import datetime, timezone
4
+ from enum import Enum
5
+ from itertools import chain
6
+ from typing import (
7
+ Any,
8
+ Awaitable,
9
+ FrozenSet,
10
+ Iterable,
11
+ Iterator,
12
+ Mapping,
13
+ Optional,
14
+ Sequence,
15
+ Union,
16
+ cast,
17
+ )
18
+
19
+ from sqlalchemy import insert, select
20
+ from sqlalchemy.ext.asyncio import AsyncSession
21
+ from typing_extensions import TypeAlias
22
+
23
+ from phoenix.db import models
24
+ from phoenix.db.insertion.helpers import DataManipulationEvent
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ DatasetId: TypeAlias = int
29
+ DatasetVersionId: TypeAlias = int
30
+ DatasetExampleId: TypeAlias = int
31
+ DatasetExampleRevisionId: TypeAlias = int
32
+ SpanRowId: TypeAlias = int
33
+ Examples: TypeAlias = Iterable[Mapping[str, Any]]
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class DatasetExampleAdditionEvent(DataManipulationEvent):
38
+ dataset_id: DatasetId
39
+
40
+
41
+ async def insert_dataset(
42
+ session: AsyncSession,
43
+ name: str,
44
+ description: Optional[str] = None,
45
+ metadata: Optional[Mapping[str, Any]] = None,
46
+ created_at: Optional[datetime] = None,
47
+ ) -> DatasetId:
48
+ id_ = await session.scalar(
49
+ insert(models.Dataset)
50
+ .values(
51
+ name=name,
52
+ description=description,
53
+ metadata_=metadata,
54
+ created_at=created_at,
55
+ )
56
+ .returning(models.Dataset.id)
57
+ )
58
+ return cast(DatasetId, id_)
59
+
60
+
61
+ async def insert_dataset_version(
62
+ session: AsyncSession,
63
+ dataset_id: DatasetId,
64
+ description: Optional[str] = None,
65
+ metadata: Optional[Mapping[str, Any]] = None,
66
+ created_at: Optional[datetime] = None,
67
+ ) -> DatasetVersionId:
68
+ id_ = await session.scalar(
69
+ insert(models.DatasetVersion)
70
+ .values(
71
+ dataset_id=dataset_id,
72
+ description=description,
73
+ metadata_=metadata,
74
+ created_at=created_at,
75
+ )
76
+ .returning(models.DatasetVersion.id)
77
+ )
78
+ return cast(DatasetVersionId, id_)
79
+
80
+
81
+ async def insert_dataset_example(
82
+ session: AsyncSession,
83
+ dataset_id: DatasetId,
84
+ span_rowid: Optional[SpanRowId] = None,
85
+ created_at: Optional[datetime] = None,
86
+ ) -> DatasetExampleId:
87
+ id_ = await session.scalar(
88
+ insert(models.DatasetExample)
89
+ .values(
90
+ dataset_id=dataset_id,
91
+ span_rowid=span_rowid,
92
+ created_at=created_at,
93
+ )
94
+ .returning(models.DatasetExample.id)
95
+ )
96
+ return cast(DatasetExampleId, id_)
97
+
98
+
99
+ class RevisionKind(Enum):
100
+ CREATE = "CREATE"
101
+ PATCH = "PATCH"
102
+ DELETE = "DELETE"
103
+
104
+ @classmethod
105
+ def _missing_(cls, v: Any) -> "RevisionKind":
106
+ if isinstance(v, str) and v and v.isascii() and not v.isupper():
107
+ return cls(v.upper())
108
+ raise ValueError(f"Invalid revision kind: {v}")
109
+
110
+
111
+ async def insert_dataset_example_revision(
112
+ session: AsyncSession,
113
+ dataset_version_id: DatasetVersionId,
114
+ dataset_example_id: DatasetExampleId,
115
+ input: Mapping[str, Any],
116
+ output: Mapping[str, Any],
117
+ metadata: Optional[Mapping[str, Any]] = None,
118
+ revision_kind: RevisionKind = RevisionKind.CREATE,
119
+ created_at: Optional[datetime] = None,
120
+ ) -> DatasetExampleRevisionId:
121
+ id_ = await session.scalar(
122
+ insert(models.DatasetExampleRevision)
123
+ .values(
124
+ dataset_version_id=dataset_version_id,
125
+ dataset_example_id=dataset_example_id,
126
+ input=input,
127
+ output=output,
128
+ metadata_=metadata,
129
+ revision_kind=revision_kind.value,
130
+ created_at=created_at,
131
+ )
132
+ .returning(models.DatasetExampleRevision.id)
133
+ )
134
+ return cast(DatasetExampleRevisionId, id_)
135
+
136
+
137
+ class DatasetAction(Enum):
138
+ CREATE = "create"
139
+ APPEND = "append"
140
+
141
+ @classmethod
142
+ def _missing_(cls, v: Any) -> "DatasetAction":
143
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
144
+ return cls(v.lower())
145
+ raise ValueError(f"Invalid dateset action: {v}")
146
+
147
+
148
+ async def add_dataset_examples(
149
+ session: AsyncSession,
150
+ name: str,
151
+ examples: Union[Examples, Awaitable[Examples]],
152
+ input_keys: Sequence[str],
153
+ output_keys: Sequence[str],
154
+ metadata_keys: Sequence[str] = (),
155
+ description: Optional[str] = None,
156
+ metadata: Optional[Mapping[str, Any]] = None,
157
+ action: DatasetAction = DatasetAction.CREATE,
158
+ ) -> Optional[DatasetExampleAdditionEvent]:
159
+ keys = DatasetKeys(frozenset(input_keys), frozenset(output_keys), frozenset(metadata_keys))
160
+ created_at = datetime.now(timezone.utc)
161
+ dataset_id: Optional[DatasetId] = None
162
+ if action is DatasetAction.APPEND and name:
163
+ dataset_id = await session.scalar(
164
+ select(models.Dataset.id).where(models.Dataset.name == name)
165
+ )
166
+ if action is DatasetAction.CREATE or dataset_id is None:
167
+ try:
168
+ dataset_id = await insert_dataset(
169
+ session=session,
170
+ name=name,
171
+ description=description,
172
+ metadata=metadata,
173
+ created_at=created_at,
174
+ )
175
+ except Exception:
176
+ logger.exception(
177
+ f"Fail to insert dataset: {input_keys=}, {output_keys=}, {metadata_keys=}"
178
+ )
179
+ raise
180
+ try:
181
+ dataset_version_id = await insert_dataset_version(
182
+ session=session,
183
+ dataset_id=dataset_id,
184
+ created_at=created_at,
185
+ )
186
+ except Exception:
187
+ logger.exception(f"Fail to insert dataset version for {dataset_id=}")
188
+ raise
189
+ for example in (await examples) if isinstance(examples, Awaitable) else examples:
190
+ try:
191
+ dataset_example_id = await insert_dataset_example(
192
+ session=session,
193
+ dataset_id=dataset_id,
194
+ created_at=created_at,
195
+ )
196
+ except Exception:
197
+ logger.exception(f"Fail to insert dataset example for {dataset_id=}")
198
+ raise
199
+ try:
200
+ await insert_dataset_example_revision(
201
+ session=session,
202
+ dataset_version_id=dataset_version_id,
203
+ dataset_example_id=dataset_example_id,
204
+ input={key: example.get(key) for key in keys.input},
205
+ output={key: example.get(key) for key in keys.output},
206
+ metadata={key: example.get(key) for key in keys.metadata},
207
+ created_at=created_at,
208
+ )
209
+ except Exception:
210
+ logger.exception(
211
+ f"Fail to insert dataset example revision for {dataset_version_id=}, "
212
+ f"{dataset_example_id=}"
213
+ )
214
+ raise
215
+ return DatasetExampleAdditionEvent(dataset_id=dataset_id)
216
+
217
+
218
+ @dataclass(frozen=True)
219
+ class DatasetKeys:
220
+ input: FrozenSet[str]
221
+ output: FrozenSet[str]
222
+ metadata: FrozenSet[str]
223
+
224
+ def __iter__(self) -> Iterator[str]:
225
+ yield from sorted(set(chain(self.input, self.output, self.metadata)))
226
+
227
+ def check_differences(self, column_headers_set: FrozenSet[str]) -> None:
228
+ for category, keys in (
229
+ ("input", self.input),
230
+ ("output", self.output),
231
+ ("metadata", self.metadata),
232
+ ):
233
+ if diff := keys.difference(column_headers_set):
234
+ raise ValueError(f"{category} keys not found in table column headers: {diff}")
@@ -15,24 +15,24 @@ class InsertEvaluationError(PhoenixException):
15
15
  pass
16
16
 
17
17
 
18
- class EvaluationInsertionResult(NamedTuple):
18
+ class EvaluationInsertionEvent(NamedTuple):
19
19
  project_rowid: int
20
20
  evaluation_name: str
21
21
 
22
22
 
23
- class SpanEvaluationInsertionEvent(EvaluationInsertionResult): ...
23
+ class SpanEvaluationInsertionEvent(EvaluationInsertionEvent): ...
24
24
 
25
25
 
26
- class TraceEvaluationInsertionEvent(EvaluationInsertionResult): ...
26
+ class TraceEvaluationInsertionEvent(EvaluationInsertionEvent): ...
27
27
 
28
28
 
29
- class DocumentEvaluationInsertionEvent(EvaluationInsertionResult): ...
29
+ class DocumentEvaluationInsertionEvent(EvaluationInsertionEvent): ...
30
30
 
31
31
 
32
32
  async def insert_evaluation(
33
33
  session: AsyncSession,
34
34
  evaluation: pb.Evaluation,
35
- ) -> Optional[EvaluationInsertionResult]:
35
+ ) -> Optional[EvaluationInsertionEvent]:
36
36
  evaluation_name = evaluation.name
37
37
  result = evaluation.result
38
38
  label = result.label.value if result.HasField("label") else None
@@ -160,7 +160,7 @@ async def _insert_document_evaluation(
160
160
  label: Optional[str],
161
161
  score: Optional[float],
162
162
  explanation: Optional[str],
163
- ) -> EvaluationInsertionResult:
163
+ ) -> EvaluationInsertionEvent:
164
164
  dialect = SupportedSQLDialect(session.bind.dialect.name)
165
165
  stmt = (
166
166
  select(
@@ -1,14 +1,25 @@
1
+ from abc import ABC
1
2
  from enum import Enum, auto
2
- from typing import Any, Mapping, Optional, Sequence
3
+ from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence
3
4
 
4
5
  from sqlalchemy import Insert, insert
5
6
  from sqlalchemy.dialects.postgresql import insert as insert_postgresql
6
7
  from sqlalchemy.dialects.sqlite import insert as insert_sqlite
7
- from typing_extensions import assert_never
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+ from typing_extensions import TypeAlias, assert_never
8
10
 
9
11
  from phoenix.db.helpers import SupportedSQLDialect
10
12
 
11
13
 
14
+ class DataManipulationEvent(ABC):
15
+ """
16
+ Execution of DML (Data Manipulation Language) statements.
17
+ """
18
+
19
+
20
+ DataManipulation: TypeAlias = Callable[[AsyncSession], Awaitable[Optional[DataManipulationEvent]]]
21
+
22
+
12
23
  class OnConflict(Enum):
13
24
  DO_NOTHING = auto()
14
25
  DO_UPDATE = auto()
@@ -0,0 +1,29 @@
1
+ from typing import Any
2
+
3
+ from sqlalchemy import JSON
4
+ from sqlalchemy.dialects import postgresql
5
+ from sqlalchemy.ext.compiler import compiles
6
+
7
+
8
+ class JSONB(JSON):
9
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
10
+ __visit_name__ = "JSONB"
11
+
12
+
13
+ @compiles(JSONB, "sqlite") # type: ignore
14
+ def _(*args: Any, **kwargs: Any) -> str:
15
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
16
+ return "JSONB"
17
+
18
+
19
+ JSON_ = (
20
+ JSON()
21
+ .with_variant(
22
+ postgresql.JSONB(), # type: ignore
23
+ "postgresql",
24
+ )
25
+ .with_variant(
26
+ JSONB(),
27
+ "sqlite",
28
+ )
29
+ )