arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/model_provider.py +1 -0
  12. phoenix/db/types/token_price_customization.py +29 -0
  13. phoenix/server/api/context.py +38 -4
  14. phoenix/server/api/dataloaders/__init__.py +41 -5
  15. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  16. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  20. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  21. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  27. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  28. phoenix/server/api/dataloaders/span_costs.py +35 -0
  29. phoenix/server/api/dataloaders/types.py +29 -0
  30. phoenix/server/api/helpers/playground_clients.py +562 -12
  31. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  32. phoenix/server/api/helpers/prompts/models.py +67 -0
  33. phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
  34. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  35. phoenix/server/api/input_types/SpanSort.py +17 -0
  36. phoenix/server/api/mutations/__init__.py +2 -0
  37. phoenix/server/api/mutations/chat_mutations.py +17 -0
  38. phoenix/server/api/mutations/model_mutations.py +208 -0
  39. phoenix/server/api/queries.py +82 -41
  40. phoenix/server/api/routers/v1/traces.py +11 -4
  41. phoenix/server/api/subscriptions.py +36 -2
  42. phoenix/server/api/types/CostBreakdown.py +15 -0
  43. phoenix/server/api/types/Experiment.py +59 -1
  44. phoenix/server/api/types/ExperimentRun.py +58 -4
  45. phoenix/server/api/types/GenerativeModel.py +143 -2
  46. phoenix/server/api/types/GenerativeProvider.py +33 -20
  47. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  48. phoenix/server/api/types/ModelInterface.py +11 -0
  49. phoenix/server/api/types/PlaygroundModel.py +10 -0
  50. phoenix/server/api/types/Project.py +42 -0
  51. phoenix/server/api/types/ProjectSession.py +44 -0
  52. phoenix/server/api/types/Span.py +137 -0
  53. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  54. phoenix/server/api/types/SpanCostSummary.py +10 -0
  55. phoenix/server/api/types/TokenPrice.py +16 -0
  56. phoenix/server/api/types/TokenUsage.py +3 -3
  57. phoenix/server/api/types/Trace.py +41 -0
  58. phoenix/server/app.py +59 -0
  59. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  60. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  61. phoenix/server/cost_tracking/helpers.py +68 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  63. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  64. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  65. phoenix/server/daemons/__init__.py +0 -0
  66. phoenix/server/daemons/generative_model_store.py +51 -0
  67. phoenix/server/daemons/span_cost_calculator.py +103 -0
  68. phoenix/server/dml_event_handler.py +1 -0
  69. phoenix/server/static/.vite/manifest.json +36 -36
  70. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  71. phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
  72. phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
  73. phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
  74. phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
  75. phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  76. phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
  77. phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  80. phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
  81. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional, Union
4
+
5
+ from typing_extensions import assert_never
6
+
7
+ if TYPE_CHECKING:
8
+ from anthropic.types import (
9
+ ToolChoiceAnyParam,
10
+ ToolChoiceAutoParam,
11
+ ToolChoiceParam,
12
+ ToolChoiceToolParam,
13
+ )
14
+
15
+ from phoenix.server.api.helpers.prompts.models import (
16
+ PromptToolChoiceNone,
17
+ PromptToolChoiceOneOrMore,
18
+ PromptToolChoiceSpecificFunctionTool,
19
+ PromptToolChoiceZeroOrMore,
20
+ )
21
+
22
+
23
+ class AwsToolChoiceConversion:
24
+ @staticmethod
25
+ def to_aws(
26
+ obj: Union[
27
+ PromptToolChoiceNone,
28
+ PromptToolChoiceZeroOrMore,
29
+ PromptToolChoiceOneOrMore,
30
+ PromptToolChoiceSpecificFunctionTool,
31
+ ],
32
+ disable_parallel_tool_use: Optional[bool] = None,
33
+ ) -> ToolChoiceParam:
34
+ if obj.type == "zero_or_more":
35
+ choice_auto: ToolChoiceAutoParam = {"type": "auto"}
36
+ if disable_parallel_tool_use is not None:
37
+ choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
38
+ return choice_auto
39
+ if obj.type == "one_or_more":
40
+ choice_any: ToolChoiceAnyParam = {"type": "any"}
41
+ if disable_parallel_tool_use is not None:
42
+ choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
43
+ return choice_any
44
+ if obj.type == "specific_function":
45
+ choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
46
+ if disable_parallel_tool_use is not None:
47
+ choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
48
+ return choice_tool
49
+ if obj.type == "none":
50
+ return {"type": "none"}
51
+ assert_never(obj.type)
52
+
53
+ @staticmethod
54
+ def from_aws(
55
+ obj: ToolChoiceParam,
56
+ ) -> Union[
57
+ PromptToolChoiceNone,
58
+ PromptToolChoiceZeroOrMore,
59
+ PromptToolChoiceOneOrMore,
60
+ PromptToolChoiceSpecificFunctionTool,
61
+ ]:
62
+ from phoenix.server.api.helpers.prompts.models import (
63
+ PromptToolChoiceNone,
64
+ PromptToolChoiceOneOrMore,
65
+ PromptToolChoiceSpecificFunctionTool,
66
+ PromptToolChoiceZeroOrMore,
67
+ )
68
+
69
+ if obj["type"] == "auto":
70
+ choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero_or_more")
71
+ return choice_zero_or_more
72
+ if obj["type"] == "any":
73
+ choice_one_or_more = PromptToolChoiceOneOrMore(type="one_or_more")
74
+ return choice_one_or_more
75
+ if obj["type"] == "tool":
76
+ choice_function_tool = PromptToolChoiceSpecificFunctionTool(
77
+ type="specific_function",
78
+ function_name=obj["name"],
79
+ )
80
+ return choice_function_tool
81
+ if obj["type"] == "none":
82
+ return PromptToolChoiceNone(type="none")
83
+ assert_never(obj)
@@ -9,6 +9,7 @@ from typing_extensions import Annotated, Self, TypeAlias, TypeGuard, assert_neve
9
9
  from phoenix.db.types.db_models import UNDEFINED, DBBaseModel
10
10
  from phoenix.db.types.model_provider import ModelProvider
11
11
  from phoenix.server.api.helpers.prompts.conversions.anthropic import AnthropicToolChoiceConversion
12
+ from phoenix.server.api.helpers.prompts.conversions.aws import AwsToolChoiceConversion
12
13
  from phoenix.server.api.helpers.prompts.conversions.openai import OpenAIToolChoiceConversion
13
14
 
14
15
  JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]
@@ -312,6 +313,14 @@ class AnthropicToolDefinition(DBBaseModel):
312
313
  description: str = UNDEFINED
313
314
 
314
315
 
316
+ class BedrockToolDefinition(DBBaseModel):
317
+ """
318
+ Based on https://github.com/aws/amazon-bedrock-sdk-python/blob/main/src/bedrock/types/tool_param.py#L12
319
+ """
320
+
321
+ toolSpec: dict[str, Any]
322
+
323
+
315
324
  class PromptOpenAIInvocationParametersContent(DBBaseModel):
316
325
  temperature: float = UNDEFINED
317
326
  max_tokens: int = UNDEFINED
@@ -397,6 +406,17 @@ class PromptAnthropicInvocationParameters(DBBaseModel):
397
406
  anthropic: PromptAnthropicInvocationParametersContent
398
407
 
399
408
 
409
+ class PromptAwsInvocationParametersContent(DBBaseModel):
410
+ max_tokens: int = UNDEFINED
411
+ temperature: float = UNDEFINED
412
+ top_p: float = UNDEFINED
413
+
414
+
415
+ class PromptAwsInvocationParameters(DBBaseModel):
416
+ type: Literal["aws"]
417
+ aws: PromptAwsInvocationParametersContent
418
+
419
+
400
420
  class PromptGoogleInvocationParametersContent(DBBaseModel):
401
421
  temperature: float = UNDEFINED
402
422
  max_output_tokens: int = UNDEFINED
@@ -421,6 +441,7 @@ PromptInvocationParameters: TypeAlias = Annotated[
421
441
  PromptDeepSeekInvocationParameters,
422
442
  PromptXAIInvocationParameters,
423
443
  PromptOllamaInvocationParameters,
444
+ PromptAwsInvocationParameters,
424
445
  ],
425
446
  Field(..., discriminator="type"),
426
447
  ]
@@ -443,6 +464,8 @@ def get_raw_invocation_parameters(
443
464
  return invocation_parameters.xai.model_dump()
444
465
  if isinstance(invocation_parameters, PromptOllamaInvocationParameters):
445
466
  return invocation_parameters.ollama.model_dump()
467
+ if isinstance(invocation_parameters, PromptAwsInvocationParameters):
468
+ return invocation_parameters.aws.model_dump()
446
469
  assert_never(invocation_parameters)
447
470
 
448
471
 
@@ -459,6 +482,7 @@ def is_prompt_invocation_parameters(
459
482
  PromptDeepSeekInvocationParameters,
460
483
  PromptXAIInvocationParameters,
461
484
  PromptOllamaInvocationParameters,
485
+ PromptAwsInvocationParameters,
462
486
  ),
463
487
  )
464
488
 
@@ -512,6 +536,11 @@ def validate_invocation_parameters(
512
536
  type="ollama",
513
537
  ollama=PromptOllamaInvocationParametersContent.model_validate(invocation_parameters),
514
538
  )
539
+ elif model_provider is ModelProvider.AWS:
540
+ return PromptAwsInvocationParameters(
541
+ type="aws",
542
+ aws=PromptAwsInvocationParametersContent.model_validate(invocation_parameters),
543
+ )
515
544
  assert_never(model_provider)
516
545
 
517
546
 
@@ -530,12 +559,16 @@ def normalize_tools(
530
559
  ):
531
560
  openai_tools = [OpenAIToolDefinition.model_validate(schema) for schema in schemas]
532
561
  tools = [_openai_to_prompt_tool(openai_tool) for openai_tool in openai_tools]
562
+ elif model_provider is ModelProvider.AWS:
563
+ bedrock_tools = [BedrockToolDefinition.model_validate(schema) for schema in schemas]
564
+ tools = [_bedrock_to_prompt_tool(bedrock_tool) for bedrock_tool in bedrock_tools]
533
565
  elif model_provider is ModelProvider.ANTHROPIC:
534
566
  anthropic_tools = [AnthropicToolDefinition.model_validate(schema) for schema in schemas]
535
567
  tools = [_anthropic_to_prompt_tool(anthropic_tool) for anthropic_tool in anthropic_tools]
536
568
  else:
537
569
  raise ValueError(f"Unsupported model provider: {model_provider}")
538
570
  ans = PromptTools(type="tools", tools=tools)
571
+
539
572
  if tool_choice is not None:
540
573
  if (
541
574
  model_provider is ModelProvider.OPENAI
@@ -545,6 +578,8 @@ def normalize_tools(
545
578
  or model_provider is ModelProvider.OLLAMA
546
579
  ):
547
580
  ans.tool_choice = OpenAIToolChoiceConversion.from_openai(tool_choice) # type: ignore[arg-type]
581
+ elif model_provider is ModelProvider.AWS:
582
+ ans.tool_choice = AwsToolChoiceConversion.from_aws(tool_choice) # type: ignore[arg-type]
548
583
  elif model_provider is ModelProvider.ANTHROPIC:
549
584
  choice, disable_parallel_tool_calls = AnthropicToolChoiceConversion.from_anthropic(
550
585
  tool_choice # type: ignore[arg-type]
@@ -571,6 +606,10 @@ def denormalize_tools(
571
606
  denormalized_tools = [_prompt_to_openai_tool(tool) for tool in tools.tools]
572
607
  if tools.tool_choice:
573
608
  tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
609
+ elif model_provider is ModelProvider.AWS:
610
+ denormalized_tools = [_prompt_to_bedrock_tool(tool) for tool in tools.tools]
611
+ if tools.tool_choice:
612
+ tool_choice = OpenAIToolChoiceConversion.to_openai(tools.tool_choice)
574
613
  elif model_provider is ModelProvider.ANTHROPIC:
575
614
  denormalized_tools = [_prompt_to_anthropic_tool(tool) for tool in tools.tools]
576
615
  if tools.tool_choice and tools.tool_choice.type != "none":
@@ -614,6 +653,19 @@ def _prompt_to_openai_tool(
614
653
  )
615
654
 
616
655
 
656
+ def _bedrock_to_prompt_tool(
657
+ tool: BedrockToolDefinition,
658
+ ) -> PromptToolFunction:
659
+ return PromptToolFunction(
660
+ type="function",
661
+ function=PromptToolFunctionDefinition(
662
+ name=tool.toolSpec["name"],
663
+ description=tool.toolSpec["description"],
664
+ parameters=tool.toolSpec["inputSchema"]["json"],
665
+ ),
666
+ )
667
+
668
+
617
669
  def _anthropic_to_prompt_tool(
618
670
  tool: AnthropicToolDefinition,
619
671
  ) -> PromptToolFunction:
@@ -636,3 +688,18 @@ def _prompt_to_anthropic_tool(
636
688
  name=function.name,
637
689
  description=function.description,
638
690
  )
691
+
692
+
693
+ def _prompt_to_bedrock_tool(
694
+ tool: PromptToolFunction,
695
+ ) -> BedrockToolDefinition:
696
+ function = tool.function
697
+ return BedrockToolDefinition(
698
+ toolSpec={
699
+ "name": function.name,
700
+ "description": function.description,
701
+ "inputSchema": {
702
+ "json": function.parameters,
703
+ },
704
+ }
705
+ )
@@ -17,3 +17,5 @@ class GenerativeModelInput:
17
17
  """ The endpoint to use for the model. Only required for Azure OpenAI models. """
18
18
  api_version: Optional[str] = UNSET
19
19
  """ The API version to use for the model. """
20
+ region: Optional[str] = UNSET
21
+ """ The region to use for the model. """
@@ -13,6 +13,7 @@ class ProjectSessionColumn(Enum):
13
13
  endTime = auto()
14
14
  tokenCountTotal = auto()
15
15
  numTraces = auto()
16
+ costTotal = auto()
16
17
 
17
18
  @property
18
19
  def data_type(self) -> CursorSortColumnDataType:
@@ -20,6 +21,8 @@ class ProjectSessionColumn(Enum):
20
21
  return CursorSortColumnDataType.INT
21
22
  if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
22
23
  return CursorSortColumnDataType.DATETIME
24
+ if self is ProjectSessionColumn.costTotal:
25
+ return CursorSortColumnDataType.FLOAT
23
26
  assert_never(self)
24
27
 
25
28
 
@@ -27,6 +27,7 @@ class SpanColumn(Enum):
27
27
  cumulativeTokenCountTotal = auto()
28
28
  cumulativeTokenCountPrompt = auto()
29
29
  cumulativeTokenCountCompletion = auto()
30
+ tokenCostTotal = auto()
30
31
 
31
32
  @property
32
33
  def column_name(self) -> str:
@@ -56,6 +57,8 @@ class SpanColumn(Enum):
56
57
  expr = models.Span.cumulative_llm_token_count_prompt
57
58
  elif self is SpanColumn.cumulativeTokenCountCompletion:
58
59
  expr = models.Span.cumulative_llm_token_count_completion
60
+ elif self is SpanColumn.tokenCostTotal:
61
+ expr = models.SpanCost.total_cost
59
62
  else:
60
63
  assert_never(self)
61
64
  return expr.label(self.column_name)
@@ -73,12 +76,25 @@ class SpanColumn(Enum):
73
76
  or self is SpanColumn.tokenCountTotal
74
77
  or self is SpanColumn.tokenCountPrompt
75
78
  or self is SpanColumn.tokenCountCompletion
79
+ or self is SpanColumn.tokenCostTotal
76
80
  ):
77
81
  return CursorSortColumnDataType.FLOAT
78
82
  if self is SpanColumn.startTime or self is SpanColumn.endTime:
79
83
  return CursorSortColumnDataType.DATETIME
80
84
  assert_never(self)
81
85
 
86
+ def join_tables(self, stmt: Select[Any]) -> Select[Any]:
87
+ """
88
+ If needed, joins tables required for the sort column.
89
+ """
90
+ if self is SpanColumn.tokenCostTotal:
91
+ return stmt.join_from(
92
+ models.Span,
93
+ models.SpanCost,
94
+ onclause=models.SpanCost.span_rowid == models.Span.id,
95
+ )
96
+ return stmt
97
+
82
98
 
83
99
  @strawberry.enum
84
100
  class EvalAttr(Enum):
@@ -140,6 +156,7 @@ class SpanSort:
140
156
  def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig:
141
157
  if (col := self.col) and not self.eval_result_key:
142
158
  expr = col.orm_expression
159
+ stmt = col.join_tables(stmt)
143
160
  stmt = stmt.add_columns(expr)
144
161
  if self.dir == SortDir.desc:
145
162
  expr = desc(expr)
@@ -8,6 +8,7 @@ from phoenix.server.api.mutations.chat_mutations import (
8
8
  from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
9
9
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
10
10
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
11
+ from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
11
12
  from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
12
13
  from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
13
14
  ProjectTraceRetentionPolicyMutationMixin,
@@ -29,6 +30,7 @@ class Mutation(
29
30
  DatasetMutationMixin,
30
31
  ExperimentMutationMixin,
31
32
  ExportEventsMutationMixin,
33
+ ModelMutationMixin,
32
34
  ProjectMutationMixin,
33
35
  ProjectTraceRetentionPolicyMutationMixin,
34
36
  PromptMutationMixin,
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import logging
2
3
  from dataclasses import asdict, field
3
4
  from datetime import datetime, timezone
4
5
  from itertools import chain, islice
@@ -73,6 +74,8 @@ from phoenix.utilities.template_formatters import (
73
74
  TemplateFormatter,
74
75
  )
75
76
 
77
+ logger = logging.getLogger(__name__)
78
+
76
79
  initialize_playground_clients()
77
80
 
78
81
  ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
@@ -450,6 +453,19 @@ class ChatCompletionMutationMixin:
450
453
  session.add(trace)
451
454
  session.add(span)
452
455
  await session.flush()
456
+ try:
457
+ span_cost = info.context.span_cost_calculator.calculate_cost(
458
+ start_time=span.start_time,
459
+ attributes=span.attributes,
460
+ )
461
+ except Exception as e:
462
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
463
+ span_cost = None
464
+ if span_cost:
465
+ span_cost.span_rowid = span.id
466
+ span_cost.trace_rowid = trace.id
467
+ session.add(span_cost)
468
+ await session.flush()
453
469
 
454
470
  gql_span = Span(span_rowid=span.id, db_span=span)
455
471
 
@@ -605,5 +621,6 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
605
621
  TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
606
622
  PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
607
623
 
624
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
608
625
 
609
626
  PLAYGROUND_PROJECT_NAME = "playground"
@@ -0,0 +1,208 @@
1
+ import re
2
+ from datetime import datetime, timezone
3
+ from typing import Optional
4
+
5
+ import sqlalchemy as sa
6
+ import strawberry
7
+ from sqlalchemy import delete
8
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
9
+ from sqlalchemy.orm import joinedload
10
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
11
+ from strawberry.relay import GlobalID
12
+ from strawberry.types import Info
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.auth import IsNotReadOnly
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
+ from phoenix.server.api.queries import Query
19
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
20
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
21
+ from phoenix.server.api.types.TokenPrice import TokenKind
22
+
23
+
24
+ @strawberry.input
25
+ class TokenPriceInput:
26
+ token_type: str
27
+ cost_per_million_tokens: float
28
+ kind: TokenKind
29
+
30
+ @property
31
+ def token_prices(self) -> models.TokenPrice:
32
+ """Generate TokenPrice instances based on the input."""
33
+ return models.TokenPrice(
34
+ token_type=self.token_type,
35
+ is_prompt=self.kind == TokenKind.PROMPT,
36
+ base_rate=self.cost_per_million_tokens / 1_000_000,
37
+ )
38
+
39
+
40
+ @strawberry.input
41
+ class CreateModelMutationInput:
42
+ name: str
43
+ provider: Optional[str] = None
44
+ name_pattern: str
45
+ costs: list[TokenPriceInput]
46
+ start_time: Optional[datetime] = None
47
+
48
+
49
+ @strawberry.type
50
+ class CreateModelMutationPayload:
51
+ model: GenerativeModel
52
+ query: Query
53
+
54
+
55
+ @strawberry.input
56
+ class UpdateModelMutationInput:
57
+ id: GlobalID
58
+ name: str
59
+ provider: Optional[str]
60
+ name_pattern: str
61
+ costs: list[TokenPriceInput]
62
+ start_time: Optional[datetime] = None
63
+
64
+
65
+ @strawberry.type
66
+ class UpdateModelMutationPayload:
67
+ model: GenerativeModel
68
+ query: Query
69
+
70
+
71
+ @strawberry.input
72
+ class DeleteModelMutationInput:
73
+ id: GlobalID
74
+
75
+
76
+ @strawberry.type
77
+ class DeleteModelMutationPayload:
78
+ model: GenerativeModel
79
+ query: Query
80
+
81
+
82
+ @strawberry.type
83
+ class ModelMutationMixin:
84
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
85
+ async def create_model(
86
+ self,
87
+ info: Info[Context, None],
88
+ input: CreateModelMutationInput,
89
+ ) -> CreateModelMutationPayload:
90
+ cost_types = set(cost.token_type for cost in input.costs)
91
+ if "input" not in cost_types:
92
+ raise BadRequest("input cost is required")
93
+ if "output" not in cost_types:
94
+ raise BadRequest("output cost is required")
95
+ name_pattern = _compile_regular_expression(input.name_pattern)
96
+ token_prices = [cost.token_prices for cost in input.costs]
97
+ model = models.GenerativeModel(
98
+ name=input.name,
99
+ provider=input.provider,
100
+ name_pattern=name_pattern,
101
+ is_built_in=False,
102
+ token_prices=token_prices,
103
+ start_time=input.start_time,
104
+ )
105
+ async with info.context.db() as session:
106
+ session.add(model)
107
+ try:
108
+ await session.flush()
109
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
110
+ raise Conflict(f"Model with name '{input.name}' already exists")
111
+
112
+ return CreateModelMutationPayload(
113
+ model=to_gql_generative_model(model),
114
+ query=Query(),
115
+ )
116
+
117
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
118
+ async def update_model(
119
+ self,
120
+ info: Info[Context, None],
121
+ input: UpdateModelMutationInput,
122
+ ) -> UpdateModelMutationPayload:
123
+ try:
124
+ model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
125
+ except ValueError:
126
+ raise BadRequest(f'Invalid model id: "{input.id}"')
127
+
128
+ cost_types = set(cost.token_type for cost in input.costs)
129
+ if "input" not in cost_types:
130
+ raise BadRequest("input cost is required")
131
+ if "output" not in cost_types:
132
+ raise BadRequest("output cost is required")
133
+ name_pattern = _compile_regular_expression(input.name_pattern)
134
+ token_prices = [cost.token_prices for cost in input.costs]
135
+ async with info.context.db() as session:
136
+ model = await session.scalar(
137
+ sa.select(models.GenerativeModel)
138
+ .where(models.GenerativeModel.deleted_at.is_(None))
139
+ .where(models.GenerativeModel.id == model_id)
140
+ .options(joinedload(models.GenerativeModel.token_prices))
141
+ )
142
+ if model is None:
143
+ raise NotFound(f'Model "{input.id}" not found')
144
+ if model.is_built_in:
145
+ raise BadRequest("Cannot update built-in model")
146
+
147
+ await session.execute(
148
+ delete(models.TokenPrice).where(models.TokenPrice.model_id == model.id)
149
+ )
150
+
151
+ await session.refresh(model)
152
+
153
+ model.name = input.name
154
+ model.provider = input.provider or ""
155
+ model.name_pattern = name_pattern
156
+ model.token_prices = token_prices
157
+ model.start_time = input.start_time
158
+ session.add(model)
159
+ try:
160
+ await session.flush()
161
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
162
+ raise Conflict(f"Model with name '{input.name}' already exists")
163
+ await session.refresh(model)
164
+
165
+ return UpdateModelMutationPayload(
166
+ model=to_gql_generative_model(model),
167
+ query=Query(),
168
+ )
169
+
170
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
171
+ async def delete_model(
172
+ self,
173
+ info: Info[Context, None],
174
+ input: DeleteModelMutationInput,
175
+ ) -> DeleteModelMutationPayload:
176
+ try:
177
+ model_id = from_global_id_with_expected_type(input.id, GenerativeModel.__name__)
178
+ except ValueError:
179
+ raise BadRequest(f'Invalid model id: "{input.id}"')
180
+
181
+ async with info.context.db() as session:
182
+ model = await session.scalar(
183
+ sa.update(models.GenerativeModel)
184
+ .values(deleted_at=datetime.now(timezone.utc))
185
+ .where(models.GenerativeModel.deleted_at.is_(None))
186
+ .where(models.GenerativeModel.id == model_id)
187
+ .returning(models.GenerativeModel)
188
+ )
189
+ if model is None:
190
+ raise NotFound(f'Model "{input.id}" not found')
191
+ if model.is_built_in:
192
+ await session.rollback()
193
+ raise BadRequest("Cannot delete built-in model")
194
+ return DeleteModelMutationPayload(
195
+ model=to_gql_generative_model(model),
196
+ query=Query(),
197
+ )
198
+
199
+
200
+ def _compile_regular_expression(maybe_regex: str) -> re.Pattern[str]:
201
+ """
202
+ Compile the given string as a regular expression.
203
+ Raises a BadRequest error if the given string is not a valid regex.
204
+ """
205
+ try:
206
+ return re.compile(maybe_regex)
207
+ except re.error as error:
208
+ raise BadRequest(f"Invalid regex: {str(error)}")