arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
+
from collections import deque
|
|
3
4
|
from collections.abc import AsyncIterator, Iterator
|
|
4
5
|
from datetime import datetime, timedelta, timezone
|
|
5
6
|
from typing import (
|
|
6
7
|
Any,
|
|
7
8
|
AsyncGenerator,
|
|
9
|
+
Callable,
|
|
8
10
|
Coroutine,
|
|
9
11
|
Iterable,
|
|
10
12
|
Mapping,
|
|
@@ -17,7 +19,7 @@ from typing import (
|
|
|
17
19
|
import strawberry
|
|
18
20
|
from openinference.instrumentation import safe_json_dumps
|
|
19
21
|
from openinference.semconv.trace import SpanAttributes
|
|
20
|
-
from sqlalchemy import and_,
|
|
22
|
+
from sqlalchemy import and_, insert, select
|
|
21
23
|
from sqlalchemy.orm import load_only
|
|
22
24
|
from strawberry.relay.types import GlobalID
|
|
23
25
|
from strawberry.types import Info
|
|
@@ -26,10 +28,15 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
26
28
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
29
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
30
|
from phoenix.db import models
|
|
29
|
-
from phoenix.
|
|
31
|
+
from phoenix.db.helpers import (
|
|
32
|
+
get_dataset_example_revisions,
|
|
33
|
+
insert_experiment_with_examples_snapshot,
|
|
34
|
+
)
|
|
35
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
30
36
|
from phoenix.server.api.context import Context
|
|
31
37
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
32
38
|
from phoenix.server.api.helpers.playground_clients import (
|
|
39
|
+
PlaygroundClientCredential,
|
|
33
40
|
PlaygroundStreamingClient,
|
|
34
41
|
initialize_playground_clients,
|
|
35
42
|
)
|
|
@@ -42,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
42
49
|
get_db_trace,
|
|
43
50
|
streaming_llm_span,
|
|
44
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
45
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
46
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
47
55
|
ChatCompletionInput,
|
|
@@ -58,10 +66,12 @@ from phoenix.server.api.types.Dataset import Dataset
|
|
|
58
66
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
59
67
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
60
68
|
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
61
|
-
from phoenix.server.api.types.ExperimentRun import
|
|
69
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
62
70
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
63
71
|
from phoenix.server.api.types.Span import Span
|
|
72
|
+
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
64
73
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
74
|
+
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
65
75
|
from phoenix.server.types import DbSessionFactory
|
|
66
76
|
from phoenix.utilities.template_formatters import (
|
|
67
77
|
FStringTemplateFormatter,
|
|
@@ -87,9 +97,109 @@ ChatCompletionResult: TypeAlias = tuple[
|
|
|
87
97
|
ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
88
98
|
|
|
89
99
|
|
|
100
|
+
async def _stream_single_chat_completion(
|
|
101
|
+
*,
|
|
102
|
+
input: ChatCompletionInput,
|
|
103
|
+
llm_client: PlaygroundStreamingClient,
|
|
104
|
+
project_id: int,
|
|
105
|
+
repetition_number: int,
|
|
106
|
+
results: asyncio.Queue[tuple[Optional[models.Span], int]],
|
|
107
|
+
) -> ChatStream:
|
|
108
|
+
messages = [
|
|
109
|
+
(
|
|
110
|
+
message.role,
|
|
111
|
+
message.content,
|
|
112
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
113
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
114
|
+
)
|
|
115
|
+
for message in input.messages
|
|
116
|
+
]
|
|
117
|
+
attributes = None
|
|
118
|
+
if template_options := input.template:
|
|
119
|
+
messages = list(
|
|
120
|
+
_formatted_messages(
|
|
121
|
+
messages=messages,
|
|
122
|
+
template_format=template_options.format,
|
|
123
|
+
template_variables=template_options.variables,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
|
|
127
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
128
|
+
async with streaming_llm_span(
|
|
129
|
+
input=input,
|
|
130
|
+
messages=messages,
|
|
131
|
+
invocation_parameters=invocation_parameters,
|
|
132
|
+
attributes=attributes,
|
|
133
|
+
) as span:
|
|
134
|
+
try:
|
|
135
|
+
async for chunk in llm_client.chat_completion_create(
|
|
136
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
137
|
+
):
|
|
138
|
+
span.add_response_chunk(chunk)
|
|
139
|
+
chunk.repetition_number = repetition_number
|
|
140
|
+
yield chunk
|
|
141
|
+
finally:
|
|
142
|
+
span.set_attributes(llm_client.attributes)
|
|
143
|
+
if span.status_message is not None:
|
|
144
|
+
yield ChatCompletionSubscriptionError(
|
|
145
|
+
message=span.status_message,
|
|
146
|
+
repetition_number=repetition_number,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
db_trace = get_db_trace(span, project_id)
|
|
150
|
+
db_span = get_db_span(span, db_trace)
|
|
151
|
+
await results.put((db_span, repetition_number))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _chat_completion_span_result_payloads(
|
|
155
|
+
*,
|
|
156
|
+
db: DbSessionFactory,
|
|
157
|
+
results: Sequence[tuple[Optional[models.Span], int]],
|
|
158
|
+
span_cost_calculator: SpanCostCalculator,
|
|
159
|
+
on_span_insertion: Callable[[], None],
|
|
160
|
+
) -> ChatStream:
|
|
161
|
+
if not results:
|
|
162
|
+
return
|
|
163
|
+
async with db() as session:
|
|
164
|
+
for span, repetition_number in results:
|
|
165
|
+
if span:
|
|
166
|
+
session.add(span)
|
|
167
|
+
await session.flush()
|
|
168
|
+
try:
|
|
169
|
+
span_cost = span_cost_calculator.calculate_cost(
|
|
170
|
+
start_time=span.start_time,
|
|
171
|
+
attributes=span.attributes,
|
|
172
|
+
)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
175
|
+
span_cost = None
|
|
176
|
+
if span_cost:
|
|
177
|
+
span_cost.span_rowid = span.id
|
|
178
|
+
span_cost.trace_rowid = span.trace_rowid
|
|
179
|
+
session.add(span_cost)
|
|
180
|
+
await session.flush()
|
|
181
|
+
for span, repetition_number in results:
|
|
182
|
+
if span:
|
|
183
|
+
yield ChatCompletionSubscriptionResult(
|
|
184
|
+
span=Span(id=span.id, db_record=span),
|
|
185
|
+
repetition_number=repetition_number,
|
|
186
|
+
)
|
|
187
|
+
on_span_insertion()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _is_span_result_payloads_stream(
|
|
191
|
+
stream: ChatStream,
|
|
192
|
+
) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
Checks if the given generator was instantiated from
|
|
195
|
+
`_chat_completion_span_result_payloads`
|
|
196
|
+
"""
|
|
197
|
+
return stream.ag_code == _chat_completion_span_result_payloads.__code__ # type: ignore
|
|
198
|
+
|
|
199
|
+
|
|
90
200
|
@strawberry.type
|
|
91
201
|
class Subscription:
|
|
92
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
202
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
93
203
|
async def chat_completion(
|
|
94
204
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
95
205
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -98,9 +208,17 @@ class Subscription:
|
|
|
98
208
|
if llm_client_class is None:
|
|
99
209
|
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
100
210
|
try:
|
|
211
|
+
# Convert GraphQL credentials to PlaygroundCredential objects
|
|
212
|
+
playground_credentials = None
|
|
213
|
+
if input.credentials:
|
|
214
|
+
playground_credentials = [
|
|
215
|
+
PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
|
|
216
|
+
for cred in input.credentials
|
|
217
|
+
]
|
|
218
|
+
|
|
101
219
|
llm_client = llm_client_class(
|
|
102
220
|
model=input.model,
|
|
103
|
-
|
|
221
|
+
credentials=playground_credentials,
|
|
104
222
|
)
|
|
105
223
|
except CustomGraphQLError:
|
|
106
224
|
raise
|
|
@@ -110,42 +228,6 @@ class Subscription:
|
|
|
110
228
|
f"{str(error)}"
|
|
111
229
|
)
|
|
112
230
|
|
|
113
|
-
messages = [
|
|
114
|
-
(
|
|
115
|
-
message.role,
|
|
116
|
-
message.content,
|
|
117
|
-
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
118
|
-
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
119
|
-
)
|
|
120
|
-
for message in input.messages
|
|
121
|
-
]
|
|
122
|
-
attributes = None
|
|
123
|
-
if template_options := input.template:
|
|
124
|
-
messages = list(
|
|
125
|
-
_formatted_messages(
|
|
126
|
-
messages=messages,
|
|
127
|
-
template_format=template_options.format,
|
|
128
|
-
template_variables=template_options.variables,
|
|
129
|
-
)
|
|
130
|
-
)
|
|
131
|
-
attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
|
|
132
|
-
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
133
|
-
input.invocation_parameters
|
|
134
|
-
)
|
|
135
|
-
async with streaming_llm_span(
|
|
136
|
-
input=input,
|
|
137
|
-
messages=messages,
|
|
138
|
-
invocation_parameters=invocation_parameters,
|
|
139
|
-
attributes=attributes,
|
|
140
|
-
) as span:
|
|
141
|
-
async for chunk in llm_client.chat_completion_create(
|
|
142
|
-
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
143
|
-
):
|
|
144
|
-
span.add_response_chunk(chunk)
|
|
145
|
-
yield chunk
|
|
146
|
-
span.set_attributes(llm_client.attributes)
|
|
147
|
-
if span.status_message is not None:
|
|
148
|
-
yield ChatCompletionSubscriptionError(message=span.status_message)
|
|
149
231
|
async with info.context.db() as session:
|
|
150
232
|
if (
|
|
151
233
|
playground_project_id := await session.scalar(
|
|
@@ -160,14 +242,100 @@ class Subscription:
|
|
|
160
242
|
description="Traces from prompt playground",
|
|
161
243
|
)
|
|
162
244
|
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
245
|
+
|
|
246
|
+
results: asyncio.Queue[tuple[Optional[models.Span], int]] = asyncio.Queue()
|
|
247
|
+
not_started: deque[tuple[int, ChatStream]] = deque(
|
|
248
|
+
(
|
|
249
|
+
repetition_number,
|
|
250
|
+
_stream_single_chat_completion(
|
|
251
|
+
input=input,
|
|
252
|
+
llm_client=llm_client,
|
|
253
|
+
project_id=playground_project_id,
|
|
254
|
+
repetition_number=repetition_number,
|
|
255
|
+
results=results,
|
|
256
|
+
),
|
|
257
|
+
)
|
|
258
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
259
|
+
)
|
|
260
|
+
in_progress: list[
|
|
261
|
+
tuple[
|
|
262
|
+
Optional[int],
|
|
263
|
+
ChatStream,
|
|
264
|
+
asyncio.Task[ChatCompletionSubscriptionPayload],
|
|
265
|
+
]
|
|
266
|
+
] = []
|
|
267
|
+
max_in_progress = 3
|
|
268
|
+
write_batch_size = 10
|
|
269
|
+
write_interval = timedelta(seconds=10)
|
|
270
|
+
last_write_time = datetime.now()
|
|
271
|
+
while not_started or in_progress:
|
|
272
|
+
while not_started and len(in_progress) < max_in_progress:
|
|
273
|
+
rep_num, stream = not_started.popleft()
|
|
274
|
+
task = _create_task_with_timeout(stream)
|
|
275
|
+
in_progress.append((rep_num, stream, task))
|
|
276
|
+
async_tasks_to_run = [task for _, _, task in in_progress]
|
|
277
|
+
completed_tasks, _ = await asyncio.wait(
|
|
278
|
+
async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
|
|
279
|
+
)
|
|
280
|
+
for completed_task in completed_tasks:
|
|
281
|
+
idx = [task for _, _, task in in_progress].index(completed_task)
|
|
282
|
+
repetition_number, stream, _ = in_progress[idx]
|
|
283
|
+
try:
|
|
284
|
+
yield completed_task.result()
|
|
285
|
+
except StopAsyncIteration:
|
|
286
|
+
del in_progress[idx] # removes exhausted stream
|
|
287
|
+
except asyncio.TimeoutError:
|
|
288
|
+
del in_progress[idx] # removes timed-out stream
|
|
289
|
+
if repetition_number is not None:
|
|
290
|
+
yield ChatCompletionSubscriptionError(
|
|
291
|
+
message="Playground task timed out",
|
|
292
|
+
repetition_number=repetition_number,
|
|
293
|
+
)
|
|
294
|
+
except Exception as error:
|
|
295
|
+
del in_progress[idx] # removes failed stream
|
|
296
|
+
if repetition_number is not None:
|
|
297
|
+
yield ChatCompletionSubscriptionError(
|
|
298
|
+
message="An unexpected error occurred",
|
|
299
|
+
repetition_number=repetition_number,
|
|
300
|
+
)
|
|
301
|
+
logger.exception(error)
|
|
302
|
+
else:
|
|
303
|
+
task = _create_task_with_timeout(stream)
|
|
304
|
+
in_progress[idx] = (repetition_number, stream, task)
|
|
305
|
+
|
|
306
|
+
exceeded_write_batch_size = results.qsize() >= write_batch_size
|
|
307
|
+
exceeded_write_interval = datetime.now() - last_write_time > write_interval
|
|
308
|
+
write_already_in_progress = any(
|
|
309
|
+
_is_span_result_payloads_stream(stream) for _, stream, _ in in_progress
|
|
310
|
+
)
|
|
311
|
+
if (
|
|
312
|
+
not results.empty()
|
|
313
|
+
and (exceeded_write_batch_size or exceeded_write_interval)
|
|
314
|
+
and not write_already_in_progress
|
|
315
|
+
):
|
|
316
|
+
result_payloads_stream = _chat_completion_span_result_payloads(
|
|
317
|
+
db=info.context.db,
|
|
318
|
+
results=_drain_no_wait(results),
|
|
319
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
320
|
+
on_span_insertion=lambda: info.context.event_queue.put(
|
|
321
|
+
SpanInsertEvent(ids=(playground_project_id,))
|
|
322
|
+
),
|
|
323
|
+
)
|
|
324
|
+
task = _create_task_with_timeout(result_payloads_stream)
|
|
325
|
+
in_progress.append((None, result_payloads_stream, task))
|
|
326
|
+
last_write_time = datetime.now()
|
|
327
|
+
if remaining_results := await _drain(results):
|
|
328
|
+
async for result_payload in _chat_completion_span_result_payloads(
|
|
329
|
+
db=info.context.db,
|
|
330
|
+
results=remaining_results,
|
|
331
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
332
|
+
on_span_insertion=lambda: info.context.event_queue.put(
|
|
333
|
+
SpanInsertEvent(ids=(playground_project_id,))
|
|
334
|
+
),
|
|
335
|
+
):
|
|
336
|
+
yield result_payload
|
|
337
|
+
|
|
338
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
171
339
|
async def chat_completion_over_dataset(
|
|
172
340
|
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
173
341
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -176,9 +344,17 @@ class Subscription:
|
|
|
176
344
|
if llm_client_class is None:
|
|
177
345
|
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
178
346
|
try:
|
|
347
|
+
# Convert GraphQL credentials to PlaygroundCredential objects
|
|
348
|
+
playground_credentials = None
|
|
349
|
+
if input.credentials:
|
|
350
|
+
playground_credentials = [
|
|
351
|
+
PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
|
|
352
|
+
for cred in input.credentials
|
|
353
|
+
]
|
|
354
|
+
|
|
179
355
|
llm_client = llm_client_class(
|
|
180
356
|
model=input.model,
|
|
181
|
-
|
|
357
|
+
credentials=playground_credentials,
|
|
182
358
|
)
|
|
183
359
|
except CustomGraphQLError:
|
|
184
360
|
raise
|
|
@@ -223,27 +399,22 @@ class Subscription:
|
|
|
223
399
|
)
|
|
224
400
|
) is None:
|
|
225
401
|
raise NotFound(f"Could not find dataset version with ID {version_id}")
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
236
|
-
)
|
|
402
|
+
|
|
403
|
+
# Parse split IDs if provided
|
|
404
|
+
resolved_split_ids: Optional[list[int]] = None
|
|
405
|
+
if input.split_ids is not None and len(input.split_ids) > 0:
|
|
406
|
+
resolved_split_ids = [
|
|
407
|
+
from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
|
|
408
|
+
for split_id in input.split_ids
|
|
409
|
+
]
|
|
410
|
+
|
|
237
411
|
if not (
|
|
238
412
|
revisions := [
|
|
239
413
|
rev
|
|
240
414
|
async for rev in await session.stream_scalars(
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
245
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
246
|
-
)
|
|
415
|
+
get_dataset_example_revisions(
|
|
416
|
+
resolved_version_id,
|
|
417
|
+
split_ids=resolved_split_ids,
|
|
247
418
|
)
|
|
248
419
|
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
249
420
|
.options(
|
|
@@ -256,31 +427,38 @@ class Subscription:
|
|
|
256
427
|
]
|
|
257
428
|
):
|
|
258
429
|
raise NotFound("No examples found for the given dataset and version")
|
|
430
|
+
project_name = generate_experiment_project_name()
|
|
259
431
|
if (
|
|
260
432
|
playground_project_id := await session.scalar(
|
|
261
|
-
select(models.Project.id).where(models.Project.name ==
|
|
433
|
+
select(models.Project.id).where(models.Project.name == project_name)
|
|
262
434
|
)
|
|
263
435
|
) is None:
|
|
264
436
|
playground_project_id = await session.scalar(
|
|
265
437
|
insert(models.Project)
|
|
266
438
|
.returning(models.Project.id)
|
|
267
439
|
.values(
|
|
268
|
-
name=
|
|
440
|
+
name=project_name,
|
|
269
441
|
description="Traces from prompt playground",
|
|
270
442
|
)
|
|
271
443
|
)
|
|
444
|
+
user_id = get_user(info)
|
|
272
445
|
experiment = models.Experiment(
|
|
273
446
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
274
447
|
dataset_version_id=resolved_version_id,
|
|
275
448
|
name=input.experiment_name
|
|
276
449
|
or _default_playground_experiment_name(input.prompt_name),
|
|
277
450
|
description=input.experiment_description,
|
|
278
|
-
repetitions=
|
|
451
|
+
repetitions=input.repetitions,
|
|
279
452
|
metadata_=input.experiment_metadata or dict(),
|
|
280
|
-
project_name=
|
|
453
|
+
project_name=project_name,
|
|
454
|
+
user_id=user_id,
|
|
281
455
|
)
|
|
282
|
-
|
|
283
|
-
|
|
456
|
+
if resolved_split_ids:
|
|
457
|
+
experiment.experiment_dataset_splits = [
|
|
458
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
459
|
+
for split_id in resolved_split_ids
|
|
460
|
+
]
|
|
461
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
284
462
|
yield ChatCompletionSubscriptionExperiment(
|
|
285
463
|
experiment=to_gql_experiment(experiment)
|
|
286
464
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
@@ -294,11 +472,15 @@ class Subscription:
|
|
|
294
472
|
llm_client=llm_client,
|
|
295
473
|
revision=revision,
|
|
296
474
|
results=results,
|
|
475
|
+
repetition_number=repetition_number,
|
|
297
476
|
experiment_id=experiment.id,
|
|
298
477
|
project_id=playground_project_id,
|
|
299
478
|
),
|
|
300
479
|
)
|
|
301
480
|
for revision in revisions
|
|
481
|
+
for repetition_number in reversed(
|
|
482
|
+
range(1, input.repetitions + 1)
|
|
483
|
+
) # since we pop right, this runs the repetitions in increasing order
|
|
302
484
|
]
|
|
303
485
|
in_progress: list[
|
|
304
486
|
tuple[
|
|
@@ -355,14 +537,18 @@ class Subscription:
|
|
|
355
537
|
and not write_already_in_progress
|
|
356
538
|
):
|
|
357
539
|
result_payloads_stream = _chat_completion_result_payloads(
|
|
358
|
-
db=info.context.db,
|
|
540
|
+
db=info.context.db,
|
|
541
|
+
results=_drain_no_wait(results),
|
|
542
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
359
543
|
)
|
|
360
544
|
task = _create_task_with_timeout(result_payloads_stream)
|
|
361
545
|
in_progress.append((None, result_payloads_stream, task))
|
|
362
546
|
last_write_time = datetime.now()
|
|
363
547
|
if remaining_results := await _drain(results):
|
|
364
548
|
async for result_payload in _chat_completion_result_payloads(
|
|
365
|
-
db=info.context.db,
|
|
549
|
+
db=info.context.db,
|
|
550
|
+
results=remaining_results,
|
|
551
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
366
552
|
):
|
|
367
553
|
yield result_payload
|
|
368
554
|
|
|
@@ -372,6 +558,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
372
558
|
input: ChatCompletionOverDatasetInput,
|
|
373
559
|
llm_client: PlaygroundStreamingClient,
|
|
374
560
|
revision: models.DatasetExampleRevision,
|
|
561
|
+
repetition_number: int,
|
|
375
562
|
results: asyncio.Queue[ChatCompletionResult],
|
|
376
563
|
experiment_id: int,
|
|
377
564
|
project_id: int,
|
|
@@ -398,7 +585,11 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
398
585
|
)
|
|
399
586
|
except TemplateFormatterError as error:
|
|
400
587
|
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
401
|
-
yield ChatCompletionSubscriptionError(
|
|
588
|
+
yield ChatCompletionSubscriptionError(
|
|
589
|
+
message=str(error),
|
|
590
|
+
dataset_example_id=example_id,
|
|
591
|
+
repetition_number=repetition_number,
|
|
592
|
+
)
|
|
402
593
|
await results.put(
|
|
403
594
|
(
|
|
404
595
|
example_id,
|
|
@@ -408,7 +599,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
408
599
|
dataset_example_id=revision.dataset_example_id,
|
|
409
600
|
trace_id=None,
|
|
410
601
|
output={},
|
|
411
|
-
repetition_number=
|
|
602
|
+
repetition_number=repetition_number,
|
|
412
603
|
start_time=format_start_time,
|
|
413
604
|
end_time=format_end_time,
|
|
414
605
|
error=str(error),
|
|
@@ -423,22 +614,31 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
423
614
|
invocation_parameters=invocation_parameters,
|
|
424
615
|
attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
|
|
425
616
|
) as span:
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
617
|
+
try:
|
|
618
|
+
async for chunk in llm_client.chat_completion_create(
|
|
619
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
620
|
+
):
|
|
621
|
+
span.add_response_chunk(chunk)
|
|
622
|
+
chunk.dataset_example_id = example_id
|
|
623
|
+
chunk.repetition_number = repetition_number
|
|
624
|
+
yield chunk
|
|
625
|
+
finally:
|
|
626
|
+
span.set_attributes(llm_client.attributes)
|
|
433
627
|
db_trace = get_db_trace(span, project_id)
|
|
434
628
|
db_span = get_db_span(span, db_trace)
|
|
435
629
|
db_run = get_db_experiment_run(
|
|
436
|
-
db_span,
|
|
630
|
+
db_span,
|
|
631
|
+
db_trace,
|
|
632
|
+
experiment_id=experiment_id,
|
|
633
|
+
example_id=revision.dataset_example_id,
|
|
634
|
+
repetition_number=repetition_number,
|
|
437
635
|
)
|
|
438
636
|
await results.put((example_id, db_span, db_run))
|
|
439
637
|
if span.status_message is not None:
|
|
440
638
|
yield ChatCompletionSubscriptionError(
|
|
441
|
-
message=span.status_message,
|
|
639
|
+
message=span.status_message,
|
|
640
|
+
dataset_example_id=example_id,
|
|
641
|
+
repetition_number=repetition_number,
|
|
442
642
|
)
|
|
443
643
|
|
|
444
644
|
|
|
@@ -446,6 +646,7 @@ async def _chat_completion_result_payloads(
|
|
|
446
646
|
*,
|
|
447
647
|
db: DbSessionFactory,
|
|
448
648
|
results: Sequence[ChatCompletionResult],
|
|
649
|
+
span_cost_calculator: SpanCostCalculator,
|
|
449
650
|
) -> ChatStream:
|
|
450
651
|
if not results:
|
|
451
652
|
return
|
|
@@ -453,13 +654,27 @@ async def _chat_completion_result_payloads(
|
|
|
453
654
|
for _, span, run in results:
|
|
454
655
|
if span:
|
|
455
656
|
session.add(span)
|
|
657
|
+
await session.flush()
|
|
658
|
+
try:
|
|
659
|
+
span_cost = span_cost_calculator.calculate_cost(
|
|
660
|
+
start_time=span.start_time,
|
|
661
|
+
attributes=span.attributes,
|
|
662
|
+
)
|
|
663
|
+
except Exception as e:
|
|
664
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
665
|
+
span_cost = None
|
|
666
|
+
if span_cost:
|
|
667
|
+
span_cost.span_rowid = span.id
|
|
668
|
+
span_cost.trace_rowid = span.trace_rowid
|
|
669
|
+
session.add(span_cost)
|
|
456
670
|
session.add(run)
|
|
457
671
|
await session.flush()
|
|
458
672
|
for example_id, span, run in results:
|
|
459
673
|
yield ChatCompletionSubscriptionResult(
|
|
460
|
-
span=Span(
|
|
461
|
-
experiment_run=
|
|
674
|
+
span=Span(id=span.id, db_record=span) if span else None,
|
|
675
|
+
experiment_run=ExperimentRun(id=run.id, db_record=run),
|
|
462
676
|
dataset_example_id=example_id,
|
|
677
|
+
repetition_number=run.repetition_number,
|
|
463
678
|
)
|
|
464
679
|
|
|
465
680
|
|
|
@@ -577,3 +792,5 @@ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
|
577
792
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
578
793
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
579
794
|
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
795
|
+
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
796
|
+
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|