arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
- phoenix/config.py +131 -9
- phoenix/db/engines.py +127 -14
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/pg_config.py +10 -0
- phoenix/server/api/context.py +23 -0
- phoenix/server/api/dataloaders/__init__.py +6 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/helpers/playground_clients.py +3 -3
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
- phoenix/server/api/mutations/api_key_mutations.py +2 -15
- phoenix/server/api/mutations/chat_mutations.py +3 -2
- phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
- phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
- phoenix/server/api/mutations/prompt_mutations.py +24 -117
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
- phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/queries.py +80 -213
- phoenix/server/api/subscriptions.py +4 -4
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/Dataset.py +88 -48
- phoenix/server/api/types/DatasetExample.py +34 -30
- phoenix/server/api/types/DatasetLabel.py +47 -13
- phoenix/server/api/types/DatasetSplit.py +87 -21
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +182 -62
- phoenix/server/api/types/Experiment.py +146 -55
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
- phoenix/server/api/types/ExperimentRun.py +118 -61
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +70 -75
- phoenix/server/api/types/ProjectSession.py +69 -37
- phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +82 -44
- phoenix/server/api/types/PromptLabel.py +47 -13
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +116 -115
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +45 -44
- phoenix/server/api/types/TraceAnnotation.py +144 -48
- phoenix/server/api/types/User.py +103 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/app.py +29 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
- phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
- phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
- phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
phoenix/db/engines.py
CHANGED
|
@@ -5,7 +5,7 @@ import logging
|
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from sqlite3 import Connection
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any, Optional
|
|
9
9
|
|
|
10
10
|
import aiosqlite
|
|
11
11
|
import numpy as np
|
|
@@ -168,29 +168,142 @@ def aio_postgresql_engine(
|
|
|
168
168
|
log_to_stdout: bool = False,
|
|
169
169
|
log_migrations_to_stdout: bool = True,
|
|
170
170
|
) -> AsyncEngine:
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
connect_args=asyncpg_args,
|
|
175
|
-
echo=log_to_stdout,
|
|
176
|
-
json_serializer=_dumps,
|
|
171
|
+
from phoenix.config import (
|
|
172
|
+
get_env_postgres_iam_token_lifetime,
|
|
173
|
+
get_env_postgres_use_iam_auth,
|
|
177
174
|
)
|
|
175
|
+
|
|
176
|
+
use_iam_auth = get_env_postgres_use_iam_auth()
|
|
177
|
+
|
|
178
|
+
asyncpg_url, asyncpg_args = get_pg_config(url, "asyncpg", enforce_ssl=use_iam_auth)
|
|
179
|
+
|
|
180
|
+
iam_config: Optional[dict[str, Any]] = None
|
|
181
|
+
token_lifetime: int = 0
|
|
182
|
+
if use_iam_auth:
|
|
183
|
+
iam_config = _extract_iam_config_from_url(url)
|
|
184
|
+
token_lifetime = get_env_postgres_iam_token_lifetime()
|
|
185
|
+
|
|
186
|
+
async def iam_async_creator() -> Any:
|
|
187
|
+
import asyncpg # type: ignore
|
|
188
|
+
|
|
189
|
+
from phoenix.db.iam_auth import generate_aws_rds_token
|
|
190
|
+
|
|
191
|
+
assert iam_config is not None
|
|
192
|
+
token = generate_aws_rds_token(
|
|
193
|
+
host=iam_config["host"],
|
|
194
|
+
port=iam_config["port"],
|
|
195
|
+
user=iam_config["user"],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
conn_kwargs = {
|
|
199
|
+
"host": iam_config["host"],
|
|
200
|
+
"port": iam_config["port"],
|
|
201
|
+
"user": iam_config["user"],
|
|
202
|
+
"password": token,
|
|
203
|
+
"database": iam_config["database"],
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
if asyncpg_args:
|
|
207
|
+
conn_kwargs.update(asyncpg_args)
|
|
208
|
+
|
|
209
|
+
return await asyncpg.connect(**conn_kwargs)
|
|
210
|
+
|
|
211
|
+
engine = create_async_engine(
|
|
212
|
+
url=asyncpg_url,
|
|
213
|
+
async_creator=iam_async_creator,
|
|
214
|
+
echo=log_to_stdout,
|
|
215
|
+
json_serializer=_dumps,
|
|
216
|
+
pool_recycle=token_lifetime,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
engine = create_async_engine(
|
|
220
|
+
url=asyncpg_url,
|
|
221
|
+
connect_args=asyncpg_args,
|
|
222
|
+
echo=log_to_stdout,
|
|
223
|
+
json_serializer=_dumps,
|
|
224
|
+
)
|
|
225
|
+
|
|
178
226
|
if not migrate:
|
|
179
227
|
return engine
|
|
180
228
|
|
|
181
|
-
psycopg_url, psycopg_args = get_pg_config(url, "psycopg")
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
229
|
+
psycopg_url, psycopg_args = get_pg_config(url, "psycopg", enforce_ssl=use_iam_auth)
|
|
230
|
+
|
|
231
|
+
if use_iam_auth:
|
|
232
|
+
assert iam_config is not None
|
|
233
|
+
|
|
234
|
+
def iam_sync_creator() -> Any:
|
|
235
|
+
import psycopg
|
|
236
|
+
|
|
237
|
+
from phoenix.db.iam_auth import generate_aws_rds_token
|
|
238
|
+
|
|
239
|
+
token = generate_aws_rds_token(
|
|
240
|
+
host=iam_config["host"],
|
|
241
|
+
port=iam_config["port"],
|
|
242
|
+
user=iam_config["user"],
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
conn_kwargs = {
|
|
246
|
+
"host": iam_config["host"],
|
|
247
|
+
"port": iam_config["port"],
|
|
248
|
+
"user": iam_config["user"],
|
|
249
|
+
"password": token,
|
|
250
|
+
"dbname": iam_config["database"],
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
if psycopg_args:
|
|
254
|
+
conn_kwargs.update(psycopg_args)
|
|
255
|
+
|
|
256
|
+
return psycopg.connect(**conn_kwargs)
|
|
257
|
+
|
|
258
|
+
sync_engine = sqlalchemy.create_engine(
|
|
259
|
+
url=psycopg_url,
|
|
260
|
+
creator=iam_sync_creator,
|
|
261
|
+
echo=log_migrations_to_stdout,
|
|
262
|
+
json_serializer=_dumps,
|
|
263
|
+
pool_recycle=token_lifetime,
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
sync_engine = sqlalchemy.create_engine(
|
|
267
|
+
url=psycopg_url,
|
|
268
|
+
connect_args=psycopg_args,
|
|
269
|
+
echo=log_migrations_to_stdout,
|
|
270
|
+
json_serializer=_dumps,
|
|
271
|
+
)
|
|
272
|
+
|
|
188
273
|
if schema := get_env_database_schema():
|
|
189
274
|
event.listen(sync_engine, "connect", set_postgresql_search_path(schema))
|
|
190
275
|
migrate_in_thread(sync_engine)
|
|
191
276
|
return engine
|
|
192
277
|
|
|
193
278
|
|
|
279
|
+
def _extract_iam_config_from_url(url: URL) -> dict[str, Any]:
|
|
280
|
+
"""Extract connection parameters needed for IAM authentication from a SQLAlchemy URL.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
url: SQLAlchemy database URL
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Dictionary with host, port, user, and database
|
|
287
|
+
"""
|
|
288
|
+
host = url.host
|
|
289
|
+
if not host:
|
|
290
|
+
raise ValueError("Database host is required for IAM authentication")
|
|
291
|
+
|
|
292
|
+
port = url.port or 5432
|
|
293
|
+
user = url.username
|
|
294
|
+
if not user:
|
|
295
|
+
raise ValueError("Database user is required for IAM authentication")
|
|
296
|
+
|
|
297
|
+
database = url.database or "postgres"
|
|
298
|
+
|
|
299
|
+
return {
|
|
300
|
+
"host": host,
|
|
301
|
+
"port": port,
|
|
302
|
+
"user": user,
|
|
303
|
+
"database": database,
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
|
|
194
307
|
def _dumps(obj: Any) -> str:
|
|
195
308
|
return orjson.dumps(obj, default=_default).decode()
|
|
196
309
|
|
phoenix/db/iam_auth.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def generate_aws_rds_token(
|
|
9
|
+
host: str,
|
|
10
|
+
port: int,
|
|
11
|
+
user: str,
|
|
12
|
+
) -> str:
|
|
13
|
+
"""Generate an AWS RDS IAM authentication token.
|
|
14
|
+
|
|
15
|
+
This function creates a short-lived (15 minutes) authentication token for connecting
|
|
16
|
+
to AWS RDS/Aurora PostgreSQL instances using IAM database authentication.
|
|
17
|
+
|
|
18
|
+
The AWS region is automatically resolved using boto3.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
host: The database hostname (e.g., 'mydb.abc123.us-west-2.rds.amazonaws.com')
|
|
22
|
+
port: The database port (typically 5432 for PostgreSQL)
|
|
23
|
+
user: The database username (must match an IAM-enabled database user)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A temporary authentication token string to use as the database password
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ImportError: If boto3 is not installed
|
|
30
|
+
Exception: If AWS credentials/region are not configured or token generation fails
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
>>> token = generate_aws_rds_token(
|
|
34
|
+
... host='mydb.us-west-2.rds.amazonaws.com',
|
|
35
|
+
... port=5432,
|
|
36
|
+
... user='myuser'
|
|
37
|
+
... )
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
import boto3 # type: ignore
|
|
41
|
+
except ImportError as e:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"boto3 is required for AWS RDS IAM authentication. "
|
|
44
|
+
"Install it with: pip install 'arize-phoenix[aws]'"
|
|
45
|
+
) from e
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
client = boto3.client("rds")
|
|
49
|
+
|
|
50
|
+
logger.debug(f"Generating AWS RDS IAM auth token for user '{user}' at {host}:{port}")
|
|
51
|
+
token = client.generate_db_auth_token( # pyright: ignore
|
|
52
|
+
DBHostname=host,
|
|
53
|
+
Port=port,
|
|
54
|
+
DBUsername=user,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return str(token) # pyright: ignore
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(
|
|
61
|
+
f"Failed to generate AWS RDS IAM authentication token: {e}. "
|
|
62
|
+
"Ensure AWS credentials are configured and have 'rds-db:connect' permission."
|
|
63
|
+
)
|
|
64
|
+
raise
|
phoenix/db/pg_config.py
CHANGED
|
@@ -10,12 +10,14 @@ from typing_extensions import assert_never
|
|
|
10
10
|
def get_pg_config(
|
|
11
11
|
url: URL,
|
|
12
12
|
driver: Literal["psycopg", "asyncpg"],
|
|
13
|
+
enforce_ssl: bool = False,
|
|
13
14
|
) -> tuple[URL, dict[str, Any]]:
|
|
14
15
|
"""Convert SQLAlchemy URL to driver-specific configuration.
|
|
15
16
|
|
|
16
17
|
Args:
|
|
17
18
|
url: SQLAlchemy URL
|
|
18
19
|
driver: "psycopg" or "asyncpg"
|
|
20
|
+
enforce_ssl: If True, ensure SSL is enabled (required for AWS RDS IAM auth)
|
|
19
21
|
|
|
20
22
|
Returns:
|
|
21
23
|
Tuple of (base_url, connect_args):
|
|
@@ -26,6 +28,14 @@ def get_pg_config(
|
|
|
26
28
|
query = url.query
|
|
27
29
|
ssl_args = _get_ssl_args(query)
|
|
28
30
|
|
|
31
|
+
if enforce_ssl and not ssl_args:
|
|
32
|
+
ssl_args = {"sslmode": "require"}
|
|
33
|
+
elif enforce_ssl and ssl_args.get("sslmode") == "disable":
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"SSL cannot be disabled when using AWS RDS IAM authentication. "
|
|
36
|
+
"Remove 'sslmode=disable' from the connection string."
|
|
37
|
+
)
|
|
38
|
+
|
|
29
39
|
# Create base URL without SSL parameters
|
|
30
40
|
base_url = url.set(
|
|
31
41
|
drivername=f"postgresql+{driver}",
|
phoenix/server/api/context.py
CHANGED
|
@@ -35,6 +35,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
35
35
|
ExperimentRepeatedRunGroupsDataLoader,
|
|
36
36
|
ExperimentRunAnnotations,
|
|
37
37
|
ExperimentRunCountsDataLoader,
|
|
38
|
+
ExperimentRunsByExperimentAndExampleDataLoader,
|
|
38
39
|
ExperimentSequenceNumberDataLoader,
|
|
39
40
|
LastUsedTimesByGenerativeModelIdDataLoader,
|
|
40
41
|
LatencyMsQuantileDataLoader,
|
|
@@ -71,6 +72,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
71
72
|
SpanProjectsDataLoader,
|
|
72
73
|
TableFieldsDataLoader,
|
|
73
74
|
TokenCountDataLoader,
|
|
75
|
+
TokenPricesByModelDataLoader,
|
|
74
76
|
TraceAnnotationsByTraceDataLoader,
|
|
75
77
|
TraceByTraceIdsDataLoader,
|
|
76
78
|
TraceRetentionPolicyIdByProjectIdDataLoader,
|
|
@@ -100,27 +102,38 @@ class DataLoaders:
|
|
|
100
102
|
AverageExperimentRepeatedRunGroupLatencyDataLoader
|
|
101
103
|
)
|
|
102
104
|
average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
|
|
105
|
+
dataset_example_fields: TableFieldsDataLoader
|
|
103
106
|
dataset_example_revisions: DatasetExampleRevisionsDataLoader
|
|
104
107
|
dataset_example_spans: DatasetExampleSpansDataLoader
|
|
105
108
|
dataset_labels: DatasetLabelsDataLoader
|
|
109
|
+
dataset_label_fields: TableFieldsDataLoader
|
|
106
110
|
dataset_dataset_splits: DatasetDatasetSplitsDataLoader
|
|
107
111
|
dataset_examples_and_versions_by_experiment_run: (
|
|
108
112
|
DatasetExamplesAndVersionsByExperimentRunDataLoader
|
|
109
113
|
)
|
|
110
114
|
dataset_example_splits: DatasetExampleSplitsDataLoader
|
|
115
|
+
dataset_fields: TableFieldsDataLoader
|
|
116
|
+
dataset_split_fields: TableFieldsDataLoader
|
|
117
|
+
dataset_version_fields: TableFieldsDataLoader
|
|
118
|
+
document_annotation_fields: TableFieldsDataLoader
|
|
111
119
|
document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
|
|
112
120
|
document_evaluations: DocumentEvaluationsDataLoader
|
|
113
121
|
document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
|
|
114
122
|
experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
|
|
115
123
|
experiment_dataset_splits: ExperimentDatasetSplitsDataLoader
|
|
116
124
|
experiment_error_rates: ExperimentErrorRatesDataLoader
|
|
125
|
+
experiment_fields: TableFieldsDataLoader
|
|
117
126
|
experiment_repeated_run_group_annotation_summaries: (
|
|
118
127
|
ExperimentRepeatedRunGroupAnnotationSummariesDataLoader
|
|
119
128
|
)
|
|
120
129
|
experiment_repeated_run_groups: ExperimentRepeatedRunGroupsDataLoader
|
|
130
|
+
experiment_run_annotation_fields: TableFieldsDataLoader
|
|
121
131
|
experiment_run_annotations: ExperimentRunAnnotations
|
|
122
132
|
experiment_run_counts: ExperimentRunCountsDataLoader
|
|
133
|
+
experiment_run_fields: TableFieldsDataLoader
|
|
134
|
+
experiment_runs_by_experiment_and_example: ExperimentRunsByExperimentAndExampleDataLoader
|
|
123
135
|
experiment_sequence_number: ExperimentSequenceNumberDataLoader
|
|
136
|
+
generative_model_fields: TableFieldsDataLoader
|
|
124
137
|
last_used_times_by_generative_model_id: LastUsedTimesByGenerativeModelIdDataLoader
|
|
125
138
|
latency_ms_quantile: LatencyMsQuantileDataLoader
|
|
126
139
|
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
|
|
@@ -130,7 +143,12 @@ class DataLoaders:
|
|
|
130
143
|
project_fields: TableFieldsDataLoader
|
|
131
144
|
project_trace_retention_policy_fields: TableFieldsDataLoader
|
|
132
145
|
projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
|
|
146
|
+
prompt_fields: TableFieldsDataLoader
|
|
147
|
+
prompt_label_fields: TableFieldsDataLoader
|
|
133
148
|
prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
|
|
149
|
+
prompt_version_tag_fields: TableFieldsDataLoader
|
|
150
|
+
project_session_annotation_fields: TableFieldsDataLoader
|
|
151
|
+
project_session_fields: TableFieldsDataLoader
|
|
134
152
|
record_counts: RecordCountDataLoader
|
|
135
153
|
session_annotations_by_session: SessionAnnotationsBySessionDataLoader
|
|
136
154
|
session_first_inputs: SessionIODataLoader
|
|
@@ -139,6 +157,7 @@ class DataLoaders:
|
|
|
139
157
|
session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
|
|
140
158
|
session_token_usages: SessionTokenUsagesDataLoader
|
|
141
159
|
session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
|
|
160
|
+
span_annotation_fields: TableFieldsDataLoader
|
|
142
161
|
span_annotations: SpanAnnotationsDataLoader
|
|
143
162
|
span_by_id: SpanByIdDataLoader
|
|
144
163
|
span_cost_by_span: SpanCostBySpanDataLoader
|
|
@@ -167,12 +186,16 @@ class DataLoaders:
|
|
|
167
186
|
span_fields: TableFieldsDataLoader
|
|
168
187
|
span_projects: SpanProjectsDataLoader
|
|
169
188
|
token_counts: TokenCountDataLoader
|
|
189
|
+
token_prices_by_model: TokenPricesByModelDataLoader
|
|
190
|
+
trace_annotation_fields: TableFieldsDataLoader
|
|
170
191
|
trace_annotations_by_trace: TraceAnnotationsByTraceDataLoader
|
|
171
192
|
trace_by_trace_ids: TraceByTraceIdsDataLoader
|
|
172
193
|
trace_fields: TableFieldsDataLoader
|
|
173
194
|
trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
|
|
174
195
|
trace_root_spans: TraceRootSpansDataLoader
|
|
175
196
|
user_roles: UserRolesDataLoader
|
|
197
|
+
user_api_key_fields: TableFieldsDataLoader
|
|
198
|
+
user_fields: TableFieldsDataLoader
|
|
176
199
|
users: UsersDataLoader
|
|
177
200
|
|
|
178
201
|
|
|
@@ -33,6 +33,9 @@ from .experiment_repeated_run_group_annotation_summaries import (
|
|
|
33
33
|
from .experiment_repeated_run_groups import ExperimentRepeatedRunGroupsDataLoader
|
|
34
34
|
from .experiment_run_annotations import ExperimentRunAnnotations
|
|
35
35
|
from .experiment_run_counts import ExperimentRunCountsDataLoader
|
|
36
|
+
from .experiment_runs_by_experiment_and_example import (
|
|
37
|
+
ExperimentRunsByExperimentAndExampleDataLoader,
|
|
38
|
+
)
|
|
36
39
|
from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
|
|
37
40
|
from .last_used_times_by_generative_model_id import LastUsedTimesByGenerativeModelIdDataLoader
|
|
38
41
|
from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
|
|
@@ -73,6 +76,7 @@ from .span_descendants import SpanDescendantsDataLoader
|
|
|
73
76
|
from .span_projects import SpanProjectsDataLoader
|
|
74
77
|
from .table_fields import TableFieldsDataLoader
|
|
75
78
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
79
|
+
from .token_prices_by_model import TokenPricesByModelDataLoader
|
|
76
80
|
from .trace_annotations_by_trace import TraceAnnotationsByTraceDataLoader
|
|
77
81
|
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
|
|
78
82
|
from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
|
|
@@ -102,6 +106,7 @@ __all__ = [
|
|
|
102
106
|
"ExperimentRepeatedRunGroupAnnotationSummariesDataLoader",
|
|
103
107
|
"ExperimentRunAnnotations",
|
|
104
108
|
"ExperimentRunCountsDataLoader",
|
|
109
|
+
"ExperimentRunsByExperimentAndExampleDataLoader",
|
|
105
110
|
"ExperimentSequenceNumberDataLoader",
|
|
106
111
|
"LastUsedTimesByGenerativeModelIdDataLoader",
|
|
107
112
|
"LatencyMsQuantileDataLoader",
|
|
@@ -139,6 +144,7 @@ __all__ = [
|
|
|
139
144
|
"SpanProjectsDataLoader",
|
|
140
145
|
"TableFieldsDataLoader",
|
|
141
146
|
"TokenCountDataLoader",
|
|
147
|
+
"TokenPricesByModelDataLoader",
|
|
142
148
|
"TraceAnnotationsByTraceDataLoader",
|
|
143
149
|
"TraceByTraceIdsDataLoader",
|
|
144
150
|
"TraceRetentionPolicyIdByProjectIdDataLoader",
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import select, tuple_
|
|
4
|
-
from sqlalchemy.orm import joinedload
|
|
5
4
|
from strawberry.dataloader import DataLoader
|
|
6
5
|
from typing_extensions import TypeAlias
|
|
7
6
|
|
|
@@ -38,7 +37,6 @@ class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
|
|
|
38
37
|
).in_(set(keys))
|
|
39
38
|
)
|
|
40
39
|
.order_by(models.ExperimentRun.repetition_number)
|
|
41
|
-
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
42
40
|
)
|
|
43
41
|
|
|
44
42
|
async with self._db() as session:
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentId: TypeAlias = int
|
|
11
|
+
DatasetExampleId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
13
|
+
Result: TypeAlias = list[models.ExperimentRun]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ExperimentRunsByExperimentAndExampleDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
runs_by_key: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
stmt = (
|
|
26
|
+
select(models.ExperimentRun)
|
|
27
|
+
.where(
|
|
28
|
+
tuple_(
|
|
29
|
+
models.ExperimentRun.experiment_id,
|
|
30
|
+
models.ExperimentRun.dataset_example_id,
|
|
31
|
+
).in_(keys)
|
|
32
|
+
)
|
|
33
|
+
.order_by(
|
|
34
|
+
models.ExperimentRun.experiment_id,
|
|
35
|
+
models.ExperimentRun.dataset_example_id,
|
|
36
|
+
models.ExperimentRun.repetition_number,
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
result = await session.stream_scalars(stmt)
|
|
40
|
+
async for run in result:
|
|
41
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
42
|
+
runs_by_key[key].append(run)
|
|
43
|
+
|
|
44
|
+
return [runs_by_key[key] for key in keys]
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import select
|
|
4
|
-
from sqlalchemy.orm import joinedload, load_only
|
|
5
4
|
from strawberry.dataloader import DataLoader
|
|
6
5
|
from typing_extensions import TypeAlias
|
|
7
6
|
|
|
@@ -22,14 +21,9 @@ class SpanCostsDataLoader(DataLoader[Key, Result]):
|
|
|
22
21
|
span_ids = list(set(keys))
|
|
23
22
|
async with self._db() as session:
|
|
24
23
|
costs = {
|
|
25
|
-
|
|
26
|
-
async for
|
|
27
|
-
select(models.
|
|
28
|
-
.where(models.Span.id.in_(span_ids))
|
|
29
|
-
.options(
|
|
30
|
-
load_only(models.Span.id),
|
|
31
|
-
joinedload(models.Span.span_cost),
|
|
32
|
-
)
|
|
24
|
+
span_cost.span_rowid: span_cost
|
|
25
|
+
async for span_cost in await session.stream_scalars(
|
|
26
|
+
select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
|
|
33
27
|
)
|
|
34
28
|
}
|
|
35
29
|
return [costs.get(span_id) for span_id in keys]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ModelId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ModelId
|
|
12
|
+
Result: TypeAlias = list[models.TokenPrice]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
model_ids = keys
|
|
22
|
+
token_prices: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
async for token_price in await session.stream_scalars(
|
|
26
|
+
select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
|
|
27
|
+
):
|
|
28
|
+
token_prices[token_price.model_id].append(token_price)
|
|
29
|
+
|
|
30
|
+
return [token_prices[model_id] for model_id in keys]
|
|
@@ -699,7 +699,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
699
699
|
self.model_name = model.name
|
|
700
700
|
self.client = boto3.client(
|
|
701
701
|
service_name="bedrock-runtime",
|
|
702
|
-
region_name=
|
|
702
|
+
region_name=self.region,
|
|
703
703
|
aws_access_key_id=self.aws_access_key_id,
|
|
704
704
|
aws_secret_access_key=self.aws_secret_access_key,
|
|
705
705
|
aws_session_token=self.aws_session_token,
|
|
@@ -805,7 +805,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
805
805
|
|
|
806
806
|
# Build the request parameters for Converse API
|
|
807
807
|
converse_params: dict[str, Any] = {
|
|
808
|
-
"modelId":
|
|
808
|
+
"modelId": self.model_name,
|
|
809
809
|
"messages": converse_messages,
|
|
810
810
|
"inferenceConfig": {
|
|
811
811
|
"maxTokens": invocation_parameters["max_tokens"],
|
|
@@ -953,7 +953,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
|
|
|
953
953
|
}
|
|
954
954
|
|
|
955
955
|
response = self.client.invoke_model_with_response_stream(
|
|
956
|
-
modelId=
|
|
956
|
+
modelId=self.model_name,
|
|
957
957
|
contentType="application/json",
|
|
958
958
|
accept="application/json",
|
|
959
959
|
body=json.dumps(bedrock_params),
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import Optional, cast
|
|
2
|
+
from typing import Any, Optional, Union, cast
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from strawberry import UNSET
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
7
|
|
|
8
|
+
from phoenix.db import models
|
|
8
9
|
from phoenix.db.types.model_provider import ModelProvider
|
|
9
10
|
from phoenix.server.api.helpers.prompts.models import (
|
|
10
11
|
ContentPart,
|
|
@@ -12,11 +13,15 @@ from phoenix.server.api.helpers.prompts.models import (
|
|
|
12
13
|
PromptMessage,
|
|
13
14
|
PromptMessageRole,
|
|
14
15
|
PromptTemplateFormat,
|
|
16
|
+
PromptTemplateType,
|
|
15
17
|
RoleConversion,
|
|
16
18
|
TextContentPart,
|
|
17
19
|
ToolCallContentPart,
|
|
18
20
|
ToolCallFunction,
|
|
19
21
|
ToolResultContentPart,
|
|
22
|
+
normalize_response_format,
|
|
23
|
+
normalize_tools,
|
|
24
|
+
validate_invocation_parameters,
|
|
20
25
|
)
|
|
21
26
|
|
|
22
27
|
|
|
@@ -88,6 +93,47 @@ class ChatPromptVersionInput:
|
|
|
88
93
|
k: v for k, v in self.invocation_parameters.items() if v is not None
|
|
89
94
|
}
|
|
90
95
|
|
|
96
|
+
def to_orm_prompt_version(
|
|
97
|
+
self,
|
|
98
|
+
user_id: Optional[int],
|
|
99
|
+
) -> models.PromptVersion:
|
|
100
|
+
tool_definitions = [tool.definition for tool in self.tools]
|
|
101
|
+
tool_choice = cast(
|
|
102
|
+
Optional[Union[str, dict[str, Any]]],
|
|
103
|
+
cast(dict[str, Any], self.invocation_parameters).pop("tool_choice", None),
|
|
104
|
+
)
|
|
105
|
+
model_provider = ModelProvider(self.model_provider)
|
|
106
|
+
tools = (
|
|
107
|
+
normalize_tools(tool_definitions, model_provider, tool_choice)
|
|
108
|
+
if tool_definitions
|
|
109
|
+
else None
|
|
110
|
+
)
|
|
111
|
+
template = to_pydantic_prompt_chat_template_v1(self.template)
|
|
112
|
+
response_format = (
|
|
113
|
+
normalize_response_format(
|
|
114
|
+
self.response_format.definition,
|
|
115
|
+
model_provider,
|
|
116
|
+
)
|
|
117
|
+
if self.response_format
|
|
118
|
+
else None
|
|
119
|
+
)
|
|
120
|
+
invocation_parameters = validate_invocation_parameters(
|
|
121
|
+
self.invocation_parameters,
|
|
122
|
+
model_provider,
|
|
123
|
+
)
|
|
124
|
+
return models.PromptVersion(
|
|
125
|
+
description=self.description,
|
|
126
|
+
user_id=user_id,
|
|
127
|
+
template_type=PromptTemplateType.CHAT,
|
|
128
|
+
template_format=self.template_format,
|
|
129
|
+
template=template,
|
|
130
|
+
invocation_parameters=invocation_parameters,
|
|
131
|
+
tools=tools,
|
|
132
|
+
response_format=response_format,
|
|
133
|
+
model_provider=ModelProvider(self.model_provider),
|
|
134
|
+
model_name=self.model_name,
|
|
135
|
+
)
|
|
136
|
+
|
|
91
137
|
|
|
92
138
|
def to_pydantic_prompt_chat_template_v1(
|
|
93
139
|
prompt_chat_template_input: PromptChatTemplateInput,
|
|
@@ -374,7 +374,7 @@ class AnnotationConfigMutationMixin:
|
|
|
374
374
|
)
|
|
375
375
|
return AddAnnotationConfigToProjectPayload(
|
|
376
376
|
query=Query(),
|
|
377
|
-
project=Project(
|
|
377
|
+
project=Project(id=project_id),
|
|
378
378
|
)
|
|
379
379
|
|
|
380
380
|
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore[misc]
|
|
@@ -409,5 +409,5 @@ class AnnotationConfigMutationMixin:
|
|
|
409
409
|
raise NotFound("Could not find one or more input project annotation configs")
|
|
410
410
|
return RemoveAnnotationConfigFromProjectPayload(
|
|
411
411
|
query=Query(),
|
|
412
|
-
project=Project(
|
|
412
|
+
project=Project(id=project_id),
|
|
413
413
|
)
|
|
@@ -92,13 +92,7 @@ class ApiKeyMutationMixin:
|
|
|
92
92
|
token, token_id = await token_store.create_api_key(claims)
|
|
93
93
|
return CreateSystemApiKeyMutationPayload(
|
|
94
94
|
jwt=token,
|
|
95
|
-
api_key=SystemApiKey(
|
|
96
|
-
id_attr=int(token_id),
|
|
97
|
-
name=input.name,
|
|
98
|
-
description=input.description or None,
|
|
99
|
-
created_at=issued_at,
|
|
100
|
-
expires_at=input.expires_at or None,
|
|
101
|
-
),
|
|
95
|
+
api_key=SystemApiKey(id=int(token_id)),
|
|
102
96
|
query=Query(),
|
|
103
97
|
)
|
|
104
98
|
|
|
@@ -134,14 +128,7 @@ class ApiKeyMutationMixin:
|
|
|
134
128
|
token, token_id = await token_store.create_api_key(claims)
|
|
135
129
|
return CreateUserApiKeyMutationPayload(
|
|
136
130
|
jwt=token,
|
|
137
|
-
api_key=UserApiKey(
|
|
138
|
-
id_attr=int(token_id),
|
|
139
|
-
name=input.name,
|
|
140
|
-
description=input.description or None,
|
|
141
|
-
created_at=issued_at,
|
|
142
|
-
expires_at=input.expires_at or None,
|
|
143
|
-
user_id=int(user.identity),
|
|
144
|
-
),
|
|
131
|
+
api_key=UserApiKey(id=int(token_id)),
|
|
145
132
|
query=Query(),
|
|
146
133
|
)
|
|
147
134
|
|
|
@@ -19,6 +19,7 @@ from openinference.semconv.trace import (
|
|
|
19
19
|
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
20
20
|
from opentelemetry.trace import StatusCode
|
|
21
21
|
from sqlalchemy import insert, select
|
|
22
|
+
from strawberry import Private
|
|
22
23
|
from strawberry.relay import GlobalID
|
|
23
24
|
from strawberry.types import Info
|
|
24
25
|
from typing_extensions import assert_never
|
|
@@ -101,7 +102,7 @@ class ChatCompletionToolCall:
|
|
|
101
102
|
|
|
102
103
|
@strawberry.type
|
|
103
104
|
class ChatCompletionMutationPayload:
|
|
104
|
-
db_span:
|
|
105
|
+
db_span: Private[models.Span]
|
|
105
106
|
content: Optional[str]
|
|
106
107
|
tool_calls: List[ChatCompletionToolCall]
|
|
107
108
|
span: Span
|
|
@@ -502,7 +503,7 @@ class ChatCompletionMutationMixin:
|
|
|
502
503
|
session.add(span_cost)
|
|
503
504
|
await session.flush()
|
|
504
505
|
|
|
505
|
-
gql_span = Span(
|
|
506
|
+
gql_span = Span(id=span.id, db_record=span)
|
|
506
507
|
|
|
507
508
|
info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
|
|
508
509
|
|