arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -49,10 +49,10 @@ from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSor
49
49
  from phoenix.server.api.input_types.PromptFilter import PromptFilter
50
50
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
51
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
52
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
52
+ from phoenix.server.api.types.Dataset import Dataset
53
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
54
- from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
55
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
54
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
55
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
56
56
  from phoenix.server.api.types.Dimension import to_gql_dimension
57
57
  from phoenix.server.api.types.EmbeddingDimension import (
58
58
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -61,7 +61,7 @@ from phoenix.server.api.types.EmbeddingDimension import (
61
61
  to_gql_embedding_dimension,
62
62
  )
63
63
  from phoenix.server.api.types.Event import create_event_id, unpack_event_id
64
- from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
64
+ from phoenix.server.api.types.Experiment import Experiment
65
65
  from phoenix.server.api.types.ExperimentComparison import (
66
66
  ExperimentComparison,
67
67
  )
@@ -69,9 +69,9 @@ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
69
69
  ExperimentRepeatedRunGroup,
70
70
  parse_experiment_repeated_run_group_node_id,
71
71
  )
72
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
72
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
73
73
  from phoenix.server.api.types.Functionality import Functionality
74
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
74
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
75
75
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
76
76
  from phoenix.server.api.types.InferenceModel import InferenceModel
77
77
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
@@ -89,21 +89,21 @@ from phoenix.server.api.types.pagination import (
89
89
  )
90
90
  from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
91
91
  from phoenix.server.api.types.Project import Project
92
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
92
+ from phoenix.server.api.types.ProjectSession import ProjectSession
93
93
  from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
94
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
95
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
94
+ from phoenix.server.api.types.Prompt import Prompt
95
+ from phoenix.server.api.types.PromptLabel import PromptLabel
96
96
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
97
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
97
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
98
98
  from phoenix.server.api.types.ServerStatus import ServerStatus
99
99
  from phoenix.server.api.types.SortDir import SortDir
100
100
  from phoenix.server.api.types.Span import Span
101
- from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
101
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
102
102
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
103
103
  from phoenix.server.api.types.Trace import Trace
104
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
105
- from phoenix.server.api.types.User import User, to_gql_user
106
- from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
104
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
105
+ from phoenix.server.api.types.User import User
106
+ from phoenix.server.api.types.UserApiKey import UserApiKey
107
107
  from phoenix.server.api.types.UserRole import UserRole
108
108
  from phoenix.server.api.types.ValidationResult import ValidationResult
109
109
 
@@ -208,9 +208,8 @@ class Query:
208
208
  models.GenerativeModel.provider.nullslast(),
209
209
  models.GenerativeModel.name,
210
210
  )
211
- .options(joinedload(models.GenerativeModel.token_prices))
212
211
  )
213
- data = [to_gql_generative_model(model) for model in result.unique()]
212
+ data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
214
213
  return connection_from_list(data=data, args=args)
215
214
 
216
215
  @strawberry.field
@@ -218,7 +217,7 @@ class Query:
218
217
  if input is not None and input.provider_key is not None:
219
218
  supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
220
219
  supported_models = [
221
- PlaygroundModel(name=model_name, provider_key=input.provider_key)
220
+ PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
222
221
  for model_name in supported_model_names
223
222
  ]
224
223
  return supported_models
@@ -227,7 +226,9 @@ class Query:
227
226
  all_models: list[PlaygroundModel] = []
228
227
  for provider_key, model_name in registered_models:
229
228
  if model_name is not None and provider_key is not None:
230
- all_models.append(PlaygroundModel(name=model_name, provider_key=provider_key))
229
+ all_models.append(
230
+ PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
231
+ )
231
232
  return all_models
232
233
 
233
234
  @strawberry.field
@@ -271,7 +272,7 @@ class Query:
271
272
  )
272
273
  async with info.context.db() as session:
273
274
  users = await session.stream_scalars(stmt)
274
- data = [to_gql_user(user) async for user in users]
275
+ data = [User(id=user.id, db_record=user) async for user in users]
275
276
  return connection_from_list(data=data, args=args)
276
277
 
277
278
  @strawberry.field
@@ -301,7 +302,7 @@ class Query:
301
302
  )
302
303
  async with info.context.db() as session:
303
304
  api_keys = await session.scalars(stmt)
304
- return [to_gql_api_key(api_key) for api_key in api_keys]
305
+ return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
305
306
 
306
307
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
307
308
  async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
@@ -313,16 +314,7 @@ class Query:
313
314
  )
314
315
  async with info.context.db() as session:
315
316
  api_keys = await session.scalars(stmt)
316
- return [
317
- SystemApiKey(
318
- id_attr=api_key.id,
319
- name=api_key.name,
320
- description=api_key.description,
321
- created_at=api_key.created_at,
322
- expires_at=api_key.expires_at,
323
- )
324
- for api_key in api_keys
325
- ]
317
+ return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
326
318
 
327
319
  @strawberry.field
328
320
  async def projects(
@@ -363,13 +355,7 @@ class Query:
363
355
  stmt = exclude_experiment_projects(stmt)
364
356
  async with info.context.db() as session:
365
357
  projects = await session.stream_scalars(stmt)
366
- data = [
367
- Project(
368
- project_rowid=project.id,
369
- db_project=project,
370
- )
371
- async for project in projects
372
- ]
358
+ data = [Project(id=project.id, db_record=project) async for project in projects]
373
359
  return connection_from_list(data=data, args=args)
374
360
 
375
361
  @strawberry.field
@@ -430,7 +416,7 @@ class Query:
430
416
  async with info.context.db() as session:
431
417
  datasets = await session.scalars(stmt)
432
418
  return connection_from_list(
433
- data=[to_gql_dataset(dataset) for dataset in datasets], args=args
419
+ data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
434
420
  )
435
421
 
436
422
  @strawberry.field
@@ -555,10 +541,11 @@ class Query:
555
541
  ExperimentRepeatedRunGroup(
556
542
  experiment_rowid=experiment_id,
557
543
  dataset_example_rowid=example.id,
558
- runs=[
559
- to_gql_experiment_run(run)
544
+ cached_runs=[
545
+ ExperimentRun(id=run.id, db_record=run)
560
546
  for run in sorted(
561
- runs[example.id][experiment_id], key=lambda run: run.id
547
+ runs[example.id][experiment_id],
548
+ key=lambda run: run.repetition_number,
562
549
  )
563
550
  ],
564
551
  )
@@ -566,8 +553,8 @@ class Query:
566
553
  experiment_comparison = ExperimentComparison(
567
554
  id_attr=example.id,
568
555
  example=DatasetExample(
569
- id_attr=example.id,
570
- created_at=example.created_at,
556
+ id=example.id,
557
+ db_record=example,
571
558
  version_id=base_experiment.dataset_version_id,
572
559
  ),
573
560
  repeated_run_groups=repeated_run_groups,
@@ -908,25 +895,9 @@ class Query:
908
895
  )
909
896
  except Exception:
910
897
  raise NotFound(f"Unknown node: {id}")
911
-
912
- async with info.context.db() as session:
913
- runs = (
914
- await session.scalars(
915
- select(models.ExperimentRun)
916
- .where(models.ExperimentRun.experiment_id == experiment_rowid)
917
- .where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
918
- .order_by(models.ExperimentRun.repetition_number.asc())
919
- .options(
920
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
921
- )
922
- )
923
- ).all()
924
- if not runs:
925
- raise NotFound(f"Unknown experiment or dataset example: {id}")
926
898
  return ExperimentRepeatedRunGroup(
927
899
  experiment_rowid=experiment_rowid,
928
900
  dataset_example_rowid=dataset_example_rowid,
929
- runs=[to_gql_experiment_run(run) for run in runs],
930
901
  )
931
902
 
932
903
  global_id = GlobalID.from_id(id)
@@ -937,111 +908,30 @@ class Query:
937
908
  elif type_name == "EmbeddingDimension":
938
909
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
939
910
  return to_gql_embedding_dimension(node_id, embedding_dimension)
940
- elif type_name == "Project":
941
- project_stmt = select(models.Project).filter_by(id=node_id)
942
- async with info.context.db() as session:
943
- project = await session.scalar(project_stmt)
944
- if project is None:
945
- raise NotFound(f"Unknown project: {id}")
946
- return Project(
947
- project_rowid=project.id,
948
- db_project=project,
949
- )
950
- elif type_name == "Trace":
951
- trace_stmt = select(models.Trace).filter_by(id=node_id)
952
- async with info.context.db() as session:
953
- trace = await session.scalar(trace_stmt)
954
- if trace is None:
955
- raise NotFound(f"Unknown trace: {id}")
956
- return Trace(trace_rowid=trace.id, db_trace=trace)
911
+ elif type_name == Project.__name__:
912
+ return Project(id=node_id)
913
+ elif type_name == Trace.__name__:
914
+ return Trace(id=node_id)
957
915
  elif type_name == Span.__name__:
958
- span_stmt = (
959
- select(models.Span)
960
- .options(
961
- joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
962
- )
963
- .where(models.Span.id == node_id)
964
- )
965
- async with info.context.db() as session:
966
- span = await session.scalar(span_stmt)
967
- if span is None:
968
- raise NotFound(f"Unknown span: {id}")
969
- return Span(span_rowid=span.id, db_span=span)
916
+ return Span(id=node_id)
970
917
  elif type_name == Dataset.__name__:
971
- dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
972
- async with info.context.db() as session:
973
- if (dataset := await session.scalar(dataset_stmt)) is None:
974
- raise NotFound(f"Unknown dataset: {id}")
975
- return to_gql_dataset(dataset)
918
+ return Dataset(id=node_id)
976
919
  elif type_name == DatasetExample.__name__:
977
- example_id = node_id
978
- async with info.context.db() as session:
979
- example = await session.scalar(
980
- select(models.DatasetExample).where(models.DatasetExample.id == example_id)
981
- )
982
- if not example:
983
- raise NotFound(f"Unknown dataset example: {id}")
984
- return DatasetExample(
985
- id_attr=example.id,
986
- created_at=example.created_at,
987
- )
920
+ return DatasetExample(id=node_id)
988
921
  elif type_name == DatasetSplit.__name__:
989
- async with info.context.db() as session:
990
- dataset_split = await session.scalar(
991
- select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
992
- )
993
- if not dataset_split:
994
- raise NotFound(f"Unknown dataset split: {id}")
995
- return to_gql_dataset_split(dataset_split)
922
+ return DatasetSplit(id=node_id)
996
923
  elif type_name == Experiment.__name__:
997
- async with info.context.db() as session:
998
- experiment = await session.scalar(
999
- select(models.Experiment).where(models.Experiment.id == node_id)
1000
- )
1001
- if not experiment:
1002
- raise NotFound(f"Unknown experiment: {id}")
1003
- return to_gql_experiment(experiment)
924
+ return Experiment(id=node_id)
1004
925
  elif type_name == ExperimentRun.__name__:
1005
- async with info.context.db() as session:
1006
- if not (
1007
- run := await session.scalar(
1008
- select(models.ExperimentRun)
1009
- .where(models.ExperimentRun.id == node_id)
1010
- .options(
1011
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
1012
- )
1013
- )
1014
- ):
1015
- raise NotFound(f"Unknown experiment run: {id}")
1016
- return to_gql_experiment_run(run)
926
+ return ExperimentRun(id=node_id)
1017
927
  elif type_name == User.__name__:
1018
928
  if int((user := info.context.user).identity) != node_id and not user.is_admin:
1019
929
  raise Unauthorized(MSG_ADMIN_ONLY)
1020
- async with info.context.db() as session:
1021
- if not (
1022
- user := await session.scalar(
1023
- select(models.User).where(models.User.id == node_id)
1024
- )
1025
- ):
1026
- raise NotFound(f"Unknown user: {id}")
1027
- return to_gql_user(user)
930
+ return User(id=node_id)
1028
931
  elif type_name == ProjectSession.__name__:
1029
- async with info.context.db() as session:
1030
- if not (
1031
- project_session := await session.scalar(
1032
- select(models.ProjectSession).filter_by(id=node_id)
1033
- )
1034
- ):
1035
- raise NotFound(f"Unknown user: {id}")
1036
- return to_gql_project_session(project_session)
932
+ return ProjectSession(id=node_id)
1037
933
  elif type_name == Prompt.__name__:
1038
- async with info.context.db() as session:
1039
- if orm_prompt := await session.scalar(
1040
- select(models.Prompt).where(models.Prompt.id == node_id)
1041
- ):
1042
- return to_gql_prompt_from_orm(orm_prompt)
1043
- else:
1044
- raise NotFound(f"Unknown prompt: {id}")
934
+ return Prompt(id=node_id)
1045
935
  elif type_name == PromptVersion.__name__:
1046
936
  async with info.context.db() as session:
1047
937
  if orm_prompt_version := await session.scalar(
@@ -1051,51 +941,17 @@ class Query:
1051
941
  else:
1052
942
  raise NotFound(f"Unknown prompt version: {id}")
1053
943
  elif type_name == PromptLabel.__name__:
1054
- async with info.context.db() as session:
1055
- if not (
1056
- prompt_label := await session.scalar(
1057
- select(models.PromptLabel).where(models.PromptLabel.id == node_id)
1058
- )
1059
- ):
1060
- raise NotFound(f"Unknown prompt label: {id}")
1061
- return to_gql_prompt_label(prompt_label)
944
+ return PromptLabel(id=node_id)
1062
945
  elif type_name == PromptVersionTag.__name__:
1063
- async with info.context.db() as session:
1064
- if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
1065
- raise NotFound(f"Unknown prompt version tag: {id}")
1066
- return to_gql_prompt_version_tag(prompt_version_tag)
946
+ return PromptVersionTag(id=node_id)
1067
947
  elif type_name == ProjectTraceRetentionPolicy.__name__:
1068
- async with info.context.db() as session:
1069
- db_policy = await session.scalar(
1070
- select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
1071
- )
1072
- if not db_policy:
1073
- raise NotFound(f"Unknown project trace retention policy: {id}")
1074
- return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
948
+ return ProjectTraceRetentionPolicy(id=node_id)
1075
949
  elif type_name == SpanAnnotation.__name__:
1076
- async with info.context.db() as session:
1077
- span_annotation = await session.get(models.SpanAnnotation, node_id)
1078
- if not span_annotation:
1079
- raise NotFound(f"Unknown span annotation: {id}")
1080
- return to_gql_span_annotation(span_annotation)
950
+ return SpanAnnotation(id=node_id)
1081
951
  elif type_name == TraceAnnotation.__name__:
1082
- async with info.context.db() as session:
1083
- trace_annotation = await session.get(models.TraceAnnotation, node_id)
1084
- if not trace_annotation:
1085
- raise NotFound(f"Unknown trace annotation: {id}")
1086
- return to_gql_trace_annotation(trace_annotation)
952
+ return TraceAnnotation(id=node_id)
1087
953
  elif type_name == GenerativeModel.__name__:
1088
- async with info.context.db() as session:
1089
- stmt = (
1090
- select(models.GenerativeModel)
1091
- .where(models.GenerativeModel.deleted_at.is_(None))
1092
- .where(models.GenerativeModel.id == node_id)
1093
- .options(joinedload(models.GenerativeModel.token_prices))
1094
- )
1095
- model = await session.scalar(stmt)
1096
- if not model:
1097
- raise NotFound(f"Unknown model: {id}")
1098
- return to_gql_generative_model(model)
954
+ return GenerativeModel(id=node_id)
1099
955
  raise NotFound(f"Unknown node type: {type_name}")
1100
956
 
1101
957
  @strawberry.field
@@ -1107,16 +963,7 @@ class Query:
1107
963
  return None
1108
964
  if isinstance(user, UnauthenticatedUser):
1109
965
  return None
1110
- async with info.context.db() as session:
1111
- if (
1112
- user := await session.scalar(
1113
- select(models.User)
1114
- .where(models.User.id == int(user.identity))
1115
- .options(joinedload(models.User.role))
1116
- )
1117
- ) is None:
1118
- return None
1119
- return to_gql_user(user)
966
+ return User(id=int(user.identity))
1120
967
 
1121
968
  @strawberry.field
1122
969
  async def prompts(
@@ -1156,7 +1003,9 @@ class Query:
1156
1003
  stmt = stmt.distinct()
1157
1004
  async with info.context.db() as session:
1158
1005
  orm_prompts = await session.stream_scalars(stmt)
1159
- data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
1006
+ data = [
1007
+ Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
1008
+ ]
1160
1009
  return connection_from_list(
1161
1010
  data=data,
1162
1011
  args=args,
@@ -1179,7 +1028,10 @@ class Query:
1179
1028
  )
1180
1029
  async with info.context.db() as session:
1181
1030
  prompt_labels = await session.stream_scalars(select(models.PromptLabel))
1182
- data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
1031
+ data = [
1032
+ PromptLabel(id=prompt_label.id, db_record=prompt_label)
1033
+ async for prompt_label in prompt_labels
1034
+ ]
1183
1035
  return connection_from_list(
1184
1036
  data=data,
1185
1037
  args=args,
@@ -1204,7 +1056,10 @@ class Query:
1204
1056
  dataset_labels = await session.scalars(
1205
1057
  select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
1206
1058
  )
1207
- data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
1059
+ data = [
1060
+ DatasetLabel(id=dataset_label.id, db_record=dataset_label)
1061
+ for dataset_label in dataset_labels
1062
+ ]
1208
1063
  return connection_from_list(data=data, args=args)
1209
1064
 
1210
1065
  @strawberry.field
@@ -1224,7 +1079,7 @@ class Query:
1224
1079
  )
1225
1080
  async with info.context.db() as session:
1226
1081
  splits = await session.stream_scalars(select(models.DatasetSplit))
1227
- data = [to_gql_dataset_split(split) async for split in splits]
1082
+ data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
1228
1083
  return connection_from_list(
1229
1084
  data=data,
1230
1085
  args=args,
@@ -1495,7 +1350,7 @@ class Query:
1495
1350
  async with info.context.db() as session:
1496
1351
  span_rowid = await session.scalar(stmt)
1497
1352
  if span_rowid:
1498
- return Span(span_rowid=span_rowid)
1353
+ return Span(id=span_rowid)
1499
1354
  return None
1500
1355
 
1501
1356
  @strawberry.field
@@ -1508,7 +1363,7 @@ class Query:
1508
1363
  async with info.context.db() as session:
1509
1364
  trace_rowid = await session.scalar(stmt)
1510
1365
  if trace_rowid:
1511
- return Trace(trace_rowid=trace_rowid)
1366
+ return Trace(id=trace_rowid)
1512
1367
  return None
1513
1368
 
1514
1369
  @strawberry.field
@@ -1521,7 +1376,7 @@ class Query:
1521
1376
  async with info.context.db() as session:
1522
1377
  session_row = await session.scalar(stmt)
1523
1378
  if session_row:
1524
- return to_gql_project_session(session_row)
1379
+ return ProjectSession(id=session_row.id, db_record=session_row)
1525
1380
  return None
1526
1381
 
1527
1382
 
@@ -64,7 +64,7 @@ from phoenix.server.api.types.Dataset import Dataset
64
64
  from phoenix.server.api.types.DatasetExample import DatasetExample
65
65
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
66
66
  from phoenix.server.api.types.Experiment import to_gql_experiment
67
- from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
67
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
68
68
  from phoenix.server.api.types.node import from_global_id_with_expected_type
69
69
  from phoenix.server.api.types.Span import Span
70
70
  from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
@@ -194,7 +194,7 @@ class Subscription:
194
194
  session.add(span_cost)
195
195
 
196
196
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
197
- yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
197
+ yield ChatCompletionSubscriptionResult(span=Span(id=db_span.id, db_record=db_span))
198
198
 
199
199
  @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
200
200
  async def chat_completion_over_dataset(
@@ -528,8 +528,8 @@ async def _chat_completion_result_payloads(
528
528
  await session.flush()
529
529
  for example_id, span, run in results:
530
530
  yield ChatCompletionSubscriptionResult(
531
- span=Span(span_rowid=span.id, db_span=span) if span else None,
532
- experiment_run=to_gql_experiment_run(run),
531
+ span=Span(id=span.id, db_record=span) if span else None,
532
+ experiment_run=ExperimentRun(id=run.id, db_record=run),
533
533
  dataset_example_id=example_id,
534
534
  repetition_number=run.repetition_number,
535
535
  )
@@ -1,31 +1,98 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry.scalars import JSON
6
+ from strawberry.types import Info
5
7
 
6
- from phoenix.server.api.interceptor import GqlValueMediator
8
+ from phoenix.server.api.context import Context
9
+
10
+ from .AnnotationSource import AnnotationSource
11
+ from .AnnotatorKind import AnnotatorKind
12
+
13
+ if TYPE_CHECKING:
14
+ from .User import User
7
15
 
8
16
 
9
17
  @strawberry.interface
10
18
  class Annotation:
11
- name: str = strawberry.field(
12
- description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
13
- )
14
- score: Optional[float] = strawberry.field(
15
- description="Value of the annotation in the form of a numeric score.",
16
- default=GqlValueMediator(),
17
- )
18
- label: Optional[str] = strawberry.field(
19
- description="Value of the annotation in the form of a string, e.g. "
20
- "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
21
- )
22
- explanation: Optional[str] = strawberry.field(
23
- description="The annotator's explanation for the annotation result (i.e. "
24
- "score or label, or both) given to the subject."
25
- )
26
- created_at: datetime = strawberry.field(
27
- description="The date and time when the annotation was created."
28
- )
29
- updated_at: datetime = strawberry.field(
30
- description="The date and time when the annotation was last updated."
31
- )
19
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
20
+ async def name(
21
+ self,
22
+ info: Info[Context, None],
23
+ ) -> str:
24
+ raise NotImplementedError
25
+
26
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
27
+ async def annotator_kind(
28
+ self,
29
+ info: Info[Context, None],
30
+ ) -> AnnotatorKind:
31
+ raise NotImplementedError
32
+
33
+ @strawberry.field(
34
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
35
+ ) # type: ignore
36
+ async def label(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> Optional[str]:
40
+ raise NotImplementedError
41
+
42
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
43
+ async def score(
44
+ self,
45
+ info: Info[Context, None],
46
+ ) -> Optional[float]:
47
+ raise NotImplementedError
48
+
49
+ @strawberry.field(
50
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
51
+ ) # type: ignore
52
+ async def explanation(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> Optional[str]:
56
+ raise NotImplementedError
57
+
58
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
59
+ async def metadata(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> JSON:
63
+ raise NotImplementedError
64
+
65
+ @strawberry.field(description="The source of the annotation.") # type: ignore
66
+ async def source(
67
+ self,
68
+ info: Info[Context, None],
69
+ ) -> AnnotationSource:
70
+ raise NotImplementedError
71
+
72
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
73
+ async def identifier(
74
+ self,
75
+ info: Info[Context, None],
76
+ ) -> str:
77
+ raise NotImplementedError
78
+
79
+ @strawberry.field(description="The date and time the annotation was created.") # type: ignore
80
+ async def created_at(
81
+ self,
82
+ info: Info[Context, None],
83
+ ) -> datetime:
84
+ raise NotImplementedError
85
+
86
+ @strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
87
+ async def updated_at(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> datetime:
91
+ raise NotImplementedError
92
+
93
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
94
+ async def user(
95
+ self,
96
+ info: Info[Context, None],
97
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
98
+ raise NotImplementedError
@@ -3,25 +3,21 @@ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
 
6
- from phoenix.db.models import ApiKey as ORMApiKey
7
-
8
6
 
9
7
  @strawberry.interface
10
8
  class ApiKey:
11
- name: str = strawberry.field(description="Name of the API key.")
12
- description: Optional[str] = strawberry.field(description="Description of the API key.")
13
- created_at: datetime = strawberry.field(
14
- description="The date and time the API key was created."
15
- )
16
- expires_at: Optional[datetime] = strawberry.field(
17
- description="The date and time the API key will expire."
18
- )
9
+ @strawberry.field(description="Name of the API key.") # type: ignore
10
+ async def name(self) -> str:
11
+ raise NotImplementedError
12
+
13
+ @strawberry.field(description="Description of the API key.") # type: ignore
14
+ async def description(self) -> Optional[str]:
15
+ raise NotImplementedError
19
16
 
17
+ @strawberry.field(description="The date and time the API key was created.") # type: ignore
18
+ async def created_at(self) -> datetime:
19
+ raise NotImplementedError
20
20
 
21
- def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
22
- return ApiKey(
23
- name=api_key.name,
24
- description=api_key.description,
25
- created_at=api_key.created_at,
26
- expires_at=api_key.expires_at,
27
- )
21
+ @strawberry.field(description="The date and time the API key will expire.") # type: ignore
22
+ async def expires_at(self) -> Optional[datetime]:
23
+ raise NotImplementedError