arize-phoenix 12.0.0__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (39) hide show
  1. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +39 -35
  3. phoenix/db/insertion/document_annotation.py +1 -1
  4. phoenix/db/insertion/session_annotation.py +1 -1
  5. phoenix/db/insertion/span_annotation.py +1 -1
  6. phoenix/db/insertion/trace_annotation.py +1 -1
  7. phoenix/db/insertion/types.py +0 -4
  8. phoenix/db/models.py +22 -1
  9. phoenix/server/api/context.py +2 -0
  10. phoenix/server/api/dataloaders/__init__.py +2 -0
  11. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  12. phoenix/server/api/helpers/playground_clients.py +1 -0
  13. phoenix/server/api/mutations/__init__.py +2 -0
  14. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  15. phoenix/server/api/mutations/dataset_split_mutations.py +38 -2
  16. phoenix/server/api/queries.py +21 -0
  17. phoenix/server/api/routers/auth.py +5 -5
  18. phoenix/server/api/routers/oauth2.py +5 -23
  19. phoenix/server/api/types/Dataset.py +8 -0
  20. phoenix/server/api/types/DatasetExample.py +7 -0
  21. phoenix/server/api/types/DatasetLabel.py +23 -0
  22. phoenix/server/api/types/Prompt.py +18 -1
  23. phoenix/server/app.py +7 -12
  24. phoenix/server/static/.vite/manifest.json +39 -39
  25. phoenix/server/static/assets/{components-Dl9SUw1U.js → components-BG6v0EM8.js} +665 -389
  26. phoenix/server/static/assets/{index-CqQS0dTo.js → index-CSVcULw1.js} +13 -13
  27. phoenix/server/static/assets/{pages-DKSjVA_E.js → pages-DgaM7kpM.js} +1135 -1182
  28. phoenix/server/static/assets/{vendor-CtbHQYl8.js → vendor-BqTEkGQU.js} +183 -183
  29. phoenix/server/static/assets/{vendor-arizeai-D-lWOwIS.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
  30. phoenix/server/static/assets/{vendor-codemirror-BRBpy3_z.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
  31. phoenix/server/static/assets/{vendor-recharts--KdSwB3m.js → vendor-recharts-CKsi4IjN.js} +1 -1
  32. phoenix/server/static/assets/{vendor-shiki-CvRzZnIo.js → vendor-shiki-DN26BkKE.js} +1 -1
  33. phoenix/server/utils.py +74 -0
  34. phoenix/session/session.py +25 -5
  35. phoenix/version.py +1 -1
  36. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
  37. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
  38. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  39. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,291 @@
1
+ from typing import Optional
2
+
3
+ import sqlalchemy
4
+ import strawberry
5
+ from sqlalchemy import delete, select
6
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
7
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
8
+ from strawberry import UNSET
9
+ from strawberry.relay.types import GlobalID
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
14
+ from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
+ from phoenix.server.api.queries import Query
17
+ from phoenix.server.api.types.Dataset import Dataset
18
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
19
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
20
+
21
+
22
+ @strawberry.input
23
+ class CreateDatasetLabelInput:
24
+ name: str
25
+ description: Optional[str] = UNSET
26
+ color: str
27
+
28
+
29
+ @strawberry.type
30
+ class CreateDatasetLabelMutationPayload:
31
+ dataset_label: DatasetLabel
32
+
33
+
34
+ @strawberry.input
35
+ class DeleteDatasetLabelsInput:
36
+ dataset_label_ids: list[GlobalID]
37
+
38
+
39
+ @strawberry.type
40
+ class DeleteDatasetLabelsMutationPayload:
41
+ dataset_labels: list[DatasetLabel]
42
+
43
+
44
+ @strawberry.input
45
+ class UpdateDatasetLabelInput:
46
+ dataset_label_id: GlobalID
47
+ name: str
48
+ description: Optional[str] = None
49
+ color: str
50
+
51
+
52
+ @strawberry.type
53
+ class UpdateDatasetLabelMutationPayload:
54
+ dataset_label: DatasetLabel
55
+
56
+
57
+ @strawberry.input
58
+ class SetDatasetLabelsInput:
59
+ dataset_label_ids: list[GlobalID]
60
+ dataset_ids: list[GlobalID]
61
+
62
+
63
+ @strawberry.type
64
+ class SetDatasetLabelsMutationPayload:
65
+ query: "Query"
66
+
67
+
68
+ @strawberry.input
69
+ class UnsetDatasetLabelsInput:
70
+ dataset_label_ids: list[GlobalID]
71
+ dataset_ids: list[GlobalID]
72
+
73
+
74
+ @strawberry.type
75
+ class UnsetDatasetLabelsMutationPayload:
76
+ query: "Query"
77
+
78
+
79
+ @strawberry.type
80
+ class DatasetLabelMutationMixin:
81
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
82
+ async def create_dataset_label(
83
+ self,
84
+ info: Info[Context, None],
85
+ input: CreateDatasetLabelInput,
86
+ ) -> CreateDatasetLabelMutationPayload:
87
+ name = input.name
88
+ description = input.description
89
+ color = input.color
90
+ async with info.context.db() as session:
91
+ dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
92
+ session.add(dataset_label_orm)
93
+ try:
94
+ await session.commit()
95
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
96
+ raise Conflict(f"A dataset label named '{name}' already exists")
97
+ except sqlalchemy.exc.StatementError as error:
98
+ raise BadRequest(str(error.orig))
99
+ return CreateDatasetLabelMutationPayload(
100
+ dataset_label=to_gql_dataset_label(dataset_label_orm)
101
+ )
102
+
103
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
104
+ async def update_dataset_label(
105
+ self, info: Info[Context, None], input: UpdateDatasetLabelInput
106
+ ) -> UpdateDatasetLabelMutationPayload:
107
+ if not input.name or not input.name.strip():
108
+ raise BadRequest("Dataset label name cannot be empty")
109
+
110
+ try:
111
+ dataset_label_id = from_global_id_with_expected_type(
112
+ input.dataset_label_id, DatasetLabel.__name__
113
+ )
114
+ except ValueError:
115
+ raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
116
+
117
+ async with info.context.db() as session:
118
+ dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
119
+ if not dataset_label_orm:
120
+ raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
121
+
122
+ dataset_label_orm.name = input.name.strip()
123
+ dataset_label_orm.description = input.description
124
+ dataset_label_orm.color = input.color.strip()
125
+
126
+ try:
127
+ await session.commit()
128
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
129
+ raise Conflict(f"A dataset label named '{input.name}' already exists")
130
+ except sqlalchemy.exc.StatementError as error:
131
+ raise BadRequest(str(error.orig))
132
+ return UpdateDatasetLabelMutationPayload(
133
+ dataset_label=to_gql_dataset_label(dataset_label_orm)
134
+ )
135
+
136
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
137
+ async def delete_dataset_labels(
138
+ self, info: Info[Context, None], input: DeleteDatasetLabelsInput
139
+ ) -> DeleteDatasetLabelsMutationPayload:
140
+ dataset_label_row_ids: dict[int, None] = {}
141
+ for dataset_label_node_id in input.dataset_label_ids:
142
+ try:
143
+ dataset_label_row_id = from_global_id_with_expected_type(
144
+ dataset_label_node_id, DatasetLabel.__name__
145
+ )
146
+ except ValueError:
147
+ raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
148
+ dataset_label_row_ids[dataset_label_row_id] = None
149
+ async with info.context.db() as session:
150
+ stmt = (
151
+ delete(models.DatasetLabel)
152
+ .where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
153
+ .returning(models.DatasetLabel)
154
+ )
155
+ deleted_dataset_labels = (await session.scalars(stmt)).all()
156
+ if len(deleted_dataset_labels) < len(dataset_label_row_ids):
157
+ await session.rollback()
158
+ raise NotFound("Could not find one or more dataset labels with given IDs")
159
+ deleted_dataset_labels_by_id = {
160
+ dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
161
+ }
162
+ return DeleteDatasetLabelsMutationPayload(
163
+ dataset_labels=[
164
+ to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
165
+ for dataset_label_row_id in dataset_label_row_ids
166
+ ]
167
+ )
168
+
169
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
170
+ async def set_dataset_labels(
171
+ self, info: Info[Context, None], input: SetDatasetLabelsInput
172
+ ) -> SetDatasetLabelsMutationPayload:
173
+ if not input.dataset_ids:
174
+ raise BadRequest("No datasets provided.")
175
+ if not input.dataset_label_ids:
176
+ raise BadRequest("No dataset labels provided.")
177
+
178
+ unique_dataset_rowids: set[int] = set()
179
+ for dataset_gid in input.dataset_ids:
180
+ try:
181
+ dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
182
+ except ValueError:
183
+ raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
184
+ unique_dataset_rowids.add(dataset_rowid)
185
+ dataset_rowids = list(unique_dataset_rowids)
186
+
187
+ unique_dataset_label_rowids: set[int] = set()
188
+ for dataset_label_gid in input.dataset_label_ids:
189
+ try:
190
+ dataset_label_rowid = from_global_id_with_expected_type(
191
+ dataset_label_gid, DatasetLabel.__name__
192
+ )
193
+ except ValueError:
194
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
195
+ unique_dataset_label_rowids.add(dataset_label_rowid)
196
+ dataset_label_rowids = list(unique_dataset_label_rowids)
197
+
198
+ async with info.context.db() as session:
199
+ existing_dataset_ids = (
200
+ await session.scalars(
201
+ select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
202
+ )
203
+ ).all()
204
+ if len(existing_dataset_ids) != len(dataset_rowids):
205
+ raise NotFound("One or more datasets not found")
206
+
207
+ existing_dataset_label_ids = (
208
+ await session.scalars(
209
+ select(models.DatasetLabel.id).where(
210
+ models.DatasetLabel.id.in_(dataset_label_rowids)
211
+ )
212
+ )
213
+ ).all()
214
+ if len(existing_dataset_label_ids) != len(dataset_label_rowids):
215
+ raise NotFound("One or more dataset labels not found")
216
+
217
+ existing_dataset_label_keys = await session.execute(
218
+ select(
219
+ models.DatasetsDatasetLabel.dataset_id,
220
+ models.DatasetsDatasetLabel.dataset_label_id,
221
+ ).where(
222
+ models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
223
+ & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
224
+ )
225
+ )
226
+ unique_dataset_label_keys = set(existing_dataset_label_keys.all())
227
+
228
+ datasets_dataset_labels = []
229
+ for dataset_rowid in dataset_rowids:
230
+ for dataset_label_rowid in dataset_label_rowids:
231
+ if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
232
+ continue
233
+ datasets_dataset_labels.append(
234
+ models.DatasetsDatasetLabel(
235
+ dataset_id=dataset_rowid,
236
+ dataset_label_id=dataset_label_rowid,
237
+ )
238
+ )
239
+ session.add_all(datasets_dataset_labels)
240
+
241
+ if datasets_dataset_labels:
242
+ try:
243
+ await session.commit()
244
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
245
+ raise Conflict("Failed to add dataset labels to datasets.") from e
246
+
247
+ return SetDatasetLabelsMutationPayload(
248
+ query=Query(),
249
+ )
250
+
251
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
252
+ async def unset_dataset_labels(
253
+ self, info: Info[Context, None], input: UnsetDatasetLabelsInput
254
+ ) -> UnsetDatasetLabelsMutationPayload:
255
+ if not input.dataset_ids:
256
+ raise BadRequest("No datasets provided.")
257
+ if not input.dataset_label_ids:
258
+ raise BadRequest("No dataset labels provided.")
259
+
260
+ unique_dataset_rowids: set[int] = set()
261
+ for dataset_gid in input.dataset_ids:
262
+ try:
263
+ dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
264
+ except ValueError:
265
+ raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
266
+ unique_dataset_rowids.add(dataset_rowid)
267
+ dataset_rowids = list(unique_dataset_rowids)
268
+
269
+ unique_dataset_label_rowids: set[int] = set()
270
+ for dataset_label_gid in input.dataset_label_ids:
271
+ try:
272
+ dataset_label_rowid = from_global_id_with_expected_type(
273
+ dataset_label_gid, DatasetLabel.__name__
274
+ )
275
+ except ValueError:
276
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
277
+ unique_dataset_label_rowids.add(dataset_label_rowid)
278
+ dataset_label_rowids = list(unique_dataset_label_rowids)
279
+
280
+ async with info.context.db() as session:
281
+ await session.execute(
282
+ delete(models.DatasetsDatasetLabel).where(
283
+ models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
284
+ & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
285
+ )
286
+ )
287
+ await session.commit()
288
+
289
+ return UnsetDatasetLabelsMutationPayload(
290
+ query=Query(),
291
+ )
@@ -15,6 +15,7 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.helpers.playground_users import get_user
17
17
  from phoenix.server.api.queries import Query
18
+ from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
18
19
  from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
19
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
21
 
@@ -68,6 +69,13 @@ class DatasetSplitMutationPayload:
68
69
  query: "Query"
69
70
 
70
71
 
72
+ @strawberry.type
73
+ class DatasetSplitMutationPayloadWithExamples:
74
+ dataset_split: DatasetSplit
75
+ query: "Query"
76
+ examples: list[DatasetExample]
77
+
78
+
71
79
  @strawberry.type
72
80
  class DeleteDatasetSplitsMutationPayload:
73
81
  dataset_splits: list[DatasetSplit]
@@ -77,11 +85,13 @@ class DeleteDatasetSplitsMutationPayload:
77
85
  @strawberry.type
78
86
  class AddDatasetExamplesToDatasetSplitsMutationPayload:
79
87
  query: "Query"
88
+ examples: list[DatasetExample]
80
89
 
81
90
 
82
91
  @strawberry.type
83
92
  class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
84
93
  query: "Query"
94
+ examples: list[DatasetExample]
85
95
 
86
96
 
87
97
  @strawberry.type
@@ -262,8 +272,16 @@ class DatasetSplitMutationMixin:
262
272
  except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
263
273
  raise Conflict("Failed to add examples to dataset splits.") from e
264
274
 
275
+ examples = (
276
+ await session.scalars(
277
+ select(models.DatasetExample).where(
278
+ models.DatasetExample.id.in_(example_rowids)
279
+ )
280
+ )
281
+ ).all()
265
282
  return AddDatasetExamplesToDatasetSplitsMutationPayload(
266
283
  query=Query(),
284
+ examples=[to_gql_dataset_example(example) for example in examples],
267
285
  )
268
286
 
269
287
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
@@ -314,14 +332,23 @@ class DatasetSplitMutationMixin:
314
332
 
315
333
  await session.execute(stmt)
316
334
 
335
+ examples = (
336
+ await session.scalars(
337
+ select(models.DatasetExample).where(
338
+ models.DatasetExample.id.in_(example_rowids)
339
+ )
340
+ )
341
+ ).all()
342
+
317
343
  return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
318
344
  query=Query(),
345
+ examples=[to_gql_dataset_example(example) for example in examples],
319
346
  )
320
347
 
321
348
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
322
349
  async def create_dataset_split_with_examples(
323
350
  self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
324
- ) -> DatasetSplitMutationPayload:
351
+ ) -> DatasetSplitMutationPayloadWithExamples:
325
352
  user_id = get_user(info)
326
353
  validated_name = _validated_name(input.name)
327
354
  unique_example_rowids: set[int] = set()
@@ -374,9 +401,18 @@ class DatasetSplitMutationMixin:
374
401
  "Failed to associate examples with the new dataset split."
375
402
  ) from e
376
403
 
377
- return DatasetSplitMutationPayload(
404
+ examples = (
405
+ await session.scalars(
406
+ select(models.DatasetExample).where(
407
+ models.DatasetExample.id.in_(example_rowids)
408
+ )
409
+ )
410
+ ).all()
411
+
412
+ return DatasetSplitMutationPayloadWithExamples(
378
413
  dataset_split=to_gql_dataset_split(dataset_split_orm),
379
414
  query=Query(),
415
+ examples=[to_gql_dataset_example(example) for example in examples],
380
416
  )
381
417
 
382
418
 
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
48
48
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
49
49
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
50
50
  from phoenix.server.api.types.DatasetExample import DatasetExample
51
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
51
52
  from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
52
53
  from phoenix.server.api.types.Dimension import to_gql_dimension
53
54
  from phoenix.server.api.types.EmbeddingDimension import (
@@ -1149,6 +1150,26 @@ class Query:
1149
1150
  args=args,
1150
1151
  )
1151
1152
 
1153
+ @strawberry.field
1154
+ async def dataset_labels(
1155
+ self,
1156
+ info: Info[Context, None],
1157
+ first: Optional[int] = 50,
1158
+ last: Optional[int] = UNSET,
1159
+ after: Optional[CursorString] = UNSET,
1160
+ before: Optional[CursorString] = UNSET,
1161
+ ) -> Connection[DatasetLabel]:
1162
+ args = ConnectionArgs(
1163
+ first=first,
1164
+ after=after if isinstance(after, CursorString) else None,
1165
+ last=last,
1166
+ before=before if isinstance(before, CursorString) else None,
1167
+ )
1168
+ async with info.context.db() as session:
1169
+ dataset_labels = await session.scalars(select(models.DatasetLabel))
1170
+ data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
1171
+ return connection_from_list(data=data, args=args)
1172
+
1152
1173
  @strawberry.field
1153
1174
  async def dataset_splits(
1154
1175
  self,
@@ -2,7 +2,6 @@ import asyncio
2
2
  import secrets
3
3
  from datetime import datetime, timedelta, timezone
4
4
  from functools import partial
5
- from pathlib import Path
6
5
  from urllib.parse import urlencode, urlparse, urlunparse
7
6
 
8
7
  from fastapi import APIRouter, Depends, HTTPException, Request, Response
@@ -38,7 +37,6 @@ from phoenix.config import (
38
37
  get_base_url,
39
38
  get_env_disable_basic_auth,
40
39
  get_env_disable_rate_limit,
41
- get_env_host_root_path,
42
40
  )
43
41
  from phoenix.db import models
44
42
  from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
@@ -52,6 +50,7 @@ from phoenix.server.types import (
52
50
  TokenStore,
53
51
  UserId,
54
52
  )
53
+ from phoenix.server.utils import prepend_root_path
55
54
 
56
55
  rate_limiter = ServerRateLimiter(
57
56
  per_second_rate_limit=0.2,
@@ -145,7 +144,8 @@ async def logout(
145
144
  user_id = subject
146
145
  if user_id:
147
146
  await token_store.log_out(user_id)
148
- redirect_url = "/logout" if get_env_disable_basic_auth() else "/login"
147
+ redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
148
+ redirect_url = prepend_root_path(request.scope, redirect_path)
149
149
  response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
150
150
  response = delete_access_token_cookie(response)
151
151
  response = delete_refresh_token_cookie(response)
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
242
242
  )
243
243
  token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
244
244
  url = urlparse(request.headers.get("referer") or get_base_url())
245
- path = Path(get_env_host_root_path()) / "reset-password-with-token"
245
+ path = prepend_root_path(request.scope, "/reset-password-with-token")
246
246
  query_string = urlencode(dict(token=token))
247
- components = (url.scheme, url.netloc, path.as_posix(), "", query_string, "")
247
+ components = (url.scheme, url.netloc, path, "", query_string, "")
248
248
  reset_url = urlunparse(components)
249
249
  await sender.send_password_reset_email(email, reset_url)
250
250
  return Response(status_code=HTTP_204_NO_CONTENT)
@@ -46,6 +46,7 @@ from phoenix.server.rate_limiters import (
46
46
  fastapi_route_rate_limiter,
47
47
  )
48
48
  from phoenix.server.types import TokenStore
49
+ from phoenix.server.utils import get_root_path, prepend_root_path
49
50
 
50
51
  _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
51
52
 
@@ -191,7 +192,7 @@ async def create_tokens(
191
192
  access_token_expiry=access_token_expiry,
192
193
  refresh_token_expiry=refresh_token_expiry,
193
194
  )
194
- redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
195
+ redirect_path = prepend_root_path(request.scope, return_url or "/")
195
196
  response = RedirectResponse(
196
197
  url=redirect_path,
197
198
  status_code=HTTP_302_FOUND,
@@ -564,8 +565,8 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
564
565
  Creates a RedirectResponse to the login page to display an error message.
565
566
  """
566
567
  # TODO: this needs some cleanup
567
- login_path = _prepend_root_path_if_exists(
568
- request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
568
+ login_path = prepend_root_path(
569
+ request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
569
570
  )
570
571
  url = URL(login_path).include_query_params(error=error)
571
572
  response = RedirectResponse(url=url)
@@ -574,34 +575,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
574
575
  return response
575
576
 
576
577
 
577
- def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
578
- """
579
- If a root path is configured, prepends it to the input path.
580
- """
581
- if not path.startswith("/"):
582
- raise ValueError("path must start with a forward slash")
583
- root_path = _get_root_path(request=request)
584
- if root_path.endswith("/"):
585
- root_path = root_path.rstrip("/")
586
- return root_path + path
587
-
588
-
589
578
  def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
590
579
  """
591
580
  If a root path is configured, appends it to the input base url.
592
581
  """
593
- if not (root_path := _get_root_path(request=request)):
582
+ if not (root_path := get_root_path(request.scope)):
594
583
  return base_url
595
584
  return str(URLPath(root_path).make_absolute_url(base_url=base_url))
596
585
 
597
586
 
598
- def _get_root_path(*, request: Request) -> str:
599
- """
600
- Gets the root path from the request.
601
- """
602
- return str(request.scope.get("root_path", ""))
603
-
604
-
605
587
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
606
588
  """
607
589
  Gets the endpoint for create tokens route.
@@ -18,6 +18,7 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
18
18
  from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
19
  DatasetExperimentAnnotationSummary,
20
20
  )
21
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
21
22
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
22
23
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
23
24
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -303,6 +304,13 @@ class Dataset(Node):
303
304
  async for scores_tuple in await session.stream(query)
304
305
  ]
305
306
 
307
+ @strawberry.field
308
+ async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
309
+ return [
310
+ to_gql_dataset_label(label)
311
+ for label in await info.context.data_loaders.dataset_labels.load(self.id_attr)
312
+ ]
313
+
306
314
  @strawberry.field
307
315
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
308
316
  return info.context.last_updated_at.get(self._table, self.id_attr)
@@ -142,3 +142,10 @@ class DatasetExample(Node):
142
142
  to_gql_dataset_split(split)
143
143
  for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
144
144
  ]
145
+
146
+
147
+ def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
148
+ return DatasetExample(
149
+ id_attr=example.id,
150
+ created_at=example.created_at,
151
+ )
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import Node, NodeID
5
+
6
+ from phoenix.db import models
7
+
8
+
9
+ @strawberry.type
10
+ class DatasetLabel(Node):
11
+ id_attr: NodeID[int]
12
+ name: str
13
+ description: Optional[str]
14
+ color: str
15
+
16
+
17
+ def to_gql_dataset_label(dataset_label: models.DatasetLabel) -> DatasetLabel:
18
+ return DatasetLabel(
19
+ id_attr=dataset_label.id,
20
+ name=dataset_label.name,
21
+ description=dataset_label.description,
22
+ color=dataset_label.color,
23
+ )
@@ -9,6 +9,7 @@ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
9
  from strawberry.types import Info
10
10
 
11
11
  from phoenix.db import models
12
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
12
13
  from phoenix.server.api.context import Context
13
14
  from phoenix.server.api.exceptions import NotFound
14
15
  from phoenix.server.api.types.Identifier import Identifier
@@ -37,7 +38,10 @@ class Prompt(Node):
37
38
 
38
39
  @strawberry.field
39
40
  async def version(
40
- self, info: Info[Context, None], version_id: Optional[GlobalID] = None
41
+ self,
42
+ info: Info[Context, None],
43
+ version_id: Optional[GlobalID] = None,
44
+ tag_name: Optional[Identifier] = None,
41
45
  ) -> PromptVersion:
42
46
  async with info.context.db() as session:
43
47
  if version_id:
@@ -50,6 +54,19 @@ class Prompt(Node):
50
54
  )
51
55
  if not version:
52
56
  raise NotFound(f"Prompt version not found: {version_id}")
57
+ elif tag_name:
58
+ try:
59
+ name = IdentifierModel(tag_name)
60
+ except ValueError:
61
+ raise NotFound(f"Prompt version tag not found: {tag_name}")
62
+ version = await session.scalar(
63
+ select(models.PromptVersion)
64
+ .where(models.PromptVersion.prompt_id == self.id_attr)
65
+ .join_from(models.PromptVersion, models.PromptVersionTag)
66
+ .where(models.PromptVersionTag.name == name)
67
+ )
68
+ if not version:
69
+ raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
53
70
  else:
54
71
  stmt = (
55
72
  select(models.PromptVersion)