arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,10 +12,7 @@ from strawberry.types import Info
12
12
  from phoenix.db import models
13
13
  from phoenix.server.api.context import Context
14
14
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
15
- from phoenix.server.api.types.ExperimentRunAnnotation import (
16
- ExperimentRunAnnotation,
17
- to_gql_experiment_run_annotation,
18
- )
15
+ from phoenix.server.api.types.ExperimentRunAnnotation import ExperimentRunAnnotation
19
16
  from phoenix.server.api.types.pagination import (
20
17
  ConnectionArgs,
21
18
  CursorString,
@@ -26,23 +23,100 @@ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
26
23
  from phoenix.server.api.types.Trace import Trace
27
24
 
28
25
  if TYPE_CHECKING:
29
- from phoenix.server.api.types.DatasetExample import DatasetExample
26
+ from .DatasetExample import DatasetExample
27
+ from .Trace import Trace
30
28
 
31
29
 
32
30
  @strawberry.type
33
31
  class ExperimentRun(Node):
34
- id_attr: NodeID[int]
35
- experiment_id: GlobalID
36
- repetition_number: int
37
- trace_id: Optional[str]
38
- output: Optional[JSON]
39
- start_time: datetime
40
- end_time: datetime
41
- error: Optional[str]
32
+ id: NodeID[int]
33
+ db_record: strawberry.Private[Optional[models.ExperimentRun]] = None
34
+
35
+ def __post_init__(self) -> None:
36
+ if self.db_record and self.id != self.db_record.id:
37
+ raise ValueError("ExperimentRun ID mismatch")
38
+
39
+ @strawberry.field
40
+ async def experiment_id(self, info: Info[Context, None]) -> GlobalID:
41
+ from .Experiment import Experiment
42
+
43
+ if self.db_record:
44
+ experiment_id = self.db_record.experiment_id
45
+ else:
46
+ experiment_id = await info.context.data_loaders.experiment_run_fields.load(
47
+ (self.id, models.ExperimentRun.experiment_id),
48
+ )
49
+ return GlobalID(Experiment.__name__, str(experiment_id))
50
+
51
+ @strawberry.field
52
+ async def repetition_number(self, info: Info[Context, None]) -> int:
53
+ if self.db_record:
54
+ val = self.db_record.repetition_number
55
+ else:
56
+ val = await info.context.data_loaders.experiment_run_fields.load(
57
+ (self.id, models.ExperimentRun.repetition_number),
58
+ )
59
+ return val
60
+
61
+ @strawberry.field
62
+ async def trace_id(self, info: Info[Context, None]) -> Optional[str]:
63
+ if self.db_record:
64
+ val = self.db_record.trace_id
65
+ else:
66
+ val = await info.context.data_loaders.experiment_run_fields.load(
67
+ (self.id, models.ExperimentRun.trace_id),
68
+ )
69
+ return val
70
+
71
+ @strawberry.field
72
+ async def output(self, info: Info[Context, None]) -> Optional[JSON]:
73
+ if self.db_record:
74
+ output_dict = self.db_record.output
75
+ else:
76
+ output_dict = await info.context.data_loaders.experiment_run_fields.load(
77
+ (self.id, models.ExperimentRun.output),
78
+ )
79
+ return output_dict.get("task_output") if output_dict else None
42
80
 
43
81
  @strawberry.field
44
- def latency_ms(self) -> float:
45
- return (self.end_time - self.start_time).total_seconds() * 1000
82
+ async def start_time(self, info: Info[Context, None]) -> datetime:
83
+ if self.db_record:
84
+ val = self.db_record.start_time
85
+ else:
86
+ val = await info.context.data_loaders.experiment_run_fields.load(
87
+ (self.id, models.ExperimentRun.start_time),
88
+ )
89
+ return val
90
+
91
+ @strawberry.field
92
+ async def end_time(self, info: Info[Context, None]) -> datetime:
93
+ if self.db_record:
94
+ val = self.db_record.end_time
95
+ else:
96
+ val = await info.context.data_loaders.experiment_run_fields.load(
97
+ (self.id, models.ExperimentRun.end_time),
98
+ )
99
+ return val
100
+
101
+ @strawberry.field
102
+ async def error(self, info: Info[Context, None]) -> Optional[str]:
103
+ if self.db_record:
104
+ val = self.db_record.error
105
+ else:
106
+ val = await info.context.data_loaders.experiment_run_fields.load(
107
+ (self.id, models.ExperimentRun.error),
108
+ )
109
+ return val
110
+
111
+ @strawberry.field
112
+ async def latency_ms(self, info: Info[Context, None]) -> float:
113
+ if self.db_record:
114
+ val = self.db_record.latency_ms
115
+ else:
116
+ val = await info.context.data_loaders.experiment_run_fields.load(
117
+ (self.id, models.ExperimentRun.latency_ms),
118
+ )
119
+ return val
46
120
 
47
121
  @strawberry.field
48
122
  async def annotations(
@@ -59,45 +133,49 @@ class ExperimentRun(Node):
59
133
  last=last,
60
134
  before=before if isinstance(before, CursorString) else None,
61
135
  )
62
- run_id = self.id_attr
63
- annotations = await info.context.data_loaders.experiment_run_annotations.load(run_id)
136
+ annotations = await info.context.data_loaders.experiment_run_annotations.load(self.id)
64
137
  return connection_from_list(
65
- [to_gql_experiment_run_annotation(annotation) for annotation in annotations], args
138
+ [
139
+ ExperimentRunAnnotation(id=annotation.id, db_record=annotation)
140
+ for annotation in annotations
141
+ ],
142
+ args,
66
143
  )
67
144
 
68
145
  @strawberry.field
69
- async def trace(self, info: Info) -> Optional[Trace]:
70
- if not self.trace_id:
146
+ async def trace(
147
+ self, info: Info[Context, None]
148
+ ) -> Optional[Annotated["Trace", strawberry.lazy(".Trace")]]:
149
+ if self.db_record:
150
+ trace_id = self.db_record.trace_id
151
+ else:
152
+ trace_id = await info.context.data_loaders.experiment_run_fields.load(
153
+ (self.id, models.ExperimentRun.trace_id),
154
+ )
155
+ if not trace_id:
71
156
  return None
72
- dataloader = info.context.data_loaders.trace_by_trace_ids
73
- if (trace := await dataloader.load(self.trace_id)) is None:
157
+ loader = info.context.data_loaders.trace_by_trace_ids
158
+ if (trace := await loader.load(trace_id)) is None:
74
159
  return None
75
- return Trace(trace_rowid=trace.id, db_trace=trace)
160
+ from .Trace import Trace
161
+
162
+ return Trace(id=trace.id, db_record=trace)
76
163
 
77
164
  @strawberry.field
78
165
  async def example(
79
- self, info: Info
166
+ self, info: Info[Context, None]
80
167
  ) -> Annotated[
81
- "DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
168
+ "DatasetExample", strawberry.lazy(".DatasetExample")
82
169
  ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
83
- from phoenix.server.api.types.DatasetExample import DatasetExample
170
+ from .DatasetExample import DatasetExample
84
171
 
85
- (
86
- example,
87
- version_id,
88
- ) = await info.context.data_loaders.dataset_examples_and_versions_by_experiment_run.load(
89
- self.id_attr
90
- )
91
- return DatasetExample(
92
- id_attr=example.id,
93
- created_at=example.created_at,
94
- version_id=version_id,
95
- )
172
+ loader = info.context.data_loaders.dataset_examples_and_versions_by_experiment_run
173
+ (example, version_id) = await loader.load(self.id)
174
+ return DatasetExample(id=example.id, db_record=example, version_id=version_id)
96
175
 
97
176
  @strawberry.field
98
177
  async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
99
- run_id = self.id_attr
100
- summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(run_id)
178
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(self.id)
101
179
  return SpanCostSummary(
102
180
  prompt=CostBreakdown(
103
181
  tokens=summary.prompt.tokens,
@@ -117,8 +195,6 @@ class ExperimentRun(Node):
117
195
  async def cost_detail_summary_entries(
118
196
  self, info: Info[Context, None]
119
197
  ) -> list[SpanCostDetailSummaryEntry]:
120
- run_id = self.id_attr
121
-
122
198
  stmt = (
123
199
  select(
124
200
  models.SpanCostDetail.token_type,
@@ -131,7 +207,7 @@ class ExperimentRun(Node):
131
207
  .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
132
208
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
133
209
  .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
134
- .where(models.ExperimentRun.id == run_id)
210
+ .where(models.ExperimentRun.id == self.id)
135
211
  .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
136
212
  )
137
213
 
@@ -145,22 +221,3 @@ class ExperimentRun(Node):
145
221
  )
146
222
  async for token_type, is_prompt, cost, tokens in data
147
223
  ]
148
-
149
-
150
- def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
151
- """
152
- Converts an ORM experiment run to a GraphQL ExperimentRun.
153
- """
154
-
155
- from phoenix.server.api.types.Experiment import Experiment
156
-
157
- return ExperimentRun(
158
- id_attr=run.id,
159
- experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
160
- repetition_number=run.repetition_number,
161
- trace_id=run.trace.trace_id if run.trace else None,
162
- output=run.output.get("task_output"),
163
- start_time=run.start_time,
164
- end_time=run.end_time,
165
- error=run.error,
166
- )
@@ -1,56 +1,175 @@
1
1
  from datetime import datetime
2
+ from math import isfinite
2
3
  from typing import Optional
3
4
 
4
5
  import strawberry
5
6
  from strawberry import Info
6
- from strawberry.relay import Node, NodeID
7
+ from strawberry.relay import GlobalID, Node, NodeID
7
8
  from strawberry.scalars import JSON
8
9
 
9
10
  from phoenix.db import models
11
+ from phoenix.server.api.context import Context
10
12
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
13
  from phoenix.server.api.types.Trace import Trace
12
14
 
13
15
 
14
16
  @strawberry.type
15
17
  class ExperimentRunAnnotation(Node):
16
- id_attr: NodeID[int]
17
- name: str
18
- annotator_kind: ExperimentRunAnnotatorKind
19
- label: Optional[str]
20
- score: Optional[float]
21
- explanation: Optional[str]
22
- error: Optional[str]
23
- metadata: JSON
24
- start_time: datetime
25
- end_time: datetime
26
- trace_id: Optional[str]
27
-
28
- @strawberry.field
29
- async def trace(self, info: Info) -> Optional[Trace]:
30
- if not self.trace_id:
18
+ id: NodeID[int]
19
+ db_record: strawberry.Private[Optional[models.ExperimentRunAnnotation]] = None
20
+
21
+ def __post_init__(self) -> None:
22
+ if self.db_record and self.id != self.db_record.id:
23
+ raise ValueError("ExperimentRunAnnotation ID mismatch")
24
+
25
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
26
+ async def name(
27
+ self,
28
+ info: Info[Context, None],
29
+ ) -> str:
30
+ if self.db_record:
31
+ val = self.db_record.name
32
+ else:
33
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
34
+ (self.id, models.ExperimentRunAnnotation.name),
35
+ )
36
+ return val
37
+
38
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
39
+ async def annotator_kind(
40
+ self,
41
+ info: Info[Context, None],
42
+ ) -> ExperimentRunAnnotatorKind:
43
+ if self.db_record:
44
+ val = self.db_record.annotator_kind
45
+ else:
46
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
47
+ (self.id, models.ExperimentRunAnnotation.annotator_kind),
48
+ )
49
+ return ExperimentRunAnnotatorKind(val)
50
+
51
+ @strawberry.field(
52
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
53
+ ) # type: ignore
54
+ async def label(
55
+ self,
56
+ info: Info[Context, None],
57
+ ) -> Optional[str]:
58
+ if self.db_record:
59
+ val = self.db_record.label
60
+ else:
61
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
62
+ (self.id, models.ExperimentRunAnnotation.label),
63
+ )
64
+ return val
65
+
66
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
67
+ async def score(
68
+ self,
69
+ info: Info[Context, None],
70
+ ) -> Optional[float]:
71
+ if self.db_record:
72
+ val = self.db_record.score
73
+ else:
74
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
75
+ (self.id, models.ExperimentRunAnnotation.score),
76
+ )
77
+ return val if val is not None and isfinite(val) else None
78
+
79
+ @strawberry.field(
80
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
81
+ ) # type: ignore
82
+ async def explanation(
83
+ self,
84
+ info: Info[Context, None],
85
+ ) -> Optional[str]:
86
+ if self.db_record:
87
+ val = self.db_record.explanation
88
+ else:
89
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
90
+ (self.id, models.ExperimentRunAnnotation.explanation),
91
+ )
92
+ return val
93
+
94
+ @strawberry.field(description="Error message if the annotation failed to produce a result.") # type: ignore
95
+ async def error(
96
+ self,
97
+ info: Info[Context, None],
98
+ ) -> Optional[str]:
99
+ if self.db_record:
100
+ val = self.db_record.error
101
+ else:
102
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
103
+ (self.id, models.ExperimentRunAnnotation.error),
104
+ )
105
+ return val
106
+
107
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
108
+ async def metadata(
109
+ self,
110
+ info: Info[Context, None],
111
+ ) -> JSON:
112
+ if self.db_record:
113
+ val = self.db_record.metadata_
114
+ else:
115
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
116
+ (self.id, models.ExperimentRunAnnotation.metadata_),
117
+ )
118
+ return val
119
+
120
+ @strawberry.field(description="The date and time when the annotation was created.") # type: ignore
121
+ async def start_time(
122
+ self,
123
+ info: Info[Context, None],
124
+ ) -> datetime:
125
+ if self.db_record:
126
+ val = self.db_record.start_time
127
+ else:
128
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
129
+ (self.id, models.ExperimentRunAnnotation.start_time),
130
+ )
131
+ return val
132
+
133
+ @strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
134
+ async def end_time(
135
+ self,
136
+ info: Info[Context, None],
137
+ ) -> datetime:
138
+ if self.db_record:
139
+ val = self.db_record.end_time
140
+ else:
141
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
142
+ (self.id, models.ExperimentRunAnnotation.end_time),
143
+ )
144
+ return val
145
+
146
+ @strawberry.field(description="The identifier of the trace associated with the annotation.") # type: ignore
147
+ async def trace_id(
148
+ self,
149
+ info: Info[Context, None],
150
+ ) -> Optional[GlobalID]:
151
+ if self.db_record:
152
+ val = self.db_record.trace_id
153
+ else:
154
+ val = await info.context.data_loaders.experiment_run_annotation_fields.load(
155
+ (self.id, models.ExperimentRunAnnotation.trace_id),
156
+ )
157
+ return None if val is None else GlobalID(type_name=Trace.__name__, node_id=val)
158
+
159
+ @strawberry.field(description="The trace associated with the annotation.") # type: ignore
160
+ async def trace(
161
+ self,
162
+ info: Info[Context, None],
163
+ ) -> Optional[Trace]:
164
+ if self.db_record:
165
+ trace_id = self.db_record.trace_id
166
+ else:
167
+ trace_id = await info.context.data_loaders.experiment_run_annotation_fields.load(
168
+ (self.id, models.ExperimentRunAnnotation.trace_id),
169
+ )
170
+ if not trace_id:
31
171
  return None
32
172
  dataloader = info.context.data_loaders.trace_by_trace_ids
33
- if (trace := await dataloader.load(self.trace_id)) is None:
173
+ if (trace := await dataloader.load(trace_id)) is None:
34
174
  return None
35
- return Trace(trace_rowid=trace.id, db_trace=trace)
36
-
37
-
38
- def to_gql_experiment_run_annotation(
39
- annotation: models.ExperimentRunAnnotation,
40
- ) -> ExperimentRunAnnotation:
41
- """
42
- Converts an ORM experiment run annotation to a GraphQL ExperimentRunAnnotation.
43
- """
44
- return ExperimentRunAnnotation(
45
- id_attr=annotation.id,
46
- name=annotation.name,
47
- annotator_kind=ExperimentRunAnnotatorKind(annotation.annotator_kind),
48
- label=annotation.label,
49
- score=annotation.score,
50
- explanation=annotation.explanation,
51
- error=annotation.error,
52
- metadata=annotation.metadata_,
53
- start_time=annotation.start_time,
54
- end_time=annotation.end_time,
55
- trace_id=annotation.trace_id,
56
- )
175
+ return Trace(id=trace.id, db_record=trace)
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import strawberry
6
6
  from openinference.semconv.trace import OpenInferenceLLMProviderValues
7
- from sqlalchemy import inspect
8
7
  from strawberry.relay import Node, NodeID
9
8
  from strawberry.relay.types import GlobalID
10
9
  from strawberry.types import Info
@@ -37,20 +36,98 @@ CachedCostSummaryKey: TypeAlias = tuple[Optional[ProjectId], TimeRangeKey]
37
36
 
38
37
  @strawberry.type
39
38
  class GenerativeModel(Node, ModelInterface):
40
- id_attr: NodeID[int]
41
- name: str
42
- provider: Optional[str]
43
- name_pattern: str
44
- kind: GenerativeModelKind
45
- created_at: datetime
46
- updated_at: datetime
47
- provider_key: Optional[GenerativeProviderKey]
48
- costs: strawberry.Private[Optional[list[models.TokenPrice]]] = None
49
- start_time: Optional[datetime] = None
39
+ id: NodeID[int]
40
+ db_record: strawberry.Private[Optional[models.GenerativeModel]] = None
50
41
  cached_cost_summary: strawberry.Private[
51
42
  Optional[dict[CachedCostSummaryKey, SpanCostSummary]]
52
43
  ] = None
53
44
 
45
+ def __post_init__(self) -> None:
46
+ if self.db_record and self.id != self.db_record.id:
47
+ raise ValueError("GenerativeModel ID mismatch")
48
+
49
+ @strawberry.field
50
+ async def name(self, info: Info[Context, None]) -> str:
51
+ if self.db_record:
52
+ val = self.db_record.name
53
+ else:
54
+ val = await info.context.data_loaders.generative_model_fields.load(
55
+ (self.id, models.GenerativeModel.name),
56
+ )
57
+ return val
58
+
59
+ @strawberry.field
60
+ async def provider(self, info: Info[Context, None]) -> Optional[str]:
61
+ if self.db_record:
62
+ provider = self.db_record.provider
63
+ else:
64
+ provider = await info.context.data_loaders.generative_model_fields.load(
65
+ (self.id, models.GenerativeModel.provider),
66
+ )
67
+ return provider or None
68
+
69
+ @strawberry.field
70
+ async def name_pattern(self, info: Info[Context, None]) -> str:
71
+ if self.db_record:
72
+ pattern = self.db_record.name_pattern.pattern
73
+ else:
74
+ name_pattern_obj = await info.context.data_loaders.generative_model_fields.load(
75
+ (self.id, models.GenerativeModel.name_pattern),
76
+ )
77
+ pattern = name_pattern_obj.pattern
78
+ assert isinstance(pattern, str)
79
+ return pattern
80
+
81
+ @strawberry.field
82
+ async def kind(self, info: Info[Context, None]) -> GenerativeModelKind:
83
+ if self.db_record:
84
+ is_built_in = self.db_record.is_built_in
85
+ else:
86
+ is_built_in = await info.context.data_loaders.generative_model_fields.load(
87
+ (self.id, models.GenerativeModel.is_built_in),
88
+ )
89
+ return GenerativeModelKind.BUILT_IN if is_built_in else GenerativeModelKind.CUSTOM
90
+
91
+ @strawberry.field
92
+ async def created_at(self, info: Info[Context, None]) -> datetime:
93
+ if self.db_record:
94
+ val = self.db_record.created_at
95
+ else:
96
+ val = await info.context.data_loaders.generative_model_fields.load(
97
+ (self.id, models.GenerativeModel.created_at),
98
+ )
99
+ return val
100
+
101
+ @strawberry.field
102
+ async def updated_at(self, info: Info[Context, None]) -> datetime:
103
+ if self.db_record:
104
+ val = self.db_record.updated_at
105
+ else:
106
+ val = await info.context.data_loaders.generative_model_fields.load(
107
+ (self.id, models.GenerativeModel.updated_at),
108
+ )
109
+ return val
110
+
111
+ @strawberry.field
112
+ async def provider_key(self, info: Info[Context, None]) -> Optional[GenerativeProviderKey]:
113
+ if self.db_record:
114
+ provider = self.db_record.provider
115
+ else:
116
+ provider = await info.context.data_loaders.generative_model_fields.load(
117
+ (self.id, models.GenerativeModel.provider),
118
+ )
119
+ return _semconv_provider_to_gql_generative_provider_key(provider) if provider else None
120
+
121
+ @strawberry.field
122
+ async def start_time(self, info: Info[Context, None]) -> Optional[datetime]:
123
+ if self.db_record:
124
+ val = self.db_record.start_time
125
+ else:
126
+ val = await info.context.data_loaders.generative_model_fields.load(
127
+ (self.id, models.GenerativeModel.start_time),
128
+ )
129
+ return val
130
+
54
131
  def add_cached_cost_summary(
55
132
  self, project_id: Optional[int], time_range: TimeRange, cost_summary: SpanCostSummary
56
133
  ) -> None:
@@ -61,11 +138,10 @@ class GenerativeModel(Node, ModelInterface):
61
138
  self.cached_cost_summary[cache_key] = cost_summary
62
139
 
63
140
  @strawberry.field
64
- async def token_prices(self) -> list[TokenPrice]:
65
- if self.costs is None:
66
- raise NotImplementedError
67
- token_prices: list[TokenPrice] = list()
68
- for cost in self.costs:
141
+ async def token_prices(self, info: Info[Context, None]) -> list[TokenPrice]:
142
+ costs = await info.context.data_loaders.token_prices_by_model.load(self.id)
143
+ token_prices: list[TokenPrice] = []
144
+ for cost in costs:
69
145
  token_prices.append(
70
146
  TokenPrice(
71
147
  token_type=cost.token_type,
@@ -100,7 +176,7 @@ class GenerativeModel(Node, ModelInterface):
100
176
  )
101
177
 
102
178
  loader = info.context.data_loaders.span_cost_summary_by_generative_model
103
- summary = await loader.load(self.id_attr)
179
+ summary = await loader.load(self.id)
104
180
  return SpanCostSummary(
105
181
  prompt=CostBreakdown(
106
182
  tokens=summary.prompt.tokens,
@@ -122,7 +198,7 @@ class GenerativeModel(Node, ModelInterface):
122
198
  info: Info[Context, None],
123
199
  ) -> list[SpanCostDetailSummaryEntry]:
124
200
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_generative_model
125
- summary = await loader.load(self.id_attr)
201
+ summary = await loader.load(self.id)
126
202
  return [
127
203
  SpanCostDetailSummaryEntry(
128
204
  token_type=entry.token_type,
@@ -137,30 +213,7 @@ class GenerativeModel(Node, ModelInterface):
137
213
 
138
214
  @strawberry.field
139
215
  async def last_used_at(self, info: Info[Context, None]) -> Optional[datetime]:
140
- model_id = self.id_attr
141
- return await info.context.data_loaders.last_used_times_by_generative_model_id.load(model_id)
142
-
143
-
144
- def to_gql_generative_model(
145
- model: models.GenerativeModel,
146
- ) -> GenerativeModel:
147
- costs_are_loaded = isinstance(inspect(model).attrs.token_prices.loaded_value, list)
148
- name_pattern = model.name_pattern.pattern
149
- assert isinstance(name_pattern, str)
150
- return GenerativeModel(
151
- id_attr=model.id,
152
- name=model.name,
153
- provider=model.provider or None,
154
- name_pattern=name_pattern,
155
- kind=GenerativeModelKind.BUILT_IN if model.is_built_in else GenerativeModelKind.CUSTOM,
156
- created_at=model.created_at,
157
- updated_at=model.updated_at,
158
- start_time=model.start_time,
159
- provider_key=_semconv_provider_to_gql_generative_provider_key(model.provider)
160
- if model.provider
161
- else None,
162
- costs=model.token_prices if costs_are_loaded else None,
163
- )
216
+ return await info.context.data_loaders.last_used_times_by_generative_model_id.load(self.id)
164
217
 
165
218
 
166
219
  def _semconv_provider_to_gql_generative_provider_key(
@@ -7,5 +7,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
7
7
 
8
8
  @strawberry.interface
9
9
  class ModelInterface:
10
- name: str
11
- provider_key: Optional[GenerativeProviderKey]
10
+ @strawberry.field
11
+ async def name(self) -> str:
12
+ raise NotImplementedError
13
+
14
+ @strawberry.field
15
+ async def provider_key(self) -> Optional[GenerativeProviderKey]:
16
+ raise NotImplementedError
@@ -1,10 +1,20 @@
1
1
  import strawberry
2
+ from strawberry.types import Info
2
3
 
4
+ from phoenix.server.api.context import Context
3
5
  from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
4
6
  from phoenix.server.api.types.ModelInterface import ModelInterface
5
7
 
6
8
 
7
9
  @strawberry.type
8
10
  class PlaygroundModel(ModelInterface):
9
- name: str
10
- provider_key: GenerativeProviderKey # PlaygroundModel always has a provider_key
11
+ name_value: strawberry.Private[str]
12
+ provider_key_value: strawberry.Private[GenerativeProviderKey]
13
+
14
+ @strawberry.field
15
+ async def name(self, info: Info[Context, None]) -> str:
16
+ return self.name_value
17
+
18
+ @strawberry.field
19
+ async def provider_key(self, info: Info[Context, None]) -> GenerativeProviderKey:
20
+ return self.provider_key_value