arize-phoenix 7.12.3__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/METADATA +31 -28
  2. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/RECORD +70 -47
  3. phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
  4. phoenix/db/models.py +307 -0
  5. phoenix/db/types/__init__.py +0 -0
  6. phoenix/db/types/identifier.py +7 -0
  7. phoenix/db/types/model_provider.py +8 -0
  8. phoenix/server/api/context.py +2 -0
  9. phoenix/server/api/dataloaders/__init__.py +2 -0
  10. phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
  11. phoenix/server/api/helpers/jsonschema.py +135 -0
  12. phoenix/server/api/helpers/playground_clients.py +15 -15
  13. phoenix/server/api/helpers/playground_spans.py +9 -0
  14. phoenix/server/api/helpers/prompts/__init__.py +0 -0
  15. phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
  16. phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
  17. phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
  18. phoenix/server/api/helpers/prompts/models.py +575 -0
  19. phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
  20. phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
  21. phoenix/server/api/input_types/PromptVersionInput.py +133 -0
  22. phoenix/server/api/mutations/__init__.py +6 -0
  23. phoenix/server/api/mutations/chat_mutations.py +18 -16
  24. phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
  25. phoenix/server/api/mutations/prompt_mutations.py +312 -0
  26. phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
  27. phoenix/server/api/mutations/user_mutations.py +7 -6
  28. phoenix/server/api/openapi/schema.py +1 -0
  29. phoenix/server/api/queries.py +84 -31
  30. phoenix/server/api/routers/oauth2.py +3 -2
  31. phoenix/server/api/routers/v1/__init__.py +2 -0
  32. phoenix/server/api/routers/v1/datasets.py +1 -1
  33. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
  34. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  35. phoenix/server/api/routers/v1/experiments.py +1 -1
  36. phoenix/server/api/routers/v1/models.py +45 -0
  37. phoenix/server/api/routers/v1/prompts.py +412 -0
  38. phoenix/server/api/routers/v1/spans.py +1 -1
  39. phoenix/server/api/routers/v1/traces.py +1 -1
  40. phoenix/server/api/routers/v1/utils.py +1 -1
  41. phoenix/server/api/subscriptions.py +21 -24
  42. phoenix/server/api/types/GenerativeProvider.py +4 -4
  43. phoenix/server/api/types/Identifier.py +15 -0
  44. phoenix/server/api/types/Project.py +5 -7
  45. phoenix/server/api/types/Prompt.py +134 -0
  46. phoenix/server/api/types/PromptLabel.py +41 -0
  47. phoenix/server/api/types/PromptVersion.py +148 -0
  48. phoenix/server/api/types/PromptVersionTag.py +27 -0
  49. phoenix/server/api/types/PromptVersionTemplate.py +148 -0
  50. phoenix/server/api/types/ResponseFormat.py +9 -0
  51. phoenix/server/api/types/ToolDefinition.py +9 -0
  52. phoenix/server/app.py +3 -0
  53. phoenix/server/static/.vite/manifest.json +45 -45
  54. phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
  55. phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
  56. phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
  57. phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
  58. phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
  59. phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
  60. phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
  61. phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
  62. phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
  63. phoenix/session/client.py +25 -21
  64. phoenix/utilities/client.py +6 -0
  65. phoenix/version.py +1 -1
  66. phoenix/server/api/input_types/TemplateOptions.py +0 -10
  67. phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
  68. phoenix/server/api/types/TemplateLanguage.py +0 -10
  69. phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
  70. phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
  71. phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
  72. phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
  73. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
  74. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
  75. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
  76. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/LICENSE +0 -0
  80. /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
@@ -13,7 +13,7 @@ class GenerativeProviderKey(Enum):
13
13
  OPENAI = "OpenAI"
14
14
  ANTHROPIC = "Anthropic"
15
15
  AZURE_OPENAI = "Azure OpenAI"
16
- GEMINI = "Google AI Studio"
16
+ GOOGLE = "Google AI Studio"
17
17
 
18
18
 
19
19
  @strawberry.type
@@ -25,21 +25,21 @@ class GenerativeProvider:
25
25
  GenerativeProviderKey.AZURE_OPENAI: [],
26
26
  GenerativeProviderKey.ANTHROPIC: ["claude"],
27
27
  GenerativeProviderKey.OPENAI: ["gpt", "o1"],
28
- GenerativeProviderKey.GEMINI: ["gemini"],
28
+ GenerativeProviderKey.GOOGLE: ["gemini"],
29
29
  }
30
30
 
31
31
  attribute_provider_to_generative_provider_map: ClassVar[dict[str, GenerativeProviderKey]] = {
32
32
  OpenInferenceLLMProviderValues.OPENAI.value: GenerativeProviderKey.OPENAI,
33
33
  OpenInferenceLLMProviderValues.ANTHROPIC.value: GenerativeProviderKey.ANTHROPIC,
34
34
  OpenInferenceLLMProviderValues.AZURE.value: GenerativeProviderKey.AZURE_OPENAI,
35
- OpenInferenceLLMProviderValues.GOOGLE.value: GenerativeProviderKey.GEMINI,
35
+ OpenInferenceLLMProviderValues.GOOGLE.value: GenerativeProviderKey.GOOGLE,
36
36
  }
37
37
 
38
38
  model_provider_to_api_key_env_var_map: ClassVar[dict[GenerativeProviderKey, str]] = {
39
39
  GenerativeProviderKey.AZURE_OPENAI: "AZURE_OPENAI_API_KEY",
40
40
  GenerativeProviderKey.ANTHROPIC: "ANTHROPIC_API_KEY",
41
41
  GenerativeProviderKey.OPENAI: "OPENAI_API_KEY",
42
- GenerativeProviderKey.GEMINI: "GEMINI_API_KEY",
42
+ GenerativeProviderKey.GOOGLE: "GEMINI_API_KEY",
43
43
  }
44
44
 
45
45
  @strawberry.field
@@ -0,0 +1,15 @@
1
+ from typing import NewType
2
+
3
+ import strawberry
4
+
5
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
6
+
7
+
8
+ def parse_value(value: str) -> str:
9
+ return IdentifierModel.model_validate(value).root
10
+
11
+
12
+ Identifier = strawberry.scalar(
13
+ NewType("Identifier", str),
14
+ parse_value=parse_value,
15
+ )
@@ -5,7 +5,7 @@ from typing import Any, ClassVar, Optional
5
5
  import strawberry
6
6
  from aioitertools.itertools import islice
7
7
  from openinference.semconv.trace import SpanAttributes
8
- from sqlalchemy import and_, desc, distinct, func, or_, select
8
+ from sqlalchemy import desc, distinct, func, or_, select
9
9
  from sqlalchemy.orm import contains_eager
10
10
  from sqlalchemy.sql.elements import ColumnElement
11
11
  from sqlalchemy.sql.expression import tuple_
@@ -190,12 +190,10 @@ class Project(Node):
190
190
  .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
191
191
  )
192
192
  if time_range:
193
- stmt = stmt.where(
194
- and_(
195
- time_range.start <= models.Span.start_time,
196
- models.Span.start_time < time_range.end,
197
- )
198
- )
193
+ if time_range.start:
194
+ stmt = stmt.where(time_range.start <= models.Span.start_time)
195
+ if time_range.end:
196
+ stmt = stmt.where(models.Span.start_time < time_range.end)
199
197
  if root_spans_only:
200
198
  # A root span is any span whose parent span is missing in the
201
199
  # database, even if its `parent_span_id` may not be NULL.
@@ -0,0 +1,134 @@
1
+ # Part of the Phoenix PromptHub feature set
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ import strawberry
6
+ from sqlalchemy import func, select
7
+ from strawberry import UNSET
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import NotFound
14
+ from phoenix.server.api.types.Identifier import Identifier
15
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
16
+ from phoenix.server.api.types.pagination import (
17
+ ConnectionArgs,
18
+ CursorString,
19
+ connection_from_list,
20
+ )
21
+
22
+ from .PromptVersion import (
23
+ PromptVersion,
24
+ to_gql_prompt_version,
25
+ )
26
+ from .PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
27
+
28
+
29
+ @strawberry.type
30
+ class Prompt(Node):
31
+ id_attr: NodeID[int]
32
+ source_prompt_id: Optional[GlobalID]
33
+ name: Identifier
34
+ description: Optional[str]
35
+ created_at: datetime
36
+
37
+ @strawberry.field
38
+ async def version(
39
+ self, info: Info[Context, None], version_id: Optional[GlobalID] = None
40
+ ) -> PromptVersion:
41
+ async with info.context.db() as session:
42
+ if version_id:
43
+ v_id = from_global_id_with_expected_type(version_id, PromptVersion.__name__)
44
+ version = await session.scalar(
45
+ select(models.PromptVersion).where(
46
+ models.PromptVersion.id == v_id,
47
+ models.PromptVersion.prompt_id == self.id_attr,
48
+ )
49
+ )
50
+ if not version:
51
+ raise NotFound(f"Prompt version not found: {version_id}")
52
+ else:
53
+ stmt = (
54
+ select(models.PromptVersion)
55
+ .where(models.PromptVersion.prompt_id == self.id_attr)
56
+ .order_by(models.PromptVersion.id.desc())
57
+ .limit(1)
58
+ )
59
+ version = await session.scalar(stmt)
60
+ if not version:
61
+ raise NotFound("This prompt has no associated versions")
62
+ return to_gql_prompt_version(version)
63
+
64
+ @strawberry.field
65
+ async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
66
+ async with info.context.db() as session:
67
+ stmt = select(models.PromptVersionTag).where(
68
+ models.PromptVersionTag.prompt_id == self.id_attr
69
+ )
70
+ return [
71
+ to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
72
+ ]
73
+
74
+ @strawberry.field
75
+ async def prompt_versions(
76
+ self,
77
+ info: Info[Context, None],
78
+ first: Optional[int] = 50,
79
+ last: Optional[int] = UNSET,
80
+ after: Optional[CursorString] = UNSET,
81
+ before: Optional[CursorString] = UNSET,
82
+ ) -> Connection[PromptVersion]:
83
+ args = ConnectionArgs(
84
+ first=first,
85
+ after=after if isinstance(after, CursorString) else None,
86
+ last=last,
87
+ before=before if isinstance(before, CursorString) else None,
88
+ )
89
+ row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
90
+ stmt = (
91
+ select(models.PromptVersion, row_number)
92
+ .where(models.PromptVersion.prompt_id == self.id_attr)
93
+ .order_by(models.PromptVersion.id.desc())
94
+ )
95
+ async with info.context.db() as session:
96
+ data = [
97
+ to_gql_prompt_version(prompt_version, sequence_number)
98
+ async for prompt_version, sequence_number in await session.stream(stmt)
99
+ ]
100
+ return connection_from_list(data=data, args=args)
101
+
102
+ @strawberry.field
103
+ async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
104
+ if not self.source_prompt_id:
105
+ return None
106
+
107
+ source_prompt_id = from_global_id_with_expected_type(
108
+ global_id=self.source_prompt_id, expected_type_name=Prompt.__name__
109
+ )
110
+
111
+ async with info.context.db() as session:
112
+ source_prompt = await session.scalar(
113
+ select(models.Prompt).where(models.Prompt.id == source_prompt_id)
114
+ )
115
+ if not source_prompt:
116
+ raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
117
+ return to_gql_prompt_from_orm(source_prompt)
118
+
119
+
120
+ def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
121
+ if not orm_model.source_prompt_id:
122
+ source_prompt_gid = None
123
+ else:
124
+ source_prompt_gid = GlobalID(
125
+ Prompt.__name__,
126
+ str(orm_model.source_prompt_id),
127
+ )
128
+ return Prompt(
129
+ id_attr=orm_model.id,
130
+ source_prompt_id=source_prompt_gid,
131
+ name=Identifier(orm_model.name.root),
132
+ description=orm_model.description,
133
+ created_at=orm_model.created_at,
134
+ )
@@ -0,0 +1,41 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy import select
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.context import Context
10
+ from phoenix.server.api.types.Identifier import Identifier
11
+ from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
12
+
13
+
14
+ @strawberry.type
15
+ class PromptLabel(Node):
16
+ id_attr: NodeID[int]
17
+ name: Identifier
18
+ description: Optional[str] = None
19
+
20
+ @strawberry.field
21
+ async def prompts(self, info: Info[Context, None]) -> list[Prompt]:
22
+ async with info.context.db() as session:
23
+ statement = (
24
+ select(models.Prompt)
25
+ .join(
26
+ models.PromptPromptLabel, models.Prompt.id == models.PromptPromptLabel.prompt_id
27
+ )
28
+ .where(models.PromptPromptLabel.prompt_label_id == self.id_attr)
29
+ )
30
+ return [
31
+ to_gql_prompt_from_orm(prompt_orm)
32
+ async for prompt_orm in await session.stream_scalars(statement)
33
+ ]
34
+
35
+
36
+ def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
37
+ return PromptLabel(
38
+ id_attr=label_orm.id,
39
+ name=Identifier(label_orm.name),
40
+ description=label_orm.description,
41
+ )
@@ -0,0 +1,148 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from sqlalchemy import select
6
+ from strawberry import Private
7
+ from strawberry.relay import Node, NodeID
8
+ from strawberry.scalars import JSON
9
+ from strawberry.types import Info
10
+
11
+ from phoenix.db import models
12
+ from phoenix.db.types.model_provider import ModelProvider
13
+ from phoenix.server.api.context import Context
14
+ from phoenix.server.api.helpers.prompts.models import (
15
+ PromptTemplateFormat,
16
+ PromptTemplateType,
17
+ denormalize_response_format,
18
+ denormalize_tools,
19
+ get_raw_invocation_parameters,
20
+ )
21
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
22
+ from phoenix.server.api.types.PromptVersionTemplate import (
23
+ PromptTemplate,
24
+ to_gql_template_from_orm,
25
+ )
26
+
27
+ from .ResponseFormat import ResponseFormat
28
+ from .ToolDefinition import ToolDefinition
29
+ from .User import User, to_gql_user
30
+
31
+
32
+ @strawberry.type
33
+ class PromptVersion(Node):
34
+ id_attr: NodeID[int]
35
+ user_id: strawberry.Private[Optional[int]]
36
+ description: Optional[str]
37
+ template_type: PromptTemplateType
38
+ template_format: PromptTemplateFormat
39
+ template: PromptTemplate
40
+ invocation_parameters: Optional[JSON] = None
41
+ tools: list[ToolDefinition]
42
+ response_format: Optional[ResponseFormat] = None
43
+ model_name: str
44
+ model_provider: ModelProvider
45
+ metadata: JSON
46
+ created_at: datetime
47
+ cached_sequence_number: Private[Optional[int]] = None
48
+
49
+ @strawberry.field
50
+ async def tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
51
+ async with info.context.db() as session:
52
+ stmt = select(models.PromptVersionTag).where(
53
+ models.PromptVersionTag.prompt_version_id == self.id_attr
54
+ )
55
+ return [
56
+ to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
57
+ ]
58
+
59
+ @strawberry.field
60
+ async def user(self, info: Info[Context, None]) -> Optional[User]:
61
+ if self.user_id is None:
62
+ return None
63
+ async with info.context.db() as session:
64
+ user = await session.get(models.User, self.user_id)
65
+ return to_gql_user(user) if user is not None else None
66
+
67
+ @strawberry.field
68
+ async def previous_version(self, info: Info[Context, None]) -> Optional["PromptVersion"]:
69
+ async with info.context.db() as session:
70
+ current_version = await session.get(models.PromptVersion, self.id_attr)
71
+ if current_version is None:
72
+ return None
73
+
74
+ prompt_id = current_version.prompt_id
75
+
76
+ stmt = (
77
+ select(models.PromptVersion)
78
+ .where(models.PromptVersion.prompt_id == prompt_id)
79
+ .where(models.PromptVersion.id < self.id_attr)
80
+ .order_by(models.PromptVersion.created_at.desc())
81
+ .limit(1)
82
+ )
83
+ previous_version = await session.scalar(stmt)
84
+
85
+ if previous_version is not None:
86
+ return to_gql_prompt_version(prompt_version=previous_version)
87
+ return None
88
+
89
+ @strawberry.field(
90
+ description="Sequence number (1-based) of prompt versions belonging to the same prompt"
91
+ ) # type: ignore
92
+ async def sequence_number(
93
+ self,
94
+ info: Info[Context, None],
95
+ ) -> int:
96
+ if self.cached_sequence_number is None:
97
+ seq_num = await info.context.data_loaders.prompt_version_sequence_number.load(
98
+ self.id_attr
99
+ )
100
+ if seq_num is None:
101
+ raise ValueError(f"invalid prompt version: id={self.id_attr}")
102
+ self.cached_sequence_number = seq_num
103
+ return self.cached_sequence_number
104
+
105
+
106
+ def to_gql_prompt_version(
107
+ prompt_version: models.PromptVersion, sequence_number: Optional[int] = None
108
+ ) -> PromptVersion:
109
+ prompt_template_type = PromptTemplateType(prompt_version.template_type)
110
+ prompt_template = to_gql_template_from_orm(prompt_version)
111
+ prompt_template_format = PromptTemplateFormat(prompt_version.template_format)
112
+ tool_choice = None
113
+ if prompt_version.tools is not None:
114
+ tool_schemas, tool_choice = denormalize_tools(
115
+ prompt_version.tools, prompt_version.model_provider
116
+ )
117
+ tools = [ToolDefinition(definition=schema) for schema in tool_schemas]
118
+ else:
119
+ tools = []
120
+ response_format = (
121
+ ResponseFormat(
122
+ definition=denormalize_response_format(
123
+ prompt_version.response_format,
124
+ prompt_version.model_provider,
125
+ )
126
+ )
127
+ if prompt_version.response_format is not None
128
+ else None
129
+ )
130
+ invocation_parameters = get_raw_invocation_parameters(prompt_version.invocation_parameters)
131
+ if tool_choice is not None:
132
+ invocation_parameters["tool_choice"] = tool_choice
133
+ return PromptVersion(
134
+ id_attr=prompt_version.id,
135
+ user_id=prompt_version.user_id,
136
+ description=prompt_version.description,
137
+ template_type=prompt_template_type,
138
+ template_format=prompt_template_format,
139
+ template=prompt_template,
140
+ invocation_parameters=invocation_parameters,
141
+ tools=tools,
142
+ response_format=response_format,
143
+ model_name=prompt_version.model_name,
144
+ model_provider=prompt_version.model_provider,
145
+ metadata=prompt_version.metadata_,
146
+ created_at=prompt_version.created_at,
147
+ cached_sequence_number=sequence_number,
148
+ )
@@ -0,0 +1,27 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import GlobalID, Node, NodeID
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.api.types.Identifier import Identifier
8
+
9
+
10
+ @strawberry.type
11
+ class PromptVersionTag(Node):
12
+ id_attr: NodeID[int]
13
+ prompt_version_id: GlobalID
14
+ name: Identifier
15
+ description: Optional[str] = None
16
+
17
+
18
+ def to_gql_prompt_version_tag(prompt_version_tag: models.PromptVersionTag) -> PromptVersionTag:
19
+ from phoenix.server.api.types.PromptVersion import PromptVersion
20
+
21
+ version_gid = GlobalID(PromptVersion.__name__, str(prompt_version_tag.prompt_version_id))
22
+ return PromptVersionTag(
23
+ id_attr=prompt_version_tag.id,
24
+ prompt_version_id=version_gid,
25
+ name=Identifier(prompt_version_tag.name.root),
26
+ description=prompt_version_tag.description,
27
+ )
@@ -0,0 +1,148 @@
1
+ # Part of the Phoenix PromptHub feature set
2
+ import json
3
+ from typing import Annotated, Union
4
+
5
+ import strawberry
6
+ from strawberry.scalars import JSON
7
+ from typing_extensions import TypeAlias, assert_never
8
+
9
+ from phoenix.db.models import PromptVersion as ORMPromptVersion
10
+ from phoenix.server.api.helpers.prompts.models import (
11
+ PromptChatTemplate as PromptChatTemplateModel,
12
+ )
13
+ from phoenix.server.api.helpers.prompts.models import PromptMessage as PromptMessageModel
14
+ from phoenix.server.api.helpers.prompts.models import (
15
+ PromptMessageRole,
16
+ PromptTemplateType,
17
+ RoleConversion,
18
+ )
19
+ from phoenix.server.api.helpers.prompts.models import (
20
+ PromptStringTemplate as PromptStringTemplateModel,
21
+ )
22
+
23
+
24
+ @strawberry.type
25
+ class TextContentValue:
26
+ text: str
27
+
28
+
29
+ @strawberry.type
30
+ class TextContentPart:
31
+ text: TextContentValue
32
+
33
+
34
+ @strawberry.type
35
+ class ToolCallFunction:
36
+ name: str
37
+ arguments: str
38
+
39
+
40
+ @strawberry.type
41
+ class ToolCallContentValue:
42
+ tool_call_id: str
43
+ tool_call: ToolCallFunction
44
+
45
+
46
+ @strawberry.type
47
+ class ToolCallContentPart:
48
+ tool_call: ToolCallContentValue
49
+
50
+
51
+ @strawberry.type
52
+ class ToolResultContentValue:
53
+ tool_call_id: str
54
+ result: JSON
55
+
56
+
57
+ @strawberry.type
58
+ class ToolResultContentPart:
59
+ tool_result: ToolResultContentValue
60
+
61
+
62
+ ContentPart: TypeAlias = Annotated[
63
+ Union[TextContentPart, ToolCallContentPart, ToolResultContentPart],
64
+ strawberry.union("ContentPart"),
65
+ ]
66
+
67
+
68
+ @strawberry.type
69
+ class PromptMessage:
70
+ role: PromptMessageRole
71
+ content: list[ContentPart]
72
+
73
+
74
+ @strawberry.experimental.pydantic.type(PromptChatTemplateModel)
75
+ class PromptChatTemplate:
76
+ messages: list[PromptMessage]
77
+
78
+
79
+ def to_gql_prompt_chat_template_from_orm(orm_model: "ORMPromptVersion") -> "PromptChatTemplate":
80
+ template = PromptChatTemplateModel.model_validate(orm_model.template)
81
+ messages: list[PromptMessage] = []
82
+ for msg in template.messages:
83
+ role = RoleConversion.to_gql(msg.role)
84
+ if isinstance(msg, PromptMessageModel):
85
+ if isinstance(msg.content, str):
86
+ messages.append(
87
+ PromptMessage(
88
+ role=role,
89
+ content=[TextContentPart(text=TextContentValue(text=msg.content))],
90
+ )
91
+ )
92
+ continue
93
+ content: list[ContentPart] = []
94
+ for part in msg.content:
95
+ if part.type == "text":
96
+ content.append(TextContentPart(text=TextContentValue(text=part.text)))
97
+ elif part.type == "tool_call":
98
+ content.append(
99
+ ToolCallContentPart(
100
+ tool_call=ToolCallContentValue(
101
+ tool_call_id=part.tool_call_id,
102
+ tool_call=ToolCallFunction(
103
+ name=part.tool_call.name,
104
+ arguments=part.tool_call.arguments,
105
+ ),
106
+ )
107
+ )
108
+ )
109
+ elif part.type == "tool_result":
110
+ content.append(
111
+ ToolResultContentPart(
112
+ tool_result=ToolResultContentValue(
113
+ tool_call_id=part.tool_call_id,
114
+ result=json.dumps(part.tool_result),
115
+ )
116
+ )
117
+ )
118
+ else:
119
+ assert_never(part)
120
+ messages.append(PromptMessage(role=role, content=content))
121
+ else:
122
+ assert_never(msg)
123
+ return PromptChatTemplate(messages=messages)
124
+
125
+
126
+ @strawberry.experimental.pydantic.type(PromptStringTemplateModel)
127
+ class PromptStringTemplate:
128
+ template: strawberry.auto
129
+
130
+
131
+ def to_gql_prompt_string_template_from_orm(orm_model: "ORMPromptVersion") -> "PromptStringTemplate":
132
+ model = PromptStringTemplateModel.model_validate(orm_model.template)
133
+ return PromptStringTemplate(template=model.template)
134
+
135
+
136
+ def to_gql_template_from_orm(orm_prompt_version: "ORMPromptVersion") -> "PromptTemplate":
137
+ template_type = PromptTemplateType(orm_prompt_version.template_type)
138
+ if template_type is PromptTemplateType.STRING:
139
+ return to_gql_prompt_string_template_from_orm(orm_prompt_version)
140
+ elif template_type is PromptTemplateType.CHAT:
141
+ return to_gql_prompt_chat_template_from_orm(orm_prompt_version)
142
+ assert_never(template_type)
143
+
144
+
145
+ PromptTemplate: TypeAlias = Annotated[
146
+ Union[PromptStringTemplate, PromptChatTemplate],
147
+ strawberry.union("PromptTemplate"),
148
+ ]
@@ -0,0 +1,9 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+
5
+ @strawberry.type
6
+ class ResponseFormat:
7
+ """A JSON schema definition used to guide an LLM's output"""
8
+
9
+ definition: JSON
@@ -0,0 +1,9 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+
5
+ @strawberry.type
6
+ class ToolDefinition:
7
+ """The definition of a tool that a generative tool can invoke."""
8
+
9
+ definition: JSON
phoenix/server/app.py CHANGED
@@ -86,6 +86,7 @@ from phoenix.server.api.dataloaders import (
86
86
  LatencyMsQuantileDataLoader,
87
87
  MinStartOrMaxEndTimeDataLoader,
88
88
  ProjectByNameDataLoader,
89
+ PromptVersionSequenceNumberDataLoader,
89
90
  RecordCountDataLoader,
90
91
  SessionIODataLoader,
91
92
  SessionNumTracesDataLoader,
@@ -611,6 +612,7 @@ def create_graphql_router(
611
612
  else None
612
613
  ),
613
614
  ),
615
+ prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
614
616
  record_counts=RecordCountDataLoader(
615
617
  db,
616
618
  cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
@@ -914,6 +916,7 @@ def create_app(
914
916
  ),
915
917
  name="static",
916
918
  )
919
+ app.state.authentication_enabled = authentication_enabled
917
920
  app.state.read_only = read_only
918
921
  app.state.export_path = export_path
919
922
  app.state.password_reset_token_expiry = password_reset_token_expiry