arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +65 -210
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
- phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
- phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
- phoenix/version.py +1 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections.abc import AsyncIterable
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, cast
|
|
4
4
|
|
|
5
5
|
import strawberry
|
|
6
6
|
from sqlalchemy import Text, and_, func, or_, select
|
|
@@ -18,8 +18,8 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
|
18
18
|
from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
19
19
|
DatasetExperimentAnnotationSummary,
|
|
20
20
|
)
|
|
21
|
-
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
22
|
-
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
21
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
22
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
23
23
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
24
24
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
25
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -33,13 +33,77 @@ from phoenix.server.api.types.SortDir import SortDir
|
|
|
33
33
|
|
|
34
34
|
@strawberry.type
|
|
35
35
|
class Dataset(Node):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
36
|
+
id: NodeID[int]
|
|
37
|
+
db_record: strawberry.Private[Optional[models.Dataset]] = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
if self.db_record and self.id != self.db_record.id:
|
|
41
|
+
raise ValueError("Dataset ID mismatch")
|
|
42
|
+
|
|
43
|
+
@strawberry.field
|
|
44
|
+
async def name(
|
|
45
|
+
self,
|
|
46
|
+
info: Info[Context, None],
|
|
47
|
+
) -> str:
|
|
48
|
+
if self.db_record:
|
|
49
|
+
val = self.db_record.name
|
|
50
|
+
else:
|
|
51
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
52
|
+
(self.id, models.Dataset.name),
|
|
53
|
+
)
|
|
54
|
+
return val
|
|
55
|
+
|
|
56
|
+
@strawberry.field
|
|
57
|
+
async def description(
|
|
58
|
+
self,
|
|
59
|
+
info: Info[Context, None],
|
|
60
|
+
) -> Optional[str]:
|
|
61
|
+
if self.db_record:
|
|
62
|
+
val = self.db_record.description
|
|
63
|
+
else:
|
|
64
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
65
|
+
(self.id, models.Dataset.description),
|
|
66
|
+
)
|
|
67
|
+
return val
|
|
68
|
+
|
|
69
|
+
@strawberry.field
|
|
70
|
+
async def metadata(
|
|
71
|
+
self,
|
|
72
|
+
info: Info[Context, None],
|
|
73
|
+
) -> JSON:
|
|
74
|
+
if self.db_record:
|
|
75
|
+
val = self.db_record.metadata_
|
|
76
|
+
else:
|
|
77
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
78
|
+
(self.id, models.Dataset.metadata_),
|
|
79
|
+
)
|
|
80
|
+
return val
|
|
81
|
+
|
|
82
|
+
@strawberry.field
|
|
83
|
+
async def created_at(
|
|
84
|
+
self,
|
|
85
|
+
info: Info[Context, None],
|
|
86
|
+
) -> datetime:
|
|
87
|
+
if self.db_record:
|
|
88
|
+
val = self.db_record.created_at
|
|
89
|
+
else:
|
|
90
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
91
|
+
(self.id, models.Dataset.created_at),
|
|
92
|
+
)
|
|
93
|
+
return val
|
|
94
|
+
|
|
95
|
+
@strawberry.field
|
|
96
|
+
async def updated_at(
|
|
97
|
+
self,
|
|
98
|
+
info: Info[Context, None],
|
|
99
|
+
) -> datetime:
|
|
100
|
+
if self.db_record:
|
|
101
|
+
val = self.db_record.updated_at
|
|
102
|
+
else:
|
|
103
|
+
val = await info.context.data_loaders.dataset_fields.load(
|
|
104
|
+
(self.id, models.Dataset.updated_at),
|
|
105
|
+
)
|
|
106
|
+
return val
|
|
43
107
|
|
|
44
108
|
@strawberry.field
|
|
45
109
|
async def versions(
|
|
@@ -58,7 +122,7 @@ class Dataset(Node):
|
|
|
58
122
|
before=before if isinstance(before, CursorString) else None,
|
|
59
123
|
)
|
|
60
124
|
async with info.context.db() as session:
|
|
61
|
-
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.
|
|
125
|
+
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id)
|
|
62
126
|
if sort:
|
|
63
127
|
# For now assume the the column names match 1:1 with the enum values
|
|
64
128
|
sort_col = getattr(models.DatasetVersion, sort.col.value)
|
|
@@ -69,15 +133,7 @@ class Dataset(Node):
|
|
|
69
133
|
else:
|
|
70
134
|
stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
|
|
71
135
|
versions = await session.scalars(stmt)
|
|
72
|
-
data = [
|
|
73
|
-
DatasetVersion(
|
|
74
|
-
id_attr=version.id,
|
|
75
|
-
description=version.description,
|
|
76
|
-
metadata=version.metadata_,
|
|
77
|
-
created_at=version.created_at,
|
|
78
|
-
)
|
|
79
|
-
for version in versions
|
|
80
|
-
]
|
|
136
|
+
data = [DatasetVersion(id=version.id, db_record=version) for version in versions]
|
|
81
137
|
return connection_from_list(data=data, args=args)
|
|
82
138
|
|
|
83
139
|
@strawberry.field(
|
|
@@ -90,7 +146,7 @@ class Dataset(Node):
|
|
|
90
146
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
91
147
|
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
92
148
|
) -> int:
|
|
93
|
-
dataset_id = self.
|
|
149
|
+
dataset_id = self.id
|
|
94
150
|
version_id = (
|
|
95
151
|
from_global_id_with_expected_type(
|
|
96
152
|
global_id=dataset_version_id,
|
|
@@ -180,7 +236,7 @@ class Dataset(Node):
|
|
|
180
236
|
last=last,
|
|
181
237
|
before=before if isinstance(before, CursorString) else None,
|
|
182
238
|
)
|
|
183
|
-
dataset_id = self.
|
|
239
|
+
dataset_id = self.id
|
|
184
240
|
version_id = (
|
|
185
241
|
from_global_id_with_expected_type(
|
|
186
242
|
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
@@ -261,9 +317,9 @@ class Dataset(Node):
|
|
|
261
317
|
async with info.context.db() as session:
|
|
262
318
|
dataset_examples = [
|
|
263
319
|
DatasetExample(
|
|
264
|
-
|
|
320
|
+
id=example.id,
|
|
321
|
+
db_record=example,
|
|
265
322
|
version_id=version_id,
|
|
266
|
-
created_at=example.created_at,
|
|
267
323
|
)
|
|
268
324
|
async for example in await session.stream_scalars(query)
|
|
269
325
|
]
|
|
@@ -272,8 +328,8 @@ class Dataset(Node):
|
|
|
272
328
|
@strawberry.field
|
|
273
329
|
async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
|
|
274
330
|
return [
|
|
275
|
-
|
|
276
|
-
for split in await info.context.data_loaders.dataset_dataset_splits.load(self.
|
|
331
|
+
DatasetSplit(id=split.id, db_record=split)
|
|
332
|
+
for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id)
|
|
277
333
|
]
|
|
278
334
|
|
|
279
335
|
@strawberry.field(
|
|
@@ -285,9 +341,7 @@ class Dataset(Node):
|
|
|
285
341
|
info: Info[Context, None],
|
|
286
342
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
287
343
|
) -> int:
|
|
288
|
-
stmt = select(count(models.Experiment.id)).where(
|
|
289
|
-
models.Experiment.dataset_id == self.id_attr
|
|
290
|
-
)
|
|
344
|
+
stmt = select(count(models.Experiment.id)).where(models.Experiment.dataset_id == self.id)
|
|
291
345
|
version_id = (
|
|
292
346
|
from_global_id_with_expected_type(
|
|
293
347
|
global_id=dataset_version_id,
|
|
@@ -320,7 +374,7 @@ class Dataset(Node):
|
|
|
320
374
|
last=last,
|
|
321
375
|
before=before if isinstance(before, CursorString) else None,
|
|
322
376
|
)
|
|
323
|
-
dataset_id = self.
|
|
377
|
+
dataset_id = self.id
|
|
324
378
|
row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
|
|
325
379
|
query = (
|
|
326
380
|
select(models.Experiment, row_number)
|
|
@@ -363,7 +417,7 @@ class Dataset(Node):
|
|
|
363
417
|
async def experiment_annotation_summaries(
|
|
364
418
|
self, info: Info[Context, None]
|
|
365
419
|
) -> list[DatasetExperimentAnnotationSummary]:
|
|
366
|
-
dataset_id = self.
|
|
420
|
+
dataset_id = self.id
|
|
367
421
|
query = (
|
|
368
422
|
select(
|
|
369
423
|
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
@@ -396,24 +450,10 @@ class Dataset(Node):
|
|
|
396
450
|
@strawberry.field
|
|
397
451
|
async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
|
|
398
452
|
return [
|
|
399
|
-
|
|
400
|
-
for label in await info.context.data_loaders.dataset_labels.load(self.
|
|
453
|
+
DatasetLabel(id=label.id, db_record=label)
|
|
454
|
+
for label in await info.context.data_loaders.dataset_labels.load(self.id)
|
|
401
455
|
]
|
|
402
456
|
|
|
403
457
|
@strawberry.field
|
|
404
458
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
405
|
-
return info.context.last_updated_at.get(
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def to_gql_dataset(dataset: models.Dataset) -> Dataset:
|
|
409
|
-
"""
|
|
410
|
-
Converts an ORM dataset to a GraphQL dataset.
|
|
411
|
-
"""
|
|
412
|
-
return Dataset(
|
|
413
|
-
id_attr=dataset.id,
|
|
414
|
-
name=dataset.name,
|
|
415
|
-
description=dataset.description,
|
|
416
|
-
metadata=dataset.metadata_,
|
|
417
|
-
created_at=dataset.created_at,
|
|
418
|
-
updated_at=dataset.updated_at,
|
|
419
|
-
)
|
|
459
|
+
return info.context.last_updated_at.get(models.Dataset, self.id)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
6
|
-
from sqlalchemy.orm import joinedload
|
|
7
6
|
from strawberry import UNSET
|
|
8
7
|
from strawberry.relay.types import Connection, GlobalID, Node, NodeID
|
|
9
8
|
from strawberry.types import Info
|
|
@@ -12,34 +11,49 @@ from phoenix.db import models
|
|
|
12
11
|
from phoenix.server.api.context import Context
|
|
13
12
|
from phoenix.server.api.exceptions import BadRequest
|
|
14
13
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
15
|
-
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
14
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
16
15
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
17
16
|
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
18
17
|
ExperimentRepeatedRunGroup,
|
|
19
18
|
)
|
|
20
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
19
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
21
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
21
|
from phoenix.server.api.types.pagination import (
|
|
23
22
|
ConnectionArgs,
|
|
24
23
|
CursorString,
|
|
25
24
|
connection_from_list,
|
|
26
25
|
)
|
|
27
|
-
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .Span import Span
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
@strawberry.type
|
|
31
32
|
class DatasetExample(Node):
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
id: NodeID[int]
|
|
34
|
+
db_record: strawberry.Private[Optional[models.DatasetExample]] = None
|
|
34
35
|
version_id: strawberry.Private[Optional[int]] = None
|
|
35
36
|
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if self.db_record and self.id != self.db_record.id:
|
|
39
|
+
raise ValueError("DatasetExample ID mismatch")
|
|
40
|
+
|
|
41
|
+
@strawberry.field
|
|
42
|
+
async def created_at(self, info: Info[Context, None]) -> datetime:
|
|
43
|
+
if self.db_record:
|
|
44
|
+
val = self.db_record.created_at
|
|
45
|
+
else:
|
|
46
|
+
val = await info.context.data_loaders.dataset_example_fields.load(
|
|
47
|
+
(self.id, models.DatasetExample.created_at),
|
|
48
|
+
)
|
|
49
|
+
return val
|
|
50
|
+
|
|
36
51
|
@strawberry.field
|
|
37
52
|
async def revision(
|
|
38
53
|
self,
|
|
39
54
|
info: Info[Context, None],
|
|
40
55
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
41
56
|
) -> DatasetExampleRevision:
|
|
42
|
-
example_id = self.id_attr
|
|
43
57
|
version_id: Optional[int] = None
|
|
44
58
|
if dataset_version_id:
|
|
45
59
|
version_id = from_global_id_with_expected_type(
|
|
@@ -47,18 +61,18 @@ class DatasetExample(Node):
|
|
|
47
61
|
)
|
|
48
62
|
elif self.version_id is not None:
|
|
49
63
|
version_id = self.version_id
|
|
50
|
-
return await info.context.data_loaders.dataset_example_revisions.load(
|
|
51
|
-
(example_id, version_id)
|
|
52
|
-
)
|
|
64
|
+
return await info.context.data_loaders.dataset_example_revisions.load((self.id, version_id))
|
|
53
65
|
|
|
54
66
|
@strawberry.field
|
|
55
67
|
async def span(
|
|
56
68
|
self,
|
|
57
69
|
info: Info[Context, None],
|
|
58
|
-
) -> Optional[Span]:
|
|
70
|
+
) -> Optional[Annotated["Span", strawberry.lazy(".Span")]]:
|
|
71
|
+
from .Span import Span
|
|
72
|
+
|
|
59
73
|
return (
|
|
60
|
-
Span(
|
|
61
|
-
if (span := await info.context.data_loaders.dataset_example_spans.load(self.
|
|
74
|
+
Span(id=span.id, db_record=span)
|
|
75
|
+
if (span := await info.context.data_loaders.dataset_example_spans.load(self.id))
|
|
62
76
|
else None
|
|
63
77
|
)
|
|
64
78
|
|
|
@@ -78,12 +92,10 @@ class DatasetExample(Node):
|
|
|
78
92
|
last=last,
|
|
79
93
|
before=before if isinstance(before, CursorString) else None,
|
|
80
94
|
)
|
|
81
|
-
example_id = self.id_attr
|
|
82
95
|
query = (
|
|
83
96
|
select(models.ExperimentRun)
|
|
84
|
-
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
85
97
|
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
86
|
-
.where(models.ExperimentRun.dataset_example_id ==
|
|
98
|
+
.where(models.ExperimentRun.dataset_example_id == self.id)
|
|
87
99
|
.order_by(
|
|
88
100
|
models.ExperimentRun.experiment_id.asc(),
|
|
89
101
|
models.ExperimentRun.repetition_number.asc(),
|
|
@@ -100,7 +112,7 @@ class DatasetExample(Node):
|
|
|
100
112
|
query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
|
|
101
113
|
async with info.context.db() as session:
|
|
102
114
|
runs = (await session.scalars(query)).all()
|
|
103
|
-
return connection_from_list([
|
|
115
|
+
return connection_from_list([ExperimentRun(id=run.id, db_record=run) for run in runs], args)
|
|
104
116
|
|
|
105
117
|
@strawberry.field
|
|
106
118
|
async def experiment_repeated_run_groups(
|
|
@@ -108,7 +120,6 @@ class DatasetExample(Node):
|
|
|
108
120
|
info: Info[Context, None],
|
|
109
121
|
experiment_ids: list[GlobalID],
|
|
110
122
|
) -> list[ExperimentRepeatedRunGroup]:
|
|
111
|
-
example_rowid = self.id_attr
|
|
112
123
|
experiment_rowids = []
|
|
113
124
|
for experiment_id in experiment_ids:
|
|
114
125
|
try:
|
|
@@ -121,14 +132,14 @@ class DatasetExample(Node):
|
|
|
121
132
|
experiment_rowids.append(experiment_rowid)
|
|
122
133
|
repeated_run_groups = (
|
|
123
134
|
await info.context.data_loaders.experiment_repeated_run_groups.load_many(
|
|
124
|
-
[(experiment_rowid,
|
|
135
|
+
[(experiment_rowid, self.id) for experiment_rowid in experiment_rowids]
|
|
125
136
|
)
|
|
126
137
|
)
|
|
127
138
|
return [
|
|
128
139
|
ExperimentRepeatedRunGroup(
|
|
129
140
|
experiment_rowid=group.experiment_rowid,
|
|
130
141
|
dataset_example_rowid=group.dataset_example_rowid,
|
|
131
|
-
|
|
142
|
+
cached_runs=[ExperimentRun(id=run.id, db_record=run) for run in group.runs],
|
|
132
143
|
)
|
|
133
144
|
for group in repeated_run_groups
|
|
134
145
|
]
|
|
@@ -139,13 +150,6 @@ class DatasetExample(Node):
|
|
|
139
150
|
info: Info[Context, None],
|
|
140
151
|
) -> list[DatasetSplit]:
|
|
141
152
|
return [
|
|
142
|
-
|
|
143
|
-
for split in await info.context.data_loaders.dataset_example_splits.load(self.
|
|
153
|
+
DatasetSplit(id=split.id, db_record=split)
|
|
154
|
+
for split in await info.context.data_loaders.dataset_example_splits.load(self.id)
|
|
144
155
|
]
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
|
|
148
|
-
return DatasetExample(
|
|
149
|
-
id_attr=example.id,
|
|
150
|
-
created_at=example.created_at,
|
|
151
|
-
)
|
|
@@ -2,22 +2,56 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from strawberry.relay import Node, NodeID
|
|
5
|
+
from strawberry.types import Info
|
|
5
6
|
|
|
6
7
|
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
@strawberry.type
|
|
10
12
|
class DatasetLabel(Node):
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
13
|
+
id: NodeID[int]
|
|
14
|
+
db_record: strawberry.Private[Optional[models.DatasetLabel]] = None
|
|
15
|
+
|
|
16
|
+
def __post_init__(self) -> None:
|
|
17
|
+
if self.db_record and self.id != self.db_record.id:
|
|
18
|
+
raise ValueError("DatasetLabel ID mismatch")
|
|
19
|
+
|
|
20
|
+
@strawberry.field
|
|
21
|
+
async def name(
|
|
22
|
+
self,
|
|
23
|
+
info: Info[Context, None],
|
|
24
|
+
) -> str:
|
|
25
|
+
if self.db_record:
|
|
26
|
+
val = self.db_record.name
|
|
27
|
+
else:
|
|
28
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
29
|
+
(self.id, models.DatasetLabel.name),
|
|
30
|
+
)
|
|
31
|
+
return val
|
|
32
|
+
|
|
33
|
+
@strawberry.field
|
|
34
|
+
async def description(
|
|
35
|
+
self,
|
|
36
|
+
info: Info[Context, None],
|
|
37
|
+
) -> Optional[str]:
|
|
38
|
+
if self.db_record:
|
|
39
|
+
val = self.db_record.description
|
|
40
|
+
else:
|
|
41
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
42
|
+
(self.id, models.DatasetLabel.description),
|
|
43
|
+
)
|
|
44
|
+
return val
|
|
45
|
+
|
|
46
|
+
@strawberry.field
|
|
47
|
+
async def color(
|
|
48
|
+
self,
|
|
49
|
+
info: Info[Context, None],
|
|
50
|
+
) -> str:
|
|
51
|
+
if self.db_record:
|
|
52
|
+
val = self.db_record.color
|
|
53
|
+
else:
|
|
54
|
+
val = await info.context.data_loaders.dataset_label_fields.load(
|
|
55
|
+
(self.id, models.DatasetLabel.color),
|
|
56
|
+
)
|
|
57
|
+
return val
|
|
@@ -1,32 +1,98 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from strawberry.relay import Node, NodeID
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
|
+
from strawberry.types import Info
|
|
7
8
|
|
|
8
9
|
from phoenix.db import models
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@strawberry.type
|
|
12
14
|
class DatasetSplit(Node):
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
15
|
+
id: NodeID[int]
|
|
16
|
+
db_record: strawberry.Private[Optional[models.DatasetSplit]] = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if self.db_record and self.id != self.db_record.id:
|
|
20
|
+
raise ValueError("DatasetSplit ID mismatch")
|
|
21
|
+
|
|
22
|
+
@strawberry.field
|
|
23
|
+
async def name(
|
|
24
|
+
self,
|
|
25
|
+
info: Info[Context, None],
|
|
26
|
+
) -> str:
|
|
27
|
+
if self.db_record:
|
|
28
|
+
val = self.db_record.name
|
|
29
|
+
else:
|
|
30
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
31
|
+
(self.id, models.DatasetSplit.name),
|
|
32
|
+
)
|
|
33
|
+
return val
|
|
34
|
+
|
|
35
|
+
@strawberry.field
|
|
36
|
+
async def description(
|
|
37
|
+
self,
|
|
38
|
+
info: Info[Context, None],
|
|
39
|
+
) -> Optional[str]:
|
|
40
|
+
if self.db_record:
|
|
41
|
+
val = self.db_record.description
|
|
42
|
+
else:
|
|
43
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
44
|
+
(self.id, models.DatasetSplit.description),
|
|
45
|
+
)
|
|
46
|
+
return val
|
|
47
|
+
|
|
48
|
+
@strawberry.field
|
|
49
|
+
async def metadata(
|
|
50
|
+
self,
|
|
51
|
+
info: Info[Context, None],
|
|
52
|
+
) -> JSON:
|
|
53
|
+
if self.db_record:
|
|
54
|
+
val = self.db_record.metadata_
|
|
55
|
+
else:
|
|
56
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
57
|
+
(self.id, models.DatasetSplit.metadata_),
|
|
58
|
+
)
|
|
59
|
+
return val
|
|
60
|
+
|
|
61
|
+
@strawberry.field
|
|
62
|
+
async def color(
|
|
63
|
+
self,
|
|
64
|
+
info: Info[Context, None],
|
|
65
|
+
) -> str:
|
|
66
|
+
if self.db_record:
|
|
67
|
+
val = self.db_record.color
|
|
68
|
+
else:
|
|
69
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
70
|
+
(self.id, models.DatasetSplit.color),
|
|
71
|
+
)
|
|
72
|
+
return val
|
|
73
|
+
|
|
74
|
+
@strawberry.field
|
|
75
|
+
async def created_at(
|
|
76
|
+
self,
|
|
77
|
+
info: Info[Context, None],
|
|
78
|
+
) -> datetime:
|
|
79
|
+
if self.db_record:
|
|
80
|
+
val = self.db_record.created_at
|
|
81
|
+
else:
|
|
82
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
83
|
+
(self.id, models.DatasetSplit.created_at),
|
|
84
|
+
)
|
|
85
|
+
return val
|
|
86
|
+
|
|
87
|
+
@strawberry.field
|
|
88
|
+
async def updated_at(
|
|
89
|
+
self,
|
|
90
|
+
info: Info[Context, None],
|
|
91
|
+
) -> datetime:
|
|
92
|
+
if self.db_record:
|
|
93
|
+
val = self.db_record.updated_at
|
|
94
|
+
else:
|
|
95
|
+
val = await info.context.data_loaders.dataset_split_fields.load(
|
|
96
|
+
(self.id, models.DatasetSplit.updated_at),
|
|
97
|
+
)
|
|
98
|
+
return val
|
|
@@ -4,11 +4,56 @@ from typing import Optional
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from strawberry.relay import Node, NodeID
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
|
+
from strawberry.types import Info
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
7
11
|
|
|
8
12
|
|
|
9
13
|
@strawberry.type
|
|
10
14
|
class DatasetVersion(Node):
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
+
id: NodeID[int]
|
|
16
|
+
db_record: strawberry.Private[Optional[models.DatasetVersion]] = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self) -> None:
|
|
19
|
+
if self.db_record and self.id != self.db_record.id:
|
|
20
|
+
raise ValueError("DatasetVersion ID mismatch")
|
|
21
|
+
|
|
22
|
+
@strawberry.field
|
|
23
|
+
async def description(
|
|
24
|
+
self,
|
|
25
|
+
info: Info[Context, None],
|
|
26
|
+
) -> Optional[str]:
|
|
27
|
+
if self.db_record:
|
|
28
|
+
val = self.db_record.description
|
|
29
|
+
else:
|
|
30
|
+
val = await info.context.data_loaders.dataset_version_fields.load(
|
|
31
|
+
(self.id, models.DatasetVersion.description),
|
|
32
|
+
)
|
|
33
|
+
return val
|
|
34
|
+
|
|
35
|
+
@strawberry.field
|
|
36
|
+
async def metadata(
|
|
37
|
+
self,
|
|
38
|
+
info: Info[Context, None],
|
|
39
|
+
) -> JSON:
|
|
40
|
+
if self.db_record:
|
|
41
|
+
val = self.db_record.metadata_
|
|
42
|
+
else:
|
|
43
|
+
val = await info.context.data_loaders.dataset_version_fields.load(
|
|
44
|
+
(self.id, models.DatasetVersion.metadata_),
|
|
45
|
+
)
|
|
46
|
+
return val
|
|
47
|
+
|
|
48
|
+
@strawberry.field
|
|
49
|
+
async def created_at(
|
|
50
|
+
self,
|
|
51
|
+
info: Info[Context, None],
|
|
52
|
+
) -> datetime:
|
|
53
|
+
if self.db_record:
|
|
54
|
+
val = self.db_record.created_at
|
|
55
|
+
else:
|
|
56
|
+
val = await info.context.data_loaders.dataset_version_fields.load(
|
|
57
|
+
(self.id, models.DatasetVersion.created_at),
|
|
58
|
+
)
|
|
59
|
+
return val
|