arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (85) hide show
  1. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
  2. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
  3. phoenix/db/bulk_inserter.py +131 -5
  4. phoenix/db/engines.py +2 -1
  5. phoenix/db/helpers.py +23 -1
  6. phoenix/db/insertion/constants.py +2 -0
  7. phoenix/db/insertion/document_annotation.py +157 -0
  8. phoenix/db/insertion/helpers.py +13 -0
  9. phoenix/db/insertion/span_annotation.py +144 -0
  10. phoenix/db/insertion/trace_annotation.py +144 -0
  11. phoenix/db/insertion/types.py +261 -0
  12. phoenix/experiments/functions.py +3 -2
  13. phoenix/experiments/types.py +3 -3
  14. phoenix/server/api/context.py +7 -9
  15. phoenix/server/api/dataloaders/__init__.py +2 -0
  16. phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
  17. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  18. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  19. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  21. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  22. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  23. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  24. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
  25. phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
  26. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  27. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  28. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  29. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  30. phoenix/server/api/dataloaders/record_counts.py +2 -4
  31. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  32. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  34. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  35. phoenix/server/api/dataloaders/span_projects.py +3 -3
  36. phoenix/server/api/dataloaders/token_counts.py +2 -4
  37. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  38. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  39. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  40. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  41. phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
  42. phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
  43. phoenix/server/api/openapi/main.py +18 -2
  44. phoenix/server/api/openapi/schema.py +12 -12
  45. phoenix/server/api/routers/v1/__init__.py +36 -83
  46. phoenix/server/api/routers/v1/datasets.py +515 -509
  47. phoenix/server/api/routers/v1/evaluations.py +164 -73
  48. phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
  49. phoenix/server/api/routers/v1/experiment_runs.py +98 -155
  50. phoenix/server/api/routers/v1/experiments.py +132 -181
  51. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  52. phoenix/server/api/routers/v1/spans.py +164 -203
  53. phoenix/server/api/routers/v1/traces.py +134 -159
  54. phoenix/server/api/routers/v1/utils.py +95 -0
  55. phoenix/server/api/types/Span.py +27 -3
  56. phoenix/server/api/types/Trace.py +21 -4
  57. phoenix/server/api/utils.py +4 -4
  58. phoenix/server/app.py +172 -192
  59. phoenix/server/grpc_server.py +2 -2
  60. phoenix/server/main.py +5 -9
  61. phoenix/server/static/.vite/manifest.json +31 -31
  62. phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
  63. phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
  64. phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
  65. phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
  66. phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
  67. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  68. phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
  69. phoenix/server/thread_server.py +2 -2
  70. phoenix/server/types.py +18 -0
  71. phoenix/session/client.py +5 -3
  72. phoenix/session/session.py +2 -2
  73. phoenix/trace/dsl/filter.py +2 -6
  74. phoenix/trace/fixtures.py +17 -23
  75. phoenix/trace/utils.py +23 -0
  76. phoenix/utilities/client.py +116 -0
  77. phoenix/utilities/project.py +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  80. phoenix/server/openapi/docs.py +0 -221
  81. phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
  82. phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
  83. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
  84. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
  85. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,261 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from copy import copy
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime, timezone
9
+ from typing import (
10
+ Any,
11
+ Generic,
12
+ List,
13
+ Mapping,
14
+ Optional,
15
+ Protocol,
16
+ Sequence,
17
+ Tuple,
18
+ Type,
19
+ TypeVar,
20
+ cast,
21
+ )
22
+
23
+ from sqlalchemy.ext.asyncio import AsyncSession
24
+ from sqlalchemy.sql.dml import ReturningInsert
25
+
26
+ from phoenix.db import models
27
+ from phoenix.db.insertion.constants import DEFAULT_RETRY_ALLOWANCE, DEFAULT_RETRY_DELAY_SEC
28
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
29
+ from phoenix.server.types import DbSessionFactory
30
+
31
+ logger = logging.getLogger("__name__")
32
+
33
+
34
+ class Insertable(Protocol):
35
+ @property
36
+ def row(self) -> models.Base: ...
37
+
38
+
39
+ _AnyT = TypeVar("_AnyT")
40
+ _PrecursorT = TypeVar("_PrecursorT")
41
+ _InsertableT = TypeVar("_InsertableT", bound=Insertable)
42
+ _RowT = TypeVar("_RowT", bound=models.Base)
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class Received(Generic[_AnyT]):
47
+ item: _AnyT
48
+ received_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
49
+
50
+ def postpone(self, retries_left: int = DEFAULT_RETRY_ALLOWANCE) -> Postponed[_AnyT]:
51
+ return Postponed(item=self.item, received_at=self.received_at, retries_left=retries_left)
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class Postponed(Received[_AnyT]):
56
+ retries_left: int = field(default=DEFAULT_RETRY_ALLOWANCE)
57
+
58
+
59
+ class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT]):
60
+ table: Type[_RowT]
61
+ unique_by: Sequence[str]
62
+
63
+ def __init_subclass__(
64
+ cls,
65
+ table: Type[_RowT],
66
+ unique_by: Sequence[str],
67
+ ) -> None:
68
+ cls.table = table
69
+ cls.unique_by = unique_by
70
+
71
+ def __init__(
72
+ self,
73
+ db: DbSessionFactory,
74
+ retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC,
75
+ retry_allowance: int = DEFAULT_RETRY_ALLOWANCE,
76
+ ) -> None:
77
+ self._queue: List[Received[_PrecursorT]] = []
78
+ self._db = db
79
+ self._retry_delay_sec = retry_delay_sec
80
+ self._retry_allowance = retry_allowance
81
+
82
+ @property
83
+ def empty(self) -> bool:
84
+ return not bool(self._queue)
85
+
86
+ async def enqueue(self, *items: _PrecursorT) -> None:
87
+ self._queue.extend([Received(item) for item in items])
88
+
89
+ @abstractmethod
90
+ async def _partition(
91
+ self,
92
+ session: AsyncSession,
93
+ *parcels: Received[_PrecursorT],
94
+ ) -> Tuple[
95
+ List[Received[_InsertableT]],
96
+ List[Postponed[_PrecursorT]],
97
+ List[Received[_PrecursorT]],
98
+ ]: ...
99
+
100
+ async def insert(self) -> Tuple[Type[_RowT], List[int]]:
101
+ if not self._queue:
102
+ return self.table, []
103
+ parcels = self._queue
104
+ self._queue = []
105
+ inserted_ids: List[int] = []
106
+ async with self._db() as session:
107
+ to_insert, to_postpone, _ = await self._partition(session, *parcels)
108
+ if to_insert:
109
+ inserted_ids, to_retry, _ = await self._insert(session, *to_insert)
110
+ to_postpone.extend(to_retry)
111
+ if to_postpone:
112
+ loop = asyncio.get_running_loop()
113
+ loop.call_later(self._retry_delay_sec, self._queue.extend, to_postpone)
114
+ return self.table, inserted_ids
115
+
116
+ def _stmt(self, *records: Mapping[str, Any]) -> ReturningInsert[Tuple[int]]:
117
+ pk = next(c for c in self.table.__table__.c if c.primary_key)
118
+ return insert_on_conflict(
119
+ *records,
120
+ table=self.table,
121
+ unique_by=self.unique_by,
122
+ dialect=self._db.dialect,
123
+ ).returning(pk)
124
+
125
+ async def _insert(
126
+ self,
127
+ session: AsyncSession,
128
+ *insertions: Received[_InsertableT],
129
+ ) -> Tuple[List[int], List[Postponed[_PrecursorT]], List[Received[_InsertableT]]]:
130
+ records = [dict(as_kv(ins.item.row)) for ins in insertions]
131
+ inserted_ids: List[int] = []
132
+ to_retry: List[Postponed[_PrecursorT]] = []
133
+ failures: List[Received[_InsertableT]] = []
134
+ stmt = self._stmt(*records)
135
+ try:
136
+ async with session.begin_nested():
137
+ ids = [id_ async for id_ in await session.stream_scalars(stmt)]
138
+ inserted_ids.extend(ids)
139
+ except BaseException:
140
+ logger.exception(
141
+ f"Failed to bulk insert for {self.table.__name__}. "
142
+ f"Will try to insert ({len(records)} records) individually instead."
143
+ )
144
+ for i, record in enumerate(records):
145
+ stmt = self._stmt(record)
146
+ try:
147
+ async with session.begin_nested():
148
+ ids = [id_ async for id_ in await session.stream_scalars(stmt)]
149
+ inserted_ids.extend(ids)
150
+ except BaseException:
151
+ logger.exception(f"Failed to insert for {self.table.__name__}.")
152
+ p = insertions[i]
153
+ if isinstance(p, Postponed) and p.retries_left == 1:
154
+ failures.append(p)
155
+ else:
156
+ to_retry.append(
157
+ Postponed(
158
+ item=cast(_PrecursorT, p.item),
159
+ received_at=p.received_at,
160
+ retries_left=(p.retries_left - 1)
161
+ if isinstance(p, Postponed)
162
+ else self._retry_allowance,
163
+ )
164
+ )
165
+ return inserted_ids, to_retry, failures
166
+
167
+
168
+ class Precursors(ABC):
169
+ @dataclass(frozen=True)
170
+ class SpanAnnotation:
171
+ span_id: str
172
+ obj: models.SpanAnnotation
173
+
174
+ def as_insertable(
175
+ self,
176
+ span_rowid: int,
177
+ id_: Optional[int] = None,
178
+ ) -> Insertables.SpanAnnotation:
179
+ return Insertables.SpanAnnotation(
180
+ span_id=self.span_id,
181
+ obj=self.obj,
182
+ span_rowid=span_rowid,
183
+ id_=id_,
184
+ )
185
+
186
+ @dataclass(frozen=True)
187
+ class TraceAnnotation:
188
+ trace_id: str
189
+ obj: models.TraceAnnotation
190
+
191
+ def as_insertable(
192
+ self,
193
+ trace_rowid: int,
194
+ id_: Optional[int] = None,
195
+ ) -> Insertables.TraceAnnotation:
196
+ return Insertables.TraceAnnotation(
197
+ trace_id=self.trace_id,
198
+ obj=self.obj,
199
+ trace_rowid=trace_rowid,
200
+ id_=id_,
201
+ )
202
+
203
+ @dataclass(frozen=True)
204
+ class DocumentAnnotation:
205
+ span_id: str
206
+ document_position: int
207
+ obj: models.DocumentAnnotation
208
+
209
+ def as_insertable(
210
+ self,
211
+ span_rowid: int,
212
+ id_: Optional[int] = None,
213
+ ) -> Insertables.DocumentAnnotation:
214
+ return Insertables.DocumentAnnotation(
215
+ span_id=self.span_id,
216
+ document_position=self.document_position,
217
+ obj=self.obj,
218
+ span_rowid=span_rowid,
219
+ id_=id_,
220
+ )
221
+
222
+
223
+ class Insertables(ABC):
224
+ @dataclass(frozen=True)
225
+ class SpanAnnotation(Precursors.SpanAnnotation):
226
+ span_rowid: int
227
+ id_: Optional[int] = None
228
+
229
+ @property
230
+ def row(self) -> models.SpanAnnotation:
231
+ obj = copy(self.obj)
232
+ obj.span_rowid = self.span_rowid
233
+ if self.id_ is not None:
234
+ obj.id = self.id_
235
+ return obj
236
+
237
+ @dataclass(frozen=True)
238
+ class TraceAnnotation(Precursors.TraceAnnotation):
239
+ trace_rowid: int
240
+ id_: Optional[int] = None
241
+
242
+ @property
243
+ def row(self) -> models.TraceAnnotation:
244
+ obj = copy(self.obj)
245
+ obj.trace_rowid = self.trace_rowid
246
+ if self.id_ is not None:
247
+ obj.id = self.id_
248
+ return obj
249
+
250
+ @dataclass(frozen=True)
251
+ class DocumentAnnotation(Precursors.DocumentAnnotation):
252
+ span_rowid: int
253
+ id_: Optional[int] = None
254
+
255
+ @property
256
+ def row(self) -> models.DocumentAnnotation:
257
+ obj = copy(self.obj)
258
+ obj.span_rowid = self.span_rowid
259
+ if self.id_ is not None:
260
+ obj.id = self.id_
261
+ return obj
@@ -72,15 +72,16 @@ from phoenix.experiments.types import (
72
72
  )
73
73
  from phoenix.experiments.utils import get_dataset_experiments_url, get_experiment_url, get_func_name
74
74
  from phoenix.trace.attributes import flatten
75
+ from phoenix.utilities.client import VersionedAsyncClient, VersionedClient
75
76
  from phoenix.utilities.json import jsonify
76
77
 
77
78
 
78
79
  def _phoenix_clients() -> Tuple[httpx.Client, httpx.AsyncClient]:
79
80
  headers = get_env_client_headers()
80
- return httpx.Client(
81
+ return VersionedClient(
81
82
  base_url=get_base_url(),
82
83
  headers=headers,
83
- ), httpx.AsyncClient(
84
+ ), VersionedAsyncClient(
84
85
  base_url=get_base_url(),
85
86
  headers=headers,
86
87
  )
@@ -226,7 +226,7 @@ class ExperimentRun:
226
226
 
227
227
  def __post_init__(self) -> None:
228
228
  if bool(self.output) == bool(self.error):
229
- ValueError("Must specify exactly one of experiment_run_output or error")
229
+ raise ValueError("Must specify exactly one of experiment_run_output or error")
230
230
 
231
231
 
232
232
  @dataclass(frozen=True)
@@ -249,7 +249,7 @@ class EvaluationResult:
249
249
 
250
250
  def __post_init__(self) -> None:
251
251
  if self.score is None and not self.label:
252
- ValueError("Must specify score or label, or both")
252
+ raise ValueError("Must specify score or label, or both")
253
253
  if self.score is None and not self.label:
254
254
  object.__setattr__(self, "score", 0)
255
255
  for k in ("label", "explanation"):
@@ -285,7 +285,7 @@ class ExperimentEvaluationRun:
285
285
 
286
286
  def __post_init__(self) -> None:
287
287
  if bool(self.result) == bool(self.error):
288
- ValueError("Must specify either result or error")
288
+ raise ValueError("Must specify either result or error")
289
289
 
290
290
 
291
291
  ExperimentTask: TypeAlias = Union[
@@ -1,12 +1,9 @@
1
1
  from dataclasses import dataclass
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
- from typing import AsyncContextManager, Callable, Optional, Union
4
+ from typing import Callable, Optional
5
5
 
6
- from sqlalchemy.ext.asyncio import AsyncSession
7
- from starlette.requests import Request
8
- from starlette.responses import Response
9
- from starlette.websockets import WebSocket
6
+ from strawberry.fastapi import BaseContext
10
7
  from typing_extensions import TypeAlias
11
8
 
12
9
  from phoenix.core.model_schema import Model
@@ -28,6 +25,7 @@ from phoenix.server.api.dataloaders import (
28
25
  ProjectByNameDataLoader,
29
26
  RecordCountDataLoader,
30
27
  SpanAnnotationsDataLoader,
28
+ SpanDatasetExamplesDataLoader,
31
29
  SpanDescendantsDataLoader,
32
30
  SpanEvaluationsDataLoader,
33
31
  SpanProjectsDataLoader,
@@ -35,6 +33,7 @@ from phoenix.server.api.dataloaders import (
35
33
  TraceEvaluationsDataLoader,
36
34
  TraceRowIdsDataLoader,
37
35
  )
36
+ from phoenix.server.types import DbSessionFactory
38
37
 
39
38
 
40
39
  @dataclass
@@ -53,6 +52,7 @@ class DataLoaders:
53
52
  latency_ms_quantile: LatencyMsQuantileDataLoader
54
53
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
55
54
  record_counts: RecordCountDataLoader
55
+ span_dataset_examples: SpanDatasetExamplesDataLoader
56
56
  span_descendants: SpanDescendantsDataLoader
57
57
  span_evaluations: SpanEvaluationsDataLoader
58
58
  span_projects: SpanProjectsDataLoader
@@ -67,10 +67,8 @@ ProjectRowId: TypeAlias = int
67
67
 
68
68
 
69
69
  @dataclass
70
- class Context:
71
- request: Union[Request, WebSocket]
72
- response: Optional[Response]
73
- db: Callable[[], AsyncContextManager[AsyncSession]]
70
+ class Context(BaseContext):
71
+ db: DbSessionFactory
74
72
  data_loaders: DataLoaders
75
73
  cache_for_dataloaders: Optional[CacheForDataLoaders]
76
74
  model: Model
@@ -27,6 +27,7 @@ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMax
27
27
  from .project_by_name import ProjectByNameDataLoader
28
28
  from .record_counts import RecordCountCache, RecordCountDataLoader
29
29
  from .span_annotations import SpanAnnotationsDataLoader
30
+ from .span_dataset_examples import SpanDatasetExamplesDataLoader
30
31
  from .span_descendants import SpanDescendantsDataLoader
31
32
  from .span_evaluations import SpanEvaluationsDataLoader
32
33
  from .span_projects import SpanProjectsDataLoader
@@ -50,6 +51,7 @@ __all__ = [
50
51
  "LatencyMsQuantileDataLoader",
51
52
  "MinStartOrMaxEndTimeDataLoader",
52
53
  "RecordCountDataLoader",
54
+ "SpanDatasetExamplesDataLoader",
53
55
  "SpanDescendantsDataLoader",
54
56
  "SpanEvaluationsDataLoader",
55
57
  "SpanProjectsDataLoader",
@@ -1,11 +1,11 @@
1
- from typing import AsyncContextManager, Callable, List, Optional
1
+ from typing import List, Optional
2
2
 
3
3
  from sqlalchemy import func, select
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
4
  from strawberry.dataloader import DataLoader
6
5
  from typing_extensions import TypeAlias
7
6
 
8
7
  from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
9
 
10
10
  ExperimentID: TypeAlias = int
11
11
  RunLatency: TypeAlias = Optional[float]
@@ -16,7 +16,7 @@ Result: TypeAlias = RunLatency
16
16
  class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
17
17
  def __init__(
18
18
  self,
19
- db: Callable[[], AsyncContextManager[AsyncSession]],
19
+ db: DbSessionFactory,
20
20
  ) -> None:
21
21
  super().__init__(load_fn=self._load_fn)
22
22
  self._db = db
@@ -1,6 +1,4 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  Tuple,
@@ -8,12 +6,12 @@ from typing import (
8
6
  )
9
7
 
10
8
  from sqlalchemy import Integer, case, func, literal, or_, select, union
11
- from sqlalchemy.ext.asyncio import AsyncSession
12
9
  from strawberry.dataloader import DataLoader
13
10
  from typing_extensions import TypeAlias
14
11
 
15
12
  from phoenix.db import models
16
13
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
+ from phoenix.server.types import DbSessionFactory
17
15
 
18
16
  ExampleID: TypeAlias = int
19
17
  VersionID: TypeAlias = Optional[int]
@@ -22,7 +20,7 @@ Result: TypeAlias = DatasetExampleRevision
22
20
 
23
21
 
24
22
  class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
25
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ def __init__(self, db: DbSessionFactory) -> None:
26
24
  super().__init__(load_fn=self._load_fn)
27
25
  self._db = db
28
26
 
@@ -1,17 +1,15 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  )
7
5
 
8
6
  from sqlalchemy import select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
7
  from sqlalchemy.orm import joinedload
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
12
+ from phoenix.server.types import DbSessionFactory
15
13
 
16
14
  ExampleID: TypeAlias = int
17
15
  Key: TypeAlias = ExampleID
@@ -19,7 +17,7 @@ Result: TypeAlias = Optional[models.Span]
19
17
 
20
18
 
21
19
  class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Optional,
@@ -14,7 +12,6 @@ import numpy as np
14
12
  from aioitertools.itertools import groupby
15
13
  from cachetools import LFUCache, TTLCache
16
14
  from sqlalchemy import Select, select
17
- from sqlalchemy.ext.asyncio import AsyncSession
18
15
  from strawberry.dataloader import AbstractCache, DataLoader
19
16
  from typing_extensions import TypeAlias
20
17
 
@@ -24,6 +21,7 @@ from phoenix.metrics.retrieval_metrics import RetrievalMetrics
24
21
  from phoenix.server.api.dataloaders.cache import TwoTierCache
25
22
  from phoenix.server.api.input_types.TimeRange import TimeRange
26
23
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
24
+ from phoenix.server.types import DbSessionFactory
27
25
  from phoenix.trace.dsl import SpanFilter
28
26
 
29
27
  ProjectRowId: TypeAlias = int
@@ -77,7 +75,7 @@ class DocumentEvaluationSummaryCache(
77
75
  class DocumentEvaluationSummaryDataLoader(DataLoader[Key, Result]):
78
76
  def __init__(
79
77
  self,
80
- db: Callable[[], AsyncContextManager[AsyncSession]],
78
+ db: DbSessionFactory,
81
79
  cache_map: Optional[AbstractCache[Key, Result]] = None,
82
80
  ) -> None:
83
81
  super().__init__(
@@ -1,25 +1,23 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  List,
7
5
  )
8
6
 
9
7
  from sqlalchemy import select
10
- from sqlalchemy.ext.asyncio import AsyncSession
11
8
  from strawberry.dataloader import DataLoader
12
9
  from typing_extensions import TypeAlias
13
10
 
14
11
  from phoenix.db import models
15
12
  from phoenix.server.api.types.Evaluation import DocumentEvaluation
13
+ from phoenix.server.types import DbSessionFactory
16
14
 
17
15
  Key: TypeAlias = int
18
16
  Result: TypeAlias = List[DocumentEvaluation]
19
17
 
20
18
 
21
19
  class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
20
+ def __init__(self, db: DbSessionFactory) -> None:
23
21
  super().__init__(load_fn=self._load_fn)
24
22
  self._db = db
25
23
 
@@ -1,7 +1,5 @@
1
1
  from collections import defaultdict
2
2
  from typing import (
3
- AsyncContextManager,
4
- Callable,
5
3
  DefaultDict,
6
4
  Dict,
7
5
  List,
@@ -13,13 +11,13 @@ from typing import (
13
11
  import numpy as np
14
12
  from aioitertools.itertools import groupby
15
13
  from sqlalchemy import select
16
- from sqlalchemy.ext.asyncio import AsyncSession
17
14
  from strawberry.dataloader import DataLoader
18
15
  from typing_extensions import TypeAlias
19
16
 
20
17
  from phoenix.db import models
21
18
  from phoenix.metrics.retrieval_metrics import RetrievalMetrics
22
19
  from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
20
+ from phoenix.server.types import DbSessionFactory
23
21
 
24
22
  RowId: TypeAlias = int
25
23
  NumDocs: TypeAlias = int
@@ -30,7 +28,7 @@ Result: TypeAlias = List[DocumentRetrievalMetrics]
30
28
 
31
29
 
32
30
  class DocumentRetrievalMetricsDataLoader(DataLoader[Key, Result]):
33
- def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
31
+ def __init__(self, db: DbSessionFactory) -> None:
34
32
  super().__init__(load_fn=self._load_fn)
35
33
  self._db = db
36
34
 
@@ -2,8 +2,6 @@ from collections import defaultdict
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
- AsyncContextManager,
6
- Callable,
7
5
  DefaultDict,
8
6
  List,
9
7
  Literal,
@@ -15,7 +13,6 @@ import pandas as pd
15
13
  from aioitertools.itertools import groupby
16
14
  from cachetools import LFUCache, TTLCache
17
15
  from sqlalchemy import Select, func, or_, select
18
- from sqlalchemy.ext.asyncio import AsyncSession
19
16
  from strawberry.dataloader import AbstractCache, DataLoader
20
17
  from typing_extensions import TypeAlias, assert_never
21
18
 
@@ -23,6 +20,7 @@ from phoenix.db import models
23
20
  from phoenix.server.api.dataloaders.cache import TwoTierCache
24
21
  from phoenix.server.api.input_types.TimeRange import TimeRange
25
22
  from phoenix.server.api.types.EvaluationSummary import EvaluationSummary
23
+ from phoenix.server.types import DbSessionFactory
26
24
  from phoenix.trace.dsl import SpanFilter
27
25
 
28
26
  Kind: TypeAlias = Literal["span", "trace"]
@@ -77,7 +75,7 @@ class EvaluationSummaryCache(
77
75
  class EvaluationSummaryDataLoader(DataLoader[Key, Result]):
78
76
  def __init__(
79
77
  self,
80
- db: Callable[[], AsyncContextManager[AsyncSession]],
78
+ db: DbSessionFactory,
81
79
  cache_map: Optional[AbstractCache[Key, Result]] = None,
82
80
  ) -> None:
83
81
  super().__init__(
@@ -1,19 +1,17 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import dataclass
3
3
  from typing import (
4
- AsyncContextManager,
5
- Callable,
6
4
  DefaultDict,
7
5
  List,
8
6
  Optional,
9
7
  )
10
8
 
11
9
  from sqlalchemy import func, select
12
- from sqlalchemy.ext.asyncio import AsyncSession
13
10
  from strawberry.dataloader import AbstractCache, DataLoader
14
11
  from typing_extensions import TypeAlias
15
12
 
16
13
  from phoenix.db import models
14
+ from phoenix.server.types import DbSessionFactory
17
15
 
18
16
 
19
17
  @dataclass
@@ -34,7 +32,7 @@ Result: TypeAlias = List[ExperimentAnnotationSummary]
34
32
  class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
35
33
  def __init__(
36
34
  self,
37
- db: Callable[[], AsyncContextManager[AsyncSession]],
35
+ db: DbSessionFactory,
38
36
  cache_map: Optional[AbstractCache[Key, Result]] = None,
39
37
  ) -> None:
40
38
  super().__init__(load_fn=self._load_fn)
@@ -1,16 +1,14 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  Optional,
6
4
  )
7
5
 
8
6
  from sqlalchemy import case, func, select
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
7
  from strawberry.dataloader import DataLoader
11
8
  from typing_extensions import TypeAlias
12
9
 
13
10
  from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
14
12
 
15
13
  ExperimentID: TypeAlias = int
16
14
  ErrorRate: TypeAlias = float
@@ -21,7 +19,7 @@ Result: TypeAlias = Optional[ErrorRate]
21
19
  class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
22
20
  def __init__(
23
21
  self,
24
- db: Callable[[], AsyncContextManager[AsyncSession]],
22
+ db: DbSessionFactory,
25
23
  ) -> None:
26
24
  super().__init__(load_fn=self._load_fn)
27
25
  self._db = db
@@ -1,15 +1,13 @@
1
1
  from typing import (
2
- AsyncContextManager,
3
- Callable,
4
2
  List,
5
3
  )
6
4
 
7
5
  from sqlalchemy import func, select
8
- from sqlalchemy.ext.asyncio import AsyncSession
9
6
  from strawberry.dataloader import DataLoader
10
7
  from typing_extensions import TypeAlias
11
8
 
12
9
  from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
13
11
 
14
12
  ExperimentID: TypeAlias = int
15
13
  RunCount: TypeAlias = int
@@ -20,7 +18,7 @@ Result: TypeAlias = RunCount
20
18
  class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
19
  def __init__(
22
20
  self,
23
- db: Callable[[], AsyncContextManager[AsyncSession]],
21
+ db: DbSessionFactory,
24
22
  ) -> None:
25
23
  super().__init__(load_fn=self._load_fn)
26
24
  self._db = db