arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (76) hide show
  1. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +80 -213
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
  63. phoenix/server/static/.vite/manifest.json +43 -43
  64. phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
  65. phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
  66. phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
  67. phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
  68. phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
  69. phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
  70. phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
  71. phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
  72. phoenix/version.py +1 -1
  73. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  74. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  75. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  76. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
phoenix/db/engines.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
5
  from collections.abc import Callable
6
6
  from enum import Enum
7
7
  from sqlite3 import Connection
8
- from typing import Any
8
+ from typing import Any, Optional
9
9
 
10
10
  import aiosqlite
11
11
  import numpy as np
@@ -168,29 +168,142 @@ def aio_postgresql_engine(
168
168
  log_to_stdout: bool = False,
169
169
  log_migrations_to_stdout: bool = True,
170
170
  ) -> AsyncEngine:
171
- asyncpg_url, asyncpg_args = get_pg_config(url, "asyncpg")
172
- engine = create_async_engine(
173
- url=asyncpg_url,
174
- connect_args=asyncpg_args,
175
- echo=log_to_stdout,
176
- json_serializer=_dumps,
171
+ from phoenix.config import (
172
+ get_env_postgres_iam_token_lifetime,
173
+ get_env_postgres_use_iam_auth,
177
174
  )
175
+
176
+ use_iam_auth = get_env_postgres_use_iam_auth()
177
+
178
+ asyncpg_url, asyncpg_args = get_pg_config(url, "asyncpg", enforce_ssl=use_iam_auth)
179
+
180
+ iam_config: Optional[dict[str, Any]] = None
181
+ token_lifetime: int = 0
182
+ if use_iam_auth:
183
+ iam_config = _extract_iam_config_from_url(url)
184
+ token_lifetime = get_env_postgres_iam_token_lifetime()
185
+
186
+ async def iam_async_creator() -> Any:
187
+ import asyncpg # type: ignore
188
+
189
+ from phoenix.db.iam_auth import generate_aws_rds_token
190
+
191
+ assert iam_config is not None
192
+ token = generate_aws_rds_token(
193
+ host=iam_config["host"],
194
+ port=iam_config["port"],
195
+ user=iam_config["user"],
196
+ )
197
+
198
+ conn_kwargs = {
199
+ "host": iam_config["host"],
200
+ "port": iam_config["port"],
201
+ "user": iam_config["user"],
202
+ "password": token,
203
+ "database": iam_config["database"],
204
+ }
205
+
206
+ if asyncpg_args:
207
+ conn_kwargs.update(asyncpg_args)
208
+
209
+ return await asyncpg.connect(**conn_kwargs)
210
+
211
+ engine = create_async_engine(
212
+ url=asyncpg_url,
213
+ async_creator=iam_async_creator,
214
+ echo=log_to_stdout,
215
+ json_serializer=_dumps,
216
+ pool_recycle=token_lifetime,
217
+ )
218
+ else:
219
+ engine = create_async_engine(
220
+ url=asyncpg_url,
221
+ connect_args=asyncpg_args,
222
+ echo=log_to_stdout,
223
+ json_serializer=_dumps,
224
+ )
225
+
178
226
  if not migrate:
179
227
  return engine
180
228
 
181
- psycopg_url, psycopg_args = get_pg_config(url, "psycopg")
182
- sync_engine = sqlalchemy.create_engine(
183
- url=psycopg_url,
184
- connect_args=psycopg_args,
185
- echo=log_migrations_to_stdout,
186
- json_serializer=_dumps,
187
- )
229
+ psycopg_url, psycopg_args = get_pg_config(url, "psycopg", enforce_ssl=use_iam_auth)
230
+
231
+ if use_iam_auth:
232
+ assert iam_config is not None
233
+
234
+ def iam_sync_creator() -> Any:
235
+ import psycopg
236
+
237
+ from phoenix.db.iam_auth import generate_aws_rds_token
238
+
239
+ token = generate_aws_rds_token(
240
+ host=iam_config["host"],
241
+ port=iam_config["port"],
242
+ user=iam_config["user"],
243
+ )
244
+
245
+ conn_kwargs = {
246
+ "host": iam_config["host"],
247
+ "port": iam_config["port"],
248
+ "user": iam_config["user"],
249
+ "password": token,
250
+ "dbname": iam_config["database"],
251
+ }
252
+
253
+ if psycopg_args:
254
+ conn_kwargs.update(psycopg_args)
255
+
256
+ return psycopg.connect(**conn_kwargs)
257
+
258
+ sync_engine = sqlalchemy.create_engine(
259
+ url=psycopg_url,
260
+ creator=iam_sync_creator,
261
+ echo=log_migrations_to_stdout,
262
+ json_serializer=_dumps,
263
+ pool_recycle=token_lifetime,
264
+ )
265
+ else:
266
+ sync_engine = sqlalchemy.create_engine(
267
+ url=psycopg_url,
268
+ connect_args=psycopg_args,
269
+ echo=log_migrations_to_stdout,
270
+ json_serializer=_dumps,
271
+ )
272
+
188
273
  if schema := get_env_database_schema():
189
274
  event.listen(sync_engine, "connect", set_postgresql_search_path(schema))
190
275
  migrate_in_thread(sync_engine)
191
276
  return engine
192
277
 
193
278
 
279
+ def _extract_iam_config_from_url(url: URL) -> dict[str, Any]:
280
+ """Extract connection parameters needed for IAM authentication from a SQLAlchemy URL.
281
+
282
+ Args:
283
+ url: SQLAlchemy database URL
284
+
285
+ Returns:
286
+ Dictionary with host, port, user, and database
287
+ """
288
+ host = url.host
289
+ if not host:
290
+ raise ValueError("Database host is required for IAM authentication")
291
+
292
+ port = url.port or 5432
293
+ user = url.username
294
+ if not user:
295
+ raise ValueError("Database user is required for IAM authentication")
296
+
297
+ database = url.database or "postgres"
298
+
299
+ return {
300
+ "host": host,
301
+ "port": port,
302
+ "user": user,
303
+ "database": database,
304
+ }
305
+
306
+
194
307
  def _dumps(obj: Any) -> str:
195
308
  return orjson.dumps(obj, default=_default).decode()
196
309
 
phoenix/db/iam_auth.py ADDED
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def generate_aws_rds_token(
9
+ host: str,
10
+ port: int,
11
+ user: str,
12
+ ) -> str:
13
+ """Generate an AWS RDS IAM authentication token.
14
+
15
+ This function creates a short-lived (15 minutes) authentication token for connecting
16
+ to AWS RDS/Aurora PostgreSQL instances using IAM database authentication.
17
+
18
+ The AWS region is automatically resolved using boto3.
19
+
20
+ Args:
21
+ host: The database hostname (e.g., 'mydb.abc123.us-west-2.rds.amazonaws.com')
22
+ port: The database port (typically 5432 for PostgreSQL)
23
+ user: The database username (must match an IAM-enabled database user)
24
+
25
+ Returns:
26
+ A temporary authentication token string to use as the database password
27
+
28
+ Raises:
29
+ ImportError: If boto3 is not installed
30
+ Exception: If AWS credentials/region are not configured or token generation fails
31
+
32
+ Example:
33
+ >>> token = generate_aws_rds_token(
34
+ ... host='mydb.us-west-2.rds.amazonaws.com',
35
+ ... port=5432,
36
+ ... user='myuser'
37
+ ... )
38
+ """
39
+ try:
40
+ import boto3 # type: ignore
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "boto3 is required for AWS RDS IAM authentication. "
44
+ "Install it with: pip install 'arize-phoenix[aws]'"
45
+ ) from e
46
+
47
+ try:
48
+ client = boto3.client("rds")
49
+
50
+ logger.debug(f"Generating AWS RDS IAM auth token for user '{user}' at {host}:{port}")
51
+ token = client.generate_db_auth_token( # pyright: ignore
52
+ DBHostname=host,
53
+ Port=port,
54
+ DBUsername=user,
55
+ )
56
+
57
+ return str(token) # pyright: ignore
58
+
59
+ except Exception as e:
60
+ logger.error(
61
+ f"Failed to generate AWS RDS IAM authentication token: {e}. "
62
+ "Ensure AWS credentials are configured and have 'rds-db:connect' permission."
63
+ )
64
+ raise
phoenix/db/pg_config.py CHANGED
@@ -10,12 +10,14 @@ from typing_extensions import assert_never
10
10
  def get_pg_config(
11
11
  url: URL,
12
12
  driver: Literal["psycopg", "asyncpg"],
13
+ enforce_ssl: bool = False,
13
14
  ) -> tuple[URL, dict[str, Any]]:
14
15
  """Convert SQLAlchemy URL to driver-specific configuration.
15
16
 
16
17
  Args:
17
18
  url: SQLAlchemy URL
18
19
  driver: "psycopg" or "asyncpg"
20
+ enforce_ssl: If True, ensure SSL is enabled (required for AWS RDS IAM auth)
19
21
 
20
22
  Returns:
21
23
  Tuple of (base_url, connect_args):
@@ -26,6 +28,14 @@ def get_pg_config(
26
28
  query = url.query
27
29
  ssl_args = _get_ssl_args(query)
28
30
 
31
+ if enforce_ssl and not ssl_args:
32
+ ssl_args = {"sslmode": "require"}
33
+ elif enforce_ssl and ssl_args.get("sslmode") == "disable":
34
+ raise ValueError(
35
+ "SSL cannot be disabled when using AWS RDS IAM authentication. "
36
+ "Remove 'sslmode=disable' from the connection string."
37
+ )
38
+
29
39
  # Create base URL without SSL parameters
30
40
  base_url = url.set(
31
41
  drivername=f"postgresql+{driver}",
@@ -35,6 +35,7 @@ from phoenix.server.api.dataloaders import (
35
35
  ExperimentRepeatedRunGroupsDataLoader,
36
36
  ExperimentRunAnnotations,
37
37
  ExperimentRunCountsDataLoader,
38
+ ExperimentRunsByExperimentAndExampleDataLoader,
38
39
  ExperimentSequenceNumberDataLoader,
39
40
  LastUsedTimesByGenerativeModelIdDataLoader,
40
41
  LatencyMsQuantileDataLoader,
@@ -71,6 +72,7 @@ from phoenix.server.api.dataloaders import (
71
72
  SpanProjectsDataLoader,
72
73
  TableFieldsDataLoader,
73
74
  TokenCountDataLoader,
75
+ TokenPricesByModelDataLoader,
74
76
  TraceAnnotationsByTraceDataLoader,
75
77
  TraceByTraceIdsDataLoader,
76
78
  TraceRetentionPolicyIdByProjectIdDataLoader,
@@ -100,27 +102,38 @@ class DataLoaders:
100
102
  AverageExperimentRepeatedRunGroupLatencyDataLoader
101
103
  )
102
104
  average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
105
+ dataset_example_fields: TableFieldsDataLoader
103
106
  dataset_example_revisions: DatasetExampleRevisionsDataLoader
104
107
  dataset_example_spans: DatasetExampleSpansDataLoader
105
108
  dataset_labels: DatasetLabelsDataLoader
109
+ dataset_label_fields: TableFieldsDataLoader
106
110
  dataset_dataset_splits: DatasetDatasetSplitsDataLoader
107
111
  dataset_examples_and_versions_by_experiment_run: (
108
112
  DatasetExamplesAndVersionsByExperimentRunDataLoader
109
113
  )
110
114
  dataset_example_splits: DatasetExampleSplitsDataLoader
115
+ dataset_fields: TableFieldsDataLoader
116
+ dataset_split_fields: TableFieldsDataLoader
117
+ dataset_version_fields: TableFieldsDataLoader
118
+ document_annotation_fields: TableFieldsDataLoader
111
119
  document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
112
120
  document_evaluations: DocumentEvaluationsDataLoader
113
121
  document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
114
122
  experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
115
123
  experiment_dataset_splits: ExperimentDatasetSplitsDataLoader
116
124
  experiment_error_rates: ExperimentErrorRatesDataLoader
125
+ experiment_fields: TableFieldsDataLoader
117
126
  experiment_repeated_run_group_annotation_summaries: (
118
127
  ExperimentRepeatedRunGroupAnnotationSummariesDataLoader
119
128
  )
120
129
  experiment_repeated_run_groups: ExperimentRepeatedRunGroupsDataLoader
130
+ experiment_run_annotation_fields: TableFieldsDataLoader
121
131
  experiment_run_annotations: ExperimentRunAnnotations
122
132
  experiment_run_counts: ExperimentRunCountsDataLoader
133
+ experiment_run_fields: TableFieldsDataLoader
134
+ experiment_runs_by_experiment_and_example: ExperimentRunsByExperimentAndExampleDataLoader
123
135
  experiment_sequence_number: ExperimentSequenceNumberDataLoader
136
+ generative_model_fields: TableFieldsDataLoader
124
137
  last_used_times_by_generative_model_id: LastUsedTimesByGenerativeModelIdDataLoader
125
138
  latency_ms_quantile: LatencyMsQuantileDataLoader
126
139
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
@@ -130,7 +143,12 @@ class DataLoaders:
130
143
  project_fields: TableFieldsDataLoader
131
144
  project_trace_retention_policy_fields: TableFieldsDataLoader
132
145
  projects_by_trace_retention_policy_id: ProjectIdsByTraceRetentionPolicyIdDataLoader
146
+ prompt_fields: TableFieldsDataLoader
147
+ prompt_label_fields: TableFieldsDataLoader
133
148
  prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
149
+ prompt_version_tag_fields: TableFieldsDataLoader
150
+ project_session_annotation_fields: TableFieldsDataLoader
151
+ project_session_fields: TableFieldsDataLoader
134
152
  record_counts: RecordCountDataLoader
135
153
  session_annotations_by_session: SessionAnnotationsBySessionDataLoader
136
154
  session_first_inputs: SessionIODataLoader
@@ -139,6 +157,7 @@ class DataLoaders:
139
157
  session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
140
158
  session_token_usages: SessionTokenUsagesDataLoader
141
159
  session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
160
+ span_annotation_fields: TableFieldsDataLoader
142
161
  span_annotations: SpanAnnotationsDataLoader
143
162
  span_by_id: SpanByIdDataLoader
144
163
  span_cost_by_span: SpanCostBySpanDataLoader
@@ -167,12 +186,16 @@ class DataLoaders:
167
186
  span_fields: TableFieldsDataLoader
168
187
  span_projects: SpanProjectsDataLoader
169
188
  token_counts: TokenCountDataLoader
189
+ token_prices_by_model: TokenPricesByModelDataLoader
190
+ trace_annotation_fields: TableFieldsDataLoader
170
191
  trace_annotations_by_trace: TraceAnnotationsByTraceDataLoader
171
192
  trace_by_trace_ids: TraceByTraceIdsDataLoader
172
193
  trace_fields: TableFieldsDataLoader
173
194
  trace_retention_policy_id_by_project_id: TraceRetentionPolicyIdByProjectIdDataLoader
174
195
  trace_root_spans: TraceRootSpansDataLoader
175
196
  user_roles: UserRolesDataLoader
197
+ user_api_key_fields: TableFieldsDataLoader
198
+ user_fields: TableFieldsDataLoader
176
199
  users: UsersDataLoader
177
200
 
178
201
 
@@ -33,6 +33,9 @@ from .experiment_repeated_run_group_annotation_summaries import (
33
33
  from .experiment_repeated_run_groups import ExperimentRepeatedRunGroupsDataLoader
34
34
  from .experiment_run_annotations import ExperimentRunAnnotations
35
35
  from .experiment_run_counts import ExperimentRunCountsDataLoader
36
+ from .experiment_runs_by_experiment_and_example import (
37
+ ExperimentRunsByExperimentAndExampleDataLoader,
38
+ )
36
39
  from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
37
40
  from .last_used_times_by_generative_model_id import LastUsedTimesByGenerativeModelIdDataLoader
38
41
  from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
@@ -73,6 +76,7 @@ from .span_descendants import SpanDescendantsDataLoader
73
76
  from .span_projects import SpanProjectsDataLoader
74
77
  from .table_fields import TableFieldsDataLoader
75
78
  from .token_counts import TokenCountCache, TokenCountDataLoader
79
+ from .token_prices_by_model import TokenPricesByModelDataLoader
76
80
  from .trace_annotations_by_trace import TraceAnnotationsByTraceDataLoader
77
81
  from .trace_by_trace_ids import TraceByTraceIdsDataLoader
78
82
  from .trace_retention_policy_id_by_project_id import TraceRetentionPolicyIdByProjectIdDataLoader
@@ -102,6 +106,7 @@ __all__ = [
102
106
  "ExperimentRepeatedRunGroupAnnotationSummariesDataLoader",
103
107
  "ExperimentRunAnnotations",
104
108
  "ExperimentRunCountsDataLoader",
109
+ "ExperimentRunsByExperimentAndExampleDataLoader",
105
110
  "ExperimentSequenceNumberDataLoader",
106
111
  "LastUsedTimesByGenerativeModelIdDataLoader",
107
112
  "LatencyMsQuantileDataLoader",
@@ -139,6 +144,7 @@ __all__ = [
139
144
  "SpanProjectsDataLoader",
140
145
  "TableFieldsDataLoader",
141
146
  "TokenCountDataLoader",
147
+ "TokenPricesByModelDataLoader",
142
148
  "TraceAnnotationsByTraceDataLoader",
143
149
  "TraceByTraceIdsDataLoader",
144
150
  "TraceRetentionPolicyIdByProjectIdDataLoader",
@@ -1,7 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
 
3
3
  from sqlalchemy import select, tuple_
4
- from sqlalchemy.orm import joinedload
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
@@ -38,7 +37,6 @@ class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
38
37
  ).in_(set(keys))
39
38
  )
40
39
  .order_by(models.ExperimentRun.repetition_number)
41
- .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
42
40
  )
43
41
 
44
42
  async with self._db() as session:
@@ -0,0 +1,44 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentId: TypeAlias = int
11
+ DatasetExampleId: TypeAlias = int
12
+ Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
13
+ Result: TypeAlias = list[models.ExperimentRun]
14
+
15
+
16
+ class ExperimentRunsByExperimentAndExampleDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ runs_by_key: defaultdict[Key, Result] = defaultdict(list)
23
+
24
+ async with self._db() as session:
25
+ stmt = (
26
+ select(models.ExperimentRun)
27
+ .where(
28
+ tuple_(
29
+ models.ExperimentRun.experiment_id,
30
+ models.ExperimentRun.dataset_example_id,
31
+ ).in_(keys)
32
+ )
33
+ .order_by(
34
+ models.ExperimentRun.experiment_id,
35
+ models.ExperimentRun.dataset_example_id,
36
+ models.ExperimentRun.repetition_number,
37
+ )
38
+ )
39
+ result = await session.stream_scalars(stmt)
40
+ async for run in result:
41
+ key = (run.experiment_id, run.dataset_example_id)
42
+ runs_by_key[key].append(run)
43
+
44
+ return [runs_by_key[key] for key in keys]
@@ -1,7 +1,6 @@
1
1
  from typing import Optional
2
2
 
3
3
  from sqlalchemy import select
4
- from sqlalchemy.orm import joinedload, load_only
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
@@ -22,14 +21,9 @@ class SpanCostsDataLoader(DataLoader[Key, Result]):
22
21
  span_ids = list(set(keys))
23
22
  async with self._db() as session:
24
23
  costs = {
25
- span.id: span.span_cost
26
- async for span in await session.stream_scalars(
27
- select(models.Span)
28
- .where(models.Span.id.in_(span_ids))
29
- .options(
30
- load_only(models.Span.id),
31
- joinedload(models.Span.span_cost),
32
- )
24
+ span_cost.span_rowid: span_cost
25
+ async for span_cost in await session.stream_scalars(
26
+ select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
33
27
  )
34
28
  }
35
29
  return [costs.get(span_id) for span_id in keys]
@@ -0,0 +1,30 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ModelId: TypeAlias = int
11
+ Key: TypeAlias = ModelId
12
+ Result: TypeAlias = list[models.TokenPrice]
13
+
14
+
15
+ class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ model_ids = keys
22
+ token_prices: defaultdict[Key, Result] = defaultdict(list)
23
+
24
+ async with self._db() as session:
25
+ async for token_price in await session.stream_scalars(
26
+ select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
27
+ ):
28
+ token_prices[token_price.model_id].append(token_price)
29
+
30
+ return [token_prices[model_id] for model_id in keys]
@@ -699,7 +699,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
699
699
  self.model_name = model.name
700
700
  self.client = boto3.client(
701
701
  service_name="bedrock-runtime",
702
- region_name="us-east-1", # match the default region in the UI
702
+ region_name=self.region,
703
703
  aws_access_key_id=self.aws_access_key_id,
704
704
  aws_secret_access_key=self.aws_secret_access_key,
705
705
  aws_session_token=self.aws_session_token,
@@ -805,7 +805,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
805
805
 
806
806
  # Build the request parameters for Converse API
807
807
  converse_params: dict[str, Any] = {
808
- "modelId": f"us.{self.model_name}",
808
+ "modelId": self.model_name,
809
809
  "messages": converse_messages,
810
810
  "inferenceConfig": {
811
811
  "maxTokens": invocation_parameters["max_tokens"],
@@ -953,7 +953,7 @@ class BedrockStreamingClient(PlaygroundStreamingClient):
953
953
  }
954
954
 
955
955
  response = self.client.invoke_model_with_response_stream(
956
- modelId=f"us.{self.model_name}", # or another Claude model
956
+ modelId=self.model_name,
957
957
  contentType="application/json",
958
958
  accept="application/json",
959
959
  body=json.dumps(bedrock_params),
@@ -1,10 +1,11 @@
1
1
  import json
2
- from typing import Optional, cast
2
+ from typing import Any, Optional, Union, cast
3
3
 
4
4
  import strawberry
5
5
  from strawberry import UNSET
6
6
  from strawberry.scalars import JSON
7
7
 
8
+ from phoenix.db import models
8
9
  from phoenix.db.types.model_provider import ModelProvider
9
10
  from phoenix.server.api.helpers.prompts.models import (
10
11
  ContentPart,
@@ -12,11 +13,15 @@ from phoenix.server.api.helpers.prompts.models import (
12
13
  PromptMessage,
13
14
  PromptMessageRole,
14
15
  PromptTemplateFormat,
16
+ PromptTemplateType,
15
17
  RoleConversion,
16
18
  TextContentPart,
17
19
  ToolCallContentPart,
18
20
  ToolCallFunction,
19
21
  ToolResultContentPart,
22
+ normalize_response_format,
23
+ normalize_tools,
24
+ validate_invocation_parameters,
20
25
  )
21
26
 
22
27
 
@@ -88,6 +93,47 @@ class ChatPromptVersionInput:
88
93
  k: v for k, v in self.invocation_parameters.items() if v is not None
89
94
  }
90
95
 
96
+ def to_orm_prompt_version(
97
+ self,
98
+ user_id: Optional[int],
99
+ ) -> models.PromptVersion:
100
+ tool_definitions = [tool.definition for tool in self.tools]
101
+ tool_choice = cast(
102
+ Optional[Union[str, dict[str, Any]]],
103
+ cast(dict[str, Any], self.invocation_parameters).pop("tool_choice", None),
104
+ )
105
+ model_provider = ModelProvider(self.model_provider)
106
+ tools = (
107
+ normalize_tools(tool_definitions, model_provider, tool_choice)
108
+ if tool_definitions
109
+ else None
110
+ )
111
+ template = to_pydantic_prompt_chat_template_v1(self.template)
112
+ response_format = (
113
+ normalize_response_format(
114
+ self.response_format.definition,
115
+ model_provider,
116
+ )
117
+ if self.response_format
118
+ else None
119
+ )
120
+ invocation_parameters = validate_invocation_parameters(
121
+ self.invocation_parameters,
122
+ model_provider,
123
+ )
124
+ return models.PromptVersion(
125
+ description=self.description,
126
+ user_id=user_id,
127
+ template_type=PromptTemplateType.CHAT,
128
+ template_format=self.template_format,
129
+ template=template,
130
+ invocation_parameters=invocation_parameters,
131
+ tools=tools,
132
+ response_format=response_format,
133
+ model_provider=ModelProvider(self.model_provider),
134
+ model_name=self.model_name,
135
+ )
136
+
91
137
 
92
138
  def to_pydantic_prompt_chat_template_v1(
93
139
  prompt_chat_template_input: PromptChatTemplateInput,
@@ -374,7 +374,7 @@ class AnnotationConfigMutationMixin:
374
374
  )
375
375
  return AddAnnotationConfigToProjectPayload(
376
376
  query=Query(),
377
- project=Project(project_rowid=project_id),
377
+ project=Project(id=project_id),
378
378
  )
379
379
 
380
380
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore[misc]
@@ -409,5 +409,5 @@ class AnnotationConfigMutationMixin:
409
409
  raise NotFound("Could not find one or more input project annotation configs")
410
410
  return RemoveAnnotationConfigFromProjectPayload(
411
411
  query=Query(),
412
- project=Project(project_rowid=project_id),
412
+ project=Project(id=project_id),
413
413
  )
@@ -92,13 +92,7 @@ class ApiKeyMutationMixin:
92
92
  token, token_id = await token_store.create_api_key(claims)
93
93
  return CreateSystemApiKeyMutationPayload(
94
94
  jwt=token,
95
- api_key=SystemApiKey(
96
- id_attr=int(token_id),
97
- name=input.name,
98
- description=input.description or None,
99
- created_at=issued_at,
100
- expires_at=input.expires_at or None,
101
- ),
95
+ api_key=SystemApiKey(id=int(token_id)),
102
96
  query=Query(),
103
97
  )
104
98
 
@@ -134,14 +128,7 @@ class ApiKeyMutationMixin:
134
128
  token, token_id = await token_store.create_api_key(claims)
135
129
  return CreateUserApiKeyMutationPayload(
136
130
  jwt=token,
137
- api_key=UserApiKey(
138
- id_attr=int(token_id),
139
- name=input.name,
140
- description=input.description or None,
141
- created_at=issued_at,
142
- expires_at=input.expires_at or None,
143
- user_id=int(user.identity),
144
- ),
131
+ api_key=UserApiKey(id=int(token_id)),
145
132
  query=Query(),
146
133
  )
147
134
 
@@ -19,6 +19,7 @@ from openinference.semconv.trace import (
19
19
  from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
20
20
  from opentelemetry.trace import StatusCode
21
21
  from sqlalchemy import insert, select
22
+ from strawberry import Private
22
23
  from strawberry.relay import GlobalID
23
24
  from strawberry.types import Info
24
25
  from typing_extensions import assert_never
@@ -101,7 +102,7 @@ class ChatCompletionToolCall:
101
102
 
102
103
  @strawberry.type
103
104
  class ChatCompletionMutationPayload:
104
- db_span: strawberry.Private[models.Span]
105
+ db_span: Private[models.Span]
105
106
  content: Optional[str]
106
107
  tool_calls: List[ChatCompletionToolCall]
107
108
  span: Span
@@ -502,7 +503,7 @@ class ChatCompletionMutationMixin:
502
503
  session.add(span_cost)
503
504
  await session.flush()
504
505
 
505
- gql_span = Span(span_rowid=span.id, db_span=span)
506
+ gql_span = Span(id=span.id, db_record=span)
506
507
 
507
508
  info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
508
509