arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
- phoenix/config.py +5 -2
- phoenix/datetime_utils.py +8 -1
- phoenix/db/bulk_inserter.py +40 -1
- phoenix/db/facilitator.py +263 -4
- phoenix/db/insertion/helpers.py +15 -0
- phoenix/db/insertion/span.py +3 -1
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/models.py +267 -9
- phoenix/db/types/model_provider.py +1 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/server/api/context.py +38 -4
- phoenix/server/api/dataloaders/__init__.py +41 -5
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +35 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/helpers/playground_clients.py +562 -12
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/models.py +67 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
- phoenix/server/api/input_types/SpanSort.py +17 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/chat_mutations.py +17 -0
- phoenix/server/api/mutations/model_mutations.py +208 -0
- phoenix/server/api/queries.py +82 -41
- phoenix/server/api/routers/v1/traces.py +11 -4
- phoenix/server/api/subscriptions.py +36 -2
- phoenix/server/api/types/CostBreakdown.py +15 -0
- phoenix/server/api/types/Experiment.py +59 -1
- phoenix/server/api/types/ExperimentRun.py +58 -4
- phoenix/server/api/types/GenerativeModel.py +143 -2
- phoenix/server/api/types/GenerativeProvider.py +33 -20
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +11 -0
- phoenix/server/api/types/PlaygroundModel.py +10 -0
- phoenix/server/api/types/Project.py +42 -0
- phoenix/server/api/types/ProjectSession.py +44 -0
- phoenix/server/api/types/Span.py +137 -0
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +41 -0
- phoenix/server/app.py +59 -0
- phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/generative_model_store.py +51 -0
- phoenix/server/daemons/span_cost_calculator.py +103 -0
- phoenix/server/dml_event_handler.py +1 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
- phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
- phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
- phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
- phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
- phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
4
|
+
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from anthropic.types import (
|
|
9
|
+
ToolChoiceAnyParam,
|
|
10
|
+
ToolChoiceAutoParam,
|
|
11
|
+
ToolChoiceParam,
|
|
12
|
+
ToolChoiceToolParam,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
16
|
+
PromptToolChoiceNone,
|
|
17
|
+
PromptToolChoiceOneOrMore,
|
|
18
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
19
|
+
PromptToolChoiceZeroOrMore,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AwsToolChoiceConversion:
|
|
24
|
+
@staticmethod
|
|
25
|
+
def to_aws(
|
|
26
|
+
obj: Union[
|
|
27
|
+
PromptToolChoiceNone,
|
|
28
|
+
PromptToolChoiceZeroOrMore,
|
|
29
|
+
PromptToolChoiceOneOrMore,
|
|
30
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
31
|
+
],
|
|
32
|
+
disable_parallel_tool_use: Optional[bool] = None,
|
|
33
|
+
) -> ToolChoiceParam:
|
|
34
|
+
if obj.type == "zero_or_more":
|
|
35
|
+
choice_auto: ToolChoiceAutoParam = {"type": "auto"}
|
|
36
|
+
if disable_parallel_tool_use is not None:
|
|
37
|
+
choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
38
|
+
return choice_auto
|
|
39
|
+
if obj.type == "one_or_more":
|
|
40
|
+
choice_any: ToolChoiceAnyParam = {"type": "any"}
|
|
41
|
+
if disable_parallel_tool_use is not None:
|
|
42
|
+
choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
43
|
+
return choice_any
|
|
44
|
+
if obj.type == "specific_function":
|
|
45
|
+
choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
|
|
46
|
+
if disable_parallel_tool_use is not None:
|
|
47
|
+
choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
|
|
48
|
+
return choice_tool
|
|
49
|
+
if obj.type == "none":
|
|
50
|
+
return {"type": "none"}
|
|
51
|
+
assert_never(obj.type)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def from_aws(
|
|
55
|
+
obj: ToolChoiceParam,
|
|
56
|
+
) -> Union[
|
|
57
|
+
PromptToolChoiceNone,
|
|
58
|
+
PromptToolChoiceZeroOrMore,
|
|
59
|
+
PromptToolChoiceOneOrMore,
|
|
60
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
61
|
+
]:
|
|
62
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
63
|
+
PromptToolChoiceNone,
|
|
64
|
+
PromptToolChoiceOneOrMore,
|
|
65
|
+
PromptToolChoiceSpecificFunctionTool,
|
|
66
|
+
PromptToolChoiceZeroOrMore,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if obj["type"] == "auto":
|
|
70
|
+
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
|
|
71
|
+
return choice_zero_or_more
|
|
72
|
+
if obj["type"] == "any":
|
|
73
|
+
choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
|
|
74
|
+
return choice_one_or_more
|
|
75
|
+
if obj["type"] == "tool":
|
|
76
|
+
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
|
|
77
|
+
type="specific_function",
|
|
78
|
+
function_name=obj["name"],
|
|
79
|
+
)
|
|
80
|
+
return choice_function_tool
|
|
81
|
+
if obj["type"] == "none":
|
|
82
|
+
return PromptToolChoiceNone(type="none")
|
|
83
|
+
assert_never(obj)
|
|
@@ -9,6 +9,7 @@ from typing_extensions import Annotated, Self, TypeAlias, TypeGuard, assert_neve
|
|
|
9
9
|
from phoenix.db.types.db_models import UNDEFINED, DBBaseModel
|
|
10
10
|
from phoenix.db.types.model_provider import ModelProvider
|
|
11
11
|
from phoenix.server.api.helpers.prompts.conversions.anthropic import AnthropicToolChoiceConversion
|
|
12
|
+
from phoenix.server.api.helpers.prompts.conversions.aws import AwsToolChoiceConversion
|
|
12
13
|
from phoenix.server.api.helpers.prompts.conversions.openai import OpenAIToolChoiceConversion
|
|
13
14
|
|
|
14
15
|
JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]
|
|
@@ -312,6 +313,14 @@ class AnthropicToolDefinition(DBBaseModel):
|
|
|
312
313
|
description: str = UNDEFINED
|
|
313
314
|
|
|
314
315
|
|
|
316
|
+
class BedrockToolDefinition(DBBaseModel):
|
|
317
|
+
"""
|
|
318
|
+
Based on https://github.com/aws/amazon-bedrock-sdk-python/blob/main/src/bedrock/types/tool_param.py#L12
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
toolSpec: dict[str, Any]
|
|
322
|
+
|
|
323
|
+
|
|
315
324
|
class PromptOpenAIInvocationParametersContent(DBBaseModel):
|
|
316
325
|
temperature: float = UNDEFINED
|
|
317
326
|
max_tokens: int = UNDEFINED
|
|
@@ -397,6 +406,17 @@ class PromptAnthropicInvocationParameters(DBBaseModel):
|
|
|
397
406
|
anthropic: PromptAnthropicInvocationParametersContent
|
|
398
407
|
|
|
399
408
|
|
|
409
|
+
class PromptAwsInvocationParametersContent(DBBaseModel):
|
|
410
|
+
max_tokens: int = UNDEFINED
|
|
411
|
+
temperature: float = UNDEFINED
|
|
412
|
+
top_p: float = UNDEFINED
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class PromptAwsInvocationParameters(DBBaseModel):
|
|
416
|
+
type: Literal["aws"]
|
|
417
|
+
aws: PromptAwsInvocationParametersContent
|
|
418
|
+
|
|
419
|
+
|
|
400
420
|
class PromptGoogleInvocationParametersContent(DBBaseModel):
|
|
401
421
|
temperature: float = UNDEFINED
|
|
402
422
|
max_output_tokens: int = UNDEFINED
|
|
@@ -421,6 +441,7 @@ PromptInvocationParameters: TypeAlias = Annotated[
|
|
|
421
441
|
PromptDeepSeekInvocationParameters,
|
|
422
442
|
PromptXAIInvocationParameters,
|
|
423
443
|
PromptOllamaInvocationParameters,
|
|
444
|
+
PromptAwsInvocationParameters,
|
|
424
445
|
],
|
|
425
446
|
Field(..., discriminator="type"),
|
|
426
447
|
]
|
|
@@ -443,6 +464,8 @@ def get_raw_invocation_parameters(
|
|
|
443
464
|
return invocation_parameters.xai.model_dump()
|
|
444
465
|
if isinstance(invocation_parameters, PromptOllamaInvocationParameters):
|
|
445
466
|
return invocation_parameters.ollama.model_dump()
|
|
467
|
+
if isinstance(invocation_parameters, PromptAwsInvocationParameters):
|
|
468
|
+
return invocation_parameters.aws.model_dump()
|
|
446
469
|
assert_never(invocation_parameters)
|
|
447
470
|
|
|
448
471
|
|
|
@@ -459,6 +482,7 @@ def is_prompt_invocation_parameters(
|
|
|
459
482
|
PromptDeepSeekInvocationParameters,
|
|
460
483
|
PromptXAIInvocationParameters,
|
|
461
484
|
PromptOllamaInvocationParameters,
|
|
485
|
+
PromptAwsInvocationParameters,
|
|
462
486
|
),
|
|
463
487
|
)
|
|
464
488
|
|
|
@@ -512,6 +536,11 @@ def validate_invocation_parameters(
|
|
|
512
536
|
type="ollama",
|
|
513
537
|
ollama=PromptOllamaInvocationParametersContent.model_validate(invocation_parameters),
|
|
514
538
|
)
|
|
539
|
+
elif model_provider is ModelProvider.AWS:
|
|
540
|
+
return PromptAwsInvocationParameters(
|
|
541
|
+
type="aws",
|
|
542
|
+
aws=PromptAwsInvocationParametersContent.model_validate(invocation_parameters),
|
|
543
|
+
)
|
|
515
544
|
assert_never(model_provider)
|
|
516
545
|
|
|
517
546
|
|
|
@@ -530,12 +559,16 @@ def normalize_tools(
|
|
|
530
559
|
):
|
|
531
560
|
openai_tools = [OpenAIToolDefinition.model_validate(schema) for schema in schemas]
|
|
532
561
|
tools = [_openai_to_prompt_tool(openai_tool) for openai_tool in openai_tools]
|
|
562
|
+
elif model_provider is ModelProvider.AWS:
|
|
563
|
+
bedrock_tools = [BedrockToolDefinition.model_validate(schema) for schema in schemas]
|
|
564
|
+
tools = [_bedrock_to_prompt_tool(bedrock_tool) for bedrock_tool in bedrock_tools]
|
|
533
565
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
534
566
|
anthropic_tools = [AnthropicToolDefinition.model_validate(schema) for schema in schemas]
|
|
535
567
|
tools = [_anthropic_to_prompt_tool(anthropic_tool) for anthropic_tool in anthropic_tools]
|
|
536
568
|
else:
|
|
537
569
|
raise ValueError(f"Unsupported model provider: {model_provider}")
|
|
538
570
|
ans = PromptTools(type="tools", tools=tools)
|
|
571
|
+
|
|
539
572
|
if tool_choice is not None:
|
|
540
573
|
if (
|
|
541
574
|
model_provider is ModelProvider.OPENAI
|
|
@@ -545,6 +578,8 @@ def normalize_tools(
|
|
|
545
578
|
or model_provider is ModelProvider.OLLAMA
|
|
546
579
|
):
|
|
547
580
|
ans.tool_choice = OpenAIToolChoiceConversion.from_openai(tool_choice) # type: ignore[arg-type]
|
|
581
|
+
elif model_provider is ModelProvider.AWS:
|
|
582
|
+
ans.tool_choice = AwsToolChoiceConversion.from_aws(tool_choice) # type: ignore[arg-type]
|
|
548
583
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
549
584
|
choice, disable_parallel_tool_calls = AnthropicToolChoiceConversion.from_anthropic(
|
|
550
585
|
tool_choice # type: ignore[arg-type]
|
|
@@ -571,6 +606,10 @@ def denormalize_tools(
|
|
|
571
606
|
denormalized_tools = [_prompt_to_openai_tool(tool) for tool in tools.tools]
|
|
572
607
|
if tools.tool_choice:
|
|
573
608
|
tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
|
|
609
|
+
elif model_provider is ModelProvider.AWS:
|
|
610
|
+
denormalized_tools = [_prompt_to_bedrock_tool(tool) for tool in tools.tools]
|
|
611
|
+
if tools.tool_choice:
|
|
612
|
+
tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
|
|
574
613
|
elif model_provider is ModelProvider.ANTHROPIC:
|
|
575
614
|
denormalized_tools = [_prompt_to_anthropic_tool(tool) for tool in tools.tools]
|
|
576
615
|
if tools.tool_choice and tools.tool_choice.type != "none":
|
|
@@ -614,6 +653,19 @@ def _prompt_to_openai_tool(
|
|
|
614
653
|
)
|
|
615
654
|
|
|
616
655
|
|
|
656
|
+
def _bedrock_to_prompt_tool(
|
|
657
|
+
tool: BedrockToolDefinition,
|
|
658
|
+
) -> PromptToolFunction:
|
|
659
|
+
return PromptToolFunction(
|
|
660
|
+
type="function",
|
|
661
|
+
function=PromptToolFunctionDefinition(
|
|
662
|
+
name=tool.toolSpec["name"],
|
|
663
|
+
description=tool.toolSpec["description"],
|
|
664
|
+
parameters=tool.toolSpec["inputSchema"]["json"],
|
|
665
|
+
),
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
|
|
617
669
|
def _anthropic_to_prompt_tool(
|
|
618
670
|
tool: AnthropicToolDefinition,
|
|
619
671
|
) -> PromptToolFunction:
|
|
@@ -636,3 +688,18 @@ def _prompt_to_anthropic_tool(
|
|
|
636
688
|
name=function.name,
|
|
637
689
|
description=function.description,
|
|
638
690
|
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def _prompt_to_bedrock_tool(
|
|
694
|
+
tool: PromptToolFunction,
|
|
695
|
+
) -> BedrockToolDefinition:
|
|
696
|
+
function = tool.function
|
|
697
|
+
return BedrockToolDefinition(
|
|
698
|
+
toolSpec={
|
|
699
|
+
"name": function.name,
|
|
700
|
+
"description": function.description,
|
|
701
|
+
"inputSchema": {
|
|
702
|
+
"json": function.parameters,
|
|
703
|
+
},
|
|
704
|
+
}
|
|
705
|
+
)
|
|
@@ -17,3 +17,5 @@ class GenerativeModelInput:
|
|
|
17
17
|
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
18
18
|
api_version: Optional[str] = UNSET
|
|
19
19
|
""" The API version to use for the model. """
|
|
20
|
+
region: Optional[str] = UNSET
|
|
21
|
+
""" The region to use for the model. """
|
|
@@ -13,6 +13,7 @@ class ProjectSessionColumn(Enum):
|
|
|
13
13
|
endTime = auto()
|
|
14
14
|
tokenCountTotal = auto()
|
|
15
15
|
numTraces = auto()
|
|
16
|
+
costTotal = auto()
|
|
16
17
|
|
|
17
18
|
@property
|
|
18
19
|
def data_type(self) -> CursorSortColumnDataType:
|
|
@@ -20,6 +21,8 @@ class ProjectSessionColumn(Enum):
|
|
|
20
21
|
return CursorSortColumnDataType.INT
|
|
21
22
|
if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
|
|
22
23
|
return CursorSortColumnDataType.DATETIME
|
|
24
|
+
if self is ProjectSessionColumn.costTotal:
|
|
25
|
+
return CursorSortColumnDataType.FLOAT
|
|
23
26
|
assert_never(self)
|
|
24
27
|
|
|
25
28
|
|
|
@@ -27,6 +27,7 @@ class SpanColumn(Enum):
|
|
|
27
27
|
cumulativeTokenCountTotal = auto()
|
|
28
28
|
cumulativeTokenCountPrompt = auto()
|
|
29
29
|
cumulativeTokenCountCompletion = auto()
|
|
30
|
+
tokenCostTotal = auto()
|
|
30
31
|
|
|
31
32
|
@property
|
|
32
33
|
def column_name(self) -> str:
|
|
@@ -56,6 +57,8 @@ class SpanColumn(Enum):
|
|
|
56
57
|
expr = models.Span.cumulative_llm_token_count_prompt
|
|
57
58
|
elif self is SpanColumn.cumulativeTokenCountCompletion:
|
|
58
59
|
expr = models.Span.cumulative_llm_token_count_completion
|
|
60
|
+
elif self is SpanColumn.tokenCostTotal:
|
|
61
|
+
expr = models.SpanCost.total_cost
|
|
59
62
|
else:
|
|
60
63
|
assert_never(self)
|
|
61
64
|
return expr.label(self.column_name)
|
|
@@ -73,12 +76,25 @@ class SpanColumn(Enum):
|
|
|
73
76
|
or self is SpanColumn.tokenCountTotal
|
|
74
77
|
or self is SpanColumn.tokenCountPrompt
|
|
75
78
|
or self is SpanColumn.tokenCountCompletion
|
|
79
|
+
or self is SpanColumn.tokenCostTotal
|
|
76
80
|
):
|
|
77
81
|
return CursorSortColumnDataType.FLOAT
|
|
78
82
|
if self is SpanColumn.startTime or self is SpanColumn.endTime:
|
|
79
83
|
return CursorSortColumnDataType.DATETIME
|
|
80
84
|
assert_never(self)
|
|
81
85
|
|
|
86
|
+
def join_tables(self, stmt: Select[Any]) -> Select[Any]:
|
|
87
|
+
"""
|
|
88
|
+
If needed, joins tables required for the sort column.
|
|
89
|
+
"""
|
|
90
|
+
if self is SpanColumn.tokenCostTotal:
|
|
91
|
+
return stmt.join_from(
|
|
92
|
+
models.Span,
|
|
93
|
+
models.SpanCost,
|
|
94
|
+
onclause=models.SpanCost.span_rowid == models.Span.id,
|
|
95
|
+
)
|
|
96
|
+
return stmt
|
|
97
|
+
|
|
82
98
|
|
|
83
99
|
@strawberry.enum
|
|
84
100
|
class EvalAttr(Enum):
|
|
@@ -140,6 +156,7 @@ class SpanSort:
|
|
|
140
156
|
def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig:
|
|
141
157
|
if (col := self.col) and not self.eval_result_key:
|
|
142
158
|
expr = col.orm_expression
|
|
159
|
+
stmt = col.join_tables(stmt)
|
|
143
160
|
stmt = stmt.add_columns(expr)
|
|
144
161
|
if self.dir == SortDir.desc:
|
|
145
162
|
expr = desc(expr)
|
|
@@ -8,6 +8,7 @@ from phoenix.server.api.mutations.chat_mutations import (
|
|
|
8
8
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
9
9
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
10
10
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
11
|
+
from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
|
|
11
12
|
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
|
|
12
13
|
from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
|
|
13
14
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
@@ -29,6 +30,7 @@ class Mutation(
|
|
|
29
30
|
DatasetMutationMixin,
|
|
30
31
|
ExperimentMutationMixin,
|
|
31
32
|
ExportEventsMutationMixin,
|
|
33
|
+
ModelMutationMixin,
|
|
32
34
|
ProjectMutationMixin,
|
|
33
35
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
34
36
|
PromptMutationMixin,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import logging
|
|
2
3
|
from dataclasses import asdict, field
|
|
3
4
|
from datetime import datetime, timezone
|
|
4
5
|
from itertools import chain, islice
|
|
@@ -73,6 +74,8 @@ from phoenix.utilities.template_formatters import (
|
|
|
73
74
|
TemplateFormatter,
|
|
74
75
|
)
|
|
75
76
|
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
76
79
|
initialize_playground_clients()
|
|
77
80
|
|
|
78
81
|
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
|
|
@@ -450,6 +453,19 @@ class ChatCompletionMutationMixin:
|
|
|
450
453
|
session.add(trace)
|
|
451
454
|
session.add(span)
|
|
452
455
|
await session.flush()
|
|
456
|
+
try:
|
|
457
|
+
span_cost = info.context.span_cost_calculator.calculate_cost(
|
|
458
|
+
start_time=span.start_time,
|
|
459
|
+
attributes=span.attributes,
|
|
460
|
+
)
|
|
461
|
+
except Exception as e:
|
|
462
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
463
|
+
span_cost = None
|
|
464
|
+
if span_cost:
|
|
465
|
+
span_cost.span_rowid = span.id
|
|
466
|
+
span_cost.trace_rowid = trace.id
|
|
467
|
+
session.add(span_cost)
|
|
468
|
+
await session.flush()
|
|
453
469
|
|
|
454
470
|
gql_span = Span(span_rowid=span.id, db_span=span)
|
|
455
471
|
|
|
@@ -605,5 +621,6 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
|
|
|
605
621
|
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
606
622
|
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
607
623
|
|
|
624
|
+
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|
|
608
625
|
|
|
609
626
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
import strawberry
|
|
7
|
+
from sqlalchemy import delete
|
|
8
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
9
|
+
from sqlalchemy.orm import joinedload
|
|
10
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
11
|
+
from strawberry.relay import GlobalID
|
|
12
|
+
from strawberry.types import Info
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.auth import IsNotReadOnly
|
|
16
|
+
from phoenix.server.api.context import Context
|
|
17
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
18
|
+
from phoenix.server.api.queries import Query
|
|
19
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
|
|
20
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
|
+
from phoenix.server.api.types.TokenPrice import TokenKind
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.input
|
|
25
|
+
class TokenPriceInput:
|
|
26
|
+
token_type: str
|
|
27
|
+
cost_per_million_tokens: float
|
|
28
|
+
kind: TokenKind
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def token_prices(self) -> models.TokenPrice:
|
|
32
|
+
"""Generate TokenPrice instances based on the input."""
|
|
33
|
+
return models.TokenPrice(
|
|
34
|
+
token_type=self.token_type,
|
|
35
|
+
is_prompt=self.kind == TokenKind.PROMPT,
|
|
36
|
+
base_rate=self.cost_per_million_tokens / 1_000_000,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@strawberry.input
|
|
41
|
+
class CreateModelMutationInput:
|
|
42
|
+
name: str
|
|
43
|
+
provider: Optional[str] = None
|
|
44
|
+
name_pattern: str
|
|
45
|
+
costs: list[TokenPriceInput]
|
|
46
|
+
start_time: Optional[datetime] = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@strawberry.type
|
|
50
|
+
class CreateModelMutationPayload:
|
|
51
|
+
model: GenerativeModel
|
|
52
|
+
query: Query
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@strawberry.input
|
|
56
|
+
class UpdateModelMutationInput:
|
|
57
|
+
id: GlobalID
|
|
58
|
+
name: str
|
|
59
|
+
provider: Optional[str]
|
|
60
|
+
name_pattern: str
|
|
61
|
+
costs: list[TokenPriceInput]
|
|
62
|
+
start_time: Optional[datetime] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class UpdateModelMutationPayload:
|
|
67
|
+
model: GenerativeModel
|
|
68
|
+
query: Query
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@strawberry.input
|
|
72
|
+
class DeleteModelMutationInput:
|
|
73
|
+
id: GlobalID
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@strawberry.type
|
|
77
|
+
class DeleteModelMutationPayload:
|
|
78
|
+
model: GenerativeModel
|
|
79
|
+
query: Query
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@strawberry.type
|
|
83
|
+
class ModelMutationMixin:
|
|
84
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
85
|
+
async def create_model(
|
|
86
|
+
self,
|
|
87
|
+
info: Info[Context, None],
|
|
88
|
+
input: CreateModelMutationInput,
|
|
89
|
+
) -> CreateModelMutationPayload:
|
|
90
|
+
cost_types = set(cost.token_type for cost in input.costs)
|
|
91
|
+
if "input" not in cost_types:
|
|
92
|
+
raise BadRequest("input cost is required")
|
|
93
|
+
if "output" not in cost_types:
|
|
94
|
+
raise BadRequest("output cost is required")
|
|
95
|
+
name_pattern = _compile_regular_expression(input.name_pattern)
|
|
96
|
+
token_prices = [cost.token_prices for cost in input.costs]
|
|
97
|
+
model = models.GenerativeModel(
|
|
98
|
+
name=input.name,
|
|
99
|
+
provider=input.provider,
|
|
100
|
+
name_pattern=name_pattern,
|
|
101
|
+
is_built_in=False,
|
|
102
|
+
token_prices=token_prices,
|
|
103
|
+
start_time=input.start_time,
|
|
104
|
+
)
|
|
105
|
+
async with info.context.db() as session:
|
|
106
|
+
session.add(model)
|
|
107
|
+
try:
|
|
108
|
+
await session.flush()
|
|
109
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
110
|
+
raise Conflict(f"Model with name '{input.name}' already exists")
|
|
111
|
+
|
|
112
|
+
return CreateModelMutationPayload(
|
|
113
|
+
model=to_gql_generative_model(model),
|
|
114
|
+
query=Query(),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
118
|
+
async def update_model(
|
|
119
|
+
self,
|
|
120
|
+
info: Info[Context, None],
|
|
121
|
+
input: UpdateModelMutationInput,
|
|
122
|
+
) -> UpdateModelMutationPayload:
|
|
123
|
+
try:
|
|
124
|
+
model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
|
|
125
|
+
except ValueError:
|
|
126
|
+
raise BadRequest(f'Invalid model id: "{input.id}"')
|
|
127
|
+
|
|
128
|
+
cost_types = set(cost.token_type for cost in input.costs)
|
|
129
|
+
if "input" not in cost_types:
|
|
130
|
+
raise BadRequest("input cost is required")
|
|
131
|
+
if "output" not in cost_types:
|
|
132
|
+
raise BadRequest("output cost is required")
|
|
133
|
+
name_pattern = _compile_regular_expression(input.name_pattern)
|
|
134
|
+
token_prices = [cost.token_prices for cost in input.costs]
|
|
135
|
+
async with info.context.db() as session:
|
|
136
|
+
model = await session.scalar(
|
|
137
|
+
sa.select(models.GenerativeModel)
|
|
138
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
139
|
+
.where(models.GenerativeModel.id == model_id)
|
|
140
|
+
.options(joinedload(models.GenerativeModel.token_prices))
|
|
141
|
+
)
|
|
142
|
+
if model is None:
|
|
143
|
+
raise NotFound(f'Model "{input.id}" not found')
|
|
144
|
+
if model.is_built_in:
|
|
145
|
+
raise BadRequest("Cannot update built-in model")
|
|
146
|
+
|
|
147
|
+
await session.execute(
|
|
148
|
+
delete(models.TokenPrice).where(models.TokenPrice.model_id == model.id)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
await session.refresh(model)
|
|
152
|
+
|
|
153
|
+
model.name = input.name
|
|
154
|
+
model.provider = input.provider or ""
|
|
155
|
+
model.name_pattern = name_pattern
|
|
156
|
+
model.token_prices = token_prices
|
|
157
|
+
model.start_time = input.start_time
|
|
158
|
+
session.add(model)
|
|
159
|
+
try:
|
|
160
|
+
await session.flush()
|
|
161
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
162
|
+
raise Conflict(f"Model with name '{input.name}' already exists")
|
|
163
|
+
await session.refresh(model)
|
|
164
|
+
|
|
165
|
+
return UpdateModelMutationPayload(
|
|
166
|
+
model=to_gql_generative_model(model),
|
|
167
|
+
query=Query(),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
171
|
+
async def delete_model(
|
|
172
|
+
self,
|
|
173
|
+
info: Info[Context, None],
|
|
174
|
+
input: DeleteModelMutationInput,
|
|
175
|
+
) -> DeleteModelMutationPayload:
|
|
176
|
+
try:
|
|
177
|
+
model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
|
|
178
|
+
except ValueError:
|
|
179
|
+
raise BadRequest(f'Invalid model id: "{input.id}"')
|
|
180
|
+
|
|
181
|
+
async with info.context.db() as session:
|
|
182
|
+
model = await session.scalar(
|
|
183
|
+
sa.update(models.GenerativeModel)
|
|
184
|
+
.values(deleted_at=datetime.now(timezone.utc))
|
|
185
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
186
|
+
.where(models.GenerativeModel.id == model_id)
|
|
187
|
+
.returning(models.GenerativeModel)
|
|
188
|
+
)
|
|
189
|
+
if model is None:
|
|
190
|
+
raise NotFound(f'Model "{input.id}" not found')
|
|
191
|
+
if model.is_built_in:
|
|
192
|
+
await session.rollback()
|
|
193
|
+
raise BadRequest("Cannot delete built-in model")
|
|
194
|
+
return DeleteModelMutationPayload(
|
|
195
|
+
model=to_gql_generative_model(model),
|
|
196
|
+
query=Query(),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _compile_regular_expression(maybe_regex: str) -> re.Pattern[str]:
|
|
201
|
+
"""
|
|
202
|
+
Compile the given string as a regular expression.
|
|
203
|
+
Raises a BadRequest error if the given string is not a valid regex.
|
|
204
|
+
"""
|
|
205
|
+
try:
|
|
206
|
+
return re.compile(maybe_regex)
|
|
207
|
+
except re.error as error:
|
|
208
|
+
raise BadRequest(f"Invalid regex: {str(error)}")
|