arize-phoenix 7.12.2__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/METADATA +31 -28
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/RECORD +71 -48
- phoenix/config.py +61 -36
- phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
- phoenix/db/models.py +307 -0
- phoenix/db/types/__init__.py +0 -0
- phoenix/db/types/identifier.py +7 -0
- phoenix/db/types/model_provider.py +8 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
- phoenix/server/api/helpers/jsonschema.py +135 -0
- phoenix/server/api/helpers/playground_clients.py +23 -27
- phoenix/server/api/helpers/playground_spans.py +9 -0
- phoenix/server/api/helpers/prompts/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
- phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
- phoenix/server/api/helpers/prompts/models.py +575 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
- phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
- phoenix/server/api/input_types/PromptVersionInput.py +133 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +18 -16
- phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
- phoenix/server/api/mutations/prompt_mutations.py +312 -0
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
- phoenix/server/api/mutations/user_mutations.py +7 -6
- phoenix/server/api/openapi/schema.py +1 -0
- phoenix/server/api/queries.py +84 -31
- phoenix/server/api/routers/oauth2.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/datasets.py +1 -1
- phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/models.py +45 -0
- phoenix/server/api/routers/v1/prompts.py +412 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/routers/v1/traces.py +1 -1
- phoenix/server/api/routers/v1/utils.py +1 -1
- phoenix/server/api/subscriptions.py +21 -24
- phoenix/server/api/types/GenerativeProvider.py +6 -6
- phoenix/server/api/types/Identifier.py +15 -0
- phoenix/server/api/types/Project.py +5 -7
- phoenix/server/api/types/Prompt.py +134 -0
- phoenix/server/api/types/PromptLabel.py +41 -0
- phoenix/server/api/types/PromptVersion.py +148 -0
- phoenix/server/api/types/PromptVersionTag.py +27 -0
- phoenix/server/api/types/PromptVersionTemplate.py +148 -0
- phoenix/server/api/types/ResponseFormat.py +9 -0
- phoenix/server/api/types/ToolDefinition.py +9 -0
- phoenix/server/app.py +3 -0
- phoenix/server/static/.vite/manifest.json +45 -45
- phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
- phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
- phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
- phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
- phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
- phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
- phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
- phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
- phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
- phoenix/session/client.py +25 -21
- phoenix/utilities/client.py +6 -0
- phoenix/version.py +1 -1
- phoenix/server/api/input_types/TemplateOptions.py +0 -10
- phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
- phoenix/server/api/types/TemplateLanguage.py +0 -10
- phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
- phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
- phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
- phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-7.12.2.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
|
@@ -4,7 +4,6 @@ import asyncio
|
|
|
4
4
|
import importlib.util
|
|
5
5
|
import inspect
|
|
6
6
|
import json
|
|
7
|
-
import os
|
|
8
7
|
import time
|
|
9
8
|
from abc import ABC, abstractmethod
|
|
10
9
|
from collections.abc import AsyncIterator, Callable, Iterator
|
|
@@ -22,6 +21,7 @@ from strawberry import UNSET
|
|
|
22
21
|
from strawberry.scalars import JSON as JSONScalarType
|
|
23
22
|
from typing_extensions import TypeAlias, assert_never
|
|
24
23
|
|
|
24
|
+
from phoenix.config import getenv
|
|
25
25
|
from phoenix.evals.models.rate_limiters import (
|
|
26
26
|
AsyncCallable,
|
|
27
27
|
GenericType,
|
|
@@ -483,8 +483,8 @@ class OpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
483
483
|
) -> None:
|
|
484
484
|
from openai import AsyncOpenAI
|
|
485
485
|
|
|
486
|
-
base_url = model.base_url or
|
|
487
|
-
if not (api_key := api_key or
|
|
486
|
+
base_url = model.base_url or getenv("OPENAI_BASE_URL")
|
|
487
|
+
if not (api_key := api_key or getenv("OPENAI_API_KEY")):
|
|
488
488
|
if not base_url:
|
|
489
489
|
raise BadRequest("An API key is required for OpenAI models")
|
|
490
490
|
api_key = "sk-fake-api-key"
|
|
@@ -656,11 +656,11 @@ class AzureOpenAIStreamingClient(OpenAIBaseStreamingClient):
|
|
|
656
656
|
):
|
|
657
657
|
from openai import AsyncAzureOpenAI
|
|
658
658
|
|
|
659
|
-
if not (api_key := api_key or
|
|
659
|
+
if not (api_key := api_key or getenv("AZURE_OPENAI_API_KEY")):
|
|
660
660
|
raise BadRequest("An Azure API key is required for Azure OpenAI models")
|
|
661
|
-
if not (endpoint := model.endpoint or
|
|
661
|
+
if not (endpoint := model.endpoint or getenv("AZURE_OPENAI_ENDPOINT")):
|
|
662
662
|
raise BadRequest("An Azure endpoint is required for Azure OpenAI models")
|
|
663
|
-
if not (api_version := model.api_version or
|
|
663
|
+
if not (api_version := model.api_version or getenv("OPENAI_API_VERSION")):
|
|
664
664
|
raise BadRequest("An OpenAI API version is required for Azure OpenAI models")
|
|
665
665
|
client = AsyncAzureOpenAI(
|
|
666
666
|
api_key=api_key,
|
|
@@ -697,7 +697,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
697
697
|
super().__init__(model=model, api_key=api_key)
|
|
698
698
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.ANTHROPIC.value
|
|
699
699
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.ANTHROPIC.value
|
|
700
|
-
if not (api_key := api_key or
|
|
700
|
+
if not (api_key := api_key or getenv("ANTHROPIC_API_KEY")):
|
|
701
701
|
raise BadRequest("An API key is required for Anthropic models")
|
|
702
702
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
703
703
|
self.model_name = model.name
|
|
@@ -856,7 +856,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
856
856
|
|
|
857
857
|
|
|
858
858
|
@register_llm_client(
|
|
859
|
-
provider_key=GenerativeProviderKey.
|
|
859
|
+
provider_key=GenerativeProviderKey.GOOGLE,
|
|
860
860
|
model_names=[
|
|
861
861
|
PROVIDER_DEFAULT,
|
|
862
862
|
"gemini-2.0-flash-exp",
|
|
@@ -866,7 +866,7 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
866
866
|
"gemini-1.0-pro",
|
|
867
867
|
],
|
|
868
868
|
)
|
|
869
|
-
class
|
|
869
|
+
class GoogleStreamingClient(PlaygroundStreamingClient):
|
|
870
870
|
def __init__(
|
|
871
871
|
self,
|
|
872
872
|
model: GenerativeModelInput,
|
|
@@ -877,11 +877,7 @@ class GeminiStreamingClient(PlaygroundStreamingClient):
|
|
|
877
877
|
super().__init__(model=model, api_key=api_key)
|
|
878
878
|
self._attributes[LLM_PROVIDER] = OpenInferenceLLMProviderValues.GOOGLE.value
|
|
879
879
|
self._attributes[LLM_SYSTEM] = OpenInferenceLLMSystemValues.VERTEXAI.value
|
|
880
|
-
if not (
|
|
881
|
-
api_key := api_key
|
|
882
|
-
or os.environ.get("GEMINI_API_KEY")
|
|
883
|
-
or os.environ.get("GOOGLE_API_KEY")
|
|
884
|
-
):
|
|
880
|
+
if not (api_key := api_key or getenv("GEMINI_API_KEY") or getenv("GOOGLE_API_KEY")):
|
|
885
881
|
raise BadRequest("An API key is required for Gemini models")
|
|
886
882
|
google_genai.configure(api_key=api_key)
|
|
887
883
|
self.model_name = model.name
|
|
@@ -945,7 +941,7 @@ class GeminiStreamingClient(PlaygroundStreamingClient):
|
|
|
945
941
|
) -> AsyncIterator[ChatCompletionChunk]:
|
|
946
942
|
import google.generativeai as google_genai
|
|
947
943
|
|
|
948
|
-
|
|
944
|
+
google_message_history, current_message, system_prompt = self._build_google_messages(
|
|
949
945
|
messages
|
|
950
946
|
)
|
|
951
947
|
|
|
@@ -954,17 +950,17 @@ class GeminiStreamingClient(PlaygroundStreamingClient):
|
|
|
954
950
|
model_args["system_instruction"] = system_prompt
|
|
955
951
|
client = google_genai.GenerativeModel(**model_args)
|
|
956
952
|
|
|
957
|
-
|
|
953
|
+
google_config = google_genai.GenerationConfig(
|
|
958
954
|
**invocation_parameters,
|
|
959
955
|
)
|
|
960
|
-
|
|
956
|
+
google_params = {
|
|
961
957
|
"content": current_message,
|
|
962
|
-
"generation_config":
|
|
958
|
+
"generation_config": google_config,
|
|
963
959
|
"stream": True,
|
|
964
960
|
}
|
|
965
961
|
|
|
966
|
-
chat = client.start_chat(history=
|
|
967
|
-
stream = await chat.send_message_async(**
|
|
962
|
+
chat = client.start_chat(history=google_message_history)
|
|
963
|
+
stream = await chat.send_message_async(**google_params)
|
|
968
964
|
async for event in stream:
|
|
969
965
|
self._attributes.update(
|
|
970
966
|
{
|
|
@@ -975,29 +971,29 @@ class GeminiStreamingClient(PlaygroundStreamingClient):
|
|
|
975
971
|
)
|
|
976
972
|
yield TextChunk(content=event.text)
|
|
977
973
|
|
|
978
|
-
def
|
|
974
|
+
def _build_google_messages(
|
|
979
975
|
self,
|
|
980
976
|
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
981
977
|
) -> tuple[list["ContentType"], str, str]:
|
|
982
|
-
|
|
978
|
+
google_message_history: list["ContentType"] = []
|
|
983
979
|
system_prompts = []
|
|
984
980
|
for role, content, _tool_call_id, _tool_calls in messages:
|
|
985
981
|
if role == ChatCompletionMessageRole.USER:
|
|
986
|
-
|
|
982
|
+
google_message_history.append({"role": "user", "parts": content})
|
|
987
983
|
elif role == ChatCompletionMessageRole.AI:
|
|
988
|
-
|
|
984
|
+
google_message_history.append({"role": "model", "parts": content})
|
|
989
985
|
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
990
986
|
system_prompts.append(content)
|
|
991
987
|
elif role == ChatCompletionMessageRole.TOOL:
|
|
992
988
|
raise NotImplementedError
|
|
993
989
|
else:
|
|
994
990
|
assert_never(role)
|
|
995
|
-
if
|
|
996
|
-
prompt =
|
|
991
|
+
if google_message_history:
|
|
992
|
+
prompt = google_message_history.pop()["parts"]
|
|
997
993
|
else:
|
|
998
994
|
prompt = ""
|
|
999
995
|
|
|
1000
|
-
return
|
|
996
|
+
return google_message_history, prompt, "\n".join(system_prompts)
|
|
1001
997
|
|
|
1002
998
|
|
|
1003
999
|
def initialize_playground_clients() -> None:
|
|
@@ -41,6 +41,7 @@ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
|
41
41
|
TextChunk,
|
|
42
42
|
ToolCallChunk,
|
|
43
43
|
)
|
|
44
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
44
45
|
from phoenix.trace.attributes import get_attribute_value, unflatten
|
|
45
46
|
from phoenix.trace.schemas import (
|
|
46
47
|
SpanEvent,
|
|
@@ -70,6 +71,8 @@ class streaming_llm_span:
|
|
|
70
71
|
) -> None:
|
|
71
72
|
self._input = input
|
|
72
73
|
self._attributes: dict[str, Any] = attributes if attributes is not None else {}
|
|
74
|
+
self._attributes.update(dict(prompt_metadata(input.prompt_name)))
|
|
75
|
+
|
|
73
76
|
self._attributes.update(
|
|
74
77
|
chain(
|
|
75
78
|
llm_span_kind(),
|
|
@@ -264,6 +267,11 @@ def input_value_and_mime_type(
|
|
|
264
267
|
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
265
268
|
|
|
266
269
|
|
|
270
|
+
def prompt_metadata(prompt_name: Optional[Identifier]) -> Iterator[tuple[str, Any]]:
|
|
271
|
+
if prompt_name:
|
|
272
|
+
yield METADATA, {"phoenix_prompt_id": prompt_name}
|
|
273
|
+
|
|
274
|
+
|
|
267
275
|
def _merge_tool_call_chunks(
|
|
268
276
|
chunks_by_id: defaultdict[str, list[ToolCallChunk]],
|
|
269
277
|
) -> list[dict[str, Any]]:
|
|
@@ -442,6 +450,7 @@ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
|
442
450
|
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
443
451
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
444
452
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
453
|
+
METADATA = SpanAttributes.METADATA
|
|
445
454
|
|
|
446
455
|
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
447
456
|
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
4
|
+
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from anthropic.types import (
|
|
9
|
+
ToolChoiceAnyParam,
|
|
10
|
+
ToolChoiceAutoParam,
|
|
11
|
+
ToolChoiceParam,
|
|
12
|
+
ToolChoiceToolParam,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
16
|
+
PromptToolChoiceOneOrMore,
|
|
17
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
18
|
+
PromptToolChoiceZeroOrMore,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AnthropicToolChoiceConversion:
|
|
23
|
+
@staticmethod
|
|
24
|
+
def to_anthropic(
|
|
25
|
+
obj: Union[
|
|
26
|
+
PromptToolChoiceZeroOrMore,
|
|
27
|
+
PromptToolChoiceOneOrMore,
|
|
28
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
29
|
+
],
|
|
30
|
+
disable_parallel_tool_use: Optional[bool] = None,
|
|
31
|
+
) -> ToolChoiceParam:
|
|
32
|
+
if obj.type == "zero_or_more":
|
|
33
|
+
choice_auto: ToolChoiceAutoParam = {"type": "auto"}
|
|
34
|
+
if disable_parallel_tool_use is not None:
|
|
35
|
+
choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
36
|
+
return choice_auto
|
|
37
|
+
if obj.type == "one_or_more":
|
|
38
|
+
choice_any: ToolChoiceAnyParam = {"type": "any"}
|
|
39
|
+
if disable_parallel_tool_use is not None:
|
|
40
|
+
choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
41
|
+
return choice_any
|
|
42
|
+
if obj.type == "specific_function":
|
|
43
|
+
choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
|
|
44
|
+
if disable_parallel_tool_use is not None:
|
|
45
|
+
choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
46
|
+
return choice_tool
|
|
47
|
+
assert_never(obj.type)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def from_anthropic(
|
|
51
|
+
obj: ToolChoiceParam,
|
|
52
|
+
) -> tuple[
|
|
53
|
+
Union[
|
|
54
|
+
PromptToolChoiceZeroOrMore,
|
|
55
|
+
PromptToolChoiceOneOrMore,
|
|
56
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
57
|
+
],
|
|
58
|
+
Optional[bool],
|
|
59
|
+
]:
|
|
60
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
61
|
+
PromptToolChoiceOneOrMore,
|
|
62
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
63
|
+
PromptToolChoiceZeroOrMore,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if obj["type"] == "auto":
|
|
67
|
+
disable_parallel_tool_use = (
|
|
68
|
+
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
|
|
69
|
+
)
|
|
70
|
+
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
|
|
71
|
+
return choice_zero_or_more, disable_parallel_tool_use
|
|
72
|
+
if obj["type"] == "any":
|
|
73
|
+
disable_parallel_tool_use = (
|
|
74
|
+
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
|
|
75
|
+
)
|
|
76
|
+
choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
|
|
77
|
+
return choice_one_or_more, disable_parallel_tool_use
|
|
78
|
+
if obj["type"] == "tool":
|
|
79
|
+
disable_parallel_tool_use = (
|
|
80
|
+
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
|
|
81
|
+
)
|
|
82
|
+
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
|
|
83
|
+
type="specific_function",
|
|
84
|
+
function_name=obj["name"],
|
|
85
|
+
)
|
|
86
|
+
return choice_function_tool, disable_parallel_tool_use
|
|
87
|
+
assert_never(obj)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Union
|
|
4
|
+
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from openai.types.chat import (
|
|
9
|
+
ChatCompletionNamedToolChoiceParam,
|
|
10
|
+
ChatCompletionToolChoiceOptionParam,
|
|
11
|
+
)
|
|
12
|
+
from openai.types.chat.chat_completion_named_tool_choice_param import Function
|
|
13
|
+
|
|
14
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
15
|
+
PromptToolChoiceNone,
|
|
16
|
+
PromptToolChoiceOneOrMore,
|
|
17
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
18
|
+
PromptToolChoiceZeroOrMore,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OpenAIToolChoiceConversion:
|
|
23
|
+
@staticmethod
|
|
24
|
+
def to_openai(
|
|
25
|
+
obj: Union[
|
|
26
|
+
PromptToolChoiceNone,
|
|
27
|
+
PromptToolChoiceZeroOrMore,
|
|
28
|
+
PromptToolChoiceOneOrMore,
|
|
29
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
30
|
+
],
|
|
31
|
+
) -> ChatCompletionToolChoiceOptionParam:
|
|
32
|
+
if obj.type == "none":
|
|
33
|
+
return "none"
|
|
34
|
+
if obj.type == "zero_or_more":
|
|
35
|
+
return "auto"
|
|
36
|
+
if obj.type == "one_or_more":
|
|
37
|
+
return "required"
|
|
38
|
+
if obj.type == "specific_function":
|
|
39
|
+
choice_tool: ChatCompletionNamedToolChoiceParam = {
|
|
40
|
+
"type": "function",
|
|
41
|
+
"function": {"name": obj.function_name},
|
|
42
|
+
}
|
|
43
|
+
return choice_tool
|
|
44
|
+
assert_never(obj)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_openai(
|
|
48
|
+
obj: ChatCompletionToolChoiceOptionParam,
|
|
49
|
+
) -> Union[
|
|
50
|
+
PromptToolChoiceNone,
|
|
51
|
+
PromptToolChoiceZeroOrMore,
|
|
52
|
+
PromptToolChoiceOneOrMore,
|
|
53
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
54
|
+
]:
|
|
55
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
56
|
+
PromptToolChoiceNone,
|
|
57
|
+
PromptToolChoiceOneOrMore,
|
|
58
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
59
|
+
PromptToolChoiceZeroOrMore,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if obj == "none":
|
|
63
|
+
choice_none = PromptToolChoiceNone(type="none")
|
|
64
|
+
return choice_none
|
|
65
|
+
if obj == "auto":
|
|
66
|
+
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
|
|
67
|
+
return choice_zero_or_more
|
|
68
|
+
if obj == "required":
|
|
69
|
+
choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
|
|
70
|
+
return choice_one_or_more
|
|
71
|
+
if obj["type"] == "function":
|
|
72
|
+
function: Function = obj["function"]
|
|
73
|
+
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
|
|
74
|
+
type="specific_function",
|
|
75
|
+
function_name=function["name"],
|
|
76
|
+
)
|
|
77
|
+
return choice_function_tool
|
|
78
|
+
assert_never(obj["type"])
|