arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (118) hide show
  1. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
  2. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
  3. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -21
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/db/bulk_inserter.py +14 -54
  10. phoenix/db/insertion/evaluation.py +6 -6
  11. phoenix/db/insertion/helpers.py +2 -13
  12. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  13. phoenix/db/models.py +4 -236
  14. phoenix/inferences/fixtures.py +23 -23
  15. phoenix/inferences/inferences.py +7 -7
  16. phoenix/inferences/validation.py +1 -1
  17. phoenix/server/api/context.py +0 -18
  18. phoenix/server/api/dataloaders/__init__.py +0 -18
  19. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  20. phoenix/server/api/routers/v1/__init__.py +2 -77
  21. phoenix/server/api/routers/v1/evaluations.py +2 -4
  22. phoenix/server/api/routers/v1/spans.py +1 -3
  23. phoenix/server/api/routers/v1/traces.py +4 -1
  24. phoenix/server/api/schema.py +303 -2
  25. phoenix/server/api/types/Cluster.py +19 -19
  26. phoenix/server/api/types/Dataset.py +63 -282
  27. phoenix/server/api/types/DatasetRole.py +23 -0
  28. phoenix/server/api/types/Dimension.py +29 -30
  29. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  30. phoenix/server/api/types/Event.py +16 -16
  31. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  32. phoenix/server/api/types/Model.py +42 -43
  33. phoenix/server/api/types/Project.py +12 -26
  34. phoenix/server/api/types/Span.py +2 -79
  35. phoenix/server/api/types/TimeSeries.py +6 -6
  36. phoenix/server/api/types/Trace.py +4 -15
  37. phoenix/server/api/types/UMAPPoints.py +1 -1
  38. phoenix/server/api/types/node.py +111 -5
  39. phoenix/server/api/types/pagination.py +52 -10
  40. phoenix/server/app.py +49 -101
  41. phoenix/server/main.py +27 -49
  42. phoenix/server/openapi/docs.py +0 -3
  43. phoenix/server/static/index.js +2595 -3523
  44. phoenix/server/templates/index.html +0 -1
  45. phoenix/services.py +15 -15
  46. phoenix/session/client.py +21 -438
  47. phoenix/session/session.py +37 -47
  48. phoenix/trace/exporter.py +9 -14
  49. phoenix/trace/fixtures.py +7 -133
  50. phoenix/trace/schemas.py +2 -1
  51. phoenix/trace/span_evaluations.py +3 -3
  52. phoenix/trace/trace_dataset.py +6 -6
  53. phoenix/version.py +1 -1
  54. phoenix/datasets/__init__.py +0 -0
  55. phoenix/datasets/evaluators/__init__.py +0 -18
  56. phoenix/datasets/evaluators/code_evaluators.py +0 -99
  57. phoenix/datasets/evaluators/llm_evaluators.py +0 -244
  58. phoenix/datasets/evaluators/utils.py +0 -292
  59. phoenix/datasets/experiments.py +0 -550
  60. phoenix/datasets/tracing.py +0 -85
  61. phoenix/datasets/types.py +0 -178
  62. phoenix/db/insertion/dataset.py +0 -237
  63. phoenix/db/migrations/types.py +0 -29
  64. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  65. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  66. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  67. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  68. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  69. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  70. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  71. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  72. phoenix/server/api/dataloaders/span_projects.py +0 -33
  73. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  74. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  75. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  76. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  77. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  78. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  79. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  80. phoenix/server/api/input_types/DatasetSort.py +0 -17
  81. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  82. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  83. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  84. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  85. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  86. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  87. phoenix/server/api/mutations/__init__.py +0 -13
  88. phoenix/server/api/mutations/auth.py +0 -11
  89. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  90. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  91. phoenix/server/api/mutations/project_mutations.py +0 -47
  92. phoenix/server/api/openapi/__init__.py +0 -0
  93. phoenix/server/api/openapi/main.py +0 -6
  94. phoenix/server/api/openapi/schema.py +0 -16
  95. phoenix/server/api/queries.py +0 -503
  96. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  97. phoenix/server/api/routers/v1/datasets.py +0 -965
  98. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
  99. phoenix/server/api/routers/v1/experiment_runs.py +0 -108
  100. phoenix/server/api/routers/v1/experiments.py +0 -174
  101. phoenix/server/api/types/AnnotatorKind.py +0 -10
  102. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  103. phoenix/server/api/types/DatasetExample.py +0 -85
  104. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  105. phoenix/server/api/types/DatasetVersion.py +0 -14
  106. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  107. phoenix/server/api/types/Experiment.py +0 -140
  108. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  109. phoenix/server/api/types/ExperimentComparison.py +0 -19
  110. phoenix/server/api/types/ExperimentRun.py +0 -91
  111. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  112. phoenix/server/api/types/Inferences.py +0 -80
  113. phoenix/server/api/types/InferencesRole.py +0 -23
  114. phoenix/utilities/json.py +0 -61
  115. phoenix/utilities/re.py +0 -50
  116. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  117. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  118. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,291 +0,0 @@
1
- """datasets
2
-
3
- Revision ID: 10460e46d750
4
- Revises: cf03bd6bae1d
5
- Create Date: 2024-05-10 11:24:23.985834
6
-
7
- """
8
-
9
- from typing import Sequence, Union
10
-
11
- import sqlalchemy as sa
12
- from alembic import op
13
- from phoenix.db.migrations.types import JSON_
14
-
15
- # revision identifiers, used by Alembic.
16
- revision: str = "10460e46d750"
17
- down_revision: Union[str, None] = "cf03bd6bae1d"
18
- branch_labels: Union[str, Sequence[str], None] = None
19
- depends_on: Union[str, Sequence[str], None] = None
20
-
21
-
22
- def upgrade() -> None:
23
- op.create_table(
24
- "datasets",
25
- sa.Column("id", sa.Integer, primary_key=True),
26
- sa.Column("name", sa.String, nullable=False, unique=True),
27
- sa.Column("description", sa.String, nullable=True),
28
- sa.Column("metadata", JSON_, nullable=False),
29
- sa.Column(
30
- "created_at",
31
- sa.TIMESTAMP(timezone=True),
32
- nullable=False,
33
- server_default=sa.func.now(),
34
- ),
35
- sa.Column(
36
- "updated_at",
37
- sa.TIMESTAMP(timezone=True),
38
- nullable=False,
39
- server_default=sa.func.now(),
40
- onupdate=sa.func.now(),
41
- ),
42
- )
43
- op.create_table(
44
- "dataset_versions",
45
- sa.Column("id", sa.Integer, primary_key=True),
46
- sa.Column(
47
- "dataset_id",
48
- sa.Integer,
49
- sa.ForeignKey("datasets.id", ondelete="CASCADE"),
50
- nullable=False,
51
- index=True,
52
- ),
53
- sa.Column("description", sa.String, nullable=True),
54
- sa.Column("metadata", JSON_, nullable=False),
55
- sa.Column(
56
- "created_at",
57
- sa.TIMESTAMP(timezone=True),
58
- nullable=False,
59
- server_default=sa.func.now(),
60
- ),
61
- )
62
- op.create_table(
63
- "dataset_examples",
64
- sa.Column("id", sa.Integer, primary_key=True),
65
- sa.Column(
66
- "dataset_id",
67
- sa.Integer,
68
- sa.ForeignKey("datasets.id", ondelete="CASCADE"),
69
- nullable=False,
70
- index=True,
71
- ),
72
- sa.Column(
73
- "span_rowid",
74
- sa.Integer,
75
- sa.ForeignKey("spans.id", ondelete="SET NULL"),
76
- nullable=True,
77
- index=True,
78
- ),
79
- sa.Column(
80
- "created_at",
81
- sa.TIMESTAMP(timezone=True),
82
- nullable=False,
83
- server_default=sa.func.now(),
84
- ),
85
- )
86
- op.create_table(
87
- "dataset_example_revisions",
88
- sa.Column("id", sa.Integer, primary_key=True),
89
- sa.Column(
90
- "dataset_example_id",
91
- sa.Integer,
92
- sa.ForeignKey("dataset_examples.id", ondelete="CASCADE"),
93
- nullable=False,
94
- index=True,
95
- ),
96
- sa.Column(
97
- "dataset_version_id",
98
- sa.Integer,
99
- sa.ForeignKey("dataset_versions.id", ondelete="CASCADE"),
100
- nullable=False,
101
- index=True,
102
- ),
103
- sa.Column("input", JSON_, nullable=False),
104
- sa.Column("output", JSON_, nullable=False),
105
- sa.Column("metadata", JSON_, nullable=False),
106
- sa.Column(
107
- "revision_kind",
108
- sa.String,
109
- sa.CheckConstraint(
110
- "revision_kind IN ('CREATE', 'PATCH', 'DELETE')",
111
- name="valid_revision_kind",
112
- ),
113
- nullable=False,
114
- ),
115
- sa.Column(
116
- "created_at",
117
- sa.TIMESTAMP(timezone=True),
118
- nullable=False,
119
- server_default=sa.func.now(),
120
- ),
121
- sa.UniqueConstraint(
122
- "dataset_example_id",
123
- "dataset_version_id",
124
- ),
125
- )
126
- op.create_table(
127
- "experiments",
128
- sa.Column("id", sa.Integer, primary_key=True),
129
- sa.Column(
130
- "dataset_id",
131
- sa.Integer,
132
- sa.ForeignKey("datasets.id", ondelete="CASCADE"),
133
- nullable=False,
134
- index=True,
135
- ),
136
- sa.Column(
137
- "dataset_version_id",
138
- sa.Integer,
139
- sa.ForeignKey("dataset_versions.id", ondelete="CASCADE"),
140
- nullable=False,
141
- index=True,
142
- ),
143
- sa.Column(
144
- "name",
145
- sa.String,
146
- nullable=False,
147
- ),
148
- sa.Column(
149
- "description",
150
- sa.String,
151
- nullable=True,
152
- ),
153
- sa.Column(
154
- "repetitions",
155
- sa.Integer,
156
- nullable=False,
157
- ),
158
- sa.Column("metadata", JSON_, nullable=False),
159
- sa.Column("project_name", sa.String, index=True),
160
- sa.Column(
161
- "created_at",
162
- sa.TIMESTAMP(timezone=True),
163
- nullable=False,
164
- server_default=sa.func.now(),
165
- ),
166
- sa.Column(
167
- "updated_at",
168
- sa.TIMESTAMP(timezone=True),
169
- nullable=False,
170
- server_default=sa.func.now(),
171
- onupdate=sa.func.now(),
172
- ),
173
- )
174
- op.create_table(
175
- "experiment_runs",
176
- sa.Column("id", sa.Integer, primary_key=True),
177
- sa.Column(
178
- "experiment_id",
179
- sa.Integer,
180
- sa.ForeignKey("experiments.id", ondelete="CASCADE"),
181
- nullable=False,
182
- index=True,
183
- ),
184
- sa.Column(
185
- "dataset_example_id",
186
- sa.Integer,
187
- sa.ForeignKey("dataset_examples.id", ondelete="CASCADE"),
188
- nullable=False,
189
- index=True,
190
- ),
191
- sa.Column(
192
- "repetition_number",
193
- sa.Integer,
194
- nullable=False,
195
- ),
196
- sa.Column(
197
- "trace_id",
198
- sa.String,
199
- nullable=True,
200
- ),
201
- sa.Column("output", JSON_, nullable=False),
202
- sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False),
203
- sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False),
204
- sa.Column(
205
- "prompt_token_count",
206
- sa.Integer,
207
- nullable=True,
208
- ),
209
- sa.Column(
210
- "completion_token_count",
211
- sa.Integer,
212
- nullable=True,
213
- ),
214
- sa.Column(
215
- "error",
216
- sa.String,
217
- nullable=True,
218
- ),
219
- sa.UniqueConstraint(
220
- "experiment_id",
221
- "dataset_example_id",
222
- "repetition_number",
223
- ),
224
- )
225
- op.create_table(
226
- "experiment_run_annotations",
227
- sa.Column("id", sa.Integer, primary_key=True),
228
- sa.Column(
229
- "experiment_run_id",
230
- sa.Integer,
231
- sa.ForeignKey("experiment_runs.id", ondelete="CASCADE"),
232
- nullable=False,
233
- index=True,
234
- ),
235
- sa.Column(
236
- "name",
237
- sa.String,
238
- nullable=False,
239
- ),
240
- sa.Column(
241
- "annotator_kind",
242
- sa.String,
243
- sa.CheckConstraint(
244
- "annotator_kind IN ('LLM', 'CODE', 'HUMAN')",
245
- name="valid_annotator_kind",
246
- ),
247
- nullable=False,
248
- ),
249
- sa.Column(
250
- "label",
251
- sa.String,
252
- nullable=True,
253
- ),
254
- sa.Column(
255
- "score",
256
- sa.Float,
257
- nullable=True,
258
- ),
259
- sa.Column(
260
- "explanation",
261
- sa.String,
262
- nullable=True,
263
- ),
264
- sa.Column(
265
- "trace_id",
266
- sa.String,
267
- nullable=True,
268
- ),
269
- sa.Column(
270
- "error",
271
- sa.String,
272
- nullable=True,
273
- ),
274
- sa.Column("metadata", JSON_, nullable=False),
275
- sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False),
276
- sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False),
277
- sa.UniqueConstraint(
278
- "experiment_run_id",
279
- "name",
280
- ),
281
- )
282
-
283
-
284
- def downgrade() -> None:
285
- op.drop_table("experiment_run_annotations")
286
- op.drop_table("experiment_runs")
287
- op.drop_table("experiments")
288
- op.drop_table("dataset_example_revisions")
289
- op.drop_table("dataset_examples")
290
- op.drop_table("dataset_versions")
291
- op.drop_table("datasets")
@@ -1,100 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- Tuple,
7
- Union,
8
- )
9
-
10
- from sqlalchemy import Integer, case, func, literal, or_, select, union
11
- from sqlalchemy.ext.asyncio import AsyncSession
12
- from strawberry.dataloader import DataLoader
13
- from typing_extensions import TypeAlias
14
-
15
- from phoenix.db import models
16
- from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
17
-
18
- ExampleID: TypeAlias = int
19
- VersionID: TypeAlias = Optional[int]
20
- Key: TypeAlias = Tuple[ExampleID, Optional[VersionID]]
21
- Result: TypeAlias = DatasetExampleRevision
22
-
23
-
24
- class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
25
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
26
- super().__init__(load_fn=self._load_fn)
27
- self._db = db
28
-
29
- async def _load_fn(self, keys: List[Key]) -> List[Union[Result, ValueError]]:
30
- # sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
31
- # For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
32
- keys_subquery = union(
33
- *(
34
- select(
35
- literal(example_id, Integer).label("example_id"),
36
- literal(version_id, Integer).label("version_id"),
37
- )
38
- for example_id, version_id in keys
39
- )
40
- ).subquery()
41
- revision_ids = (
42
- select(
43
- keys_subquery.c.example_id,
44
- keys_subquery.c.version_id,
45
- func.max(models.DatasetExampleRevision.id).label("revision_id"),
46
- )
47
- .select_from(keys_subquery)
48
- .join(
49
- models.DatasetExampleRevision,
50
- onclause=keys_subquery.c.example_id
51
- == models.DatasetExampleRevision.dataset_example_id,
52
- )
53
- .where(
54
- or_(
55
- keys_subquery.c.version_id.is_(None),
56
- models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
57
- )
58
- )
59
- .group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
60
- ).subquery()
61
- query = (
62
- select(
63
- revision_ids.c.example_id,
64
- revision_ids.c.version_id,
65
- case(
66
- (
67
- or_(
68
- revision_ids.c.version_id.is_(None),
69
- models.DatasetVersion.id.is_not(None),
70
- ),
71
- True,
72
- ),
73
- else_=False,
74
- ).label("is_valid_version"), # check that non-null versions exist
75
- models.DatasetExampleRevision,
76
- )
77
- .select_from(revision_ids)
78
- .join(
79
- models.DatasetExampleRevision,
80
- onclause=revision_ids.c.revision_id == models.DatasetExampleRevision.id,
81
- )
82
- .join(
83
- models.DatasetVersion,
84
- onclause=revision_ids.c.version_id == models.DatasetVersion.id,
85
- isouter=True, # keep rows where the version id is null
86
- )
87
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
88
- )
89
- async with self._db() as session:
90
- results = {
91
- (example_id, version_id): DatasetExampleRevision.from_orm_revision(revision)
92
- async for (
93
- example_id,
94
- version_id,
95
- is_valid_version,
96
- revision,
97
- ) in await session.stream(query)
98
- if is_valid_version
99
- }
100
- return [results.get(key, ValueError("Could not find revision.")) for key in keys]
@@ -1,43 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from sqlalchemy.orm import joinedload
11
- from strawberry.dataloader import DataLoader
12
- from typing_extensions import TypeAlias
13
-
14
- from phoenix.db import models
15
-
16
- ExampleID: TypeAlias = int
17
- Key: TypeAlias = ExampleID
18
- Result: TypeAlias = Optional[models.Span]
19
-
20
-
21
- class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
- super().__init__(load_fn=self._load_fn)
24
- self._db = db
25
-
26
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
- example_ids = keys
28
- async with self._db() as session:
29
- spans = {
30
- example_id: span
31
- async for example_id, span in await session.stream(
32
- select(models.DatasetExample.id, models.Span)
33
- .select_from(models.DatasetExample)
34
- .join(models.Span, models.DatasetExample.span_rowid == models.Span.id)
35
- .where(models.DatasetExample.id.in_(example_ids))
36
- .options(
37
- joinedload(models.Span.trace, innerjoin=True).load_only(
38
- models.Trace.trace_id
39
- )
40
- )
41
- )
42
- }
43
- return [spans.get(example_id) for example_id in example_ids]
@@ -1,85 +0,0 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
3
- from typing import (
4
- AsyncContextManager,
5
- Callable,
6
- DefaultDict,
7
- List,
8
- Optional,
9
- )
10
-
11
- from sqlalchemy import func, select
12
- from sqlalchemy.ext.asyncio import AsyncSession
13
- from strawberry.dataloader import AbstractCache, DataLoader
14
- from typing_extensions import TypeAlias
15
-
16
- from phoenix.db import models
17
-
18
-
19
- @dataclass
20
- class ExperimentAnnotationSummary:
21
- annotation_name: str
22
- min_score: float
23
- max_score: float
24
- mean_score: float
25
- count: int
26
- error_count: int
27
-
28
-
29
- ExperimentID: TypeAlias = int
30
- Key: TypeAlias = ExperimentID
31
- Result: TypeAlias = List[ExperimentAnnotationSummary]
32
-
33
-
34
- class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
35
- def __init__(
36
- self,
37
- db: Callable[[], AsyncContextManager[AsyncSession]],
38
- cache_map: Optional[AbstractCache[Key, Result]] = None,
39
- ) -> None:
40
- super().__init__(load_fn=self._load_fn)
41
- self._db = db
42
-
43
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
44
- experiment_ids = keys
45
- summaries: DefaultDict[ExperimentID, Result] = defaultdict(list)
46
- async with self._db() as session:
47
- async for (
48
- experiment_id,
49
- annotation_name,
50
- min_score,
51
- max_score,
52
- mean_score,
53
- count,
54
- error_count,
55
- ) in await session.stream(
56
- select(
57
- models.ExperimentRun.experiment_id,
58
- models.ExperimentRunAnnotation.name,
59
- func.min(models.ExperimentRunAnnotation.score),
60
- func.max(models.ExperimentRunAnnotation.score),
61
- func.avg(models.ExperimentRunAnnotation.score),
62
- func.count(),
63
- func.count(models.ExperimentRunAnnotation.error),
64
- )
65
- .join(
66
- models.ExperimentRun,
67
- models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
68
- )
69
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
70
- .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
71
- ):
72
- summaries[experiment_id].append(
73
- ExperimentAnnotationSummary(
74
- annotation_name=annotation_name,
75
- min_score=min_score,
76
- max_score=max_score,
77
- mean_score=mean_score,
78
- count=count,
79
- error_count=error_count,
80
- )
81
- )
82
- return [
83
- sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
84
- for experiment_id in experiment_ids
85
- ]
@@ -1,43 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from strawberry.dataloader import DataLoader
11
- from typing_extensions import TypeAlias
12
-
13
- from phoenix.db import models
14
-
15
- ExperimentID: TypeAlias = int
16
- ErrorRate: TypeAlias = float
17
- Key: TypeAlias = ExperimentID
18
- Result: TypeAlias = Optional[ErrorRate]
19
-
20
-
21
- class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
22
- def __init__(
23
- self,
24
- db: Callable[[], AsyncContextManager[AsyncSession]],
25
- ) -> None:
26
- super().__init__(load_fn=self._load_fn)
27
- self._db = db
28
-
29
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
- experiment_ids = keys
31
- async with self._db() as session:
32
- error_rates = {
33
- experiment_id: error_rate
34
- async for experiment_id, error_rate in await session.stream(
35
- select(
36
- models.ExperimentRun.experiment_id,
37
- func.count(models.ExperimentRun.error) / func.count(),
38
- )
39
- .group_by(models.ExperimentRun.experiment_id)
40
- .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
41
- )
42
- }
43
- return [error_rates.get(experiment_id) for experiment_id in experiment_ids]
@@ -1,42 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- )
6
-
7
- from sqlalchemy import func, select
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
- from strawberry.dataloader import DataLoader
10
- from typing_extensions import TypeAlias
11
-
12
- from phoenix.db import models
13
-
14
- ExperimentID: TypeAlias = int
15
- RunCount: TypeAlias = int
16
- Key: TypeAlias = ExperimentID
17
- Result: TypeAlias = RunCount
18
-
19
-
20
- class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
- def __init__(
22
- self,
23
- db: Callable[[], AsyncContextManager[AsyncSession]],
24
- ) -> None:
25
- super().__init__(load_fn=self._load_fn)
26
- self._db = db
27
-
28
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
- experiment_ids = keys
30
- async with self._db() as session:
31
- run_counts = {
32
- experiment_id: run_count
33
- async for experiment_id, run_count in await session.stream(
34
- select(models.ExperimentRun.experiment_id, func.count())
35
- .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
- .group_by(models.ExperimentRun.experiment_id)
37
- )
38
- }
39
- return [
40
- run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
41
- for experiment_id in experiment_ids
42
- ]
@@ -1,49 +0,0 @@
1
- from typing import (
2
- AsyncContextManager,
3
- Callable,
4
- List,
5
- Optional,
6
- )
7
-
8
- from sqlalchemy import distinct, func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
- from strawberry.dataloader import DataLoader
11
- from typing_extensions import TypeAlias
12
-
13
- from phoenix.db import models
14
-
15
- ExperimentId: TypeAlias = int
16
- Key: TypeAlias = ExperimentId
17
- Result: TypeAlias = Optional[int]
18
-
19
-
20
- class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
22
- super().__init__(load_fn=self._load_fn)
23
- self._db = db
24
-
25
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
26
- experiment_ids = keys
27
- dataset_ids = (
28
- select(distinct(models.Experiment.dataset_id))
29
- .where(models.Experiment.id.in_(experiment_ids))
30
- .scalar_subquery()
31
- )
32
- row_number = (
33
- func.row_number().over(
34
- partition_by=models.Experiment.dataset_id,
35
- order_by=models.Experiment.id,
36
- )
37
- ).label("row_number")
38
- subq = (
39
- select(models.Experiment.id, row_number)
40
- .where(models.Experiment.dataset_id.in_(dataset_ids))
41
- .subquery()
42
- )
43
- stmt = select(subq).where(subq.c.id.in_(experiment_ids))
44
- async with self._db() as session:
45
- result = {
46
- experiment_id: sequence_number
47
- async for experiment_id, sequence_number in await session.stream(stmt)
48
- }
49
- return [result.get(experiment_id) for experiment_id in experiment_ids]
@@ -1,31 +0,0 @@
1
- from collections import defaultdict
2
- from typing import AsyncContextManager, Callable, DefaultDict, List, Optional
3
-
4
- from sqlalchemy import select
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
- from strawberry.dataloader import DataLoader
7
- from typing_extensions import TypeAlias
8
-
9
- from phoenix.db import models
10
-
11
- ProjectName: TypeAlias = str
12
- Key: TypeAlias = ProjectName
13
- Result: TypeAlias = Optional[models.Project]
14
-
15
-
16
- class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
18
- super().__init__(load_fn=self._load_fn)
19
- self._db = db
20
-
21
- async def _load_fn(self, keys: List[Key]) -> List[Result]:
22
- project_names = list(set(keys))
23
- projects_by_name: DefaultDict[Key, Result] = defaultdict(None)
24
- async with self._db() as session:
25
- data = await session.stream_scalars(
26
- select(models.Project).where(models.Project.name.in_(project_names))
27
- )
28
- async for project in data:
29
- projects_by_name[project.name] = project
30
-
31
- return [projects_by_name[project_name] for project_name in project_names]