arize-phoenix 7.12.3__py3-none-any.whl → 8.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/METADATA +31 -28
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/RECORD +70 -47
- phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
- phoenix/db/models.py +307 -0
- phoenix/db/types/__init__.py +0 -0
- phoenix/db/types/identifier.py +7 -0
- phoenix/db/types/model_provider.py +8 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
- phoenix/server/api/helpers/jsonschema.py +135 -0
- phoenix/server/api/helpers/playground_clients.py +15 -15
- phoenix/server/api/helpers/playground_spans.py +9 -0
- phoenix/server/api/helpers/prompts/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
- phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
- phoenix/server/api/helpers/prompts/models.py +575 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
- phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
- phoenix/server/api/input_types/PromptVersionInput.py +133 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +18 -16
- phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
- phoenix/server/api/mutations/prompt_mutations.py +312 -0
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
- phoenix/server/api/mutations/user_mutations.py +7 -6
- phoenix/server/api/openapi/schema.py +1 -0
- phoenix/server/api/queries.py +84 -31
- phoenix/server/api/routers/oauth2.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/datasets.py +1 -1
- phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/models.py +45 -0
- phoenix/server/api/routers/v1/prompts.py +415 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/routers/v1/traces.py +1 -1
- phoenix/server/api/routers/v1/utils.py +1 -1
- phoenix/server/api/subscriptions.py +21 -24
- phoenix/server/api/types/GenerativeProvider.py +4 -4
- phoenix/server/api/types/Identifier.py +15 -0
- phoenix/server/api/types/Project.py +5 -7
- phoenix/server/api/types/Prompt.py +134 -0
- phoenix/server/api/types/PromptLabel.py +41 -0
- phoenix/server/api/types/PromptVersion.py +148 -0
- phoenix/server/api/types/PromptVersionTag.py +27 -0
- phoenix/server/api/types/PromptVersionTemplate.py +148 -0
- phoenix/server/api/types/ResponseFormat.py +9 -0
- phoenix/server/api/types/ToolDefinition.py +9 -0
- phoenix/server/app.py +3 -0
- phoenix/server/static/.vite/manifest.json +45 -45
- phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
- phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
- phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
- phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
- phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
- phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
- phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
- phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
- phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
- phoenix/session/client.py +25 -21
- phoenix/utilities/client.py +6 -0
- phoenix/version.py +1 -1
- phoenix/server/api/input_types/TemplateOptions.py +0 -10
- phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
- phoenix/server/api/types/TemplateLanguage.py +0 -10
- phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
- phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
- phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
- phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/WHEEL +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
|
@@ -13,7 +13,7 @@ class GenerativeProviderKey(Enum):
|
|
|
13
13
|
OPENAI = "OpenAI"
|
|
14
14
|
ANTHROPIC = "Anthropic"
|
|
15
15
|
AZURE_OPENAI = "Azure OpenAI"
|
|
16
|
-
|
|
16
|
+
GOOGLE = "Google AI Studio"
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@strawberry.type
|
|
@@ -25,21 +25,21 @@ class GenerativeProvider:
|
|
|
25
25
|
GenerativeProviderKey.AZURE_OPENAI: [],
|
|
26
26
|
GenerativeProviderKey.ANTHROPIC: ["claude"],
|
|
27
27
|
GenerativeProviderKey.OPENAI: ["gpt", "o1"],
|
|
28
|
-
GenerativeProviderKey.
|
|
28
|
+
GenerativeProviderKey.GOOGLE: ["gemini"],
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
attribute_provider_to_generative_provider_map: ClassVar[dict[str, GenerativeProviderKey]] = {
|
|
32
32
|
OpenInferenceLLMProviderValues.OPENAI.value: GenerativeProviderKey.OPENAI,
|
|
33
33
|
OpenInferenceLLMProviderValues.ANTHROPIC.value: GenerativeProviderKey.ANTHROPIC,
|
|
34
34
|
OpenInferenceLLMProviderValues.AZURE.value: GenerativeProviderKey.AZURE_OPENAI,
|
|
35
|
-
OpenInferenceLLMProviderValues.GOOGLE.value: GenerativeProviderKey.
|
|
35
|
+
OpenInferenceLLMProviderValues.GOOGLE.value: GenerativeProviderKey.GOOGLE,
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
model_provider_to_api_key_env_var_map: ClassVar[dict[GenerativeProviderKey, str]] = {
|
|
39
39
|
GenerativeProviderKey.AZURE_OPENAI: "AZURE_OPENAI_API_KEY",
|
|
40
40
|
GenerativeProviderKey.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
41
41
|
GenerativeProviderKey.OPENAI: "OPENAI_API_KEY",
|
|
42
|
-
GenerativeProviderKey.
|
|
42
|
+
GenerativeProviderKey.GOOGLE: "GEMINI_API_KEY",
|
|
43
43
|
}
|
|
44
44
|
|
|
45
45
|
@strawberry.field
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import NewType
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def parse_value(value: str) -> str:
|
|
9
|
+
return IdentifierModel.model_validate(value).root
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Identifier = strawberry.scalar(
|
|
13
|
+
NewType("Identifier", str),
|
|
14
|
+
parse_value=parse_value,
|
|
15
|
+
)
|
|
@@ -5,7 +5,7 @@ from typing import Any, ClassVar, Optional
|
|
|
5
5
|
import strawberry
|
|
6
6
|
from aioitertools.itertools import islice
|
|
7
7
|
from openinference.semconv.trace import SpanAttributes
|
|
8
|
-
from sqlalchemy import
|
|
8
|
+
from sqlalchemy import desc, distinct, func, or_, select
|
|
9
9
|
from sqlalchemy.orm import contains_eager
|
|
10
10
|
from sqlalchemy.sql.elements import ColumnElement
|
|
11
11
|
from sqlalchemy.sql.expression import tuple_
|
|
@@ -190,12 +190,10 @@ class Project(Node):
|
|
|
190
190
|
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
|
|
191
191
|
)
|
|
192
192
|
if time_range:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
)
|
|
198
|
-
)
|
|
193
|
+
if time_range.start:
|
|
194
|
+
stmt = stmt.where(time_range.start <= models.Span.start_time)
|
|
195
|
+
if time_range.end:
|
|
196
|
+
stmt = stmt.where(models.Span.start_time < time_range.end)
|
|
199
197
|
if root_spans_only:
|
|
200
198
|
# A root span is any span whose parent span is missing in the
|
|
201
199
|
# database, even if its `parent_span_id` may not be NULL.
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Part of the Phoenix PromptHub feature set
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import func, select
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import NotFound
|
|
14
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
15
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
16
|
+
from phoenix.server.api.types.pagination import (
|
|
17
|
+
ConnectionArgs,
|
|
18
|
+
CursorString,
|
|
19
|
+
connection_from_list,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from .PromptVersion import (
|
|
23
|
+
PromptVersion,
|
|
24
|
+
to_gql_prompt_version,
|
|
25
|
+
)
|
|
26
|
+
from .PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.type
|
|
30
|
+
class Prompt(Node):
|
|
31
|
+
id_attr: NodeID[int]
|
|
32
|
+
source_prompt_id: Optional[GlobalID]
|
|
33
|
+
name: Identifier
|
|
34
|
+
description: Optional[str]
|
|
35
|
+
created_at: datetime
|
|
36
|
+
|
|
37
|
+
@strawberry.field
|
|
38
|
+
async def version(
|
|
39
|
+
self, info: Info[Context, None], version_id: Optional[GlobalID] = None
|
|
40
|
+
) -> PromptVersion:
|
|
41
|
+
async with info.context.db() as session:
|
|
42
|
+
if version_id:
|
|
43
|
+
v_id = from_global_id_with_expected_type(version_id, PromptVersion.__name__)
|
|
44
|
+
version = await session.scalar(
|
|
45
|
+
select(models.PromptVersion).where(
|
|
46
|
+
models.PromptVersion.id == v_id,
|
|
47
|
+
models.PromptVersion.prompt_id == self.id_attr,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
if not version:
|
|
51
|
+
raise NotFound(f"Prompt version not found: {version_id}")
|
|
52
|
+
else:
|
|
53
|
+
stmt = (
|
|
54
|
+
select(models.PromptVersion)
|
|
55
|
+
.where(models.PromptVersion.prompt_id == self.id_attr)
|
|
56
|
+
.order_by(models.PromptVersion.id.desc())
|
|
57
|
+
.limit(1)
|
|
58
|
+
)
|
|
59
|
+
version = await session.scalar(stmt)
|
|
60
|
+
if not version:
|
|
61
|
+
raise NotFound("This prompt has no associated versions")
|
|
62
|
+
return to_gql_prompt_version(version)
|
|
63
|
+
|
|
64
|
+
@strawberry.field
|
|
65
|
+
async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
|
|
66
|
+
async with info.context.db() as session:
|
|
67
|
+
stmt = select(models.PromptVersionTag).where(
|
|
68
|
+
models.PromptVersionTag.prompt_id == self.id_attr
|
|
69
|
+
)
|
|
70
|
+
return [
|
|
71
|
+
to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
@strawberry.field
|
|
75
|
+
async def prompt_versions(
|
|
76
|
+
self,
|
|
77
|
+
info: Info[Context, None],
|
|
78
|
+
first: Optional[int] = 50,
|
|
79
|
+
last: Optional[int] = UNSET,
|
|
80
|
+
after: Optional[CursorString] = UNSET,
|
|
81
|
+
before: Optional[CursorString] = UNSET,
|
|
82
|
+
) -> Connection[PromptVersion]:
|
|
83
|
+
args = ConnectionArgs(
|
|
84
|
+
first=first,
|
|
85
|
+
after=after if isinstance(after, CursorString) else None,
|
|
86
|
+
last=last,
|
|
87
|
+
before=before if isinstance(before, CursorString) else None,
|
|
88
|
+
)
|
|
89
|
+
row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
|
|
90
|
+
stmt = (
|
|
91
|
+
select(models.PromptVersion, row_number)
|
|
92
|
+
.where(models.PromptVersion.prompt_id == self.id_attr)
|
|
93
|
+
.order_by(models.PromptVersion.id.desc())
|
|
94
|
+
)
|
|
95
|
+
async with info.context.db() as session:
|
|
96
|
+
data = [
|
|
97
|
+
to_gql_prompt_version(prompt_version, sequence_number)
|
|
98
|
+
async for prompt_version, sequence_number in await session.stream(stmt)
|
|
99
|
+
]
|
|
100
|
+
return connection_from_list(data=data, args=args)
|
|
101
|
+
|
|
102
|
+
@strawberry.field
|
|
103
|
+
async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
|
|
104
|
+
if not self.source_prompt_id:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
source_prompt_id = from_global_id_with_expected_type(
|
|
108
|
+
global_id=self.source_prompt_id, expected_type_name=Prompt.__name__
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
async with info.context.db() as session:
|
|
112
|
+
source_prompt = await session.scalar(
|
|
113
|
+
select(models.Prompt).where(models.Prompt.id == source_prompt_id)
|
|
114
|
+
)
|
|
115
|
+
if not source_prompt:
|
|
116
|
+
raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
|
|
117
|
+
return to_gql_prompt_from_orm(source_prompt)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
|
|
121
|
+
if not orm_model.source_prompt_id:
|
|
122
|
+
source_prompt_gid = None
|
|
123
|
+
else:
|
|
124
|
+
source_prompt_gid = GlobalID(
|
|
125
|
+
Prompt.__name__,
|
|
126
|
+
str(orm_model.source_prompt_id),
|
|
127
|
+
)
|
|
128
|
+
return Prompt(
|
|
129
|
+
id_attr=orm_model.id,
|
|
130
|
+
source_prompt_id=source_prompt_gid,
|
|
131
|
+
name=Identifier(orm_model.name.root),
|
|
132
|
+
description=orm_model.description,
|
|
133
|
+
created_at=orm_model.created_at,
|
|
134
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import select
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.context import Context
|
|
10
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
11
|
+
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@strawberry.type
|
|
15
|
+
class PromptLabel(Node):
|
|
16
|
+
id_attr: NodeID[int]
|
|
17
|
+
name: Identifier
|
|
18
|
+
description: Optional[str] = None
|
|
19
|
+
|
|
20
|
+
@strawberry.field
|
|
21
|
+
async def prompts(self, info: Info[Context, None]) -> list[Prompt]:
|
|
22
|
+
async with info.context.db() as session:
|
|
23
|
+
statement = (
|
|
24
|
+
select(models.Prompt)
|
|
25
|
+
.join(
|
|
26
|
+
models.PromptPromptLabel, models.Prompt.id == models.PromptPromptLabel.prompt_id
|
|
27
|
+
)
|
|
28
|
+
.where(models.PromptPromptLabel.prompt_label_id == self.id_attr)
|
|
29
|
+
)
|
|
30
|
+
return [
|
|
31
|
+
to_gql_prompt_from_orm(prompt_orm)
|
|
32
|
+
async for prompt_orm in await session.stream_scalars(statement)
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
|
|
37
|
+
return PromptLabel(
|
|
38
|
+
id_attr=label_orm.id,
|
|
39
|
+
name=Identifier(label_orm.name),
|
|
40
|
+
description=label_orm.description,
|
|
41
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from strawberry import Private
|
|
7
|
+
from strawberry.relay import Node, NodeID
|
|
8
|
+
from strawberry.scalars import JSON
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.db.types.model_provider import ModelProvider
|
|
13
|
+
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
15
|
+
PromptTemplateFormat,
|
|
16
|
+
PromptTemplateType,
|
|
17
|
+
denormalize_response_format,
|
|
18
|
+
denormalize_tools,
|
|
19
|
+
get_raw_invocation_parameters,
|
|
20
|
+
)
|
|
21
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
|
|
22
|
+
from phoenix.server.api.types.PromptVersionTemplate import (
|
|
23
|
+
PromptTemplate,
|
|
24
|
+
to_gql_template_from_orm,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .ResponseFormat import ResponseFormat
|
|
28
|
+
from .ToolDefinition import ToolDefinition
|
|
29
|
+
from .User import User, to_gql_user
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@strawberry.type
|
|
33
|
+
class PromptVersion(Node):
|
|
34
|
+
id_attr: NodeID[int]
|
|
35
|
+
user_id: strawberry.Private[Optional[int]]
|
|
36
|
+
description: Optional[str]
|
|
37
|
+
template_type: PromptTemplateType
|
|
38
|
+
template_format: PromptTemplateFormat
|
|
39
|
+
template: PromptTemplate
|
|
40
|
+
invocation_parameters: Optional[JSON] = None
|
|
41
|
+
tools: list[ToolDefinition]
|
|
42
|
+
response_format: Optional[ResponseFormat] = None
|
|
43
|
+
model_name: str
|
|
44
|
+
model_provider: ModelProvider
|
|
45
|
+
metadata: JSON
|
|
46
|
+
created_at: datetime
|
|
47
|
+
cached_sequence_number: Private[Optional[int]] = None
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
async def tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
|
|
51
|
+
async with info.context.db() as session:
|
|
52
|
+
stmt = select(models.PromptVersionTag).where(
|
|
53
|
+
models.PromptVersionTag.prompt_version_id == self.id_attr
|
|
54
|
+
)
|
|
55
|
+
return [
|
|
56
|
+
to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
@strawberry.field
|
|
60
|
+
async def user(self, info: Info[Context, None]) -> Optional[User]:
|
|
61
|
+
if self.user_id is None:
|
|
62
|
+
return None
|
|
63
|
+
async with info.context.db() as session:
|
|
64
|
+
user = await session.get(models.User, self.user_id)
|
|
65
|
+
return to_gql_user(user) if user is not None else None
|
|
66
|
+
|
|
67
|
+
@strawberry.field
|
|
68
|
+
async def previous_version(self, info: Info[Context, None]) -> Optional["PromptVersion"]:
|
|
69
|
+
async with info.context.db() as session:
|
|
70
|
+
current_version = await session.get(models.PromptVersion, self.id_attr)
|
|
71
|
+
if current_version is None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
prompt_id = current_version.prompt_id
|
|
75
|
+
|
|
76
|
+
stmt = (
|
|
77
|
+
select(models.PromptVersion)
|
|
78
|
+
.where(models.PromptVersion.prompt_id == prompt_id)
|
|
79
|
+
.where(models.PromptVersion.id < self.id_attr)
|
|
80
|
+
.order_by(models.PromptVersion.created_at.desc())
|
|
81
|
+
.limit(1)
|
|
82
|
+
)
|
|
83
|
+
previous_version = await session.scalar(stmt)
|
|
84
|
+
|
|
85
|
+
if previous_version is not None:
|
|
86
|
+
return to_gql_prompt_version(prompt_version=previous_version)
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
@strawberry.field(
|
|
90
|
+
description="Sequence number (1-based) of prompt versions belonging to the same prompt"
|
|
91
|
+
) # type: ignore
|
|
92
|
+
async def sequence_number(
|
|
93
|
+
self,
|
|
94
|
+
info: Info[Context, None],
|
|
95
|
+
) -> int:
|
|
96
|
+
if self.cached_sequence_number is None:
|
|
97
|
+
seq_num = await info.context.data_loaders.prompt_version_sequence_number.load(
|
|
98
|
+
self.id_attr
|
|
99
|
+
)
|
|
100
|
+
if seq_num is None:
|
|
101
|
+
raise ValueError(f"invalid prompt version: id={self.id_attr}")
|
|
102
|
+
self.cached_sequence_number = seq_num
|
|
103
|
+
return self.cached_sequence_number
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def to_gql_prompt_version(
|
|
107
|
+
prompt_version: models.PromptVersion, sequence_number: Optional[int] = None
|
|
108
|
+
) -> PromptVersion:
|
|
109
|
+
prompt_template_type = PromptTemplateType(prompt_version.template_type)
|
|
110
|
+
prompt_template = to_gql_template_from_orm(prompt_version)
|
|
111
|
+
prompt_template_format = PromptTemplateFormat(prompt_version.template_format)
|
|
112
|
+
tool_choice = None
|
|
113
|
+
if prompt_version.tools is not None:
|
|
114
|
+
tool_schemas, tool_choice = denormalize_tools(
|
|
115
|
+
prompt_version.tools, prompt_version.model_provider
|
|
116
|
+
)
|
|
117
|
+
tools = [ToolDefinition(definition=schema) for schema in tool_schemas]
|
|
118
|
+
else:
|
|
119
|
+
tools = []
|
|
120
|
+
response_format = (
|
|
121
|
+
ResponseFormat(
|
|
122
|
+
definition=denormalize_response_format(
|
|
123
|
+
prompt_version.response_format,
|
|
124
|
+
prompt_version.model_provider,
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
if prompt_version.response_format is not None
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
130
|
+
invocation_parameters = get_raw_invocation_parameters(prompt_version.invocation_parameters)
|
|
131
|
+
if tool_choice is not None:
|
|
132
|
+
invocation_parameters["tool_choice"] = tool_choice
|
|
133
|
+
return PromptVersion(
|
|
134
|
+
id_attr=prompt_version.id,
|
|
135
|
+
user_id=prompt_version.user_id,
|
|
136
|
+
description=prompt_version.description,
|
|
137
|
+
template_type=prompt_template_type,
|
|
138
|
+
template_format=prompt_template_format,
|
|
139
|
+
template=prompt_template,
|
|
140
|
+
invocation_parameters=invocation_parameters,
|
|
141
|
+
tools=tools,
|
|
142
|
+
response_format=response_format,
|
|
143
|
+
model_name=prompt_version.model_name,
|
|
144
|
+
model_provider=prompt_version.model_provider,
|
|
145
|
+
metadata=prompt_version.metadata_,
|
|
146
|
+
created_at=prompt_version.created_at,
|
|
147
|
+
cached_sequence_number=sequence_number,
|
|
148
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.type
|
|
11
|
+
class PromptVersionTag(Node):
|
|
12
|
+
id_attr: NodeID[int]
|
|
13
|
+
prompt_version_id: GlobalID
|
|
14
|
+
name: Identifier
|
|
15
|
+
description: Optional[str] = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def to_gql_prompt_version_tag(prompt_version_tag: models.PromptVersionTag) -> PromptVersionTag:
|
|
19
|
+
from phoenix.server.api.types.PromptVersion import PromptVersion
|
|
20
|
+
|
|
21
|
+
version_gid = GlobalID(PromptVersion.__name__, str(prompt_version_tag.prompt_version_id))
|
|
22
|
+
return PromptVersionTag(
|
|
23
|
+
id_attr=prompt_version_tag.id,
|
|
24
|
+
prompt_version_id=version_gid,
|
|
25
|
+
name=Identifier(prompt_version_tag.name.root),
|
|
26
|
+
description=prompt_version_tag.description,
|
|
27
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Part of the Phoenix PromptHub feature set
|
|
2
|
+
import json
|
|
3
|
+
from typing import Annotated, Union
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
from typing_extensions import TypeAlias, assert_never
|
|
8
|
+
|
|
9
|
+
from phoenix.db.models import PromptVersion as ORMPromptVersion
|
|
10
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
11
|
+
PromptChatTemplate as PromptChatTemplateModel,
|
|
12
|
+
)
|
|
13
|
+
from phoenix.server.api.helpers.prompts.models import PromptMessage as PromptMessageModel
|
|
14
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
15
|
+
PromptMessageRole,
|
|
16
|
+
PromptTemplateType,
|
|
17
|
+
RoleConversion,
|
|
18
|
+
)
|
|
19
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
20
|
+
PromptStringTemplate as PromptStringTemplateModel,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.type
|
|
25
|
+
class TextContentValue:
|
|
26
|
+
text: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.type
|
|
30
|
+
class TextContentPart:
|
|
31
|
+
text: TextContentValue
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@strawberry.type
|
|
35
|
+
class ToolCallFunction:
|
|
36
|
+
name: str
|
|
37
|
+
arguments: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@strawberry.type
|
|
41
|
+
class ToolCallContentValue:
|
|
42
|
+
tool_call_id: str
|
|
43
|
+
tool_call: ToolCallFunction
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@strawberry.type
|
|
47
|
+
class ToolCallContentPart:
|
|
48
|
+
tool_call: ToolCallContentValue
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@strawberry.type
|
|
52
|
+
class ToolResultContentValue:
|
|
53
|
+
tool_call_id: str
|
|
54
|
+
result: JSON
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.type
|
|
58
|
+
class ToolResultContentPart:
|
|
59
|
+
tool_result: ToolResultContentValue
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
ContentPart: TypeAlias = Annotated[
|
|
63
|
+
Union[TextContentPart, ToolCallContentPart, ToolResultContentPart],
|
|
64
|
+
strawberry.union("ContentPart"),
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@strawberry.type
|
|
69
|
+
class PromptMessage:
|
|
70
|
+
role: PromptMessageRole
|
|
71
|
+
content: list[ContentPart]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@strawberry.experimental.pydantic.type(PromptChatTemplateModel)
|
|
75
|
+
class PromptChatTemplate:
|
|
76
|
+
messages: list[PromptMessage]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def to_gql_prompt_chat_template_from_orm(orm_model: "ORMPromptVersion") -> "PromptChatTemplate":
|
|
80
|
+
template = PromptChatTemplateModel.model_validate(orm_model.template)
|
|
81
|
+
messages: list[PromptMessage] = []
|
|
82
|
+
for msg in template.messages:
|
|
83
|
+
role = RoleConversion.to_gql(msg.role)
|
|
84
|
+
if isinstance(msg, PromptMessageModel):
|
|
85
|
+
if isinstance(msg.content, str):
|
|
86
|
+
messages.append(
|
|
87
|
+
PromptMessage(
|
|
88
|
+
role=role,
|
|
89
|
+
content=[TextContentPart(text=TextContentValue(text=msg.content))],
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
continue
|
|
93
|
+
content: list[ContentPart] = []
|
|
94
|
+
for part in msg.content:
|
|
95
|
+
if part.type == "text":
|
|
96
|
+
content.append(TextContentPart(text=TextContentValue(text=part.text)))
|
|
97
|
+
elif part.type == "tool_call":
|
|
98
|
+
content.append(
|
|
99
|
+
ToolCallContentPart(
|
|
100
|
+
tool_call=ToolCallContentValue(
|
|
101
|
+
tool_call_id=part.tool_call_id,
|
|
102
|
+
tool_call=ToolCallFunction(
|
|
103
|
+
name=part.tool_call.name,
|
|
104
|
+
arguments=part.tool_call.arguments,
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
elif part.type == "tool_result":
|
|
110
|
+
content.append(
|
|
111
|
+
ToolResultContentPart(
|
|
112
|
+
tool_result=ToolResultContentValue(
|
|
113
|
+
tool_call_id=part.tool_call_id,
|
|
114
|
+
result=json.dumps(part.tool_result),
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
assert_never(part)
|
|
120
|
+
messages.append(PromptMessage(role=role, content=content))
|
|
121
|
+
else:
|
|
122
|
+
assert_never(msg)
|
|
123
|
+
return PromptChatTemplate(messages=messages)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@strawberry.experimental.pydantic.type(PromptStringTemplateModel)
|
|
127
|
+
class PromptStringTemplate:
|
|
128
|
+
template: strawberry.auto
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def to_gql_prompt_string_template_from_orm(orm_model: "ORMPromptVersion") -> "PromptStringTemplate":
|
|
132
|
+
model = PromptStringTemplateModel.model_validate(orm_model.template)
|
|
133
|
+
return PromptStringTemplate(template=model.template)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def to_gql_template_from_orm(orm_prompt_version: "ORMPromptVersion") -> "PromptTemplate":
|
|
137
|
+
template_type = PromptTemplateType(orm_prompt_version.template_type)
|
|
138
|
+
if template_type is PromptTemplateType.STRING:
|
|
139
|
+
return to_gql_prompt_string_template_from_orm(orm_prompt_version)
|
|
140
|
+
elif template_type is PromptTemplateType.CHAT:
|
|
141
|
+
return to_gql_prompt_chat_template_from_orm(orm_prompt_version)
|
|
142
|
+
assert_never(template_type)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
PromptTemplate: TypeAlias = Annotated[
|
|
146
|
+
Union[PromptStringTemplate, PromptChatTemplate],
|
|
147
|
+
strawberry.union("PromptTemplate"),
|
|
148
|
+
]
|
phoenix/server/app.py
CHANGED
|
@@ -86,6 +86,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
86
86
|
LatencyMsQuantileDataLoader,
|
|
87
87
|
MinStartOrMaxEndTimeDataLoader,
|
|
88
88
|
ProjectByNameDataLoader,
|
|
89
|
+
PromptVersionSequenceNumberDataLoader,
|
|
89
90
|
RecordCountDataLoader,
|
|
90
91
|
SessionIODataLoader,
|
|
91
92
|
SessionNumTracesDataLoader,
|
|
@@ -611,6 +612,7 @@ def create_graphql_router(
|
|
|
611
612
|
else None
|
|
612
613
|
),
|
|
613
614
|
),
|
|
615
|
+
prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
|
|
614
616
|
record_counts=RecordCountDataLoader(
|
|
615
617
|
db,
|
|
616
618
|
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
|
|
@@ -914,6 +916,7 @@ def create_app(
|
|
|
914
916
|
),
|
|
915
917
|
name="static",
|
|
916
918
|
)
|
|
919
|
+
app.state.authentication_enabled = authentication_enabled
|
|
917
920
|
app.state.read_only = read_only
|
|
918
921
|
app.state.export_path = export_path
|
|
919
922
|
app.state.password_reset_token_expiry = password_reset_token_expiry
|