arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (186) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
  2. arize_phoenix-5.7.0.dist-info/RECORD +330 -0
  3. phoenix/config.py +50 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/helpers/playground_clients.py +671 -0
  72. phoenix/server/api/helpers/playground_registry.py +70 -0
  73. phoenix/server/api/helpers/playground_spans.py +325 -0
  74. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  75. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  76. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  77. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  78. phoenix/server/api/input_types/ClusterInput.py +2 -2
  79. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  80. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  81. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  82. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  83. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  84. phoenix/server/api/input_types/Granularity.py +1 -1
  85. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  86. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  87. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  88. phoenix/server/api/mutations/__init__.py +4 -0
  89. phoenix/server/api/mutations/chat_mutations.py +374 -0
  90. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  91. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  92. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  93. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  94. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  95. phoenix/server/api/mutations/user_mutations.py +4 -4
  96. phoenix/server/api/openapi/schema.py +2 -2
  97. phoenix/server/api/queries.py +61 -72
  98. phoenix/server/api/routers/oauth2.py +4 -4
  99. phoenix/server/api/routers/v1/datasets.py +22 -36
  100. phoenix/server/api/routers/v1/evaluations.py +6 -5
  101. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  102. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  103. phoenix/server/api/routers/v1/experiments.py +4 -4
  104. phoenix/server/api/routers/v1/spans.py +13 -12
  105. phoenix/server/api/routers/v1/traces.py +5 -5
  106. phoenix/server/api/routers/v1/utils.py +5 -5
  107. phoenix/server/api/schema.py +42 -10
  108. phoenix/server/api/subscriptions.py +347 -494
  109. phoenix/server/api/types/AnnotationSummary.py +3 -3
  110. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  111. phoenix/server/api/types/Cluster.py +8 -7
  112. phoenix/server/api/types/Dataset.py +5 -4
  113. phoenix/server/api/types/Dimension.py +3 -3
  114. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  115. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  116. phoenix/server/api/types/EvaluationSummary.py +3 -3
  117. phoenix/server/api/types/Event.py +7 -7
  118. phoenix/server/api/types/Experiment.py +3 -3
  119. phoenix/server/api/types/ExperimentComparison.py +2 -4
  120. phoenix/server/api/types/GenerativeProvider.py +27 -3
  121. phoenix/server/api/types/Inferences.py +9 -8
  122. phoenix/server/api/types/InferencesRole.py +2 -2
  123. phoenix/server/api/types/Model.py +2 -2
  124. phoenix/server/api/types/Project.py +11 -18
  125. phoenix/server/api/types/Segments.py +3 -3
  126. phoenix/server/api/types/Span.py +45 -7
  127. phoenix/server/api/types/TemplateLanguage.py +9 -0
  128. phoenix/server/api/types/TimeSeries.py +8 -7
  129. phoenix/server/api/types/Trace.py +2 -2
  130. phoenix/server/api/types/UMAPPoints.py +6 -6
  131. phoenix/server/api/types/User.py +3 -3
  132. phoenix/server/api/types/node.py +1 -3
  133. phoenix/server/api/types/pagination.py +4 -4
  134. phoenix/server/api/utils.py +2 -4
  135. phoenix/server/app.py +76 -37
  136. phoenix/server/bearer_auth.py +4 -10
  137. phoenix/server/dml_event.py +3 -3
  138. phoenix/server/dml_event_handler.py +10 -24
  139. phoenix/server/grpc_server.py +3 -2
  140. phoenix/server/jwt_store.py +22 -21
  141. phoenix/server/main.py +17 -4
  142. phoenix/server/oauth2.py +3 -2
  143. phoenix/server/rate_limiters.py +5 -8
  144. phoenix/server/static/.vite/manifest.json +31 -31
  145. phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
  146. phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
  147. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
  148. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  149. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  150. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  151. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  152. phoenix/server/templates/index.html +1 -0
  153. phoenix/server/thread_server.py +1 -1
  154. phoenix/server/types.py +17 -29
  155. phoenix/services.py +8 -3
  156. phoenix/session/client.py +12 -24
  157. phoenix/session/data_extractor.py +3 -3
  158. phoenix/session/evaluation.py +1 -2
  159. phoenix/session/session.py +26 -21
  160. phoenix/trace/attributes.py +16 -28
  161. phoenix/trace/dsl/filter.py +17 -21
  162. phoenix/trace/dsl/helpers.py +3 -3
  163. phoenix/trace/dsl/query.py +13 -22
  164. phoenix/trace/fixtures.py +11 -17
  165. phoenix/trace/otel.py +5 -15
  166. phoenix/trace/projects.py +3 -2
  167. phoenix/trace/schemas.py +2 -2
  168. phoenix/trace/span_evaluations.py +9 -8
  169. phoenix/trace/span_json_decoder.py +3 -3
  170. phoenix/trace/span_json_encoder.py +2 -2
  171. phoenix/trace/trace_dataset.py +6 -5
  172. phoenix/trace/utils.py +6 -6
  173. phoenix/utilities/deprecation.py +3 -2
  174. phoenix/utilities/error_handling.py +3 -2
  175. phoenix/utilities/json.py +2 -1
  176. phoenix/utilities/logging.py +2 -2
  177. phoenix/utilities/project.py +1 -1
  178. phoenix/utilities/re.py +3 -4
  179. phoenix/utilities/template_formatters.py +16 -5
  180. phoenix/version.py +1 -1
  181. arize_phoenix-5.5.2.dist-info/RECORD +0 -321
  182. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  183. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  184. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  185. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  186. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,20 +1,163 @@
1
- from typing import List, Optional
1
+ from enum import Enum
2
+ from typing import Annotated, Any, Mapping, Optional, Union
2
3
 
3
4
  import strawberry
4
5
  from strawberry import UNSET
5
6
  from strawberry.scalars import JSON
6
7
 
7
8
 
9
+ @strawberry.enum
10
+ class CanonicalParameterName(str, Enum):
11
+ TEMPERATURE = "temperature"
12
+ MAX_COMPLETION_TOKENS = "max_completion_tokens"
13
+ STOP_SEQUENCES = "stop_sequences"
14
+ TOP_P = "top_p"
15
+ RANDOM_SEED = "random_seed"
16
+ TOOL_CHOICE = "tool_choice"
17
+ RESPONSE_FORMAT = "response_format"
18
+
19
+
20
+ @strawberry.enum
21
+ class InvocationInputField(str, Enum):
22
+ value_int = "value_int"
23
+ value_float = "value_float"
24
+ value_bool = "value_bool"
25
+ value_string = "value_string"
26
+ value_json = "value_json"
27
+ value_string_list = "value_string_list"
28
+ value_boolean = "value_boolean"
29
+
30
+
8
31
  @strawberry.input
9
- class InvocationParameters:
10
- """
11
- Invocation parameters interface shared between different providers.
12
- """
13
-
14
- temperature: Optional[float] = UNSET
15
- max_completion_tokens: Optional[int] = UNSET
16
- max_tokens: Optional[int] = UNSET
17
- top_p: Optional[float] = UNSET
18
- stop: Optional[List[str]] = UNSET
19
- seed: Optional[int] = UNSET
20
- tool_choice: Optional[JSON] = UNSET
32
+ class InvocationParameterInput:
33
+ invocation_name: str
34
+ canonical_name: Optional[CanonicalParameterName] = None
35
+ value_int: Optional[int] = UNSET
36
+ value_float: Optional[float] = UNSET
37
+ value_bool: Optional[bool] = UNSET
38
+ value_string: Optional[str] = UNSET
39
+ value_json: Optional[JSON] = UNSET
40
+ value_string_list: Optional[list[str]] = UNSET
41
+ value_boolean: Optional[bool] = UNSET
42
+
43
+
44
+ @strawberry.interface
45
+ class InvocationParameterBase:
46
+ invocation_name: str
47
+ canonical_name: Optional[CanonicalParameterName] = None
48
+ label: str
49
+ required: bool = False
50
+ hidden: bool = False
51
+
52
+
53
+ @strawberry.type
54
+ class IntInvocationParameter(InvocationParameterBase):
55
+ invocation_input_field: InvocationInputField = InvocationInputField.value_int
56
+ default_value: Optional[int] = UNSET
57
+
58
+
59
+ @strawberry.type
60
+ class FloatInvocationParameter(InvocationParameterBase):
61
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
62
+ default_value: Optional[float] = UNSET
63
+
64
+
65
+ @strawberry.type
66
+ class BoundedFloatInvocationParameter(InvocationParameterBase):
67
+ invocation_input_field: InvocationInputField = InvocationInputField.value_float
68
+ default_value: Optional[float] = UNSET
69
+ min_value: float
70
+ max_value: float
71
+
72
+
73
+ @strawberry.type
74
+ class StringInvocationParameter(InvocationParameterBase):
75
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string
76
+ default_value: Optional[str] = UNSET
77
+
78
+
79
+ @strawberry.type
80
+ class JSONInvocationParameter(InvocationParameterBase):
81
+ invocation_input_field: InvocationInputField = InvocationInputField.value_json
82
+ default_value: Optional[JSON] = UNSET
83
+
84
+
85
+ @strawberry.type
86
+ class StringListInvocationParameter(InvocationParameterBase):
87
+ invocation_input_field: InvocationInputField = InvocationInputField.value_string_list
88
+ default_value: Optional[list[str]] = UNSET
89
+
90
+
91
+ @strawberry.type
92
+ class BooleanInvocationParameter(InvocationParameterBase):
93
+ invocation_input_field: InvocationInputField = InvocationInputField.value_bool
94
+ default_value: Optional[bool] = UNSET
95
+
96
+
97
+ def extract_parameter(
98
+ param_def: InvocationParameterBase, param_input: InvocationParameterInput
99
+ ) -> Any:
100
+ if isinstance(param_def, IntInvocationParameter):
101
+ return (
102
+ param_input.value_int if param_input.value_int is not UNSET else param_def.default_value
103
+ )
104
+ elif isinstance(param_def, FloatInvocationParameter):
105
+ return (
106
+ param_input.value_float
107
+ if param_input.value_float is not UNSET
108
+ else param_def.default_value
109
+ )
110
+ elif isinstance(param_def, BoundedFloatInvocationParameter):
111
+ return (
112
+ param_input.value_float
113
+ if param_input.value_float is not UNSET
114
+ else param_def.default_value
115
+ )
116
+ elif isinstance(param_def, StringInvocationParameter):
117
+ return (
118
+ param_input.value_string
119
+ if param_input.value_string is not UNSET
120
+ else param_def.default_value
121
+ )
122
+ elif isinstance(param_def, JSONInvocationParameter):
123
+ return (
124
+ param_input.value_json
125
+ if param_input.value_json is not UNSET
126
+ else param_def.default_value
127
+ )
128
+ elif isinstance(param_def, StringListInvocationParameter):
129
+ return (
130
+ param_input.value_string_list
131
+ if param_input.value_string_list is not UNSET
132
+ else param_def.default_value
133
+ )
134
+ elif isinstance(param_def, BooleanInvocationParameter):
135
+ return (
136
+ param_input.value_bool
137
+ if param_input.value_bool is not UNSET
138
+ else param_def.default_value
139
+ )
140
+
141
+
142
+ def validate_invocation_parameters(
143
+ parameters: list["InvocationParameter"],
144
+ input: Mapping[str, Any],
145
+ ) -> None:
146
+ for param_def in parameters:
147
+ if param_def.required and param_def.invocation_name not in input:
148
+ raise ValueError(f"Required parameter {param_def.invocation_name} not provided")
149
+
150
+
151
+ # Create the union for output types
152
+ InvocationParameter = Annotated[
153
+ Union[
154
+ IntInvocationParameter,
155
+ FloatInvocationParameter,
156
+ BoundedFloatInvocationParameter,
157
+ StringInvocationParameter,
158
+ JSONInvocationParameter,
159
+ StringListInvocationParameter,
160
+ BooleanInvocationParameter,
161
+ ],
162
+ strawberry.union("InvocationParameter"),
163
+ ]
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry import UNSET
@@ -30,6 +30,6 @@ class PatchDatasetExamplesInput:
30
30
  Input type to the patchDatasetExamples mutation.
31
31
  """
32
32
 
33
- patches: List[DatasetExamplePatch]
33
+ patches: list[DatasetExamplePatch]
34
34
  version_description: Optional[str] = UNSET
35
35
  version_metadata: Optional[JSON] = UNSET
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+ from strawberry.scalars import JSON
3
+
4
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
5
+
6
+
7
+ @strawberry.input
8
+ class TemplateOptions:
9
+ variables: JSON
10
+ language: TemplateLanguage
@@ -1,6 +1,9 @@
1
1
  import strawberry
2
2
 
3
3
  from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
4
+ from phoenix.server.api.mutations.chat_mutations import (
5
+ ChatCompletionMutationMixin,
6
+ )
4
7
  from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
5
8
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
6
9
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
@@ -20,5 +23,6 @@ class Mutation(
20
23
  SpanAnnotationMutationMixin,
21
24
  TraceAnnotationMutationMixin,
22
25
  UserMutationMixin,
26
+ ChatCompletionMutationMixin,
23
27
  ):
24
28
  pass
@@ -0,0 +1,374 @@
1
+ import json
2
+ from dataclasses import asdict
3
+ from datetime import datetime, timezone
4
+ from itertools import chain
5
+ from traceback import format_exc
6
+ from typing import Any, Iterable, Iterator, List, Optional
7
+
8
+ import strawberry
9
+ from openinference.semconv.trace import (
10
+ MessageAttributes,
11
+ OpenInferenceMimeTypeValues,
12
+ OpenInferenceSpanKindValues,
13
+ SpanAttributes,
14
+ ToolAttributes,
15
+ ToolCallAttributes,
16
+ )
17
+ from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
18
+ from opentelemetry.trace import StatusCode
19
+ from sqlalchemy import insert, select
20
+ from strawberry.types import Info
21
+ from typing_extensions import assert_never
22
+
23
+ from phoenix.datetime_utils import local_now, normalize_datetime
24
+ from phoenix.db import models
25
+ from phoenix.server.api.context import Context
26
+ from phoenix.server.api.exceptions import BadRequest
27
+ from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
28
+ from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
29
+ from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
30
+ from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
31
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
32
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
33
+ TextChunk,
34
+ ToolCallChunk,
35
+ )
36
+ from phoenix.server.api.types.Span import Span, to_gql_span
37
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
38
+ from phoenix.server.dml_event import SpanInsertEvent
39
+ from phoenix.trace.attributes import unflatten
40
+ from phoenix.trace.schemas import SpanException
41
+ from phoenix.utilities.template_formatters import (
42
+ FStringTemplateFormatter,
43
+ MustacheTemplateFormatter,
44
+ TemplateFormatter,
45
+ )
46
+
47
+ initialize_playground_clients()
48
+
49
+ ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
50
+
51
+
52
+ @strawberry.type
53
+ class ChatCompletionFunctionCall:
54
+ name: str
55
+ arguments: str
56
+
57
+
58
+ @strawberry.type
59
+ class ChatCompletionToolCall:
60
+ id: str
61
+ function: ChatCompletionFunctionCall
62
+
63
+
64
+ @strawberry.type
65
+ class ChatCompletionMutationPayload:
66
+ content: Optional[str]
67
+ tool_calls: List[ChatCompletionToolCall]
68
+ span: Span
69
+ error_message: Optional[str]
70
+
71
+
72
+ @strawberry.type
73
+ class ChatCompletionMutationMixin:
74
+ @strawberry.mutation
75
+ async def chat_completion(
76
+ self, info: Info[Context, None], input: ChatCompletionInput
77
+ ) -> ChatCompletionMutationPayload:
78
+ provider_key = input.model.provider_key
79
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
80
+ if llm_client_class is None:
81
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
82
+ attributes: dict[str, Any] = {}
83
+ llm_client = llm_client_class(
84
+ model=input.model,
85
+ api_key=input.api_key,
86
+ )
87
+
88
+ messages = [
89
+ (
90
+ message.role,
91
+ message.content,
92
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
93
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
94
+ )
95
+ for message in input.messages
96
+ ]
97
+
98
+ if template_options := input.template:
99
+ messages = list(_formatted_messages(messages, template_options))
100
+
101
+ invocation_parameters = llm_client.construct_invocation_parameters(
102
+ input.invocation_parameters
103
+ )
104
+
105
+ text_content = ""
106
+ tool_calls = []
107
+ events = []
108
+ attributes.update(
109
+ chain(
110
+ _llm_span_kind(),
111
+ _llm_model_name(input.model.name),
112
+ _llm_tools(input.tools or []),
113
+ _llm_input_messages(messages),
114
+ _llm_invocation_parameters(invocation_parameters),
115
+ _input_value_and_mime_type(input),
116
+ **llm_client.attributes,
117
+ )
118
+ )
119
+
120
+ start_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
121
+ status_code = StatusCode.OK
122
+ status_message = ""
123
+ try:
124
+ async for chunk in llm_client.chat_completion_create(
125
+ messages=messages, tools=input.tools or [], **invocation_parameters
126
+ ):
127
+ # Process the chunk
128
+ if isinstance(chunk, TextChunk):
129
+ text_content += chunk.content
130
+ elif isinstance(chunk, ToolCallChunk):
131
+ tool_call = ChatCompletionToolCall(
132
+ id=chunk.id,
133
+ function=ChatCompletionFunctionCall(
134
+ name=chunk.function.name,
135
+ arguments=chunk.function.arguments,
136
+ ),
137
+ )
138
+ tool_calls.append(tool_call)
139
+ else:
140
+ assert_never(chunk)
141
+ except Exception as e:
142
+ # Handle exceptions and record exception event
143
+ status_code = StatusCode.ERROR
144
+ status_message = str(e)
145
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
146
+ assert end_time is not None
147
+ events.append(
148
+ SpanException(
149
+ timestamp=end_time,
150
+ message=status_message,
151
+ exception_type=type(e).__name__,
152
+ exception_escaped=False,
153
+ exception_stacktrace=format_exc(),
154
+ )
155
+ )
156
+ else:
157
+ end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
158
+
159
+ if text_content or tool_calls:
160
+ attributes.update(
161
+ chain(
162
+ _output_value_and_mime_type({"text": text_content, "tool_calls": tool_calls}),
163
+ _llm_output_messages(text_content, tool_calls),
164
+ )
165
+ )
166
+
167
+ # Now write the span to the database
168
+ trace_id = _generate_trace_id()
169
+ span_id = _generate_span_id()
170
+ async with info.context.db() as session:
171
+ # Get or create the project ID
172
+ if (
173
+ project_id := await session.scalar(
174
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
175
+ )
176
+ ) is None:
177
+ project_id = await session.scalar(
178
+ insert(models.Project)
179
+ .returning(models.Project.id)
180
+ .values(
181
+ name=PLAYGROUND_PROJECT_NAME,
182
+ description="Traces from prompt playground",
183
+ )
184
+ )
185
+ trace = models.Trace(
186
+ project_rowid=project_id,
187
+ trace_id=trace_id,
188
+ start_time=start_time,
189
+ end_time=end_time,
190
+ )
191
+ span = models.Span(
192
+ trace_rowid=trace.id,
193
+ span_id=span_id,
194
+ parent_id=None,
195
+ name="ChatCompletion",
196
+ span_kind=LLM,
197
+ start_time=start_time,
198
+ end_time=end_time,
199
+ attributes=unflatten(attributes.items()),
200
+ events=[_serialize_event(event) for event in events],
201
+ status_code=status_code.name,
202
+ status_message=status_message,
203
+ cumulative_error_count=int(status_code is StatusCode.ERROR),
204
+ cumulative_llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
205
+ cumulative_llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
206
+ llm_token_count_prompt=attributes.get(LLM_TOKEN_COUNT_PROMPT, 0),
207
+ llm_token_count_completion=attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0),
208
+ trace=trace,
209
+ )
210
+ session.add(trace)
211
+ session.add(span)
212
+ await session.flush()
213
+
214
+ gql_span = to_gql_span(span)
215
+
216
+ info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
217
+
218
+ if status_code is StatusCode.ERROR:
219
+ return ChatCompletionMutationPayload(
220
+ content=None,
221
+ tool_calls=[],
222
+ span=gql_span,
223
+ error_message=status_message,
224
+ )
225
+ else:
226
+ return ChatCompletionMutationPayload(
227
+ content=text_content if text_content else None,
228
+ tool_calls=tool_calls,
229
+ span=gql_span,
230
+ error_message=None,
231
+ )
232
+
233
+
234
+ def _formatted_messages(
235
+ messages: Iterable[ChatCompletionMessage],
236
+ template_options: TemplateOptions,
237
+ ) -> Iterator[ChatCompletionMessage]:
238
+ """
239
+ Formats the messages using the given template options.
240
+ """
241
+ template_formatter = _template_formatter(template_language=template_options.language)
242
+ (
243
+ roles,
244
+ templates,
245
+ tool_call_id,
246
+ tool_calls,
247
+ ) = zip(*messages)
248
+ formatted_templates = map(
249
+ lambda template: template_formatter.format(template, **template_options.variables),
250
+ templates,
251
+ )
252
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
253
+ return formatted_messages
254
+
255
+
256
+ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
257
+ """
258
+ Instantiates the appropriate template formatter for the template language.
259
+ """
260
+ if template_language is TemplateLanguage.MUSTACHE:
261
+ return MustacheTemplateFormatter()
262
+ if template_language is TemplateLanguage.F_STRING:
263
+ return FStringTemplateFormatter()
264
+ assert_never(template_language)
265
+
266
+
267
+ def _llm_span_kind() -> Iterator[tuple[str, Any]]:
268
+ yield OPENINFERENCE_SPAN_KIND, LLM
269
+
270
+
271
+ def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
272
+ yield LLM_MODEL_NAME, model_name
273
+
274
+
275
+ def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
276
+ yield LLM_INVOCATION_PARAMETERS, json.dumps(invocation_parameters)
277
+
278
+
279
+ def _llm_tools(tools: List[Any]) -> Iterator[tuple[str, Any]]:
280
+ for tool_index, tool in enumerate(tools):
281
+ yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
282
+
283
+
284
+ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
285
+ input_data = input.__dict__.copy()
286
+ input_data.pop("api_key", None)
287
+ yield INPUT_MIME_TYPE, JSON
288
+ yield INPUT_VALUE, json.dumps(input_data)
289
+
290
+
291
+ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
292
+ yield OUTPUT_MIME_TYPE, JSON
293
+ yield OUTPUT_VALUE, json.dumps(output)
294
+
295
+
296
+ def _llm_input_messages(
297
+ messages: Iterable[ChatCompletionMessage],
298
+ ) -> Iterator[tuple[str, Any]]:
299
+ for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
300
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
301
+ yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
302
+ if tool_calls:
303
+ for tool_call_index, tool_call in enumerate(tool_calls):
304
+ yield (
305
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
306
+ tool_call["function"]["name"],
307
+ )
308
+ if arguments := tool_call["function"]["arguments"]:
309
+ yield (
310
+ f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
311
+ json.dumps(arguments),
312
+ )
313
+
314
+
315
+ def _llm_output_messages(
316
+ text_content: str, tool_calls: List[ChatCompletionToolCall]
317
+ ) -> Iterator[tuple[str, Any]]:
318
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
319
+ if text_content:
320
+ yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", text_content
321
+ for tool_call_index, tool_call in enumerate(tool_calls):
322
+ yield (
323
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
324
+ tool_call.function.name,
325
+ )
326
+ if arguments := tool_call.function.arguments:
327
+ yield (
328
+ f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
329
+ json.dumps(arguments),
330
+ )
331
+
332
+
333
+ def _generate_trace_id() -> str:
334
+ return _hex(DefaultOTelIDGenerator().generate_trace_id())
335
+
336
+
337
+ def _generate_span_id() -> str:
338
+ return _hex(DefaultOTelIDGenerator().generate_span_id())
339
+
340
+
341
+ def _hex(number: int) -> str:
342
+ return hex(number)[2:]
343
+
344
+
345
+ def _serialize_event(event: SpanException) -> dict[str, Any]:
346
+ return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
347
+
348
+
349
+ JSON = OpenInferenceMimeTypeValues.JSON.value
350
+ LLM = OpenInferenceSpanKindValues.LLM.value
351
+
352
+ OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
353
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
354
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE
355
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
356
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
357
+ LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
358
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
359
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
360
+ LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
361
+ LLM_TOOLS = SpanAttributes.LLM_TOOLS
362
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
363
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
364
+
365
+ MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
366
+ MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
367
+ MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
368
+
369
+ TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
370
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
371
+
372
+ TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
373
+
374
+ PLAYGROUND_PROJECT_NAME = "playground"
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  from datetime import datetime
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
  import strawberry
6
6
  from openinference.semconv.trace import (
@@ -175,7 +175,7 @@ class DatasetMutationMixin:
175
175
  )
176
176
  ).all()
177
177
 
178
- span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans}
178
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
179
179
  for annotation in span_annotations:
180
180
  span_id = annotation.span_rowid
181
181
  if span_id not in span_annotations_by_span:
@@ -287,7 +287,7 @@ class DatasetMutationMixin:
287
287
  )
288
288
  ).all()
289
289
 
290
- span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans}
290
+ span_annotations_by_span: dict[int, dict[Any, Any]] = {span.id: {} for span in spans}
291
291
  for annotation in span_annotations:
292
292
  span_id = annotation.span_rowid
293
293
  if span_id not in span_annotations_by_span:
@@ -577,7 +577,7 @@ def _to_orm_revision(
577
577
  patch: DatasetExamplePatch,
578
578
  example_id: int,
579
579
  version_id: int,
580
- ) -> Dict[str, Any]:
580
+ ) -> dict[str, Any]:
581
581
  """
582
582
  Creates a new revision from an existing revision and a patch. The output is a
583
583
  dictionary suitable for insertion into the database using the sqlalchemy
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- from typing import List
3
2
 
4
3
  import strawberry
5
4
  from sqlalchemy import delete
@@ -20,7 +19,7 @@ from phoenix.server.dml_event import ExperimentDeleteEvent
20
19
 
21
20
  @strawberry.type
22
21
  class ExperimentMutationPayload:
23
- experiments: List[Experiment]
22
+ experiments: list[Experiment]
24
23
 
25
24
 
26
25
  @strawberry.type
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  from collections import defaultdict
3
3
  from datetime import datetime
4
- from typing import Dict, List, Optional, Tuple
4
+ from typing import Optional
5
5
 
6
6
  import strawberry
7
7
  from strawberry import ID, UNSET
@@ -29,7 +29,7 @@ class ExportEventsMutationMixin:
29
29
  async def export_events(
30
30
  self,
31
31
  info: Info[Context, None],
32
- event_ids: List[ID],
32
+ event_ids: list[ID],
33
33
  file_name: Optional[str] = UNSET,
34
34
  ) -> ExportedFile:
35
35
  if not isinstance(file_name, str):
@@ -61,7 +61,7 @@ class ExportEventsMutationMixin:
61
61
  async def export_clusters(
62
62
  self,
63
63
  info: Info[Context, None],
64
- clusters: List[ClusterInput],
64
+ clusters: list[ClusterInput],
65
65
  file_name: Optional[str] = UNSET,
66
66
  ) -> ExportedFile:
67
67
  if not isinstance(file_name, str):
@@ -81,10 +81,10 @@ class ExportEventsMutationMixin:
81
81
 
82
82
 
83
83
  def _unpack_clusters(
84
- clusters: List[ClusterInput],
85
- ) -> Tuple[Dict[ms.InferencesRole, List[int]], Dict[ms.InferencesRole, Dict[int, str]]]:
86
- row_numbers: Dict[ms.InferencesRole, List[int]] = defaultdict(list)
87
- cluster_ids: Dict[ms.InferencesRole, Dict[int, str]] = defaultdict(dict)
84
+ clusters: list[ClusterInput],
85
+ ) -> tuple[dict[ms.InferencesRole, list[int]], dict[ms.InferencesRole, dict[int, str]]]:
86
+ row_numbers: dict[ms.InferencesRole, list[int]] = defaultdict(list)
87
+ cluster_ids: dict[ms.InferencesRole, dict[int, str]] = defaultdict(dict)
88
88
  for i, cluster in enumerate(clusters):
89
89
  for row_number, inferences_role in map(unpack_event_id, cluster.event_ids):
90
90
  if isinstance(inferences_role, AncillaryInferencesRole):