arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +65 -210
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/static/.vite/manifest.json +9 -9
- phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
- phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
- phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
- phoenix/version.py +1 -1
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/queries.py
CHANGED
|
@@ -49,10 +49,10 @@ from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSor
|
|
|
49
49
|
from phoenix.server.api.input_types.PromptFilter import PromptFilter
|
|
50
50
|
from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
|
|
51
51
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
52
|
-
from phoenix.server.api.types.Dataset import Dataset
|
|
52
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
53
53
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
54
|
-
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
55
|
-
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
54
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
55
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
56
56
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
57
57
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
58
58
|
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
@@ -61,7 +61,7 @@ from phoenix.server.api.types.EmbeddingDimension import (
|
|
|
61
61
|
to_gql_embedding_dimension,
|
|
62
62
|
)
|
|
63
63
|
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
64
|
-
from phoenix.server.api.types.Experiment import Experiment
|
|
64
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
65
65
|
from phoenix.server.api.types.ExperimentComparison import (
|
|
66
66
|
ExperimentComparison,
|
|
67
67
|
)
|
|
@@ -69,9 +69,9 @@ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
|
69
69
|
ExperimentRepeatedRunGroup,
|
|
70
70
|
parse_experiment_repeated_run_group_node_id,
|
|
71
71
|
)
|
|
72
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
72
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
73
73
|
from phoenix.server.api.types.Functionality import Functionality
|
|
74
|
-
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
74
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
75
75
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
76
76
|
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
77
77
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
@@ -89,21 +89,21 @@ from phoenix.server.api.types.pagination import (
|
|
|
89
89
|
)
|
|
90
90
|
from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
|
|
91
91
|
from phoenix.server.api.types.Project import Project
|
|
92
|
-
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
92
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
93
93
|
from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
|
|
94
|
-
from phoenix.server.api.types.Prompt import Prompt
|
|
95
|
-
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
94
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
95
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
96
96
|
from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
|
|
97
|
-
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
97
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
98
98
|
from phoenix.server.api.types.ServerStatus import ServerStatus
|
|
99
99
|
from phoenix.server.api.types.SortDir import SortDir
|
|
100
100
|
from phoenix.server.api.types.Span import Span
|
|
101
|
-
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
101
|
+
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
102
102
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
103
103
|
from phoenix.server.api.types.Trace import Trace
|
|
104
|
-
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
105
|
-
from phoenix.server.api.types.User import User
|
|
106
|
-
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
104
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
105
|
+
from phoenix.server.api.types.User import User
|
|
106
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
107
107
|
from phoenix.server.api.types.UserRole import UserRole
|
|
108
108
|
from phoenix.server.api.types.ValidationResult import ValidationResult
|
|
109
109
|
|
|
@@ -208,9 +208,8 @@ class Query:
|
|
|
208
208
|
models.GenerativeModel.provider.nullslast(),
|
|
209
209
|
models.GenerativeModel.name,
|
|
210
210
|
)
|
|
211
|
-
.options(joinedload(models.GenerativeModel.token_prices))
|
|
212
211
|
)
|
|
213
|
-
data = [
|
|
212
|
+
data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
|
|
214
213
|
return connection_from_list(data=data, args=args)
|
|
215
214
|
|
|
216
215
|
@strawberry.field
|
|
@@ -218,7 +217,7 @@ class Query:
|
|
|
218
217
|
if input is not None and input.provider_key is not None:
|
|
219
218
|
supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
|
|
220
219
|
supported_models = [
|
|
221
|
-
PlaygroundModel(
|
|
220
|
+
PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
|
|
222
221
|
for model_name in supported_model_names
|
|
223
222
|
]
|
|
224
223
|
return supported_models
|
|
@@ -227,7 +226,9 @@ class Query:
|
|
|
227
226
|
all_models: list[PlaygroundModel] = []
|
|
228
227
|
for provider_key, model_name in registered_models:
|
|
229
228
|
if model_name is not None and provider_key is not None:
|
|
230
|
-
all_models.append(
|
|
229
|
+
all_models.append(
|
|
230
|
+
PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
|
|
231
|
+
)
|
|
231
232
|
return all_models
|
|
232
233
|
|
|
233
234
|
@strawberry.field
|
|
@@ -271,7 +272,7 @@ class Query:
|
|
|
271
272
|
)
|
|
272
273
|
async with info.context.db() as session:
|
|
273
274
|
users = await session.stream_scalars(stmt)
|
|
274
|
-
data = [
|
|
275
|
+
data = [User(id=user.id, db_record=user) async for user in users]
|
|
275
276
|
return connection_from_list(data=data, args=args)
|
|
276
277
|
|
|
277
278
|
@strawberry.field
|
|
@@ -301,7 +302,7 @@ class Query:
|
|
|
301
302
|
)
|
|
302
303
|
async with info.context.db() as session:
|
|
303
304
|
api_keys = await session.scalars(stmt)
|
|
304
|
-
return [
|
|
305
|
+
return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
305
306
|
|
|
306
307
|
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
307
308
|
async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
|
|
@@ -313,16 +314,7 @@ class Query:
|
|
|
313
314
|
)
|
|
314
315
|
async with info.context.db() as session:
|
|
315
316
|
api_keys = await session.scalars(stmt)
|
|
316
|
-
return [
|
|
317
|
-
SystemApiKey(
|
|
318
|
-
id_attr=api_key.id,
|
|
319
|
-
name=api_key.name,
|
|
320
|
-
description=api_key.description,
|
|
321
|
-
created_at=api_key.created_at,
|
|
322
|
-
expires_at=api_key.expires_at,
|
|
323
|
-
)
|
|
324
|
-
for api_key in api_keys
|
|
325
|
-
]
|
|
317
|
+
return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
326
318
|
|
|
327
319
|
@strawberry.field
|
|
328
320
|
async def projects(
|
|
@@ -363,13 +355,7 @@ class Query:
|
|
|
363
355
|
stmt = exclude_experiment_projects(stmt)
|
|
364
356
|
async with info.context.db() as session:
|
|
365
357
|
projects = await session.stream_scalars(stmt)
|
|
366
|
-
data = [
|
|
367
|
-
Project(
|
|
368
|
-
project_rowid=project.id,
|
|
369
|
-
db_project=project,
|
|
370
|
-
)
|
|
371
|
-
async for project in projects
|
|
372
|
-
]
|
|
358
|
+
data = [Project(id=project.id, db_record=project) async for project in projects]
|
|
373
359
|
return connection_from_list(data=data, args=args)
|
|
374
360
|
|
|
375
361
|
@strawberry.field
|
|
@@ -430,7 +416,7 @@ class Query:
|
|
|
430
416
|
async with info.context.db() as session:
|
|
431
417
|
datasets = await session.scalars(stmt)
|
|
432
418
|
return connection_from_list(
|
|
433
|
-
data=[
|
|
419
|
+
data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
|
|
434
420
|
)
|
|
435
421
|
|
|
436
422
|
@strawberry.field
|
|
@@ -555,10 +541,11 @@ class Query:
|
|
|
555
541
|
ExperimentRepeatedRunGroup(
|
|
556
542
|
experiment_rowid=experiment_id,
|
|
557
543
|
dataset_example_rowid=example.id,
|
|
558
|
-
|
|
559
|
-
|
|
544
|
+
cached_runs=[
|
|
545
|
+
ExperimentRun(id=run.id, db_record=run)
|
|
560
546
|
for run in sorted(
|
|
561
|
-
runs[example.id][experiment_id],
|
|
547
|
+
runs[example.id][experiment_id],
|
|
548
|
+
key=lambda run: run.repetition_number,
|
|
562
549
|
)
|
|
563
550
|
],
|
|
564
551
|
)
|
|
@@ -566,8 +553,8 @@ class Query:
|
|
|
566
553
|
experiment_comparison = ExperimentComparison(
|
|
567
554
|
id_attr=example.id,
|
|
568
555
|
example=DatasetExample(
|
|
569
|
-
|
|
570
|
-
|
|
556
|
+
id=example.id,
|
|
557
|
+
db_record=example,
|
|
571
558
|
version_id=base_experiment.dataset_version_id,
|
|
572
559
|
),
|
|
573
560
|
repeated_run_groups=repeated_run_groups,
|
|
@@ -908,25 +895,9 @@ class Query:
|
|
|
908
895
|
)
|
|
909
896
|
except Exception:
|
|
910
897
|
raise NotFound(f"Unknown node: {id}")
|
|
911
|
-
|
|
912
|
-
async with info.context.db() as session:
|
|
913
|
-
runs = (
|
|
914
|
-
await session.scalars(
|
|
915
|
-
select(models.ExperimentRun)
|
|
916
|
-
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
917
|
-
.where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
|
|
918
|
-
.order_by(models.ExperimentRun.repetition_number.asc())
|
|
919
|
-
.options(
|
|
920
|
-
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
921
|
-
)
|
|
922
|
-
)
|
|
923
|
-
).all()
|
|
924
|
-
if not runs:
|
|
925
|
-
raise NotFound(f"Unknown experiment or dataset example: {id}")
|
|
926
898
|
return ExperimentRepeatedRunGroup(
|
|
927
899
|
experiment_rowid=experiment_rowid,
|
|
928
900
|
dataset_example_rowid=dataset_example_rowid,
|
|
929
|
-
runs=[to_gql_experiment_run(run) for run in runs],
|
|
930
901
|
)
|
|
931
902
|
|
|
932
903
|
global_id = GlobalID.from_id(id)
|
|
@@ -937,111 +908,30 @@ class Query:
|
|
|
937
908
|
elif type_name == "EmbeddingDimension":
|
|
938
909
|
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
939
910
|
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
940
|
-
elif type_name ==
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
if project is None:
|
|
945
|
-
raise NotFound(f"Unknown project: {id}")
|
|
946
|
-
return Project(
|
|
947
|
-
project_rowid=project.id,
|
|
948
|
-
db_project=project,
|
|
949
|
-
)
|
|
950
|
-
elif type_name == "Trace":
|
|
951
|
-
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
952
|
-
async with info.context.db() as session:
|
|
953
|
-
trace = await session.scalar(trace_stmt)
|
|
954
|
-
if trace is None:
|
|
955
|
-
raise NotFound(f"Unknown trace: {id}")
|
|
956
|
-
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
911
|
+
elif type_name == Project.__name__:
|
|
912
|
+
return Project(id=node_id)
|
|
913
|
+
elif type_name == Trace.__name__:
|
|
914
|
+
return Trace(id=node_id)
|
|
957
915
|
elif type_name == Span.__name__:
|
|
958
|
-
|
|
959
|
-
select(models.Span)
|
|
960
|
-
.options(
|
|
961
|
-
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
962
|
-
)
|
|
963
|
-
.where(models.Span.id == node_id)
|
|
964
|
-
)
|
|
965
|
-
async with info.context.db() as session:
|
|
966
|
-
span = await session.scalar(span_stmt)
|
|
967
|
-
if span is None:
|
|
968
|
-
raise NotFound(f"Unknown span: {id}")
|
|
969
|
-
return Span(span_rowid=span.id, db_span=span)
|
|
916
|
+
return Span(id=node_id)
|
|
970
917
|
elif type_name == Dataset.__name__:
|
|
971
|
-
|
|
972
|
-
async with info.context.db() as session:
|
|
973
|
-
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
974
|
-
raise NotFound(f"Unknown dataset: {id}")
|
|
975
|
-
return to_gql_dataset(dataset)
|
|
918
|
+
return Dataset(id=node_id)
|
|
976
919
|
elif type_name == DatasetExample.__name__:
|
|
977
|
-
|
|
978
|
-
async with info.context.db() as session:
|
|
979
|
-
example = await session.scalar(
|
|
980
|
-
select(models.DatasetExample).where(models.DatasetExample.id == example_id)
|
|
981
|
-
)
|
|
982
|
-
if not example:
|
|
983
|
-
raise NotFound(f"Unknown dataset example: {id}")
|
|
984
|
-
return DatasetExample(
|
|
985
|
-
id_attr=example.id,
|
|
986
|
-
created_at=example.created_at,
|
|
987
|
-
)
|
|
920
|
+
return DatasetExample(id=node_id)
|
|
988
921
|
elif type_name == DatasetSplit.__name__:
|
|
989
|
-
|
|
990
|
-
dataset_split = await session.scalar(
|
|
991
|
-
select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
|
|
992
|
-
)
|
|
993
|
-
if not dataset_split:
|
|
994
|
-
raise NotFound(f"Unknown dataset split: {id}")
|
|
995
|
-
return to_gql_dataset_split(dataset_split)
|
|
922
|
+
return DatasetSplit(id=node_id)
|
|
996
923
|
elif type_name == Experiment.__name__:
|
|
997
|
-
|
|
998
|
-
experiment = await session.scalar(
|
|
999
|
-
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
1000
|
-
)
|
|
1001
|
-
if not experiment:
|
|
1002
|
-
raise NotFound(f"Unknown experiment: {id}")
|
|
1003
|
-
return to_gql_experiment(experiment)
|
|
924
|
+
return Experiment(id=node_id)
|
|
1004
925
|
elif type_name == ExperimentRun.__name__:
|
|
1005
|
-
|
|
1006
|
-
if not (
|
|
1007
|
-
run := await session.scalar(
|
|
1008
|
-
select(models.ExperimentRun)
|
|
1009
|
-
.where(models.ExperimentRun.id == node_id)
|
|
1010
|
-
.options(
|
|
1011
|
-
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
1012
|
-
)
|
|
1013
|
-
)
|
|
1014
|
-
):
|
|
1015
|
-
raise NotFound(f"Unknown experiment run: {id}")
|
|
1016
|
-
return to_gql_experiment_run(run)
|
|
926
|
+
return ExperimentRun(id=node_id)
|
|
1017
927
|
elif type_name == User.__name__:
|
|
1018
928
|
if int((user := info.context.user).identity) != node_id and not user.is_admin:
|
|
1019
929
|
raise Unauthorized(MSG_ADMIN_ONLY)
|
|
1020
|
-
|
|
1021
|
-
if not (
|
|
1022
|
-
user := await session.scalar(
|
|
1023
|
-
select(models.User).where(models.User.id == node_id)
|
|
1024
|
-
)
|
|
1025
|
-
):
|
|
1026
|
-
raise NotFound(f"Unknown user: {id}")
|
|
1027
|
-
return to_gql_user(user)
|
|
930
|
+
return User(id=node_id)
|
|
1028
931
|
elif type_name == ProjectSession.__name__:
|
|
1029
|
-
|
|
1030
|
-
if not (
|
|
1031
|
-
project_session := await session.scalar(
|
|
1032
|
-
select(models.ProjectSession).filter_by(id=node_id)
|
|
1033
|
-
)
|
|
1034
|
-
):
|
|
1035
|
-
raise NotFound(f"Unknown user: {id}")
|
|
1036
|
-
return to_gql_project_session(project_session)
|
|
932
|
+
return ProjectSession(id=node_id)
|
|
1037
933
|
elif type_name == Prompt.__name__:
|
|
1038
|
-
|
|
1039
|
-
if orm_prompt := await session.scalar(
|
|
1040
|
-
select(models.Prompt).where(models.Prompt.id == node_id)
|
|
1041
|
-
):
|
|
1042
|
-
return to_gql_prompt_from_orm(orm_prompt)
|
|
1043
|
-
else:
|
|
1044
|
-
raise NotFound(f"Unknown prompt: {id}")
|
|
934
|
+
return Prompt(id=node_id)
|
|
1045
935
|
elif type_name == PromptVersion.__name__:
|
|
1046
936
|
async with info.context.db() as session:
|
|
1047
937
|
if orm_prompt_version := await session.scalar(
|
|
@@ -1051,51 +941,17 @@ class Query:
|
|
|
1051
941
|
else:
|
|
1052
942
|
raise NotFound(f"Unknown prompt version: {id}")
|
|
1053
943
|
elif type_name == PromptLabel.__name__:
|
|
1054
|
-
|
|
1055
|
-
if not (
|
|
1056
|
-
prompt_label := await session.scalar(
|
|
1057
|
-
select(models.PromptLabel).where(models.PromptLabel.id == node_id)
|
|
1058
|
-
)
|
|
1059
|
-
):
|
|
1060
|
-
raise NotFound(f"Unknown prompt label: {id}")
|
|
1061
|
-
return to_gql_prompt_label(prompt_label)
|
|
944
|
+
return PromptLabel(id=node_id)
|
|
1062
945
|
elif type_name == PromptVersionTag.__name__:
|
|
1063
|
-
|
|
1064
|
-
if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
|
|
1065
|
-
raise NotFound(f"Unknown prompt version tag: {id}")
|
|
1066
|
-
return to_gql_prompt_version_tag(prompt_version_tag)
|
|
946
|
+
return PromptVersionTag(id=node_id)
|
|
1067
947
|
elif type_name == ProjectTraceRetentionPolicy.__name__:
|
|
1068
|
-
|
|
1069
|
-
db_policy = await session.scalar(
|
|
1070
|
-
select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
|
|
1071
|
-
)
|
|
1072
|
-
if not db_policy:
|
|
1073
|
-
raise NotFound(f"Unknown project trace retention policy: {id}")
|
|
1074
|
-
return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
|
|
948
|
+
return ProjectTraceRetentionPolicy(id=node_id)
|
|
1075
949
|
elif type_name == SpanAnnotation.__name__:
|
|
1076
|
-
|
|
1077
|
-
span_annotation = await session.get(models.SpanAnnotation, node_id)
|
|
1078
|
-
if not span_annotation:
|
|
1079
|
-
raise NotFound(f"Unknown span annotation: {id}")
|
|
1080
|
-
return to_gql_span_annotation(span_annotation)
|
|
950
|
+
return SpanAnnotation(id=node_id)
|
|
1081
951
|
elif type_name == TraceAnnotation.__name__:
|
|
1082
|
-
|
|
1083
|
-
trace_annotation = await session.get(models.TraceAnnotation, node_id)
|
|
1084
|
-
if not trace_annotation:
|
|
1085
|
-
raise NotFound(f"Unknown trace annotation: {id}")
|
|
1086
|
-
return to_gql_trace_annotation(trace_annotation)
|
|
952
|
+
return TraceAnnotation(id=node_id)
|
|
1087
953
|
elif type_name == GenerativeModel.__name__:
|
|
1088
|
-
|
|
1089
|
-
stmt = (
|
|
1090
|
-
select(models.GenerativeModel)
|
|
1091
|
-
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
1092
|
-
.where(models.GenerativeModel.id == node_id)
|
|
1093
|
-
.options(joinedload(models.GenerativeModel.token_prices))
|
|
1094
|
-
)
|
|
1095
|
-
model = await session.scalar(stmt)
|
|
1096
|
-
if not model:
|
|
1097
|
-
raise NotFound(f"Unknown model: {id}")
|
|
1098
|
-
return to_gql_generative_model(model)
|
|
954
|
+
return GenerativeModel(id=node_id)
|
|
1099
955
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
1100
956
|
|
|
1101
957
|
@strawberry.field
|
|
@@ -1107,16 +963,7 @@ class Query:
|
|
|
1107
963
|
return None
|
|
1108
964
|
if isinstance(user, UnauthenticatedUser):
|
|
1109
965
|
return None
|
|
1110
|
-
|
|
1111
|
-
if (
|
|
1112
|
-
user := await session.scalar(
|
|
1113
|
-
select(models.User)
|
|
1114
|
-
.where(models.User.id == int(user.identity))
|
|
1115
|
-
.options(joinedload(models.User.role))
|
|
1116
|
-
)
|
|
1117
|
-
) is None:
|
|
1118
|
-
return None
|
|
1119
|
-
return to_gql_user(user)
|
|
966
|
+
return User(id=int(user.identity))
|
|
1120
967
|
|
|
1121
968
|
@strawberry.field
|
|
1122
969
|
async def prompts(
|
|
@@ -1156,7 +1003,9 @@ class Query:
|
|
|
1156
1003
|
stmt = stmt.distinct()
|
|
1157
1004
|
async with info.context.db() as session:
|
|
1158
1005
|
orm_prompts = await session.stream_scalars(stmt)
|
|
1159
|
-
data = [
|
|
1006
|
+
data = [
|
|
1007
|
+
Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
|
|
1008
|
+
]
|
|
1160
1009
|
return connection_from_list(
|
|
1161
1010
|
data=data,
|
|
1162
1011
|
args=args,
|
|
@@ -1179,7 +1028,10 @@ class Query:
|
|
|
1179
1028
|
)
|
|
1180
1029
|
async with info.context.db() as session:
|
|
1181
1030
|
prompt_labels = await session.stream_scalars(select(models.PromptLabel))
|
|
1182
|
-
data = [
|
|
1031
|
+
data = [
|
|
1032
|
+
PromptLabel(id=prompt_label.id, db_record=prompt_label)
|
|
1033
|
+
async for prompt_label in prompt_labels
|
|
1034
|
+
]
|
|
1183
1035
|
return connection_from_list(
|
|
1184
1036
|
data=data,
|
|
1185
1037
|
args=args,
|
|
@@ -1204,7 +1056,10 @@ class Query:
|
|
|
1204
1056
|
dataset_labels = await session.scalars(
|
|
1205
1057
|
select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
|
|
1206
1058
|
)
|
|
1207
|
-
data = [
|
|
1059
|
+
data = [
|
|
1060
|
+
DatasetLabel(id=dataset_label.id, db_record=dataset_label)
|
|
1061
|
+
for dataset_label in dataset_labels
|
|
1062
|
+
]
|
|
1208
1063
|
return connection_from_list(data=data, args=args)
|
|
1209
1064
|
|
|
1210
1065
|
@strawberry.field
|
|
@@ -1224,7 +1079,7 @@ class Query:
|
|
|
1224
1079
|
)
|
|
1225
1080
|
async with info.context.db() as session:
|
|
1226
1081
|
splits = await session.stream_scalars(select(models.DatasetSplit))
|
|
1227
|
-
data = [
|
|
1082
|
+
data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
|
|
1228
1083
|
return connection_from_list(
|
|
1229
1084
|
data=data,
|
|
1230
1085
|
args=args,
|
|
@@ -1495,7 +1350,7 @@ class Query:
|
|
|
1495
1350
|
async with info.context.db() as session:
|
|
1496
1351
|
span_rowid = await session.scalar(stmt)
|
|
1497
1352
|
if span_rowid:
|
|
1498
|
-
return Span(
|
|
1353
|
+
return Span(id=span_rowid)
|
|
1499
1354
|
return None
|
|
1500
1355
|
|
|
1501
1356
|
@strawberry.field
|
|
@@ -1508,7 +1363,7 @@ class Query:
|
|
|
1508
1363
|
async with info.context.db() as session:
|
|
1509
1364
|
trace_rowid = await session.scalar(stmt)
|
|
1510
1365
|
if trace_rowid:
|
|
1511
|
-
return Trace(
|
|
1366
|
+
return Trace(id=trace_rowid)
|
|
1512
1367
|
return None
|
|
1513
1368
|
|
|
1514
1369
|
@strawberry.field
|
|
@@ -1521,7 +1376,7 @@ class Query:
|
|
|
1521
1376
|
async with info.context.db() as session:
|
|
1522
1377
|
session_row = await session.scalar(stmt)
|
|
1523
1378
|
if session_row:
|
|
1524
|
-
return
|
|
1379
|
+
return ProjectSession(id=session_row.id, db_record=session_row)
|
|
1525
1380
|
return None
|
|
1526
1381
|
|
|
1527
1382
|
|
|
@@ -64,7 +64,7 @@ from phoenix.server.api.types.Dataset import Dataset
|
|
|
64
64
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
65
65
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
66
66
|
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
67
|
-
from phoenix.server.api.types.ExperimentRun import
|
|
67
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
68
68
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
69
69
|
from phoenix.server.api.types.Span import Span
|
|
70
70
|
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
@@ -194,7 +194,7 @@ class Subscription:
|
|
|
194
194
|
session.add(span_cost)
|
|
195
195
|
|
|
196
196
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
197
|
-
yield ChatCompletionSubscriptionResult(span=Span(
|
|
197
|
+
yield ChatCompletionSubscriptionResult(span=Span(id=db_span.id, db_record=db_span))
|
|
198
198
|
|
|
199
199
|
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
200
200
|
async def chat_completion_over_dataset(
|
|
@@ -528,8 +528,8 @@ async def _chat_completion_result_payloads(
|
|
|
528
528
|
await session.flush()
|
|
529
529
|
for example_id, span, run in results:
|
|
530
530
|
yield ChatCompletionSubscriptionResult(
|
|
531
|
-
span=Span(
|
|
532
|
-
experiment_run=
|
|
531
|
+
span=Span(id=span.id, db_record=span) if span else None,
|
|
532
|
+
experiment_run=ExperimentRun(id=run.id, db_record=run),
|
|
533
533
|
dataset_example_id=example_id,
|
|
534
534
|
repetition_number=run.repetition_number,
|
|
535
535
|
)
|
|
@@ -1,31 +1,98 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from strawberry.scalars import JSON
|
|
6
|
+
from strawberry.types import Info
|
|
5
7
|
|
|
6
|
-
from phoenix.server.api.
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
9
|
+
|
|
10
|
+
from .AnnotationSource import AnnotationSource
|
|
11
|
+
from .AnnotatorKind import AnnotatorKind
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .User import User
|
|
7
15
|
|
|
8
16
|
|
|
9
17
|
@strawberry.interface
|
|
10
18
|
class Annotation:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
)
|
|
19
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
20
|
+
async def name(
|
|
21
|
+
self,
|
|
22
|
+
info: Info[Context, None],
|
|
23
|
+
) -> str:
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
27
|
+
async def annotator_kind(
|
|
28
|
+
self,
|
|
29
|
+
info: Info[Context, None],
|
|
30
|
+
) -> AnnotatorKind:
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
@strawberry.field(
|
|
34
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
35
|
+
) # type: ignore
|
|
36
|
+
async def label(
|
|
37
|
+
self,
|
|
38
|
+
info: Info[Context, None],
|
|
39
|
+
) -> Optional[str]:
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
43
|
+
async def score(
|
|
44
|
+
self,
|
|
45
|
+
info: Info[Context, None],
|
|
46
|
+
) -> Optional[float]:
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@strawberry.field(
|
|
50
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
51
|
+
) # type: ignore
|
|
52
|
+
async def explanation(
|
|
53
|
+
self,
|
|
54
|
+
info: Info[Context, None],
|
|
55
|
+
) -> Optional[str]:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
59
|
+
async def metadata(
|
|
60
|
+
self,
|
|
61
|
+
info: Info[Context, None],
|
|
62
|
+
) -> JSON:
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
66
|
+
async def source(
|
|
67
|
+
self,
|
|
68
|
+
info: Info[Context, None],
|
|
69
|
+
) -> AnnotationSource:
|
|
70
|
+
raise NotImplementedError
|
|
71
|
+
|
|
72
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
73
|
+
async def identifier(
|
|
74
|
+
self,
|
|
75
|
+
info: Info[Context, None],
|
|
76
|
+
) -> str:
|
|
77
|
+
raise NotImplementedError
|
|
78
|
+
|
|
79
|
+
@strawberry.field(description="The date and time the annotation was created.") # type: ignore
|
|
80
|
+
async def created_at(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
) -> datetime:
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
@strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
|
|
87
|
+
async def updated_at(
|
|
88
|
+
self,
|
|
89
|
+
info: Info[Context, None],
|
|
90
|
+
) -> datetime:
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
94
|
+
async def user(
|
|
95
|
+
self,
|
|
96
|
+
info: Info[Context, None],
|
|
97
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
98
|
+
raise NotImplementedError
|
|
@@ -3,25 +3,21 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
|
|
6
|
-
from phoenix.db.models import ApiKey as ORMApiKey
|
|
7
|
-
|
|
8
6
|
|
|
9
7
|
@strawberry.interface
|
|
10
8
|
class ApiKey:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
)
|
|
9
|
+
@strawberry.field(description="Name of the API key.") # type: ignore
|
|
10
|
+
async def name(self) -> str:
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
@strawberry.field(description="Description of the API key.") # type: ignore
|
|
14
|
+
async def description(self) -> Optional[str]:
|
|
15
|
+
raise NotImplementedError
|
|
19
16
|
|
|
17
|
+
@strawberry.field(description="The date and time the API key was created.") # type: ignore
|
|
18
|
+
async def created_at(self) -> datetime:
|
|
19
|
+
raise NotImplementedError
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
description=api_key.description,
|
|
25
|
-
created_at=api_key.created_at,
|
|
26
|
-
expires_at=api_key.expires_at,
|
|
27
|
-
)
|
|
21
|
+
@strawberry.field(description="The date and time the API key will expire.") # type: ignore
|
|
22
|
+
async def expires_at(self) -> Optional[datetime]:
|
|
23
|
+
raise NotImplementedError
|