arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
+
from collections import deque
|
|
3
4
|
from collections.abc import AsyncIterator, Iterator
|
|
4
5
|
from datetime import datetime, timedelta, timezone
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
8
|
AsyncGenerator,
|
|
9
|
+
Callable,
|
|
8
10
|
Coroutine,
|
|
9
11
|
Iterable,
|
|
10
12
|
Mapping,
|
|
@@ -17,7 +19,7 @@ from typing import (
|
|
|
17
19
|
import strawberry
|
|
18
20
|
from openinference.instrumentation import safe_json_dumps
|
|
19
21
|
from openinference.semconv.trace import SpanAttributes
|
|
20
|
-
from sqlalchemy import and_,
|
|
22
|
+
from sqlalchemy import and_, insert, select
|
|
21
23
|
from sqlalchemy.orm import load_only
|
|
22
24
|
from strawberry.relay.types import GlobalID
|
|
23
25
|
from strawberry.types import Info
|
|
@@ -26,7 +28,11 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
26
28
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
29
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
30
|
from phoenix.db import models
|
|
29
|
-
from phoenix.
|
|
31
|
+
from phoenix.db.helpers import (
|
|
32
|
+
get_dataset_example_revisions,
|
|
33
|
+
insert_experiment_with_examples_snapshot,
|
|
34
|
+
)
|
|
35
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
30
36
|
from phoenix.server.api.context import Context
|
|
31
37
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
32
38
|
from phoenix.server.api.helpers.playground_clients import (
|
|
@@ -43,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
43
49
|
get_db_trace,
|
|
44
50
|
streaming_llm_span,
|
|
45
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
46
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
47
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
48
55
|
ChatCompletionInput,
|
|
@@ -59,7 +66,7 @@ from phoenix.server.api.types.Dataset import Dataset
|
|
|
59
66
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
60
67
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
61
68
|
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
62
|
-
from phoenix.server.api.types.ExperimentRun import
|
|
69
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
63
70
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
64
71
|
from phoenix.server.api.types.Span import Span
|
|
65
72
|
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
@@ -90,9 +97,109 @@ ChatCompletionResult: TypeAlias = tuple[
|
|
|
90
97
|
ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
91
98
|
|
|
92
99
|
|
|
100
|
+
async def _stream_single_chat_completion(
|
|
101
|
+
*,
|
|
102
|
+
input: ChatCompletionInput,
|
|
103
|
+
llm_client: PlaygroundStreamingClient,
|
|
104
|
+
project_id: int,
|
|
105
|
+
repetition_number: int,
|
|
106
|
+
results: asyncio.Queue[tuple[Optional[models.Span], int]],
|
|
107
|
+
) -> ChatStream:
|
|
108
|
+
messages = [
|
|
109
|
+
(
|
|
110
|
+
message.role,
|
|
111
|
+
message.content,
|
|
112
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
113
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
114
|
+
)
|
|
115
|
+
for message in input.messages
|
|
116
|
+
]
|
|
117
|
+
attributes = None
|
|
118
|
+
if template_options := input.template:
|
|
119
|
+
messages = list(
|
|
120
|
+
_formatted_messages(
|
|
121
|
+
messages=messages,
|
|
122
|
+
template_format=template_options.format,
|
|
123
|
+
template_variables=template_options.variables,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
|
|
127
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
128
|
+
async with streaming_llm_span(
|
|
129
|
+
input=input,
|
|
130
|
+
messages=messages,
|
|
131
|
+
invocation_parameters=invocation_parameters,
|
|
132
|
+
attributes=attributes,
|
|
133
|
+
) as span:
|
|
134
|
+
try:
|
|
135
|
+
async for chunk in llm_client.chat_completion_create(
|
|
136
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
137
|
+
):
|
|
138
|
+
span.add_response_chunk(chunk)
|
|
139
|
+
chunk.repetition_number = repetition_number
|
|
140
|
+
yield chunk
|
|
141
|
+
finally:
|
|
142
|
+
span.set_attributes(llm_client.attributes)
|
|
143
|
+
if span.status_message is not None:
|
|
144
|
+
yield ChatCompletionSubscriptionError(
|
|
145
|
+
message=span.status_message,
|
|
146
|
+
repetition_number=repetition_number,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
db_trace = get_db_trace(span, project_id)
|
|
150
|
+
db_span = get_db_span(span, db_trace)
|
|
151
|
+
await results.put((db_span, repetition_number))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _chat_completion_span_result_payloads(
|
|
155
|
+
*,
|
|
156
|
+
db: DbSessionFactory,
|
|
157
|
+
results: Sequence[tuple[Optional[models.Span], int]],
|
|
158
|
+
span_cost_calculator: SpanCostCalculator,
|
|
159
|
+
on_span_insertion: Callable[[], None],
|
|
160
|
+
) -> ChatStream:
|
|
161
|
+
if not results:
|
|
162
|
+
return
|
|
163
|
+
async with db() as session:
|
|
164
|
+
for span, repetition_number in results:
|
|
165
|
+
if span:
|
|
166
|
+
session.add(span)
|
|
167
|
+
await session.flush()
|
|
168
|
+
try:
|
|
169
|
+
span_cost = span_cost_calculator.calculate_cost(
|
|
170
|
+
start_time=span.start_time,
|
|
171
|
+
attributes=span.attributes,
|
|
172
|
+
)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
175
|
+
span_cost = None
|
|
176
|
+
if span_cost:
|
|
177
|
+
span_cost.span_rowid = span.id
|
|
178
|
+
span_cost.trace_rowid = span.trace_rowid
|
|
179
|
+
session.add(span_cost)
|
|
180
|
+
await session.flush()
|
|
181
|
+
for span, repetition_number in results:
|
|
182
|
+
if span:
|
|
183
|
+
yield ChatCompletionSubscriptionResult(
|
|
184
|
+
span=Span(id=span.id, db_record=span),
|
|
185
|
+
repetition_number=repetition_number,
|
|
186
|
+
)
|
|
187
|
+
on_span_insertion()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _is_span_result_payloads_stream(
|
|
191
|
+
stream: ChatStream,
|
|
192
|
+
) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
Checks if the given generator was instantiated from
|
|
195
|
+
`_chat_completion_span_result_payloads`
|
|
196
|
+
"""
|
|
197
|
+
return stream.ag_code == _chat_completion_span_result_payloads.__code__ # type: ignore
|
|
198
|
+
|
|
199
|
+
|
|
93
200
|
@strawberry.type
|
|
94
201
|
class Subscription:
|
|
95
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
202
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
96
203
|
async def chat_completion(
|
|
97
204
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
98
205
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -121,42 +228,6 @@ class Subscription:
|
|
|
121
228
|
f"{str(error)}"
|
|
122
229
|
)
|
|
123
230
|
|
|
124
|
-
messages = [
|
|
125
|
-
(
|
|
126
|
-
message.role,
|
|
127
|
-
message.content,
|
|
128
|
-
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
129
|
-
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
130
|
-
)
|
|
131
|
-
for message in input.messages
|
|
132
|
-
]
|
|
133
|
-
attributes = None
|
|
134
|
-
if template_options := input.template:
|
|
135
|
-
messages = list(
|
|
136
|
-
_formatted_messages(
|
|
137
|
-
messages=messages,
|
|
138
|
-
template_format=template_options.format,
|
|
139
|
-
template_variables=template_options.variables,
|
|
140
|
-
)
|
|
141
|
-
)
|
|
142
|
-
attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
|
|
143
|
-
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
144
|
-
input.invocation_parameters
|
|
145
|
-
)
|
|
146
|
-
async with streaming_llm_span(
|
|
147
|
-
input=input,
|
|
148
|
-
messages=messages,
|
|
149
|
-
invocation_parameters=invocation_parameters,
|
|
150
|
-
attributes=attributes,
|
|
151
|
-
) as span:
|
|
152
|
-
async for chunk in llm_client.chat_completion_create(
|
|
153
|
-
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
154
|
-
):
|
|
155
|
-
span.add_response_chunk(chunk)
|
|
156
|
-
yield chunk
|
|
157
|
-
span.set_attributes(llm_client.attributes)
|
|
158
|
-
if span.status_message is not None:
|
|
159
|
-
yield ChatCompletionSubscriptionError(message=span.status_message)
|
|
160
231
|
async with info.context.db() as session:
|
|
161
232
|
if (
|
|
162
233
|
playground_project_id := await session.scalar(
|
|
@@ -171,27 +242,100 @@ class Subscription:
|
|
|
171
242
|
description="Traces from prompt playground",
|
|
172
243
|
)
|
|
173
244
|
)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
245
|
+
|
|
246
|
+
results: asyncio.Queue[tuple[Optional[models.Span], int]] = asyncio.Queue()
|
|
247
|
+
not_started: deque[tuple[int, ChatStream]] = deque(
|
|
248
|
+
(
|
|
249
|
+
repetition_number,
|
|
250
|
+
_stream_single_chat_completion(
|
|
251
|
+
input=input,
|
|
252
|
+
llm_client=llm_client,
|
|
253
|
+
project_id=playground_project_id,
|
|
254
|
+
repetition_number=repetition_number,
|
|
255
|
+
results=results,
|
|
256
|
+
),
|
|
257
|
+
)
|
|
258
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
259
|
+
)
|
|
260
|
+
in_progress: list[
|
|
261
|
+
tuple[
|
|
262
|
+
Optional[int],
|
|
263
|
+
ChatStream,
|
|
264
|
+
asyncio.Task[ChatCompletionSubscriptionPayload],
|
|
265
|
+
]
|
|
266
|
+
] = []
|
|
267
|
+
max_in_progress = 3
|
|
268
|
+
write_batch_size = 10
|
|
269
|
+
write_interval = timedelta(seconds=10)
|
|
270
|
+
last_write_time = datetime.now()
|
|
271
|
+
while not_started or in_progress:
|
|
272
|
+
while not_started and len(in_progress) < max_in_progress:
|
|
273
|
+
rep_num, stream = not_started.popleft()
|
|
274
|
+
task = _create_task_with_timeout(stream)
|
|
275
|
+
in_progress.append((rep_num, stream, task))
|
|
276
|
+
async_tasks_to_run = [task for _, _, task in in_progress]
|
|
277
|
+
completed_tasks, _ = await asyncio.wait(
|
|
278
|
+
async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
|
|
279
|
+
)
|
|
280
|
+
for completed_task in completed_tasks:
|
|
281
|
+
idx = [task for _, _, task in in_progress].index(completed_task)
|
|
282
|
+
repetition_number, stream, _ = in_progress[idx]
|
|
283
|
+
try:
|
|
284
|
+
yield completed_task.result()
|
|
285
|
+
except StopAsyncIteration:
|
|
286
|
+
del in_progress[idx] # removes exhausted stream
|
|
287
|
+
except asyncio.TimeoutError:
|
|
288
|
+
del in_progress[idx] # removes timed-out stream
|
|
289
|
+
if repetition_number is not None:
|
|
290
|
+
yield ChatCompletionSubscriptionError(
|
|
291
|
+
message="Playground task timed out",
|
|
292
|
+
repetition_number=repetition_number,
|
|
293
|
+
)
|
|
294
|
+
except Exception as error:
|
|
295
|
+
del in_progress[idx] # removes failed stream
|
|
296
|
+
if repetition_number is not None:
|
|
297
|
+
yield ChatCompletionSubscriptionError(
|
|
298
|
+
message="An unexpected error occurred",
|
|
299
|
+
repetition_number=repetition_number,
|
|
300
|
+
)
|
|
301
|
+
logger.exception(error)
|
|
302
|
+
else:
|
|
303
|
+
task = _create_task_with_timeout(stream)
|
|
304
|
+
in_progress[idx] = (repetition_number, stream, task)
|
|
305
|
+
|
|
306
|
+
exceeded_write_batch_size = results.qsize() >= write_batch_size
|
|
307
|
+
exceeded_write_interval = datetime.now() - last_write_time > write_interval
|
|
308
|
+
write_already_in_progress = any(
|
|
309
|
+
_is_span_result_payloads_stream(stream) for _, stream, _ in in_progress
|
|
182
310
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
311
|
+
if (
|
|
312
|
+
not results.empty()
|
|
313
|
+
and (exceeded_write_batch_size or exceeded_write_interval)
|
|
314
|
+
and not write_already_in_progress
|
|
315
|
+
):
|
|
316
|
+
result_payloads_stream = _chat_completion_span_result_payloads(
|
|
317
|
+
db=info.context.db,
|
|
318
|
+
results=_drain_no_wait(results),
|
|
319
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
320
|
+
on_span_insertion=lambda: info.context.event_queue.put(
|
|
321
|
+
SpanInsertEvent(ids=(playground_project_id,))
|
|
322
|
+
),
|
|
323
|
+
)
|
|
324
|
+
task = _create_task_with_timeout(result_payloads_stream)
|
|
325
|
+
in_progress.append((None, result_payloads_stream, task))
|
|
326
|
+
last_write_time = datetime.now()
|
|
327
|
+
if remaining_results := await _drain(results):
|
|
328
|
+
async for result_payload in _chat_completion_span_result_payloads(
|
|
329
|
+
db=info.context.db,
|
|
330
|
+
results=remaining_results,
|
|
331
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
332
|
+
on_span_insertion=lambda: info.context.event_queue.put(
|
|
333
|
+
SpanInsertEvent(ids=(playground_project_id,))
|
|
334
|
+
),
|
|
335
|
+
):
|
|
336
|
+
yield result_payload
|
|
337
|
+
|
|
338
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
195
339
|
async def chat_completion_over_dataset(
|
|
196
340
|
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
197
341
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -255,27 +399,22 @@ class Subscription:
|
|
|
255
399
|
)
|
|
256
400
|
) is None:
|
|
257
401
|
raise NotFound(f"Could not find dataset version with ID {version_id}")
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
268
|
-
)
|
|
402
|
+
|
|
403
|
+
# Parse split IDs if provided
|
|
404
|
+
resolved_split_ids: Optional[list[int]] = None
|
|
405
|
+
if input.split_ids is not None and len(input.split_ids) > 0:
|
|
406
|
+
resolved_split_ids = [
|
|
407
|
+
from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
|
|
408
|
+
for split_id in input.split_ids
|
|
409
|
+
]
|
|
410
|
+
|
|
269
411
|
if not (
|
|
270
412
|
revisions := [
|
|
271
413
|
rev
|
|
272
414
|
async for rev in await session.stream_scalars(
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
277
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
278
|
-
)
|
|
415
|
+
get_dataset_example_revisions(
|
|
416
|
+
resolved_version_id,
|
|
417
|
+
split_ids=resolved_split_ids,
|
|
279
418
|
)
|
|
280
419
|
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
281
420
|
.options(
|
|
@@ -302,18 +441,24 @@ class Subscription:
|
|
|
302
441
|
description="Traces from prompt playground",
|
|
303
442
|
)
|
|
304
443
|
)
|
|
444
|
+
user_id = get_user(info)
|
|
305
445
|
experiment = models.Experiment(
|
|
306
446
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
307
447
|
dataset_version_id=resolved_version_id,
|
|
308
448
|
name=input.experiment_name
|
|
309
449
|
or _default_playground_experiment_name(input.prompt_name),
|
|
310
450
|
description=input.experiment_description,
|
|
311
|
-
repetitions=
|
|
451
|
+
repetitions=input.repetitions,
|
|
312
452
|
metadata_=input.experiment_metadata or dict(),
|
|
313
453
|
project_name=project_name,
|
|
454
|
+
user_id=user_id,
|
|
314
455
|
)
|
|
315
|
-
|
|
316
|
-
|
|
456
|
+
if resolved_split_ids:
|
|
457
|
+
experiment.experiment_dataset_splits = [
|
|
458
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
459
|
+
for split_id in resolved_split_ids
|
|
460
|
+
]
|
|
461
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
317
462
|
yield ChatCompletionSubscriptionExperiment(
|
|
318
463
|
experiment=to_gql_experiment(experiment)
|
|
319
464
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
@@ -327,11 +472,15 @@ class Subscription:
|
|
|
327
472
|
llm_client=llm_client,
|
|
328
473
|
revision=revision,
|
|
329
474
|
results=results,
|
|
475
|
+
repetition_number=repetition_number,
|
|
330
476
|
experiment_id=experiment.id,
|
|
331
477
|
project_id=playground_project_id,
|
|
332
478
|
),
|
|
333
479
|
)
|
|
334
480
|
for revision in revisions
|
|
481
|
+
for repetition_number in reversed(
|
|
482
|
+
range(1, input.repetitions + 1)
|
|
483
|
+
) # since we pop right, this runs the repetitions in increasing order
|
|
335
484
|
]
|
|
336
485
|
in_progress: list[
|
|
337
486
|
tuple[
|
|
@@ -409,6 +558,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
409
558
|
input: ChatCompletionOverDatasetInput,
|
|
410
559
|
llm_client: PlaygroundStreamingClient,
|
|
411
560
|
revision: models.DatasetExampleRevision,
|
|
561
|
+
repetition_number: int,
|
|
412
562
|
results: asyncio.Queue[ChatCompletionResult],
|
|
413
563
|
experiment_id: int,
|
|
414
564
|
project_id: int,
|
|
@@ -435,7 +585,11 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
435
585
|
)
|
|
436
586
|
except TemplateFormatterError as error:
|
|
437
587
|
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
438
|
-
yield ChatCompletionSubscriptionError(
|
|
588
|
+
yield ChatCompletionSubscriptionError(
|
|
589
|
+
message=str(error),
|
|
590
|
+
dataset_example_id=example_id,
|
|
591
|
+
repetition_number=repetition_number,
|
|
592
|
+
)
|
|
439
593
|
await results.put(
|
|
440
594
|
(
|
|
441
595
|
example_id,
|
|
@@ -445,7 +599,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
445
599
|
dataset_example_id=revision.dataset_example_id,
|
|
446
600
|
trace_id=None,
|
|
447
601
|
output={},
|
|
448
|
-
repetition_number=
|
|
602
|
+
repetition_number=repetition_number,
|
|
449
603
|
start_time=format_start_time,
|
|
450
604
|
end_time=format_end_time,
|
|
451
605
|
error=str(error),
|
|
@@ -460,22 +614,31 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
460
614
|
invocation_parameters=invocation_parameters,
|
|
461
615
|
attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
|
|
462
616
|
) as span:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
617
|
+
try:
|
|
618
|
+
async for chunk in llm_client.chat_completion_create(
|
|
619
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
620
|
+
):
|
|
621
|
+
span.add_response_chunk(chunk)
|
|
622
|
+
chunk.dataset_example_id = example_id
|
|
623
|
+
chunk.repetition_number = repetition_number
|
|
624
|
+
yield chunk
|
|
625
|
+
finally:
|
|
626
|
+
span.set_attributes(llm_client.attributes)
|
|
470
627
|
db_trace = get_db_trace(span, project_id)
|
|
471
628
|
db_span = get_db_span(span, db_trace)
|
|
472
629
|
db_run = get_db_experiment_run(
|
|
473
|
-
db_span,
|
|
630
|
+
db_span,
|
|
631
|
+
db_trace,
|
|
632
|
+
experiment_id=experiment_id,
|
|
633
|
+
example_id=revision.dataset_example_id,
|
|
634
|
+
repetition_number=repetition_number,
|
|
474
635
|
)
|
|
475
636
|
await results.put((example_id, db_span, db_run))
|
|
476
637
|
if span.status_message is not None:
|
|
477
638
|
yield ChatCompletionSubscriptionError(
|
|
478
|
-
message=span.status_message,
|
|
639
|
+
message=span.status_message,
|
|
640
|
+
dataset_example_id=example_id,
|
|
641
|
+
repetition_number=repetition_number,
|
|
479
642
|
)
|
|
480
643
|
|
|
481
644
|
|
|
@@ -508,9 +671,10 @@ async def _chat_completion_result_payloads(
|
|
|
508
671
|
await session.flush()
|
|
509
672
|
for example_id, span, run in results:
|
|
510
673
|
yield ChatCompletionSubscriptionResult(
|
|
511
|
-
span=Span(
|
|
512
|
-
experiment_run=
|
|
674
|
+
span=Span(id=span.id, db_record=span) if span else None,
|
|
675
|
+
experiment_run=ExperimentRun(id=run.id, db_record=run),
|
|
513
676
|
dataset_example_id=example_id,
|
|
677
|
+
repetition_number=run.repetition_number,
|
|
514
678
|
)
|
|
515
679
|
|
|
516
680
|
|
|
@@ -1,31 +1,98 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from strawberry.scalars import JSON
|
|
6
|
+
from strawberry.types import Info
|
|
5
7
|
|
|
6
|
-
from phoenix.server.api.
|
|
8
|
+
from phoenix.server.api.context import Context
|
|
9
|
+
|
|
10
|
+
from .AnnotationSource import AnnotationSource
|
|
11
|
+
from .AnnotatorKind import AnnotatorKind
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .User import User
|
|
7
15
|
|
|
8
16
|
|
|
9
17
|
@strawberry.interface
|
|
10
18
|
class Annotation:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
)
|
|
19
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
20
|
+
async def name(
|
|
21
|
+
self,
|
|
22
|
+
info: Info[Context, None],
|
|
23
|
+
) -> str:
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
27
|
+
async def annotator_kind(
|
|
28
|
+
self,
|
|
29
|
+
info: Info[Context, None],
|
|
30
|
+
) -> AnnotatorKind:
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
@strawberry.field(
|
|
34
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
35
|
+
) # type: ignore
|
|
36
|
+
async def label(
|
|
37
|
+
self,
|
|
38
|
+
info: Info[Context, None],
|
|
39
|
+
) -> Optional[str]:
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
43
|
+
async def score(
|
|
44
|
+
self,
|
|
45
|
+
info: Info[Context, None],
|
|
46
|
+
) -> Optional[float]:
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@strawberry.field(
|
|
50
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
51
|
+
) # type: ignore
|
|
52
|
+
async def explanation(
|
|
53
|
+
self,
|
|
54
|
+
info: Info[Context, None],
|
|
55
|
+
) -> Optional[str]:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
59
|
+
async def metadata(
|
|
60
|
+
self,
|
|
61
|
+
info: Info[Context, None],
|
|
62
|
+
) -> JSON:
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
66
|
+
async def source(
|
|
67
|
+
self,
|
|
68
|
+
info: Info[Context, None],
|
|
69
|
+
) -> AnnotationSource:
|
|
70
|
+
raise NotImplementedError
|
|
71
|
+
|
|
72
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
73
|
+
async def identifier(
|
|
74
|
+
self,
|
|
75
|
+
info: Info[Context, None],
|
|
76
|
+
) -> str:
|
|
77
|
+
raise NotImplementedError
|
|
78
|
+
|
|
79
|
+
@strawberry.field(description="The date and time the annotation was created.") # type: ignore
|
|
80
|
+
async def created_at(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
) -> datetime:
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
@strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
|
|
87
|
+
async def updated_at(
|
|
88
|
+
self,
|
|
89
|
+
info: Info[Context, None],
|
|
90
|
+
) -> datetime:
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
94
|
+
async def user(
|
|
95
|
+
self,
|
|
96
|
+
info: Info[Context, None],
|
|
97
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
98
|
+
raise NotImplementedError
|
|
@@ -3,25 +3,21 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
|
|
6
|
-
from phoenix.db.models import ApiKey as ORMApiKey
|
|
7
|
-
|
|
8
6
|
|
|
9
7
|
@strawberry.interface
|
|
10
8
|
class ApiKey:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
)
|
|
9
|
+
@strawberry.field(description="Name of the API key.") # type: ignore
|
|
10
|
+
async def name(self) -> str:
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
@strawberry.field(description="Description of the API key.") # type: ignore
|
|
14
|
+
async def description(self) -> Optional[str]:
|
|
15
|
+
raise NotImplementedError
|
|
19
16
|
|
|
17
|
+
@strawberry.field(description="The date and time the API key was created.") # type: ignore
|
|
18
|
+
async def created_at(self) -> datetime:
|
|
19
|
+
raise NotImplementedError
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
description=api_key.description,
|
|
25
|
-
created_at=api_key.created_at,
|
|
26
|
-
expires_at=api_key.expires_at,
|
|
27
|
-
)
|
|
21
|
+
@strawberry.field(description="The date and time the API key will expire.") # type: ignore
|
|
22
|
+
async def expires_at(self) -> Optional[datetime]:
|
|
23
|
+
raise NotImplementedError
|