arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (76) hide show
  1. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +80 -213
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
  63. phoenix/server/static/.vite/manifest.json +43 -43
  64. phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
  65. phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
  66. phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
  67. phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
  68. phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
  69. phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
  70. phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
  71. phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
  72. phoenix/version.py +1 -1
  73. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  74. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  75. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  76. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -49,10 +49,10 @@ from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSor
49
49
  from phoenix.server.api.input_types.PromptFilter import PromptFilter
50
50
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
51
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
52
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
52
+ from phoenix.server.api.types.Dataset import Dataset
53
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
54
- from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
55
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
54
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
55
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
56
56
  from phoenix.server.api.types.Dimension import to_gql_dimension
57
57
  from phoenix.server.api.types.EmbeddingDimension import (
58
58
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -61,7 +61,7 @@ from phoenix.server.api.types.EmbeddingDimension import (
61
61
  to_gql_embedding_dimension,
62
62
  )
63
63
  from phoenix.server.api.types.Event import create_event_id, unpack_event_id
64
- from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
64
+ from phoenix.server.api.types.Experiment import Experiment
65
65
  from phoenix.server.api.types.ExperimentComparison import (
66
66
  ExperimentComparison,
67
67
  )
@@ -69,9 +69,9 @@ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
69
69
  ExperimentRepeatedRunGroup,
70
70
  parse_experiment_repeated_run_group_node_id,
71
71
  )
72
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
72
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
73
73
  from phoenix.server.api.types.Functionality import Functionality
74
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
74
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
75
75
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
76
76
  from phoenix.server.api.types.InferenceModel import InferenceModel
77
77
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
@@ -89,21 +89,21 @@ from phoenix.server.api.types.pagination import (
89
89
  )
90
90
  from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
91
91
  from phoenix.server.api.types.Project import Project
92
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
92
+ from phoenix.server.api.types.ProjectSession import ProjectSession
93
93
  from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
94
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
95
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
94
+ from phoenix.server.api.types.Prompt import Prompt
95
+ from phoenix.server.api.types.PromptLabel import PromptLabel
96
96
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
97
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
97
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
98
98
  from phoenix.server.api.types.ServerStatus import ServerStatus
99
99
  from phoenix.server.api.types.SortDir import SortDir
100
100
  from phoenix.server.api.types.Span import Span
101
- from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
101
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
102
102
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
103
103
  from phoenix.server.api.types.Trace import Trace
104
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
105
- from phoenix.server.api.types.User import User, to_gql_user
106
- from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
104
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
105
+ from phoenix.server.api.types.User import User
106
+ from phoenix.server.api.types.UserApiKey import UserApiKey
107
107
  from phoenix.server.api.types.UserRole import UserRole
108
108
  from phoenix.server.api.types.ValidationResult import ValidationResult
109
109
 
@@ -188,7 +188,17 @@ class Query:
188
188
  async def generative_models(
189
189
  self,
190
190
  info: Info[Context, None],
191
- ) -> list[GenerativeModel]:
191
+ first: Optional[int] = 50,
192
+ last: Optional[int] = UNSET,
193
+ after: Optional[CursorString] = UNSET,
194
+ before: Optional[CursorString] = UNSET,
195
+ ) -> Connection[GenerativeModel]:
196
+ args = ConnectionArgs(
197
+ first=first,
198
+ after=after if isinstance(after, CursorString) else None,
199
+ last=last,
200
+ before=before if isinstance(before, CursorString) else None,
201
+ )
192
202
  async with info.context.db() as session:
193
203
  result = await session.scalars(
194
204
  select(models.GenerativeModel)
@@ -198,17 +208,16 @@ class Query:
198
208
  models.GenerativeModel.provider.nullslast(),
199
209
  models.GenerativeModel.name,
200
210
  )
201
- .options(joinedload(models.GenerativeModel.token_prices))
202
211
  )
203
-
204
- return [to_gql_generative_model(model) for model in result.unique()]
212
+ data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
213
+ return connection_from_list(data=data, args=args)
205
214
 
206
215
  @strawberry.field
207
216
  async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
208
217
  if input is not None and input.provider_key is not None:
209
218
  supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
210
219
  supported_models = [
211
- PlaygroundModel(name=model_name, provider_key=input.provider_key)
220
+ PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
212
221
  for model_name in supported_model_names
213
222
  ]
214
223
  return supported_models
@@ -217,7 +226,9 @@ class Query:
217
226
  all_models: list[PlaygroundModel] = []
218
227
  for provider_key, model_name in registered_models:
219
228
  if model_name is not None and provider_key is not None:
220
- all_models.append(PlaygroundModel(name=model_name, provider_key=provider_key))
229
+ all_models.append(
230
+ PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
231
+ )
221
232
  return all_models
222
233
 
223
234
  @strawberry.field
@@ -261,7 +272,7 @@ class Query:
261
272
  )
262
273
  async with info.context.db() as session:
263
274
  users = await session.stream_scalars(stmt)
264
- data = [to_gql_user(user) async for user in users]
275
+ data = [User(id=user.id, db_record=user) async for user in users]
265
276
  return connection_from_list(data=data, args=args)
266
277
 
267
278
  @strawberry.field
@@ -291,7 +302,7 @@ class Query:
291
302
  )
292
303
  async with info.context.db() as session:
293
304
  api_keys = await session.scalars(stmt)
294
- return [to_gql_api_key(api_key) for api_key in api_keys]
305
+ return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
295
306
 
296
307
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
297
308
  async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
@@ -303,16 +314,7 @@ class Query:
303
314
  )
304
315
  async with info.context.db() as session:
305
316
  api_keys = await session.scalars(stmt)
306
- return [
307
- SystemApiKey(
308
- id_attr=api_key.id,
309
- name=api_key.name,
310
- description=api_key.description,
311
- created_at=api_key.created_at,
312
- expires_at=api_key.expires_at,
313
- )
314
- for api_key in api_keys
315
- ]
317
+ return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
316
318
 
317
319
  @strawberry.field
318
320
  async def projects(
@@ -353,13 +355,7 @@ class Query:
353
355
  stmt = exclude_experiment_projects(stmt)
354
356
  async with info.context.db() as session:
355
357
  projects = await session.stream_scalars(stmt)
356
- data = [
357
- Project(
358
- project_rowid=project.id,
359
- db_project=project,
360
- )
361
- async for project in projects
362
- ]
358
+ data = [Project(id=project.id, db_record=project) async for project in projects]
363
359
  return connection_from_list(data=data, args=args)
364
360
 
365
361
  @strawberry.field
@@ -420,7 +416,7 @@ class Query:
420
416
  async with info.context.db() as session:
421
417
  datasets = await session.scalars(stmt)
422
418
  return connection_from_list(
423
- data=[to_gql_dataset(dataset) for dataset in datasets], args=args
419
+ data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
424
420
  )
425
421
 
426
422
  @strawberry.field
@@ -545,10 +541,11 @@ class Query:
545
541
  ExperimentRepeatedRunGroup(
546
542
  experiment_rowid=experiment_id,
547
543
  dataset_example_rowid=example.id,
548
- runs=[
549
- to_gql_experiment_run(run)
544
+ cached_runs=[
545
+ ExperimentRun(id=run.id, db_record=run)
550
546
  for run in sorted(
551
- runs[example.id][experiment_id], key=lambda run: run.id
547
+ runs[example.id][experiment_id],
548
+ key=lambda run: run.repetition_number,
552
549
  )
553
550
  ],
554
551
  )
@@ -556,8 +553,8 @@ class Query:
556
553
  experiment_comparison = ExperimentComparison(
557
554
  id_attr=example.id,
558
555
  example=DatasetExample(
559
- id_attr=example.id,
560
- created_at=example.created_at,
556
+ id=example.id,
557
+ db_record=example,
561
558
  version_id=base_experiment.dataset_version_id,
562
559
  ),
563
560
  repeated_run_groups=repeated_run_groups,
@@ -898,25 +895,9 @@ class Query:
898
895
  )
899
896
  except Exception:
900
897
  raise NotFound(f"Unknown node: {id}")
901
-
902
- async with info.context.db() as session:
903
- runs = (
904
- await session.scalars(
905
- select(models.ExperimentRun)
906
- .where(models.ExperimentRun.experiment_id == experiment_rowid)
907
- .where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
908
- .order_by(models.ExperimentRun.repetition_number.asc())
909
- .options(
910
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
911
- )
912
- )
913
- ).all()
914
- if not runs:
915
- raise NotFound(f"Unknown experiment or dataset example: {id}")
916
898
  return ExperimentRepeatedRunGroup(
917
899
  experiment_rowid=experiment_rowid,
918
900
  dataset_example_rowid=dataset_example_rowid,
919
- runs=[to_gql_experiment_run(run) for run in runs],
920
901
  )
921
902
 
922
903
  global_id = GlobalID.from_id(id)
@@ -927,111 +908,30 @@ class Query:
927
908
  elif type_name == "EmbeddingDimension":
928
909
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
929
910
  return to_gql_embedding_dimension(node_id, embedding_dimension)
930
- elif type_name == "Project":
931
- project_stmt = select(models.Project).filter_by(id=node_id)
932
- async with info.context.db() as session:
933
- project = await session.scalar(project_stmt)
934
- if project is None:
935
- raise NotFound(f"Unknown project: {id}")
936
- return Project(
937
- project_rowid=project.id,
938
- db_project=project,
939
- )
940
- elif type_name == "Trace":
941
- trace_stmt = select(models.Trace).filter_by(id=node_id)
942
- async with info.context.db() as session:
943
- trace = await session.scalar(trace_stmt)
944
- if trace is None:
945
- raise NotFound(f"Unknown trace: {id}")
946
- return Trace(trace_rowid=trace.id, db_trace=trace)
911
+ elif type_name == Project.__name__:
912
+ return Project(id=node_id)
913
+ elif type_name == Trace.__name__:
914
+ return Trace(id=node_id)
947
915
  elif type_name == Span.__name__:
948
- span_stmt = (
949
- select(models.Span)
950
- .options(
951
- joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
952
- )
953
- .where(models.Span.id == node_id)
954
- )
955
- async with info.context.db() as session:
956
- span = await session.scalar(span_stmt)
957
- if span is None:
958
- raise NotFound(f"Unknown span: {id}")
959
- return Span(span_rowid=span.id, db_span=span)
916
+ return Span(id=node_id)
960
917
  elif type_name == Dataset.__name__:
961
- dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
962
- async with info.context.db() as session:
963
- if (dataset := await session.scalar(dataset_stmt)) is None:
964
- raise NotFound(f"Unknown dataset: {id}")
965
- return to_gql_dataset(dataset)
918
+ return Dataset(id=node_id)
966
919
  elif type_name == DatasetExample.__name__:
967
- example_id = node_id
968
- async with info.context.db() as session:
969
- example = await session.scalar(
970
- select(models.DatasetExample).where(models.DatasetExample.id == example_id)
971
- )
972
- if not example:
973
- raise NotFound(f"Unknown dataset example: {id}")
974
- return DatasetExample(
975
- id_attr=example.id,
976
- created_at=example.created_at,
977
- )
920
+ return DatasetExample(id=node_id)
978
921
  elif type_name == DatasetSplit.__name__:
979
- async with info.context.db() as session:
980
- dataset_split = await session.scalar(
981
- select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
982
- )
983
- if not dataset_split:
984
- raise NotFound(f"Unknown dataset split: {id}")
985
- return to_gql_dataset_split(dataset_split)
922
+ return DatasetSplit(id=node_id)
986
923
  elif type_name == Experiment.__name__:
987
- async with info.context.db() as session:
988
- experiment = await session.scalar(
989
- select(models.Experiment).where(models.Experiment.id == node_id)
990
- )
991
- if not experiment:
992
- raise NotFound(f"Unknown experiment: {id}")
993
- return to_gql_experiment(experiment)
924
+ return Experiment(id=node_id)
994
925
  elif type_name == ExperimentRun.__name__:
995
- async with info.context.db() as session:
996
- if not (
997
- run := await session.scalar(
998
- select(models.ExperimentRun)
999
- .where(models.ExperimentRun.id == node_id)
1000
- .options(
1001
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
1002
- )
1003
- )
1004
- ):
1005
- raise NotFound(f"Unknown experiment run: {id}")
1006
- return to_gql_experiment_run(run)
926
+ return ExperimentRun(id=node_id)
1007
927
  elif type_name == User.__name__:
1008
928
  if int((user := info.context.user).identity) != node_id and not user.is_admin:
1009
929
  raise Unauthorized(MSG_ADMIN_ONLY)
1010
- async with info.context.db() as session:
1011
- if not (
1012
- user := await session.scalar(
1013
- select(models.User).where(models.User.id == node_id)
1014
- )
1015
- ):
1016
- raise NotFound(f"Unknown user: {id}")
1017
- return to_gql_user(user)
930
+ return User(id=node_id)
1018
931
  elif type_name == ProjectSession.__name__:
1019
- async with info.context.db() as session:
1020
- if not (
1021
- project_session := await session.scalar(
1022
- select(models.ProjectSession).filter_by(id=node_id)
1023
- )
1024
- ):
1025
- raise NotFound(f"Unknown user: {id}")
1026
- return to_gql_project_session(project_session)
932
+ return ProjectSession(id=node_id)
1027
933
  elif type_name == Prompt.__name__:
1028
- async with info.context.db() as session:
1029
- if orm_prompt := await session.scalar(
1030
- select(models.Prompt).where(models.Prompt.id == node_id)
1031
- ):
1032
- return to_gql_prompt_from_orm(orm_prompt)
1033
- else:
1034
- raise NotFound(f"Unknown prompt: {id}")
934
+ return Prompt(id=node_id)
1035
935
  elif type_name == PromptVersion.__name__:
1036
936
  async with info.context.db() as session:
1037
937
  if orm_prompt_version := await session.scalar(
@@ -1041,51 +941,17 @@ class Query:
1041
941
  else:
1042
942
  raise NotFound(f"Unknown prompt version: {id}")
1043
943
  elif type_name == PromptLabel.__name__:
1044
- async with info.context.db() as session:
1045
- if not (
1046
- prompt_label := await session.scalar(
1047
- select(models.PromptLabel).where(models.PromptLabel.id == node_id)
1048
- )
1049
- ):
1050
- raise NotFound(f"Unknown prompt label: {id}")
1051
- return to_gql_prompt_label(prompt_label)
944
+ return PromptLabel(id=node_id)
1052
945
  elif type_name == PromptVersionTag.__name__:
1053
- async with info.context.db() as session:
1054
- if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
1055
- raise NotFound(f"Unknown prompt version tag: {id}")
1056
- return to_gql_prompt_version_tag(prompt_version_tag)
946
+ return PromptVersionTag(id=node_id)
1057
947
  elif type_name == ProjectTraceRetentionPolicy.__name__:
1058
- async with info.context.db() as session:
1059
- db_policy = await session.scalar(
1060
- select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
1061
- )
1062
- if not db_policy:
1063
- raise NotFound(f"Unknown project trace retention policy: {id}")
1064
- return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
948
+ return ProjectTraceRetentionPolicy(id=node_id)
1065
949
  elif type_name == SpanAnnotation.__name__:
1066
- async with info.context.db() as session:
1067
- span_annotation = await session.get(models.SpanAnnotation, node_id)
1068
- if not span_annotation:
1069
- raise NotFound(f"Unknown span annotation: {id}")
1070
- return to_gql_span_annotation(span_annotation)
950
+ return SpanAnnotation(id=node_id)
1071
951
  elif type_name == TraceAnnotation.__name__:
1072
- async with info.context.db() as session:
1073
- trace_annotation = await session.get(models.TraceAnnotation, node_id)
1074
- if not trace_annotation:
1075
- raise NotFound(f"Unknown trace annotation: {id}")
1076
- return to_gql_trace_annotation(trace_annotation)
952
+ return TraceAnnotation(id=node_id)
1077
953
  elif type_name == GenerativeModel.__name__:
1078
- async with info.context.db() as session:
1079
- stmt = (
1080
- select(models.GenerativeModel)
1081
- .where(models.GenerativeModel.deleted_at.is_(None))
1082
- .where(models.GenerativeModel.id == node_id)
1083
- .options(joinedload(models.GenerativeModel.token_prices))
1084
- )
1085
- model = await session.scalar(stmt)
1086
- if not model:
1087
- raise NotFound(f"Unknown model: {id}")
1088
- return to_gql_generative_model(model)
954
+ return GenerativeModel(id=node_id)
1089
955
  raise NotFound(f"Unknown node type: {type_name}")
1090
956
 
1091
957
  @strawberry.field
@@ -1097,16 +963,7 @@ class Query:
1097
963
  return None
1098
964
  if isinstance(user, UnauthenticatedUser):
1099
965
  return None
1100
- async with info.context.db() as session:
1101
- if (
1102
- user := await session.scalar(
1103
- select(models.User)
1104
- .where(models.User.id == int(user.identity))
1105
- .options(joinedload(models.User.role))
1106
- )
1107
- ) is None:
1108
- return None
1109
- return to_gql_user(user)
966
+ return User(id=int(user.identity))
1110
967
 
1111
968
  @strawberry.field
1112
969
  async def prompts(
@@ -1146,7 +1003,9 @@ class Query:
1146
1003
  stmt = stmt.distinct()
1147
1004
  async with info.context.db() as session:
1148
1005
  orm_prompts = await session.stream_scalars(stmt)
1149
- data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
1006
+ data = [
1007
+ Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
1008
+ ]
1150
1009
  return connection_from_list(
1151
1010
  data=data,
1152
1011
  args=args,
@@ -1169,7 +1028,10 @@ class Query:
1169
1028
  )
1170
1029
  async with info.context.db() as session:
1171
1030
  prompt_labels = await session.stream_scalars(select(models.PromptLabel))
1172
- data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
1031
+ data = [
1032
+ PromptLabel(id=prompt_label.id, db_record=prompt_label)
1033
+ async for prompt_label in prompt_labels
1034
+ ]
1173
1035
  return connection_from_list(
1174
1036
  data=data,
1175
1037
  args=args,
@@ -1191,8 +1053,13 @@ class Query:
1191
1053
  before=before if isinstance(before, CursorString) else None,
1192
1054
  )
1193
1055
  async with info.context.db() as session:
1194
- dataset_labels = await session.scalars(select(models.DatasetLabel))
1195
- data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
1056
+ dataset_labels = await session.scalars(
1057
+ select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
1058
+ )
1059
+ data = [
1060
+ DatasetLabel(id=dataset_label.id, db_record=dataset_label)
1061
+ for dataset_label in dataset_labels
1062
+ ]
1196
1063
  return connection_from_list(data=data, args=args)
1197
1064
 
1198
1065
  @strawberry.field
@@ -1212,7 +1079,7 @@ class Query:
1212
1079
  )
1213
1080
  async with info.context.db() as session:
1214
1081
  splits = await session.stream_scalars(select(models.DatasetSplit))
1215
- data = [to_gql_dataset_split(split) async for split in splits]
1082
+ data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
1216
1083
  return connection_from_list(
1217
1084
  data=data,
1218
1085
  args=args,
@@ -1483,7 +1350,7 @@ class Query:
1483
1350
  async with info.context.db() as session:
1484
1351
  span_rowid = await session.scalar(stmt)
1485
1352
  if span_rowid:
1486
- return Span(span_rowid=span_rowid)
1353
+ return Span(id=span_rowid)
1487
1354
  return None
1488
1355
 
1489
1356
  @strawberry.field
@@ -1496,7 +1363,7 @@ class Query:
1496
1363
  async with info.context.db() as session:
1497
1364
  trace_rowid = await session.scalar(stmt)
1498
1365
  if trace_rowid:
1499
- return Trace(trace_rowid=trace_rowid)
1366
+ return Trace(id=trace_rowid)
1500
1367
  return None
1501
1368
 
1502
1369
  @strawberry.field
@@ -1509,7 +1376,7 @@ class Query:
1509
1376
  async with info.context.db() as session:
1510
1377
  session_row = await session.scalar(stmt)
1511
1378
  if session_row:
1512
- return to_gql_project_session(session_row)
1379
+ return ProjectSession(id=session_row.id, db_record=session_row)
1513
1380
  return None
1514
1381
 
1515
1382
 
@@ -64,7 +64,7 @@ from phoenix.server.api.types.Dataset import Dataset
64
64
  from phoenix.server.api.types.DatasetExample import DatasetExample
65
65
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
66
66
  from phoenix.server.api.types.Experiment import to_gql_experiment
67
- from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
67
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
68
68
  from phoenix.server.api.types.node import from_global_id_with_expected_type
69
69
  from phoenix.server.api.types.Span import Span
70
70
  from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
@@ -194,7 +194,7 @@ class Subscription:
194
194
  session.add(span_cost)
195
195
 
196
196
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
197
- yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
197
+ yield ChatCompletionSubscriptionResult(span=Span(id=db_span.id, db_record=db_span))
198
198
 
199
199
  @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
200
200
  async def chat_completion_over_dataset(
@@ -528,8 +528,8 @@ async def _chat_completion_result_payloads(
528
528
  await session.flush()
529
529
  for example_id, span, run in results:
530
530
  yield ChatCompletionSubscriptionResult(
531
- span=Span(span_rowid=span.id, db_span=span) if span else None,
532
- experiment_run=to_gql_experiment_run(run),
531
+ span=Span(id=span.id, db_record=span) if span else None,
532
+ experiment_run=ExperimentRun(id=run.id, db_record=run),
533
533
  dataset_example_id=example_id,
534
534
  repetition_number=run.repetition_number,
535
535
  )
@@ -1,31 +1,98 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry.scalars import JSON
6
+ from strawberry.types import Info
5
7
 
6
- from phoenix.server.api.interceptor import GqlValueMediator
8
+ from phoenix.server.api.context import Context
9
+
10
+ from .AnnotationSource import AnnotationSource
11
+ from .AnnotatorKind import AnnotatorKind
12
+
13
+ if TYPE_CHECKING:
14
+ from .User import User
7
15
 
8
16
 
9
17
  @strawberry.interface
10
18
  class Annotation:
11
- name: str = strawberry.field(
12
- description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
13
- )
14
- score: Optional[float] = strawberry.field(
15
- description="Value of the annotation in the form of a numeric score.",
16
- default=GqlValueMediator(),
17
- )
18
- label: Optional[str] = strawberry.field(
19
- description="Value of the annotation in the form of a string, e.g. "
20
- "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
21
- )
22
- explanation: Optional[str] = strawberry.field(
23
- description="The annotator's explanation for the annotation result (i.e. "
24
- "score or label, or both) given to the subject."
25
- )
26
- created_at: datetime = strawberry.field(
27
- description="The date and time when the annotation was created."
28
- )
29
- updated_at: datetime = strawberry.field(
30
- description="The date and time when the annotation was last updated."
31
- )
19
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
20
+ async def name(
21
+ self,
22
+ info: Info[Context, None],
23
+ ) -> str:
24
+ raise NotImplementedError
25
+
26
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
27
+ async def annotator_kind(
28
+ self,
29
+ info: Info[Context, None],
30
+ ) -> AnnotatorKind:
31
+ raise NotImplementedError
32
+
33
+ @strawberry.field(
34
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
35
+ ) # type: ignore
36
+ async def label(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> Optional[str]:
40
+ raise NotImplementedError
41
+
42
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
43
+ async def score(
44
+ self,
45
+ info: Info[Context, None],
46
+ ) -> Optional[float]:
47
+ raise NotImplementedError
48
+
49
+ @strawberry.field(
50
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
51
+ ) # type: ignore
52
+ async def explanation(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> Optional[str]:
56
+ raise NotImplementedError
57
+
58
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
59
+ async def metadata(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> JSON:
63
+ raise NotImplementedError
64
+
65
+ @strawberry.field(description="The source of the annotation.") # type: ignore
66
+ async def source(
67
+ self,
68
+ info: Info[Context, None],
69
+ ) -> AnnotationSource:
70
+ raise NotImplementedError
71
+
72
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
73
+ async def identifier(
74
+ self,
75
+ info: Info[Context, None],
76
+ ) -> str:
77
+ raise NotImplementedError
78
+
79
+ @strawberry.field(description="The date and time the annotation was created.") # type: ignore
80
+ async def created_at(
81
+ self,
82
+ info: Info[Context, None],
83
+ ) -> datetime:
84
+ raise NotImplementedError
85
+
86
+ @strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
87
+ async def updated_at(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> datetime:
91
+ raise NotImplementedError
92
+
93
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
94
+ async def user(
95
+ self,
96
+ info: Info[Context, None],
97
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
98
+ raise NotImplementedError