arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (76) hide show
  1. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +80 -213
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
  63. phoenix/server/static/.vite/manifest.json +43 -43
  64. phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
  65. phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
  66. phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
  67. phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
  68. phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
  69. phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
  70. phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
  71. phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
  72. phoenix/version.py +1 -1
  73. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  74. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  75. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  76. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,25 +3,21 @@ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
 
6
- from phoenix.db.models import ApiKey as ORMApiKey
7
-
8
6
 
9
7
  @strawberry.interface
10
8
  class ApiKey:
11
- name: str = strawberry.field(description="Name of the API key.")
12
- description: Optional[str] = strawberry.field(description="Description of the API key.")
13
- created_at: datetime = strawberry.field(
14
- description="The date and time the API key was created."
15
- )
16
- expires_at: Optional[datetime] = strawberry.field(
17
- description="The date and time the API key will expire."
18
- )
9
+ @strawberry.field(description="Name of the API key.") # type: ignore
10
+ async def name(self) -> str:
11
+ raise NotImplementedError
12
+
13
+ @strawberry.field(description="Description of the API key.") # type: ignore
14
+ async def description(self) -> Optional[str]:
15
+ raise NotImplementedError
19
16
 
17
+ @strawberry.field(description="The date and time the API key was created.") # type: ignore
18
+ async def created_at(self) -> datetime:
19
+ raise NotImplementedError
20
20
 
21
- def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
22
- return ApiKey(
23
- name=api_key.name,
24
- description=api_key.description,
25
- created_at=api_key.created_at,
26
- expires_at=api_key.expires_at,
27
- )
21
+ @strawberry.field(description="The date and time the API key will expire.") # type: ignore
22
+ async def expires_at(self) -> Optional[datetime]:
23
+ raise NotImplementedError
@@ -1,6 +1,6 @@
1
1
  from collections.abc import AsyncIterable
2
2
  from datetime import datetime
3
- from typing import ClassVar, Optional, cast
3
+ from typing import Optional, cast
4
4
 
5
5
  import strawberry
6
6
  from sqlalchemy import Text, and_, func, or_, select
@@ -18,8 +18,8 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
18
18
  from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
19
  DatasetExperimentAnnotationSummary,
20
20
  )
21
- from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
22
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
21
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
22
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
23
23
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
24
24
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
25
25
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -33,13 +33,77 @@ from phoenix.server.api.types.SortDir import SortDir
33
33
 
34
34
  @strawberry.type
35
35
  class Dataset(Node):
36
- _table: ClassVar[type[models.Base]] = models.Experiment
37
- id_attr: NodeID[int]
38
- name: str
39
- description: Optional[str]
40
- metadata: JSON
41
- created_at: datetime
42
- updated_at: datetime
36
+ id: NodeID[int]
37
+ db_record: strawberry.Private[Optional[models.Dataset]] = None
38
+
39
+ def __post_init__(self) -> None:
40
+ if self.db_record and self.id != self.db_record.id:
41
+ raise ValueError("Dataset ID mismatch")
42
+
43
+ @strawberry.field
44
+ async def name(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> str:
48
+ if self.db_record:
49
+ val = self.db_record.name
50
+ else:
51
+ val = await info.context.data_loaders.dataset_fields.load(
52
+ (self.id, models.Dataset.name),
53
+ )
54
+ return val
55
+
56
+ @strawberry.field
57
+ async def description(
58
+ self,
59
+ info: Info[Context, None],
60
+ ) -> Optional[str]:
61
+ if self.db_record:
62
+ val = self.db_record.description
63
+ else:
64
+ val = await info.context.data_loaders.dataset_fields.load(
65
+ (self.id, models.Dataset.description),
66
+ )
67
+ return val
68
+
69
+ @strawberry.field
70
+ async def metadata(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> JSON:
74
+ if self.db_record:
75
+ val = self.db_record.metadata_
76
+ else:
77
+ val = await info.context.data_loaders.dataset_fields.load(
78
+ (self.id, models.Dataset.metadata_),
79
+ )
80
+ return val
81
+
82
+ @strawberry.field
83
+ async def created_at(
84
+ self,
85
+ info: Info[Context, None],
86
+ ) -> datetime:
87
+ if self.db_record:
88
+ val = self.db_record.created_at
89
+ else:
90
+ val = await info.context.data_loaders.dataset_fields.load(
91
+ (self.id, models.Dataset.created_at),
92
+ )
93
+ return val
94
+
95
+ @strawberry.field
96
+ async def updated_at(
97
+ self,
98
+ info: Info[Context, None],
99
+ ) -> datetime:
100
+ if self.db_record:
101
+ val = self.db_record.updated_at
102
+ else:
103
+ val = await info.context.data_loaders.dataset_fields.load(
104
+ (self.id, models.Dataset.updated_at),
105
+ )
106
+ return val
43
107
 
44
108
  @strawberry.field
45
109
  async def versions(
@@ -58,7 +122,7 @@ class Dataset(Node):
58
122
  before=before if isinstance(before, CursorString) else None,
59
123
  )
60
124
  async with info.context.db() as session:
61
- stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
125
+ stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id)
62
126
  if sort:
63
127
  # For now assume the the column names match 1:1 with the enum values
64
128
  sort_col = getattr(models.DatasetVersion, sort.col.value)
@@ -69,15 +133,7 @@ class Dataset(Node):
69
133
  else:
70
134
  stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
71
135
  versions = await session.scalars(stmt)
72
- data = [
73
- DatasetVersion(
74
- id_attr=version.id,
75
- description=version.description,
76
- metadata=version.metadata_,
77
- created_at=version.created_at,
78
- )
79
- for version in versions
80
- ]
136
+ data = [DatasetVersion(id=version.id, db_record=version) for version in versions]
81
137
  return connection_from_list(data=data, args=args)
82
138
 
83
139
  @strawberry.field(
@@ -90,7 +146,7 @@ class Dataset(Node):
90
146
  dataset_version_id: Optional[GlobalID] = UNSET,
91
147
  split_ids: Optional[list[GlobalID]] = UNSET,
92
148
  ) -> int:
93
- dataset_id = self.id_attr
149
+ dataset_id = self.id
94
150
  version_id = (
95
151
  from_global_id_with_expected_type(
96
152
  global_id=dataset_version_id,
@@ -180,7 +236,7 @@ class Dataset(Node):
180
236
  last=last,
181
237
  before=before if isinstance(before, CursorString) else None,
182
238
  )
183
- dataset_id = self.id_attr
239
+ dataset_id = self.id
184
240
  version_id = (
185
241
  from_global_id_with_expected_type(
186
242
  global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
@@ -261,9 +317,9 @@ class Dataset(Node):
261
317
  async with info.context.db() as session:
262
318
  dataset_examples = [
263
319
  DatasetExample(
264
- id_attr=example.id,
320
+ id=example.id,
321
+ db_record=example,
265
322
  version_id=version_id,
266
- created_at=example.created_at,
267
323
  )
268
324
  async for example in await session.stream_scalars(query)
269
325
  ]
@@ -272,8 +328,8 @@ class Dataset(Node):
272
328
  @strawberry.field
273
329
  async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
274
330
  return [
275
- to_gql_dataset_split(split)
276
- for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id_attr)
331
+ DatasetSplit(id=split.id, db_record=split)
332
+ for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id)
277
333
  ]
278
334
 
279
335
  @strawberry.field(
@@ -285,9 +341,7 @@ class Dataset(Node):
285
341
  info: Info[Context, None],
286
342
  dataset_version_id: Optional[GlobalID] = UNSET,
287
343
  ) -> int:
288
- stmt = select(count(models.Experiment.id)).where(
289
- models.Experiment.dataset_id == self.id_attr
290
- )
344
+ stmt = select(count(models.Experiment.id)).where(models.Experiment.dataset_id == self.id)
291
345
  version_id = (
292
346
  from_global_id_with_expected_type(
293
347
  global_id=dataset_version_id,
@@ -320,7 +374,7 @@ class Dataset(Node):
320
374
  last=last,
321
375
  before=before if isinstance(before, CursorString) else None,
322
376
  )
323
- dataset_id = self.id_attr
377
+ dataset_id = self.id
324
378
  row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
325
379
  query = (
326
380
  select(models.Experiment, row_number)
@@ -363,7 +417,7 @@ class Dataset(Node):
363
417
  async def experiment_annotation_summaries(
364
418
  self, info: Info[Context, None]
365
419
  ) -> list[DatasetExperimentAnnotationSummary]:
366
- dataset_id = self.id_attr
420
+ dataset_id = self.id
367
421
  query = (
368
422
  select(
369
423
  models.ExperimentRunAnnotation.name.label("annotation_name"),
@@ -396,24 +450,10 @@ class Dataset(Node):
396
450
  @strawberry.field
397
451
  async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
398
452
  return [
399
- to_gql_dataset_label(label)
400
- for label in await info.context.data_loaders.dataset_labels.load(self.id_attr)
453
+ DatasetLabel(id=label.id, db_record=label)
454
+ for label in await info.context.data_loaders.dataset_labels.load(self.id)
401
455
  ]
402
456
 
403
457
  @strawberry.field
404
458
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
405
- return info.context.last_updated_at.get(self._table, self.id_attr)
406
-
407
-
408
- def to_gql_dataset(dataset: models.Dataset) -> Dataset:
409
- """
410
- Converts an ORM dataset to a GraphQL dataset.
411
- """
412
- return Dataset(
413
- id_attr=dataset.id,
414
- name=dataset.name,
415
- description=dataset.description,
416
- metadata=dataset.metadata_,
417
- created_at=dataset.created_at,
418
- updated_at=dataset.updated_at,
419
- )
459
+ return info.context.last_updated_at.get(models.Dataset, self.id)
@@ -1,9 +1,8 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import select
6
- from sqlalchemy.orm import joinedload
7
6
  from strawberry import UNSET
8
7
  from strawberry.relay.types import Connection, GlobalID, Node, NodeID
9
8
  from strawberry.types import Info
@@ -12,34 +11,49 @@ from phoenix.db import models
12
11
  from phoenix.server.api.context import Context
13
12
  from phoenix.server.api.exceptions import BadRequest
14
13
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
15
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
14
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
16
15
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
17
16
  from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
18
17
  ExperimentRepeatedRunGroup,
19
18
  )
20
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
19
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
21
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
21
  from phoenix.server.api.types.pagination import (
23
22
  ConnectionArgs,
24
23
  CursorString,
25
24
  connection_from_list,
26
25
  )
27
- from phoenix.server.api.types.Span import Span
26
+
27
+ if TYPE_CHECKING:
28
+ from .Span import Span
28
29
 
29
30
 
30
31
  @strawberry.type
31
32
  class DatasetExample(Node):
32
- id_attr: NodeID[int]
33
- created_at: datetime
33
+ id: NodeID[int]
34
+ db_record: strawberry.Private[Optional[models.DatasetExample]] = None
34
35
  version_id: strawberry.Private[Optional[int]] = None
35
36
 
37
+ def __post_init__(self) -> None:
38
+ if self.db_record and self.id != self.db_record.id:
39
+ raise ValueError("DatasetExample ID mismatch")
40
+
41
+ @strawberry.field
42
+ async def created_at(self, info: Info[Context, None]) -> datetime:
43
+ if self.db_record:
44
+ val = self.db_record.created_at
45
+ else:
46
+ val = await info.context.data_loaders.dataset_example_fields.load(
47
+ (self.id, models.DatasetExample.created_at),
48
+ )
49
+ return val
50
+
36
51
  @strawberry.field
37
52
  async def revision(
38
53
  self,
39
54
  info: Info[Context, None],
40
55
  dataset_version_id: Optional[GlobalID] = UNSET,
41
56
  ) -> DatasetExampleRevision:
42
- example_id = self.id_attr
43
57
  version_id: Optional[int] = None
44
58
  if dataset_version_id:
45
59
  version_id = from_global_id_with_expected_type(
@@ -47,18 +61,18 @@ class DatasetExample(Node):
47
61
  )
48
62
  elif self.version_id is not None:
49
63
  version_id = self.version_id
50
- return await info.context.data_loaders.dataset_example_revisions.load(
51
- (example_id, version_id)
52
- )
64
+ return await info.context.data_loaders.dataset_example_revisions.load((self.id, version_id))
53
65
 
54
66
  @strawberry.field
55
67
  async def span(
56
68
  self,
57
69
  info: Info[Context, None],
58
- ) -> Optional[Span]:
70
+ ) -> Optional[Annotated["Span", strawberry.lazy(".Span")]]:
71
+ from .Span import Span
72
+
59
73
  return (
60
- Span(span_rowid=span.id, db_span=span)
61
- if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
74
+ Span(id=span.id, db_record=span)
75
+ if (span := await info.context.data_loaders.dataset_example_spans.load(self.id))
62
76
  else None
63
77
  )
64
78
 
@@ -78,12 +92,10 @@ class DatasetExample(Node):
78
92
  last=last,
79
93
  before=before if isinstance(before, CursorString) else None,
80
94
  )
81
- example_id = self.id_attr
82
95
  query = (
83
96
  select(models.ExperimentRun)
84
- .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
85
97
  .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
86
- .where(models.ExperimentRun.dataset_example_id == example_id)
98
+ .where(models.ExperimentRun.dataset_example_id == self.id)
87
99
  .order_by(
88
100
  models.ExperimentRun.experiment_id.asc(),
89
101
  models.ExperimentRun.repetition_number.asc(),
@@ -100,7 +112,7 @@ class DatasetExample(Node):
100
112
  query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
101
113
  async with info.context.db() as session:
102
114
  runs = (await session.scalars(query)).all()
103
- return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
115
+ return connection_from_list([ExperimentRun(id=run.id, db_record=run) for run in runs], args)
104
116
 
105
117
  @strawberry.field
106
118
  async def experiment_repeated_run_groups(
@@ -108,7 +120,6 @@ class DatasetExample(Node):
108
120
  info: Info[Context, None],
109
121
  experiment_ids: list[GlobalID],
110
122
  ) -> list[ExperimentRepeatedRunGroup]:
111
- example_rowid = self.id_attr
112
123
  experiment_rowids = []
113
124
  for experiment_id in experiment_ids:
114
125
  try:
@@ -121,14 +132,14 @@ class DatasetExample(Node):
121
132
  experiment_rowids.append(experiment_rowid)
122
133
  repeated_run_groups = (
123
134
  await info.context.data_loaders.experiment_repeated_run_groups.load_many(
124
- [(experiment_rowid, example_rowid) for experiment_rowid in experiment_rowids]
135
+ [(experiment_rowid, self.id) for experiment_rowid in experiment_rowids]
125
136
  )
126
137
  )
127
138
  return [
128
139
  ExperimentRepeatedRunGroup(
129
140
  experiment_rowid=group.experiment_rowid,
130
141
  dataset_example_rowid=group.dataset_example_rowid,
131
- runs=[to_gql_experiment_run(run) for run in group.runs],
142
+ cached_runs=[ExperimentRun(id=run.id, db_record=run) for run in group.runs],
132
143
  )
133
144
  for group in repeated_run_groups
134
145
  ]
@@ -139,13 +150,6 @@ class DatasetExample(Node):
139
150
  info: Info[Context, None],
140
151
  ) -> list[DatasetSplit]:
141
152
  return [
142
- to_gql_dataset_split(split)
143
- for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
153
+ DatasetSplit(id=split.id, db_record=split)
154
+ for split in await info.context.data_loaders.dataset_example_splits.load(self.id)
144
155
  ]
145
-
146
-
147
- def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
148
- return DatasetExample(
149
- id_attr=example.id,
150
- created_at=example.created_at,
151
- )
@@ -2,22 +2,56 @@ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry.relay import Node, NodeID
5
+ from strawberry.types import Info
5
6
 
6
7
  from phoenix.db import models
8
+ from phoenix.server.api.context import Context
7
9
 
8
10
 
9
11
  @strawberry.type
10
12
  class DatasetLabel(Node):
11
- id_attr: NodeID[int]
12
- name: str
13
- description: Optional[str]
14
- color: str
15
-
16
-
17
- def to_gql_dataset_label(dataset_label: models.DatasetLabel) -> DatasetLabel:
18
- return DatasetLabel(
19
- id_attr=dataset_label.id,
20
- name=dataset_label.name,
21
- description=dataset_label.description,
22
- color=dataset_label.color,
23
- )
13
+ id: NodeID[int]
14
+ db_record: strawberry.Private[Optional[models.DatasetLabel]] = None
15
+
16
+ def __post_init__(self) -> None:
17
+ if self.db_record and self.id != self.db_record.id:
18
+ raise ValueError("DatasetLabel ID mismatch")
19
+
20
+ @strawberry.field
21
+ async def name(
22
+ self,
23
+ info: Info[Context, None],
24
+ ) -> str:
25
+ if self.db_record:
26
+ val = self.db_record.name
27
+ else:
28
+ val = await info.context.data_loaders.dataset_label_fields.load(
29
+ (self.id, models.DatasetLabel.name),
30
+ )
31
+ return val
32
+
33
+ @strawberry.field
34
+ async def description(
35
+ self,
36
+ info: Info[Context, None],
37
+ ) -> Optional[str]:
38
+ if self.db_record:
39
+ val = self.db_record.description
40
+ else:
41
+ val = await info.context.data_loaders.dataset_label_fields.load(
42
+ (self.id, models.DatasetLabel.description),
43
+ )
44
+ return val
45
+
46
+ @strawberry.field
47
+ async def color(
48
+ self,
49
+ info: Info[Context, None],
50
+ ) -> str:
51
+ if self.db_record:
52
+ val = self.db_record.color
53
+ else:
54
+ val = await info.context.data_loaders.dataset_label_fields.load(
55
+ (self.id, models.DatasetLabel.color),
56
+ )
57
+ return val
@@ -1,32 +1,98 @@
1
1
  from datetime import datetime
2
- from typing import ClassVar, Optional
2
+ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
  from strawberry.relay import Node, NodeID
6
6
  from strawberry.scalars import JSON
7
+ from strawberry.types import Info
7
8
 
8
9
  from phoenix.db import models
10
+ from phoenix.server.api.context import Context
9
11
 
10
12
 
11
13
  @strawberry.type
12
14
  class DatasetSplit(Node):
13
- _table: ClassVar[type[models.Base]] = models.DatasetSplit
14
- id_attr: NodeID[int]
15
- name: str
16
- description: Optional[str]
17
- metadata: JSON
18
- color: str
19
- created_at: datetime
20
- updated_at: datetime
21
-
22
-
23
- def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
24
- return DatasetSplit(
25
- id_attr=dataset_split.id,
26
- name=dataset_split.name,
27
- description=dataset_split.description,
28
- color=dataset_split.color or "#ffffff",
29
- metadata=dataset_split.metadata_,
30
- created_at=dataset_split.created_at,
31
- updated_at=dataset_split.updated_at,
32
- )
15
+ id: NodeID[int]
16
+ db_record: strawberry.Private[Optional[models.DatasetSplit]] = None
17
+
18
+ def __post_init__(self) -> None:
19
+ if self.db_record and self.id != self.db_record.id:
20
+ raise ValueError("DatasetSplit ID mismatch")
21
+
22
+ @strawberry.field
23
+ async def name(
24
+ self,
25
+ info: Info[Context, None],
26
+ ) -> str:
27
+ if self.db_record:
28
+ val = self.db_record.name
29
+ else:
30
+ val = await info.context.data_loaders.dataset_split_fields.load(
31
+ (self.id, models.DatasetSplit.name),
32
+ )
33
+ return val
34
+
35
+ @strawberry.field
36
+ async def description(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> Optional[str]:
40
+ if self.db_record:
41
+ val = self.db_record.description
42
+ else:
43
+ val = await info.context.data_loaders.dataset_split_fields.load(
44
+ (self.id, models.DatasetSplit.description),
45
+ )
46
+ return val
47
+
48
+ @strawberry.field
49
+ async def metadata(
50
+ self,
51
+ info: Info[Context, None],
52
+ ) -> JSON:
53
+ if self.db_record:
54
+ val = self.db_record.metadata_
55
+ else:
56
+ val = await info.context.data_loaders.dataset_split_fields.load(
57
+ (self.id, models.DatasetSplit.metadata_),
58
+ )
59
+ return val
60
+
61
+ @strawberry.field
62
+ async def color(
63
+ self,
64
+ info: Info[Context, None],
65
+ ) -> str:
66
+ if self.db_record:
67
+ val = self.db_record.color
68
+ else:
69
+ val = await info.context.data_loaders.dataset_split_fields.load(
70
+ (self.id, models.DatasetSplit.color),
71
+ )
72
+ return val
73
+
74
+ @strawberry.field
75
+ async def created_at(
76
+ self,
77
+ info: Info[Context, None],
78
+ ) -> datetime:
79
+ if self.db_record:
80
+ val = self.db_record.created_at
81
+ else:
82
+ val = await info.context.data_loaders.dataset_split_fields.load(
83
+ (self.id, models.DatasetSplit.created_at),
84
+ )
85
+ return val
86
+
87
+ @strawberry.field
88
+ async def updated_at(
89
+ self,
90
+ info: Info[Context, None],
91
+ ) -> datetime:
92
+ if self.db_record:
93
+ val = self.db_record.updated_at
94
+ else:
95
+ val = await info.context.data_loaders.dataset_split_fields.load(
96
+ (self.id, models.DatasetSplit.updated_at),
97
+ )
98
+ return val
@@ -4,11 +4,56 @@ from typing import Optional
4
4
  import strawberry
5
5
  from strawberry.relay import Node, NodeID
6
6
  from strawberry.scalars import JSON
7
+ from strawberry.types import Info
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.api.context import Context
7
11
 
8
12
 
9
13
  @strawberry.type
10
14
  class DatasetVersion(Node):
11
- id_attr: NodeID[int]
12
- description: Optional[str]
13
- metadata: JSON
14
- created_at: datetime
15
+ id: NodeID[int]
16
+ db_record: strawberry.Private[Optional[models.DatasetVersion]] = None
17
+
18
+ def __post_init__(self) -> None:
19
+ if self.db_record and self.id != self.db_record.id:
20
+ raise ValueError("DatasetVersion ID mismatch")
21
+
22
+ @strawberry.field
23
+ async def description(
24
+ self,
25
+ info: Info[Context, None],
26
+ ) -> Optional[str]:
27
+ if self.db_record:
28
+ val = self.db_record.description
29
+ else:
30
+ val = await info.context.data_loaders.dataset_version_fields.load(
31
+ (self.id, models.DatasetVersion.description),
32
+ )
33
+ return val
34
+
35
+ @strawberry.field
36
+ async def metadata(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> JSON:
40
+ if self.db_record:
41
+ val = self.db_record.metadata_
42
+ else:
43
+ val = await info.context.data_loaders.dataset_version_fields.load(
44
+ (self.id, models.DatasetVersion.metadata_),
45
+ )
46
+ return val
47
+
48
+ @strawberry.field
49
+ async def created_at(
50
+ self,
51
+ info: Info[Context, None],
52
+ ) -> datetime:
53
+ if self.db_record:
54
+ val = self.db_record.created_at
55
+ else:
56
+ val = await info.context.data_loaders.dataset_version_fields.load(
57
+ (self.id, models.DatasetVersion.created_at),
58
+ )
59
+ return val