arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Annotated, Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
|
-
from sqlalchemy.orm import load_only
|
|
7
6
|
from sqlalchemy.sql.functions import coalesce
|
|
8
7
|
from strawberry import UNSET
|
|
9
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
@@ -13,10 +12,7 @@ from strawberry.types import Info
|
|
|
13
12
|
from phoenix.db import models
|
|
14
13
|
from phoenix.server.api.context import Context
|
|
15
14
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
16
|
-
from phoenix.server.api.types.ExperimentRunAnnotation import
|
|
17
|
-
ExperimentRunAnnotation,
|
|
18
|
-
to_gql_experiment_run_annotation,
|
|
19
|
-
)
|
|
15
|
+
from phoenix.server.api.types.ExperimentRunAnnotation import ExperimentRunAnnotation
|
|
20
16
|
from phoenix.server.api.types.pagination import (
|
|
21
17
|
ConnectionArgs,
|
|
22
18
|
CursorString,
|
|
@@ -27,18 +23,100 @@ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
|
27
23
|
from phoenix.server.api.types.Trace import Trace
|
|
28
24
|
|
|
29
25
|
if TYPE_CHECKING:
|
|
30
|
-
from
|
|
26
|
+
from .DatasetExample import DatasetExample
|
|
27
|
+
from .Trace import Trace
|
|
31
28
|
|
|
32
29
|
|
|
33
30
|
@strawberry.type
|
|
34
31
|
class ExperimentRun(Node):
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
32
|
+
id: NodeID[int]
|
|
33
|
+
db_record: strawberry.Private[Optional[models.ExperimentRun]] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self) -> None:
|
|
36
|
+
if self.db_record and self.id != self.db_record.id:
|
|
37
|
+
raise ValueError("ExperimentRun ID mismatch")
|
|
38
|
+
|
|
39
|
+
@strawberry.field
|
|
40
|
+
async def experiment_id(self, info: Info[Context, None]) -> GlobalID:
|
|
41
|
+
from .Experiment import Experiment
|
|
42
|
+
|
|
43
|
+
if self.db_record:
|
|
44
|
+
experiment_id = self.db_record.experiment_id
|
|
45
|
+
else:
|
|
46
|
+
experiment_id = await info.context.data_loaders.experiment_run_fields.load(
|
|
47
|
+
(self.id, models.ExperimentRun.experiment_id),
|
|
48
|
+
)
|
|
49
|
+
return GlobalID(Experiment.__name__, str(experiment_id))
|
|
50
|
+
|
|
51
|
+
@strawberry.field
|
|
52
|
+
async def repetition_number(self, info: Info[Context, None]) -> int:
|
|
53
|
+
if self.db_record:
|
|
54
|
+
val = self.db_record.repetition_number
|
|
55
|
+
else:
|
|
56
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
57
|
+
(self.id, models.ExperimentRun.repetition_number),
|
|
58
|
+
)
|
|
59
|
+
return val
|
|
60
|
+
|
|
61
|
+
@strawberry.field
|
|
62
|
+
async def trace_id(self, info: Info[Context, None]) -> Optional[str]:
|
|
63
|
+
if self.db_record:
|
|
64
|
+
val = self.db_record.trace_id
|
|
65
|
+
else:
|
|
66
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
67
|
+
(self.id, models.ExperimentRun.trace_id),
|
|
68
|
+
)
|
|
69
|
+
return val
|
|
70
|
+
|
|
71
|
+
@strawberry.field
|
|
72
|
+
async def output(self, info: Info[Context, None]) -> Optional[JSON]:
|
|
73
|
+
if self.db_record:
|
|
74
|
+
output_dict = self.db_record.output
|
|
75
|
+
else:
|
|
76
|
+
output_dict = await info.context.data_loaders.experiment_run_fields.load(
|
|
77
|
+
(self.id, models.ExperimentRun.output),
|
|
78
|
+
)
|
|
79
|
+
return output_dict.get("task_output") if output_dict else None
|
|
80
|
+
|
|
81
|
+
@strawberry.field
|
|
82
|
+
async def start_time(self, info: Info[Context, None]) -> datetime:
|
|
83
|
+
if self.db_record:
|
|
84
|
+
val = self.db_record.start_time
|
|
85
|
+
else:
|
|
86
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
87
|
+
(self.id, models.ExperimentRun.start_time),
|
|
88
|
+
)
|
|
89
|
+
return val
|
|
90
|
+
|
|
91
|
+
@strawberry.field
|
|
92
|
+
async def end_time(self, info: Info[Context, None]) -> datetime:
|
|
93
|
+
if self.db_record:
|
|
94
|
+
val = self.db_record.end_time
|
|
95
|
+
else:
|
|
96
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
97
|
+
(self.id, models.ExperimentRun.end_time),
|
|
98
|
+
)
|
|
99
|
+
return val
|
|
100
|
+
|
|
101
|
+
@strawberry.field
|
|
102
|
+
async def error(self, info: Info[Context, None]) -> Optional[str]:
|
|
103
|
+
if self.db_record:
|
|
104
|
+
val = self.db_record.error
|
|
105
|
+
else:
|
|
106
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
107
|
+
(self.id, models.ExperimentRun.error),
|
|
108
|
+
)
|
|
109
|
+
return val
|
|
110
|
+
|
|
111
|
+
@strawberry.field
|
|
112
|
+
async def latency_ms(self, info: Info[Context, None]) -> float:
|
|
113
|
+
if self.db_record:
|
|
114
|
+
val = self.db_record.latency_ms
|
|
115
|
+
else:
|
|
116
|
+
val = await info.context.data_loaders.experiment_run_fields.load(
|
|
117
|
+
(self.id, models.ExperimentRun.latency_ms),
|
|
118
|
+
)
|
|
119
|
+
return val
|
|
42
120
|
|
|
43
121
|
@strawberry.field
|
|
44
122
|
async def annotations(
|
|
@@ -55,57 +133,49 @@ class ExperimentRun(Node):
|
|
|
55
133
|
last=last,
|
|
56
134
|
before=before if isinstance(before, CursorString) else None,
|
|
57
135
|
)
|
|
58
|
-
|
|
59
|
-
annotations = await info.context.data_loaders.experiment_run_annotations.load(run_id)
|
|
136
|
+
annotations = await info.context.data_loaders.experiment_run_annotations.load(self.id)
|
|
60
137
|
return connection_from_list(
|
|
61
|
-
[
|
|
138
|
+
[
|
|
139
|
+
ExperimentRunAnnotation(id=annotation.id, db_record=annotation)
|
|
140
|
+
for annotation in annotations
|
|
141
|
+
],
|
|
142
|
+
args,
|
|
62
143
|
)
|
|
63
144
|
|
|
64
145
|
@strawberry.field
|
|
65
|
-
async def trace(
|
|
66
|
-
|
|
146
|
+
async def trace(
|
|
147
|
+
self, info: Info[Context, None]
|
|
148
|
+
) -> Optional[Annotated["Trace", strawberry.lazy(".Trace")]]:
|
|
149
|
+
if self.db_record:
|
|
150
|
+
trace_id = self.db_record.trace_id
|
|
151
|
+
else:
|
|
152
|
+
trace_id = await info.context.data_loaders.experiment_run_fields.load(
|
|
153
|
+
(self.id, models.ExperimentRun.trace_id),
|
|
154
|
+
)
|
|
155
|
+
if not trace_id:
|
|
67
156
|
return None
|
|
68
|
-
|
|
69
|
-
if (trace := await
|
|
157
|
+
loader = info.context.data_loaders.trace_by_trace_ids
|
|
158
|
+
if (trace := await loader.load(trace_id)) is None:
|
|
70
159
|
return None
|
|
71
|
-
|
|
160
|
+
from .Trace import Trace
|
|
161
|
+
|
|
162
|
+
return Trace(id=trace.id, db_record=trace)
|
|
72
163
|
|
|
73
164
|
@strawberry.field
|
|
74
165
|
async def example(
|
|
75
|
-
self, info: Info
|
|
166
|
+
self, info: Info[Context, None]
|
|
76
167
|
) -> Annotated[
|
|
77
|
-
"DatasetExample", strawberry.lazy("
|
|
168
|
+
"DatasetExample", strawberry.lazy(".DatasetExample")
|
|
78
169
|
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
79
|
-
from
|
|
170
|
+
from .DatasetExample import DatasetExample
|
|
80
171
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
select(models.DatasetExample, models.Experiment.dataset_version_id)
|
|
85
|
-
.select_from(models.ExperimentRun)
|
|
86
|
-
.join(
|
|
87
|
-
models.DatasetExample,
|
|
88
|
-
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
89
|
-
)
|
|
90
|
-
.join(
|
|
91
|
-
models.Experiment,
|
|
92
|
-
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
93
|
-
)
|
|
94
|
-
.where(models.ExperimentRun.id == self.id_attr)
|
|
95
|
-
.options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
|
|
96
|
-
)
|
|
97
|
-
) is not None
|
|
98
|
-
example, version_id = result.first()
|
|
99
|
-
return DatasetExample(
|
|
100
|
-
id_attr=example.id,
|
|
101
|
-
created_at=example.created_at,
|
|
102
|
-
version_id=version_id,
|
|
103
|
-
)
|
|
172
|
+
loader = info.context.data_loaders.dataset_examples_and_versions_by_experiment_run
|
|
173
|
+
(example, version_id) = await loader.load(self.id)
|
|
174
|
+
return DatasetExample(id=example.id, db_record=example, version_id=version_id)
|
|
104
175
|
|
|
105
176
|
@strawberry.field
|
|
106
177
|
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
107
|
-
|
|
108
|
-
summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(run_id)
|
|
178
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(self.id)
|
|
109
179
|
return SpanCostSummary(
|
|
110
180
|
prompt=CostBreakdown(
|
|
111
181
|
tokens=summary.prompt.tokens,
|
|
@@ -125,8 +195,6 @@ class ExperimentRun(Node):
|
|
|
125
195
|
async def cost_detail_summary_entries(
|
|
126
196
|
self, info: Info[Context, None]
|
|
127
197
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
128
|
-
run_id = self.id_attr
|
|
129
|
-
|
|
130
198
|
stmt = (
|
|
131
199
|
select(
|
|
132
200
|
models.SpanCostDetail.token_type,
|
|
@@ -139,7 +207,7 @@ class ExperimentRun(Node):
|
|
|
139
207
|
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
140
208
|
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
141
209
|
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
142
|
-
.where(models.ExperimentRun.id ==
|
|
210
|
+
.where(models.ExperimentRun.id == self.id)
|
|
143
211
|
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
144
212
|
)
|
|
145
213
|
|
|
@@ -153,21 +221,3 @@ class ExperimentRun(Node):
|
|
|
153
221
|
)
|
|
154
222
|
async for token_type, is_prompt, cost, tokens in data
|
|
155
223
|
]
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
159
|
-
"""
|
|
160
|
-
Converts an ORM experiment run to a GraphQL ExperimentRun.
|
|
161
|
-
"""
|
|
162
|
-
|
|
163
|
-
from phoenix.server.api.types.Experiment import Experiment
|
|
164
|
-
|
|
165
|
-
return ExperimentRun(
|
|
166
|
-
id_attr=run.id,
|
|
167
|
-
experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
|
|
168
|
-
trace_id=run.trace.trace_id if run.trace else None,
|
|
169
|
-
output=run.output.get("task_output"),
|
|
170
|
-
start_time=run.start_time,
|
|
171
|
-
end_time=run.end_time,
|
|
172
|
-
error=run.error,
|
|
173
|
-
)
|
|
@@ -1,56 +1,175 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
+
from math import isfinite
|
|
2
3
|
from typing import Optional
|
|
3
4
|
|
|
4
5
|
import strawberry
|
|
5
6
|
from strawberry import Info
|
|
6
|
-
from strawberry.relay import Node, NodeID
|
|
7
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
7
8
|
from strawberry.scalars import JSON
|
|
8
9
|
|
|
9
10
|
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.context import Context
|
|
10
12
|
from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
|
|
11
13
|
from phoenix.server.api.types.Trace import Trace
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
@strawberry.type
|
|
15
17
|
class ExperimentRunAnnotation(Node):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
18
|
+
id: NodeID[int]
|
|
19
|
+
db_record: strawberry.Private[Optional[models.ExperimentRunAnnotation]] = None
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
if self.db_record and self.id != self.db_record.id:
|
|
23
|
+
raise ValueError("ExperimentRunAnnotation ID mismatch")
|
|
24
|
+
|
|
25
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
26
|
+
async def name(
|
|
27
|
+
self,
|
|
28
|
+
info: Info[Context, None],
|
|
29
|
+
) -> str:
|
|
30
|
+
if self.db_record:
|
|
31
|
+
val = self.db_record.name
|
|
32
|
+
else:
|
|
33
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
34
|
+
(self.id, models.ExperimentRunAnnotation.name),
|
|
35
|
+
)
|
|
36
|
+
return val
|
|
37
|
+
|
|
38
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
39
|
+
async def annotator_kind(
|
|
40
|
+
self,
|
|
41
|
+
info: Info[Context, None],
|
|
42
|
+
) -> ExperimentRunAnnotatorKind:
|
|
43
|
+
if self.db_record:
|
|
44
|
+
val = self.db_record.annotator_kind
|
|
45
|
+
else:
|
|
46
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
47
|
+
(self.id, models.ExperimentRunAnnotation.annotator_kind),
|
|
48
|
+
)
|
|
49
|
+
return ExperimentRunAnnotatorKind(val)
|
|
50
|
+
|
|
51
|
+
@strawberry.field(
|
|
52
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
53
|
+
) # type: ignore
|
|
54
|
+
async def label(
|
|
55
|
+
self,
|
|
56
|
+
info: Info[Context, None],
|
|
57
|
+
) -> Optional[str]:
|
|
58
|
+
if self.db_record:
|
|
59
|
+
val = self.db_record.label
|
|
60
|
+
else:
|
|
61
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
62
|
+
(self.id, models.ExperimentRunAnnotation.label),
|
|
63
|
+
)
|
|
64
|
+
return val
|
|
65
|
+
|
|
66
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
67
|
+
async def score(
|
|
68
|
+
self,
|
|
69
|
+
info: Info[Context, None],
|
|
70
|
+
) -> Optional[float]:
|
|
71
|
+
if self.db_record:
|
|
72
|
+
val = self.db_record.score
|
|
73
|
+
else:
|
|
74
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
75
|
+
(self.id, models.ExperimentRunAnnotation.score),
|
|
76
|
+
)
|
|
77
|
+
return val if val is not None and isfinite(val) else None
|
|
78
|
+
|
|
79
|
+
@strawberry.field(
|
|
80
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
81
|
+
) # type: ignore
|
|
82
|
+
async def explanation(
|
|
83
|
+
self,
|
|
84
|
+
info: Info[Context, None],
|
|
85
|
+
) -> Optional[str]:
|
|
86
|
+
if self.db_record:
|
|
87
|
+
val = self.db_record.explanation
|
|
88
|
+
else:
|
|
89
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
90
|
+
(self.id, models.ExperimentRunAnnotation.explanation),
|
|
91
|
+
)
|
|
92
|
+
return val
|
|
93
|
+
|
|
94
|
+
@strawberry.field(description="Error message if the annotation failed to produce a result.") # type: ignore
|
|
95
|
+
async def error(
|
|
96
|
+
self,
|
|
97
|
+
info: Info[Context, None],
|
|
98
|
+
) -> Optional[str]:
|
|
99
|
+
if self.db_record:
|
|
100
|
+
val = self.db_record.error
|
|
101
|
+
else:
|
|
102
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
103
|
+
(self.id, models.ExperimentRunAnnotation.error),
|
|
104
|
+
)
|
|
105
|
+
return val
|
|
106
|
+
|
|
107
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
108
|
+
async def metadata(
|
|
109
|
+
self,
|
|
110
|
+
info: Info[Context, None],
|
|
111
|
+
) -> JSON:
|
|
112
|
+
if self.db_record:
|
|
113
|
+
val = self.db_record.metadata_
|
|
114
|
+
else:
|
|
115
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
116
|
+
(self.id, models.ExperimentRunAnnotation.metadata_),
|
|
117
|
+
)
|
|
118
|
+
return val
|
|
119
|
+
|
|
120
|
+
@strawberry.field(description="The date and time when the annotation was created.") # type: ignore
|
|
121
|
+
async def start_time(
|
|
122
|
+
self,
|
|
123
|
+
info: Info[Context, None],
|
|
124
|
+
) -> datetime:
|
|
125
|
+
if self.db_record:
|
|
126
|
+
val = self.db_record.start_time
|
|
127
|
+
else:
|
|
128
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
129
|
+
(self.id, models.ExperimentRunAnnotation.start_time),
|
|
130
|
+
)
|
|
131
|
+
return val
|
|
132
|
+
|
|
133
|
+
@strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
|
|
134
|
+
async def end_time(
|
|
135
|
+
self,
|
|
136
|
+
info: Info[Context, None],
|
|
137
|
+
) -> datetime:
|
|
138
|
+
if self.db_record:
|
|
139
|
+
val = self.db_record.end_time
|
|
140
|
+
else:
|
|
141
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
142
|
+
(self.id, models.ExperimentRunAnnotation.end_time),
|
|
143
|
+
)
|
|
144
|
+
return val
|
|
145
|
+
|
|
146
|
+
@strawberry.field(description="The identifier of the trace associated with the annotation.") # type: ignore
|
|
147
|
+
async def trace_id(
|
|
148
|
+
self,
|
|
149
|
+
info: Info[Context, None],
|
|
150
|
+
) -> Optional[GlobalID]:
|
|
151
|
+
if self.db_record:
|
|
152
|
+
val = self.db_record.trace_id
|
|
153
|
+
else:
|
|
154
|
+
val = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
155
|
+
(self.id, models.ExperimentRunAnnotation.trace_id),
|
|
156
|
+
)
|
|
157
|
+
return None if val is None else GlobalID(type_name=Trace.__name__, node_id=val)
|
|
158
|
+
|
|
159
|
+
@strawberry.field(description="The trace associated with the annotation.") # type: ignore
|
|
160
|
+
async def trace(
|
|
161
|
+
self,
|
|
162
|
+
info: Info[Context, None],
|
|
163
|
+
) -> Optional[Trace]:
|
|
164
|
+
if self.db_record:
|
|
165
|
+
trace_id = self.db_record.trace_id
|
|
166
|
+
else:
|
|
167
|
+
trace_id = await info.context.data_loaders.experiment_run_annotation_fields.load(
|
|
168
|
+
(self.id, models.ExperimentRunAnnotation.trace_id),
|
|
169
|
+
)
|
|
170
|
+
if not trace_id:
|
|
31
171
|
return None
|
|
32
172
|
dataloader = info.context.data_loaders.trace_by_trace_ids
|
|
33
|
-
if (trace := await dataloader.load(
|
|
173
|
+
if (trace := await dataloader.load(trace_id)) is None:
|
|
34
174
|
return None
|
|
35
|
-
return Trace(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def to_gql_experiment_run_annotation(
|
|
39
|
-
annotation: models.ExperimentRunAnnotation,
|
|
40
|
-
) -> ExperimentRunAnnotation:
|
|
41
|
-
"""
|
|
42
|
-
Converts an ORM experiment run annotation to a GraphQL ExperimentRunAnnotation.
|
|
43
|
-
"""
|
|
44
|
-
return ExperimentRunAnnotation(
|
|
45
|
-
id_attr=annotation.id,
|
|
46
|
-
name=annotation.name,
|
|
47
|
-
annotator_kind=ExperimentRunAnnotatorKind(annotation.annotator_kind),
|
|
48
|
-
label=annotation.label,
|
|
49
|
-
score=annotation.score,
|
|
50
|
-
explanation=annotation.explanation,
|
|
51
|
-
error=annotation.error,
|
|
52
|
-
metadata=annotation.metadata_,
|
|
53
|
-
start_time=annotation.start_time,
|
|
54
|
-
end_time=annotation.end_time,
|
|
55
|
-
trace_id=annotation.trace_id,
|
|
56
|
-
)
|
|
175
|
+
return Trace(id=trace.id, db_record=trace)
|
|
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional
|
|
|
4
4
|
|
|
5
5
|
import strawberry
|
|
6
6
|
from openinference.semconv.trace import OpenInferenceLLMProviderValues
|
|
7
|
-
from sqlalchemy import inspect
|
|
8
7
|
from strawberry.relay import Node, NodeID
|
|
9
8
|
from strawberry.relay.types import GlobalID
|
|
10
9
|
from strawberry.types import Info
|
|
@@ -37,20 +36,98 @@ CachedCostSummaryKey: TypeAlias = tuple[Optional[ProjectId], TimeRangeKey]
|
|
|
37
36
|
|
|
38
37
|
@strawberry.type
|
|
39
38
|
class GenerativeModel(Node, ModelInterface):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
provider: Optional[str]
|
|
43
|
-
name_pattern: str
|
|
44
|
-
kind: GenerativeModelKind
|
|
45
|
-
created_at: datetime
|
|
46
|
-
updated_at: datetime
|
|
47
|
-
provider_key: Optional[GenerativeProviderKey]
|
|
48
|
-
costs: strawberry.Private[Optional[list[models.TokenPrice]]] = None
|
|
49
|
-
start_time: Optional[datetime] = None
|
|
39
|
+
id: NodeID[int]
|
|
40
|
+
db_record: strawberry.Private[Optional[models.GenerativeModel]] = None
|
|
50
41
|
cached_cost_summary: strawberry.Private[
|
|
51
42
|
Optional[dict[CachedCostSummaryKey, SpanCostSummary]]
|
|
52
43
|
] = None
|
|
53
44
|
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
if self.db_record and self.id != self.db_record.id:
|
|
47
|
+
raise ValueError("GenerativeModel ID mismatch")
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
async def name(self, info: Info[Context, None]) -> str:
|
|
51
|
+
if self.db_record:
|
|
52
|
+
val = self.db_record.name
|
|
53
|
+
else:
|
|
54
|
+
val = await info.context.data_loaders.generative_model_fields.load(
|
|
55
|
+
(self.id, models.GenerativeModel.name),
|
|
56
|
+
)
|
|
57
|
+
return val
|
|
58
|
+
|
|
59
|
+
@strawberry.field
|
|
60
|
+
async def provider(self, info: Info[Context, None]) -> Optional[str]:
|
|
61
|
+
if self.db_record:
|
|
62
|
+
provider = self.db_record.provider
|
|
63
|
+
else:
|
|
64
|
+
provider = await info.context.data_loaders.generative_model_fields.load(
|
|
65
|
+
(self.id, models.GenerativeModel.provider),
|
|
66
|
+
)
|
|
67
|
+
return provider or None
|
|
68
|
+
|
|
69
|
+
@strawberry.field
|
|
70
|
+
async def name_pattern(self, info: Info[Context, None]) -> str:
|
|
71
|
+
if self.db_record:
|
|
72
|
+
pattern = self.db_record.name_pattern.pattern
|
|
73
|
+
else:
|
|
74
|
+
name_pattern_obj = await info.context.data_loaders.generative_model_fields.load(
|
|
75
|
+
(self.id, models.GenerativeModel.name_pattern),
|
|
76
|
+
)
|
|
77
|
+
pattern = name_pattern_obj.pattern
|
|
78
|
+
assert isinstance(pattern, str)
|
|
79
|
+
return pattern
|
|
80
|
+
|
|
81
|
+
@strawberry.field
|
|
82
|
+
async def kind(self, info: Info[Context, None]) -> GenerativeModelKind:
|
|
83
|
+
if self.db_record:
|
|
84
|
+
is_built_in = self.db_record.is_built_in
|
|
85
|
+
else:
|
|
86
|
+
is_built_in = await info.context.data_loaders.generative_model_fields.load(
|
|
87
|
+
(self.id, models.GenerativeModel.is_built_in),
|
|
88
|
+
)
|
|
89
|
+
return GenerativeModelKind.BUILT_IN if is_built_in else GenerativeModelKind.CUSTOM
|
|
90
|
+
|
|
91
|
+
@strawberry.field
|
|
92
|
+
async def created_at(self, info: Info[Context, None]) -> datetime:
|
|
93
|
+
if self.db_record:
|
|
94
|
+
val = self.db_record.created_at
|
|
95
|
+
else:
|
|
96
|
+
val = await info.context.data_loaders.generative_model_fields.load(
|
|
97
|
+
(self.id, models.GenerativeModel.created_at),
|
|
98
|
+
)
|
|
99
|
+
return val
|
|
100
|
+
|
|
101
|
+
@strawberry.field
|
|
102
|
+
async def updated_at(self, info: Info[Context, None]) -> datetime:
|
|
103
|
+
if self.db_record:
|
|
104
|
+
val = self.db_record.updated_at
|
|
105
|
+
else:
|
|
106
|
+
val = await info.context.data_loaders.generative_model_fields.load(
|
|
107
|
+
(self.id, models.GenerativeModel.updated_at),
|
|
108
|
+
)
|
|
109
|
+
return val
|
|
110
|
+
|
|
111
|
+
@strawberry.field
|
|
112
|
+
async def provider_key(self, info: Info[Context, None]) -> Optional[GenerativeProviderKey]:
|
|
113
|
+
if self.db_record:
|
|
114
|
+
provider = self.db_record.provider
|
|
115
|
+
else:
|
|
116
|
+
provider = await info.context.data_loaders.generative_model_fields.load(
|
|
117
|
+
(self.id, models.GenerativeModel.provider),
|
|
118
|
+
)
|
|
119
|
+
return _semconv_provider_to_gql_generative_provider_key(provider) if provider else None
|
|
120
|
+
|
|
121
|
+
@strawberry.field
|
|
122
|
+
async def start_time(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
123
|
+
if self.db_record:
|
|
124
|
+
val = self.db_record.start_time
|
|
125
|
+
else:
|
|
126
|
+
val = await info.context.data_loaders.generative_model_fields.load(
|
|
127
|
+
(self.id, models.GenerativeModel.start_time),
|
|
128
|
+
)
|
|
129
|
+
return val
|
|
130
|
+
|
|
54
131
|
def add_cached_cost_summary(
|
|
55
132
|
self, project_id: Optional[int], time_range: TimeRange, cost_summary: SpanCostSummary
|
|
56
133
|
) -> None:
|
|
@@ -61,11 +138,10 @@ class GenerativeModel(Node, ModelInterface):
|
|
|
61
138
|
self.cached_cost_summary[cache_key] = cost_summary
|
|
62
139
|
|
|
63
140
|
@strawberry.field
|
|
64
|
-
async def token_prices(self) -> list[TokenPrice]:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
for cost in self.costs:
|
|
141
|
+
async def token_prices(self, info: Info[Context, None]) -> list[TokenPrice]:
|
|
142
|
+
costs = await info.context.data_loaders.token_prices_by_model.load(self.id)
|
|
143
|
+
token_prices: list[TokenPrice] = []
|
|
144
|
+
for cost in costs:
|
|
69
145
|
token_prices.append(
|
|
70
146
|
TokenPrice(
|
|
71
147
|
token_type=cost.token_type,
|
|
@@ -100,7 +176,7 @@ class GenerativeModel(Node, ModelInterface):
|
|
|
100
176
|
)
|
|
101
177
|
|
|
102
178
|
loader = info.context.data_loaders.span_cost_summary_by_generative_model
|
|
103
|
-
summary = await loader.load(self.
|
|
179
|
+
summary = await loader.load(self.id)
|
|
104
180
|
return SpanCostSummary(
|
|
105
181
|
prompt=CostBreakdown(
|
|
106
182
|
tokens=summary.prompt.tokens,
|
|
@@ -122,7 +198,7 @@ class GenerativeModel(Node, ModelInterface):
|
|
|
122
198
|
info: Info[Context, None],
|
|
123
199
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
124
200
|
loader = info.context.data_loaders.span_cost_detail_summary_entries_by_generative_model
|
|
125
|
-
summary = await loader.load(self.
|
|
201
|
+
summary = await loader.load(self.id)
|
|
126
202
|
return [
|
|
127
203
|
SpanCostDetailSummaryEntry(
|
|
128
204
|
token_type=entry.token_type,
|
|
@@ -137,30 +213,7 @@ class GenerativeModel(Node, ModelInterface):
|
|
|
137
213
|
|
|
138
214
|
@strawberry.field
|
|
139
215
|
async def last_used_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
140
|
-
|
|
141
|
-
return await info.context.data_loaders.last_used_times_by_generative_model_id.load(model_id)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def to_gql_generative_model(
|
|
145
|
-
model: models.GenerativeModel,
|
|
146
|
-
) -> GenerativeModel:
|
|
147
|
-
costs_are_loaded = isinstance(inspect(model).attrs.token_prices.loaded_value, list)
|
|
148
|
-
name_pattern = model.name_pattern.pattern
|
|
149
|
-
assert isinstance(name_pattern, str)
|
|
150
|
-
return GenerativeModel(
|
|
151
|
-
id_attr=model.id,
|
|
152
|
-
name=model.name,
|
|
153
|
-
provider=model.provider or None,
|
|
154
|
-
name_pattern=name_pattern,
|
|
155
|
-
kind=GenerativeModelKind.BUILT_IN if model.is_built_in else GenerativeModelKind.CUSTOM,
|
|
156
|
-
created_at=model.created_at,
|
|
157
|
-
updated_at=model.updated_at,
|
|
158
|
-
start_time=model.start_time,
|
|
159
|
-
provider_key=_semconv_provider_to_gql_generative_provider_key(model.provider)
|
|
160
|
-
if model.provider
|
|
161
|
-
else None,
|
|
162
|
-
costs=model.token_prices if costs_are_loaded else None,
|
|
163
|
-
)
|
|
216
|
+
return await info.context.data_loaders.last_used_times_by_generative_model_id.load(self.id)
|
|
164
217
|
|
|
165
218
|
|
|
166
219
|
def _semconv_provider_to_gql_generative_provider_key(
|
|
@@ -7,5 +7,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
7
7
|
|
|
8
8
|
@strawberry.interface
|
|
9
9
|
class ModelInterface:
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
@strawberry.field
|
|
11
|
+
async def name(self) -> str:
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
@strawberry.field
|
|
15
|
+
async def provider_key(self) -> Optional[GenerativeProviderKey]:
|
|
16
|
+
raise NotImplementedError
|