arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (112) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/RECORD +111 -55
  3. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +21 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datasets/__init__.py +0 -0
  10. phoenix/datasets/evaluators.py +275 -0
  11. phoenix/datasets/experiments.py +469 -0
  12. phoenix/datasets/tracing.py +66 -0
  13. phoenix/datasets/types.py +212 -0
  14. phoenix/db/bulk_inserter.py +54 -14
  15. phoenix/db/insertion/dataset.py +234 -0
  16. phoenix/db/insertion/evaluation.py +6 -6
  17. phoenix/db/insertion/helpers.py +13 -2
  18. phoenix/db/migrations/types.py +29 -0
  19. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  20. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  21. phoenix/db/models.py +230 -3
  22. phoenix/inferences/fixtures.py +23 -23
  23. phoenix/inferences/inferences.py +7 -7
  24. phoenix/inferences/validation.py +1 -1
  25. phoenix/server/api/context.py +16 -0
  26. phoenix/server/api/dataloaders/__init__.py +16 -0
  27. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  28. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  29. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  30. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  31. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  32. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  34. phoenix/server/api/dataloaders/span_projects.py +33 -0
  35. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  36. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  37. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  38. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  39. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  40. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  41. phoenix/server/api/input_types/DatasetSort.py +17 -0
  42. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  43. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  44. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  45. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  46. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  47. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  48. phoenix/server/api/mutations/__init__.py +13 -0
  49. phoenix/server/api/mutations/auth.py +11 -0
  50. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  51. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  52. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  53. phoenix/server/api/mutations/project_mutations.py +42 -0
  54. phoenix/server/api/openapi/__init__.py +0 -0
  55. phoenix/server/api/openapi/main.py +6 -0
  56. phoenix/server/api/openapi/schema.py +15 -0
  57. phoenix/server/api/queries.py +503 -0
  58. phoenix/server/api/routers/v1/__init__.py +77 -2
  59. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  60. phoenix/server/api/routers/v1/datasets.py +861 -0
  61. phoenix/server/api/routers/v1/evaluations.py +4 -2
  62. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  63. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  64. phoenix/server/api/routers/v1/experiments.py +174 -0
  65. phoenix/server/api/routers/v1/spans.py +3 -1
  66. phoenix/server/api/routers/v1/traces.py +1 -4
  67. phoenix/server/api/schema.py +2 -303
  68. phoenix/server/api/types/AnnotatorKind.py +10 -0
  69. phoenix/server/api/types/Cluster.py +19 -19
  70. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  71. phoenix/server/api/types/Dataset.py +282 -63
  72. phoenix/server/api/types/DatasetExample.py +85 -0
  73. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  74. phoenix/server/api/types/DatasetVersion.py +14 -0
  75. phoenix/server/api/types/Dimension.py +30 -29
  76. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  77. phoenix/server/api/types/Event.py +16 -16
  78. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  79. phoenix/server/api/types/Experiment.py +135 -0
  80. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  81. phoenix/server/api/types/ExperimentComparison.py +19 -0
  82. phoenix/server/api/types/ExperimentRun.py +91 -0
  83. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  84. phoenix/server/api/types/Inferences.py +80 -0
  85. phoenix/server/api/types/InferencesRole.py +23 -0
  86. phoenix/server/api/types/Model.py +43 -42
  87. phoenix/server/api/types/Project.py +26 -12
  88. phoenix/server/api/types/Span.py +78 -2
  89. phoenix/server/api/types/TimeSeries.py +6 -6
  90. phoenix/server/api/types/Trace.py +15 -4
  91. phoenix/server/api/types/UMAPPoints.py +1 -1
  92. phoenix/server/api/types/node.py +5 -111
  93. phoenix/server/api/types/pagination.py +10 -52
  94. phoenix/server/app.py +99 -49
  95. phoenix/server/main.py +49 -27
  96. phoenix/server/openapi/docs.py +3 -0
  97. phoenix/server/static/index.js +2246 -1368
  98. phoenix/server/templates/index.html +1 -0
  99. phoenix/services.py +15 -15
  100. phoenix/session/client.py +316 -21
  101. phoenix/session/session.py +47 -37
  102. phoenix/trace/exporter.py +14 -9
  103. phoenix/trace/fixtures.py +133 -7
  104. phoenix/trace/span_evaluations.py +3 -3
  105. phoenix/trace/trace_dataset.py +6 -6
  106. phoenix/utilities/json.py +61 -0
  107. phoenix/utilities/re.py +50 -0
  108. phoenix/version.py +1 -1
  109. phoenix/server/api/types/DatasetRole.py +0 -23
  110. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/IP_NOTICE +0 -0
  111. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/LICENSE +0 -0
  112. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,291 @@
1
+ """datasets
2
+
3
+ Revision ID: 10460e46d750
4
+ Revises: cf03bd6bae1d
5
+ Create Date: 2024-05-10 11:24:23.985834
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+ from phoenix.db.migrations.types import JSON_
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = "10460e46d750"
17
+ down_revision: Union[str, None] = "cf03bd6bae1d"
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ op.create_table(
24
+ "datasets",
25
+ sa.Column("id", sa.Integer, primary_key=True),
26
+ sa.Column("name", sa.String, nullable=False, unique=True),
27
+ sa.Column("description", sa.String, nullable=True),
28
+ sa.Column("metadata", JSON_, nullable=False),
29
+ sa.Column(
30
+ "created_at",
31
+ sa.TIMESTAMP(timezone=True),
32
+ nullable=False,
33
+ server_default=sa.func.now(),
34
+ ),
35
+ sa.Column(
36
+ "updated_at",
37
+ sa.TIMESTAMP(timezone=True),
38
+ nullable=False,
39
+ server_default=sa.func.now(),
40
+ onupdate=sa.func.now(),
41
+ ),
42
+ )
43
+ op.create_table(
44
+ "dataset_versions",
45
+ sa.Column("id", sa.Integer, primary_key=True),
46
+ sa.Column(
47
+ "dataset_id",
48
+ sa.Integer,
49
+ sa.ForeignKey("datasets.id", ondelete="CASCADE"),
50
+ nullable=False,
51
+ index=True,
52
+ ),
53
+ sa.Column("description", sa.String, nullable=True),
54
+ sa.Column("metadata", JSON_, nullable=False),
55
+ sa.Column(
56
+ "created_at",
57
+ sa.TIMESTAMP(timezone=True),
58
+ nullable=False,
59
+ server_default=sa.func.now(),
60
+ ),
61
+ )
62
+ op.create_table(
63
+ "dataset_examples",
64
+ sa.Column("id", sa.Integer, primary_key=True),
65
+ sa.Column(
66
+ "dataset_id",
67
+ sa.Integer,
68
+ sa.ForeignKey("datasets.id", ondelete="CASCADE"),
69
+ nullable=False,
70
+ index=True,
71
+ ),
72
+ sa.Column(
73
+ "span_rowid",
74
+ sa.Integer,
75
+ sa.ForeignKey("spans.id"),
76
+ nullable=True,
77
+ index=True,
78
+ ),
79
+ sa.Column(
80
+ "created_at",
81
+ sa.TIMESTAMP(timezone=True),
82
+ nullable=False,
83
+ server_default=sa.func.now(),
84
+ ),
85
+ )
86
+ op.create_table(
87
+ "dataset_example_revisions",
88
+ sa.Column("id", sa.Integer, primary_key=True),
89
+ sa.Column(
90
+ "dataset_example_id",
91
+ sa.Integer,
92
+ sa.ForeignKey("dataset_examples.id", ondelete="CASCADE"),
93
+ nullable=False,
94
+ index=True,
95
+ ),
96
+ sa.Column(
97
+ "dataset_version_id",
98
+ sa.Integer,
99
+ sa.ForeignKey("dataset_versions.id", ondelete="CASCADE"),
100
+ nullable=False,
101
+ index=True,
102
+ ),
103
+ sa.Column("input", JSON_, nullable=False),
104
+ sa.Column("output", JSON_, nullable=False),
105
+ sa.Column("metadata", JSON_, nullable=False),
106
+ sa.Column(
107
+ "revision_kind",
108
+ sa.String,
109
+ sa.CheckConstraint(
110
+ "revision_kind IN ('CREATE', 'PATCH', 'DELETE')",
111
+ name="valid_revision_kind",
112
+ ),
113
+ nullable=False,
114
+ ),
115
+ sa.Column(
116
+ "created_at",
117
+ sa.TIMESTAMP(timezone=True),
118
+ nullable=False,
119
+ server_default=sa.func.now(),
120
+ ),
121
+ sa.UniqueConstraint(
122
+ "dataset_example_id",
123
+ "dataset_version_id",
124
+ ),
125
+ )
126
+ op.create_table(
127
+ "experiments",
128
+ sa.Column("id", sa.Integer, primary_key=True),
129
+ sa.Column(
130
+ "dataset_id",
131
+ sa.Integer,
132
+ sa.ForeignKey("datasets.id", ondelete="CASCADE"),
133
+ nullable=False,
134
+ index=True,
135
+ ),
136
+ sa.Column(
137
+ "dataset_version_id",
138
+ sa.Integer,
139
+ sa.ForeignKey("dataset_versions.id", ondelete="CASCADE"),
140
+ nullable=False,
141
+ index=True,
142
+ ),
143
+ sa.Column(
144
+ "name",
145
+ sa.String,
146
+ nullable=False,
147
+ ),
148
+ sa.Column(
149
+ "description",
150
+ sa.String,
151
+ nullable=True,
152
+ ),
153
+ sa.Column(
154
+ "repetitions",
155
+ sa.Integer,
156
+ nullable=False,
157
+ ),
158
+ sa.Column("metadata", JSON_, nullable=False),
159
+ sa.Column("project_name", sa.String, index=True),
160
+ sa.Column(
161
+ "created_at",
162
+ sa.TIMESTAMP(timezone=True),
163
+ nullable=False,
164
+ server_default=sa.func.now(),
165
+ ),
166
+ sa.Column(
167
+ "updated_at",
168
+ sa.TIMESTAMP(timezone=True),
169
+ nullable=False,
170
+ server_default=sa.func.now(),
171
+ onupdate=sa.func.now(),
172
+ ),
173
+ )
174
+ op.create_table(
175
+ "experiment_runs",
176
+ sa.Column("id", sa.Integer, primary_key=True),
177
+ sa.Column(
178
+ "experiment_id",
179
+ sa.Integer,
180
+ sa.ForeignKey("experiments.id", ondelete="CASCADE"),
181
+ nullable=False,
182
+ index=True,
183
+ ),
184
+ sa.Column(
185
+ "dataset_example_id",
186
+ sa.Integer,
187
+ sa.ForeignKey("dataset_examples.id", ondelete="CASCADE"),
188
+ nullable=False,
189
+ index=True,
190
+ ),
191
+ sa.Column(
192
+ "repetition_number",
193
+ sa.Integer,
194
+ nullable=False,
195
+ ),
196
+ sa.Column(
197
+ "trace_id",
198
+ sa.String,
199
+ nullable=True,
200
+ ),
201
+ sa.Column("output", JSON_, nullable=True),
202
+ sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False),
203
+ sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False),
204
+ sa.Column(
205
+ "prompt_token_count",
206
+ sa.Integer,
207
+ nullable=True,
208
+ ),
209
+ sa.Column(
210
+ "completion_token_count",
211
+ sa.Integer,
212
+ nullable=True,
213
+ ),
214
+ sa.Column(
215
+ "error",
216
+ sa.String,
217
+ nullable=True,
218
+ ),
219
+ sa.UniqueConstraint(
220
+ "experiment_id",
221
+ "dataset_example_id",
222
+ "repetition_number",
223
+ ),
224
+ )
225
+ op.create_table(
226
+ "experiment_run_annotations",
227
+ sa.Column("id", sa.Integer, primary_key=True),
228
+ sa.Column(
229
+ "experiment_run_id",
230
+ sa.Integer,
231
+ sa.ForeignKey("experiment_runs.id", ondelete="CASCADE"),
232
+ nullable=False,
233
+ index=True,
234
+ ),
235
+ sa.Column(
236
+ "name",
237
+ sa.String,
238
+ nullable=False,
239
+ ),
240
+ sa.Column(
241
+ "annotator_kind",
242
+ sa.String,
243
+ sa.CheckConstraint(
244
+ "annotator_kind IN ('LLM', 'CODE', 'HUMAN')",
245
+ name="valid_annotator_kind",
246
+ ),
247
+ nullable=False,
248
+ ),
249
+ sa.Column(
250
+ "label",
251
+ sa.String,
252
+ nullable=True,
253
+ ),
254
+ sa.Column(
255
+ "score",
256
+ sa.Float,
257
+ nullable=True,
258
+ ),
259
+ sa.Column(
260
+ "explanation",
261
+ sa.String,
262
+ nullable=True,
263
+ ),
264
+ sa.Column(
265
+ "trace_id",
266
+ sa.String,
267
+ nullable=True,
268
+ ),
269
+ sa.Column(
270
+ "error",
271
+ sa.String,
272
+ nullable=True,
273
+ ),
274
+ sa.Column("metadata", JSON_, nullable=False),
275
+ sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False),
276
+ sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False),
277
+ sa.UniqueConstraint(
278
+ "experiment_run_id",
279
+ "name",
280
+ ),
281
+ )
282
+
283
+
284
+ def downgrade() -> None:
285
+ op.drop_table("experiment_run_annotations")
286
+ op.drop_table("experiment_runs")
287
+ op.drop_table("experiments")
288
+ op.drop_table("dataset_example_revisions")
289
+ op.drop_table("dataset_examples")
290
+ op.drop_table("dataset_versions")
291
+ op.drop_table("datasets")
@@ -6,13 +6,11 @@ Create Date: 2024-04-03 19:41:48.871555
6
6
 
7
7
  """
8
8
 
9
- from typing import Any, Sequence, Union
9
+ from typing import Sequence, Union
10
10
 
11
11
  import sqlalchemy as sa
12
12
  from alembic import op
13
- from sqlalchemy import JSON
14
- from sqlalchemy.dialects import postgresql
15
- from sqlalchemy.ext.compiler import compiles
13
+ from phoenix.db.migrations.types import JSON_
16
14
 
17
15
  # revision identifiers, used by Alembic.
18
16
  revision: str = "cf03bd6bae1d"
@@ -21,30 +19,6 @@ branch_labels: Union[str, Sequence[str], None] = None
21
19
  depends_on: Union[str, Sequence[str], None] = None
22
20
 
23
21
 
24
- class JSONB(JSON):
25
- # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
26
- __visit_name__ = "JSONB"
27
-
28
-
29
- @compiles(JSONB, "sqlite") # type: ignore
30
- def _(*args: Any, **kwargs: Any) -> str:
31
- # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
32
- return "JSONB"
33
-
34
-
35
- JSON_ = (
36
- JSON()
37
- .with_variant(
38
- postgresql.JSONB(), # type: ignore
39
- "postgresql",
40
- )
41
- .with_variant(
42
- JSONB(),
43
- "sqlite",
44
- )
45
- )
46
-
47
-
48
22
  def upgrade() -> None:
49
23
  projects_table = op.create_table(
50
24
  "projects",
phoenix/db/models.py CHANGED
@@ -15,12 +15,14 @@ from sqlalchemy import (
15
15
  String,
16
16
  TypeDecorator,
17
17
  UniqueConstraint,
18
+ case,
18
19
  func,
19
20
  insert,
21
+ select,
20
22
  text,
21
23
  )
22
24
  from sqlalchemy.dialects import postgresql
23
- from sqlalchemy.ext.asyncio import AsyncEngine
25
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
24
26
  from sqlalchemy.ext.compiler import compiles
25
27
  from sqlalchemy.ext.hybrid import hybrid_property
26
28
  from sqlalchemy.orm import (
@@ -59,6 +61,24 @@ JSON_ = (
59
61
  )
60
62
 
61
63
 
64
+ class JsonDict(TypeDecorator[Dict[str, Any]]):
65
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
66
+ cache_ok = True
67
+ impl = JSON_
68
+
69
+ def process_bind_param(self, value: Optional[Dict[str, Any]], _: Dialect) -> Dict[str, Any]:
70
+ return value if isinstance(value, dict) else {}
71
+
72
+
73
+ class JsonList(TypeDecorator[List[Any]]):
74
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
75
+ cache_ok = True
76
+ impl = JSON_
77
+
78
+ def process_bind_param(self, value: Optional[List[Any]], _: Dialect) -> List[Any]:
79
+ return value if isinstance(value, list) else []
80
+
81
+
62
82
  class UtcTimeStamp(TypeDecorator[datetime]):
63
83
  # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
64
84
  cache_ok = True
@@ -84,8 +104,8 @@ class Base(DeclarativeBase):
84
104
  }
85
105
  )
86
106
  type_annotation_map = {
87
- Dict[str, Any]: JSON_,
88
- List[Dict[str, Any]]: JSON_,
107
+ Dict[str, Any]: JsonDict,
108
+ List[Dict[str, Any]]: JsonList,
89
109
  }
90
110
 
91
111
 
@@ -154,6 +174,10 @@ class Trace(Base):
154
174
  cascade="all, delete-orphan",
155
175
  uselist=True,
156
176
  )
177
+ experiment_runs: Mapped[List["ExperimentRun"]] = relationship(
178
+ primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
179
+ back_populates="trace",
180
+ )
157
181
  __table_args__ = (
158
182
  UniqueConstraint(
159
183
  "trace_id",
@@ -203,6 +227,7 @@ class Span(Base):
203
227
 
204
228
  trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
205
229
  document_annotations: Mapped[List["DocumentAnnotation"]] = relationship(back_populates="span")
230
+ dataset_examples: Mapped[List["DatasetExample"]] = relationship(back_populates="span")
206
231
 
207
232
  __table_args__ = (
208
233
  UniqueConstraint(
@@ -376,3 +401,205 @@ class DocumentAnnotation(Base):
376
401
  "document_position",
377
402
  ),
378
403
  )
404
+
405
+
406
+ class Dataset(Base):
407
+ __tablename__ = "datasets"
408
+ id: Mapped[int] = mapped_column(primary_key=True)
409
+ name: Mapped[str] = mapped_column(unique=True)
410
+ description: Mapped[Optional[str]]
411
+ metadata_: Mapped[Dict[str, Any]] = mapped_column("metadata")
412
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
413
+ updated_at: Mapped[datetime] = mapped_column(
414
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
415
+ )
416
+
417
+ @hybrid_property
418
+ def example_count(self) -> Optional[int]:
419
+ if hasattr(self, "_example_count_value"):
420
+ assert isinstance(self._example_count_value, int)
421
+ return self._example_count_value
422
+ return None
423
+
424
+ @example_count.inplace.expression
425
+ def _example_count(cls) -> ColumnElement[int]:
426
+ return (
427
+ select(
428
+ func.sum(
429
+ case(
430
+ (DatasetExampleRevision.revision_kind == "CREATE", 1),
431
+ (DatasetExampleRevision.revision_kind == "DELETE", -1),
432
+ else_=0,
433
+ )
434
+ )
435
+ )
436
+ .select_from(DatasetExampleRevision)
437
+ .join(
438
+ DatasetExample,
439
+ onclause=DatasetExample.id == DatasetExampleRevision.dataset_example_id,
440
+ )
441
+ .filter(DatasetExample.dataset_id == cls.id)
442
+ .label("example_count")
443
+ )
444
+
445
+ async def load_example_count(self, session: AsyncSession) -> None:
446
+ if not hasattr(self, "_example_count_value"):
447
+ self._example_count_value = await session.scalar(
448
+ select(
449
+ func.sum(
450
+ case(
451
+ (DatasetExampleRevision.revision_kind == "CREATE", 1),
452
+ (DatasetExampleRevision.revision_kind == "DELETE", -1),
453
+ else_=0,
454
+ )
455
+ )
456
+ )
457
+ .select_from(DatasetExampleRevision)
458
+ .join(
459
+ DatasetExample,
460
+ onclause=DatasetExample.id == DatasetExampleRevision.dataset_example_id,
461
+ )
462
+ .filter(DatasetExample.dataset_id == self.id)
463
+ )
464
+
465
+
466
+ class DatasetVersion(Base):
467
+ __tablename__ = "dataset_versions"
468
+ id: Mapped[int] = mapped_column(primary_key=True)
469
+ dataset_id: Mapped[int] = mapped_column(
470
+ ForeignKey("datasets.id", ondelete="CASCADE"),
471
+ index=True,
472
+ )
473
+ description: Mapped[Optional[str]]
474
+ metadata_: Mapped[Dict[str, Any]] = mapped_column("metadata")
475
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
476
+
477
+
478
+ class DatasetExample(Base):
479
+ __tablename__ = "dataset_examples"
480
+ id: Mapped[int] = mapped_column(primary_key=True)
481
+ dataset_id: Mapped[int] = mapped_column(
482
+ ForeignKey("datasets.id", ondelete="CASCADE"),
483
+ index=True,
484
+ )
485
+ span_rowid: Mapped[Optional[int]] = mapped_column(
486
+ ForeignKey("spans.id"),
487
+ index=True,
488
+ nullable=True,
489
+ )
490
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
491
+
492
+ span: Mapped[Optional[Span]] = relationship(back_populates="dataset_examples")
493
+
494
+
495
+ class DatasetExampleRevision(Base):
496
+ __tablename__ = "dataset_example_revisions"
497
+ id: Mapped[int] = mapped_column(primary_key=True)
498
+ dataset_example_id: Mapped[int] = mapped_column(
499
+ ForeignKey("dataset_examples.id", ondelete="CASCADE"),
500
+ index=True,
501
+ )
502
+ dataset_version_id: Mapped[int] = mapped_column(
503
+ ForeignKey("dataset_versions.id", ondelete="CASCADE"),
504
+ index=True,
505
+ )
506
+ input: Mapped[Dict[str, Any]]
507
+ output: Mapped[Dict[str, Any]]
508
+ metadata_: Mapped[Dict[str, Any]] = mapped_column("metadata")
509
+ revision_kind: Mapped[str] = mapped_column(
510
+ CheckConstraint(
511
+ "revision_kind IN ('CREATE', 'PATCH', 'DELETE')", name="valid_revision_kind"
512
+ ),
513
+ )
514
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
515
+
516
+ __table_args__ = (
517
+ UniqueConstraint(
518
+ "dataset_example_id",
519
+ "dataset_version_id",
520
+ ),
521
+ )
522
+
523
+
524
+ class Experiment(Base):
525
+ __tablename__ = "experiments"
526
+ id: Mapped[int] = mapped_column(primary_key=True)
527
+ dataset_id: Mapped[int] = mapped_column(
528
+ ForeignKey("datasets.id", ondelete="CASCADE"),
529
+ index=True,
530
+ )
531
+ dataset_version_id: Mapped[int] = mapped_column(
532
+ ForeignKey("dataset_versions.id", ondelete="CASCADE"),
533
+ index=True,
534
+ )
535
+ name: Mapped[str]
536
+ description: Mapped[Optional[str]]
537
+ repetitions: Mapped[int]
538
+ metadata_: Mapped[Dict[str, Any]] = mapped_column("metadata")
539
+ project_name: Mapped[Optional[str]] = mapped_column(index=True)
540
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
541
+ updated_at: Mapped[datetime] = mapped_column(
542
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
543
+ )
544
+
545
+
546
+ class ExperimentRun(Base):
547
+ __tablename__ = "experiment_runs"
548
+ id: Mapped[int] = mapped_column(primary_key=True)
549
+ experiment_id: Mapped[int] = mapped_column(
550
+ ForeignKey("experiments.id", ondelete="CASCADE"),
551
+ index=True,
552
+ )
553
+ dataset_example_id: Mapped[int] = mapped_column(
554
+ ForeignKey("dataset_examples.id", ondelete="CASCADE"),
555
+ index=True,
556
+ )
557
+ repetition_number: Mapped[int]
558
+ trace_id: Mapped[Optional[str]]
559
+ output: Mapped[Optional[Dict[str, Any]]]
560
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
561
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
562
+ prompt_token_count: Mapped[Optional[int]]
563
+ completion_token_count: Mapped[Optional[int]]
564
+ error: Mapped[Optional[str]]
565
+
566
+ trace: Mapped["Trace"] = relationship(
567
+ primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
568
+ back_populates="experiment_runs",
569
+ )
570
+
571
+ __table_args__ = (
572
+ UniqueConstraint(
573
+ "experiment_id",
574
+ "dataset_example_id",
575
+ "repetition_number",
576
+ ),
577
+ )
578
+
579
+
580
+ class ExperimentRunAnnotation(Base):
581
+ __tablename__ = "experiment_run_annotations"
582
+ id: Mapped[int] = mapped_column(primary_key=True)
583
+ experiment_run_id: Mapped[int] = mapped_column(
584
+ ForeignKey("experiment_runs.id", ondelete="CASCADE"),
585
+ index=True,
586
+ )
587
+ name: Mapped[str]
588
+ annotator_kind: Mapped[str] = mapped_column(
589
+ CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
590
+ )
591
+ label: Mapped[Optional[str]]
592
+ score: Mapped[Optional[float]]
593
+ explanation: Mapped[Optional[str]]
594
+ trace_id: Mapped[Optional[str]]
595
+ error: Mapped[Optional[str]]
596
+ metadata_: Mapped[Dict[str, Any]] = mapped_column("metadata")
597
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
598
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
599
+
600
+ __table_args__ = (
601
+ UniqueConstraint(
602
+ "experiment_run_id",
603
+ "name",
604
+ ),
605
+ )