arize-phoenix 4.19.0__py3-none-any.whl → 4.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (35) hide show
  1. {arize_phoenix-4.19.0.dist-info → arize_phoenix-4.20.1.dist-info}/METADATA +2 -1
  2. {arize_phoenix-4.19.0.dist-info → arize_phoenix-4.20.1.dist-info}/RECORD +35 -33
  3. phoenix/db/bulk_inserter.py +24 -98
  4. phoenix/db/insertion/document_annotation.py +13 -0
  5. phoenix/db/insertion/span_annotation.py +13 -0
  6. phoenix/db/insertion/trace_annotation.py +13 -0
  7. phoenix/db/insertion/types.py +34 -28
  8. phoenix/server/api/context.py +9 -7
  9. phoenix/server/api/dataloaders/__init__.py +0 -47
  10. phoenix/server/api/dataloaders/span_annotations.py +6 -9
  11. phoenix/server/api/mutations/dataset_mutations.py +44 -4
  12. phoenix/server/api/mutations/experiment_mutations.py +2 -0
  13. phoenix/server/api/mutations/project_mutations.py +5 -5
  14. phoenix/server/api/mutations/span_annotations_mutations.py +10 -2
  15. phoenix/server/api/mutations/trace_annotations_mutations.py +10 -2
  16. phoenix/server/api/queries.py +9 -0
  17. phoenix/server/api/routers/v1/datasets.py +2 -0
  18. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -0
  19. phoenix/server/api/routers/v1/experiment_runs.py +2 -0
  20. phoenix/server/api/routers/v1/experiments.py +2 -0
  21. phoenix/server/api/routers/v1/spans.py +15 -9
  22. phoenix/server/api/routers/v1/traces.py +15 -11
  23. phoenix/server/api/types/Dataset.py +6 -1
  24. phoenix/server/api/types/Experiment.py +6 -1
  25. phoenix/server/api/types/Project.py +4 -1
  26. phoenix/server/api/types/Span.py +14 -13
  27. phoenix/server/app.py +25 -8
  28. phoenix/server/dml_event.py +136 -0
  29. phoenix/server/dml_event_handler.py +272 -0
  30. phoenix/server/types.py +106 -1
  31. phoenix/session/client.py +2 -2
  32. phoenix/version.py +1 -1
  33. {arize_phoenix-4.19.0.dist-info → arize_phoenix-4.20.1.dist-info}/WHEEL +0 -0
  34. {arize_phoenix-4.19.0.dist-info → arize_phoenix-4.20.1.dist-info}/licenses/IP_NOTICE +0 -0
  35. {arize_phoenix-4.19.0.dist-info → arize_phoenix-4.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,272 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from asyncio import gather
5
+ from inspect import getmro
6
+ from itertools import chain
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Generic,
11
+ Iterable,
12
+ Iterator,
13
+ Mapping,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ Type,
18
+ TypedDict,
19
+ TypeVar,
20
+ Union,
21
+ cast,
22
+ )
23
+
24
+ from sqlalchemy import Select, select
25
+ from typing_extensions import TypeAlias, Unpack
26
+
27
+ from phoenix.db.models import (
28
+ Base,
29
+ DocumentAnnotation,
30
+ Project,
31
+ Span,
32
+ SpanAnnotation,
33
+ Trace,
34
+ TraceAnnotation,
35
+ )
36
+ from phoenix.server.api.dataloaders import CacheForDataLoaders
37
+ from phoenix.server.dml_event import (
38
+ DmlEvent,
39
+ DocumentAnnotationDmlEvent,
40
+ SpanAnnotationDmlEvent,
41
+ SpanDeleteEvent,
42
+ SpanDmlEvent,
43
+ TraceAnnotationDmlEvent,
44
+ )
45
+ from phoenix.server.types import (
46
+ BatchedCaller,
47
+ CanSetLastUpdatedAt,
48
+ DbSessionFactory,
49
+ )
50
+
51
+ _DmlEventT = TypeVar("_DmlEventT", bound=DmlEvent)
52
+
53
+
54
+ class _DmlEventQueue(Generic[_DmlEventT]):
55
+ def __init__(self, **kwargs: Any) -> None:
56
+ super().__init__(**kwargs)
57
+ self._events: Set[_DmlEventT] = set()
58
+
59
+ @property
60
+ def empty(self) -> bool:
61
+ return not self._events
62
+
63
+ def put(self, event: _DmlEventT) -> None:
64
+ self._events.add(event)
65
+
66
+ def clear(self) -> None:
67
+ self._events.clear()
68
+
69
+ def __iter__(self) -> Iterator[_DmlEventT]:
70
+ yield from self._events
71
+
72
+
73
+ class _HandlerParams(TypedDict):
74
+ db: DbSessionFactory
75
+ last_updated_at: CanSetLastUpdatedAt
76
+ cache_for_dataloaders: Optional[CacheForDataLoaders]
77
+ sleep_seconds: float
78
+
79
+
80
+ class _HasLastUpdatedAt(ABC):
81
+ def __init__(
82
+ self,
83
+ last_updated_at: CanSetLastUpdatedAt,
84
+ **kwargs: Any,
85
+ ) -> None:
86
+ super().__init__(**kwargs)
87
+ self._last_updated_at = last_updated_at
88
+
89
+
90
+ class _HasCacheForDataLoaders(ABC):
91
+ def __init__(
92
+ self,
93
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
94
+ **kwargs: Any,
95
+ ) -> None:
96
+ super().__init__(**kwargs)
97
+ self._cache_for_dataloaders = cache_for_dataloaders
98
+
99
+
100
+ class _DmlEventHandler(
101
+ _HasLastUpdatedAt,
102
+ _HasCacheForDataLoaders,
103
+ BatchedCaller[_DmlEventT],
104
+ Generic[_DmlEventT],
105
+ ABC,
106
+ ):
107
+ _batch_factory = cast(Callable[[], _DmlEventQueue[_DmlEventT]], _DmlEventQueue)
108
+
109
+ def __init__(self, *, db: DbSessionFactory, **kwargs: Any) -> None:
110
+ super().__init__(**kwargs)
111
+ self._db = db
112
+
113
+ def __hash__(self) -> int:
114
+ return id(self)
115
+
116
+
117
+ class _GenericDmlEventHandler(_DmlEventHandler[DmlEvent]):
118
+ async def __call__(self) -> None:
119
+ for e in self._batch:
120
+ for id_ in e.ids:
121
+ self._update(e.table, id_)
122
+
123
+ def _update(self, table: Type[Base], id_: int) -> None:
124
+ self._last_updated_at.set(table, id_)
125
+
126
+
127
+ class _SpanDmlEventHandler(_DmlEventHandler[SpanDmlEvent]):
128
+ async def __call__(self) -> None:
129
+ if cache := self._cache_for_dataloaders:
130
+ for id_ in set(chain.from_iterable(e.ids for e in self._batch)):
131
+ self._clear(cache, id_)
132
+
133
+ @staticmethod
134
+ def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
135
+ cache.latency_ms_quantile.invalidate(project_id)
136
+ cache.token_count.invalidate(project_id)
137
+ cache.record_count.invalidate(project_id)
138
+ cache.min_start_or_max_end_time.invalidate(project_id)
139
+
140
+
141
+ class _SpanDeleteEventHandler(_SpanDmlEventHandler):
142
+ @staticmethod
143
+ def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
144
+ cache.annotation_summary.invalidate_project(project_id)
145
+ cache.evaluation_summary.invalidate_project(project_id)
146
+ cache.document_evaluation_summary.invalidate_project(project_id)
147
+
148
+
149
+ _AnnotationTable: TypeAlias = Union[
150
+ Type[SpanAnnotation],
151
+ Type[TraceAnnotation],
152
+ Type[DocumentAnnotation],
153
+ ]
154
+
155
+ _AnnotationDmlEventT = TypeVar(
156
+ "_AnnotationDmlEventT",
157
+ SpanAnnotationDmlEvent,
158
+ TraceAnnotationDmlEvent,
159
+ DocumentAnnotationDmlEvent,
160
+ )
161
+
162
+
163
+ class _AnnotationDmlEventHandler(
164
+ _DmlEventHandler[_AnnotationDmlEventT],
165
+ Generic[_AnnotationDmlEventT],
166
+ ABC,
167
+ ):
168
+ _table: _AnnotationTable
169
+ _base_stmt: Union[Select[Tuple[int, str]], Select[Tuple[int]]] = (
170
+ select(Project.id).join_from(Project, Trace).distinct()
171
+ )
172
+
173
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
174
+ super().__init__(**kwargs)
175
+ self._stmt = self._base_stmt
176
+ if self._cache_for_dataloaders:
177
+ self._stmt = self._stmt.add_columns(self._table.name)
178
+
179
+ def _get_stmt(self) -> Union[Select[Tuple[int, str]], Select[Tuple[int]]]:
180
+ ids = set(chain.from_iterable(e.ids for e in self._batch))
181
+ return self._stmt.where(self._table.id.in_(ids))
182
+
183
+ @staticmethod
184
+ @abstractmethod
185
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None: ...
186
+
187
+ async def __call__(self) -> None:
188
+ async with self._db() as session:
189
+ async for row in await session.stream(self._get_stmt()):
190
+ if cache := self._cache_for_dataloaders:
191
+ self._clear(cache, row.id, row.name)
192
+
193
+
194
+ class _SpanAnnotationDmlEventHandler(_AnnotationDmlEventHandler[SpanAnnotationDmlEvent]):
195
+ _table = SpanAnnotation
196
+
197
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
198
+ super().__init__(**kwargs)
199
+ self._stmt = self._stmt.join_from(Trace, Span).join_from(Span, self._table)
200
+
201
+ @staticmethod
202
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
203
+ cache.annotation_summary.invalidate((project_id, name, "span"))
204
+ cache.evaluation_summary.invalidate((project_id, name, "span"))
205
+
206
+
207
+ class _TraceAnnotationDmlEventHandler(_AnnotationDmlEventHandler[TraceAnnotationDmlEvent]):
208
+ _table = TraceAnnotation
209
+
210
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
211
+ super().__init__(**kwargs)
212
+ self._stmt = self._stmt.join_from(Trace, self._table)
213
+
214
+ @staticmethod
215
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
216
+ cache.annotation_summary.invalidate((project_id, name, "trace"))
217
+ cache.evaluation_summary.invalidate((project_id, name, "trace"))
218
+
219
+
220
+ class _DocumentAnnotationDmlEventHandler(_AnnotationDmlEventHandler[DocumentAnnotationDmlEvent]):
221
+ _table = DocumentAnnotation
222
+
223
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
224
+ super().__init__(**kwargs)
225
+ self._stmt = self._stmt.join_from(Trace, Span).join_from(Span, self._table)
226
+
227
+ @staticmethod
228
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
229
+ cache.document_evaluation_summary.invalidate((project_id, name))
230
+
231
+
232
+ class DmlEventHandler:
233
+ def __init__(
234
+ self,
235
+ *,
236
+ db: DbSessionFactory,
237
+ last_updated_at: CanSetLastUpdatedAt,
238
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
239
+ sleep_seconds: float = 0.1,
240
+ ) -> None:
241
+ kwargs = _HandlerParams(
242
+ db=db,
243
+ last_updated_at=last_updated_at,
244
+ cache_for_dataloaders=cache_for_dataloaders,
245
+ sleep_seconds=sleep_seconds,
246
+ )
247
+ self._handlers: Mapping[Type[DmlEvent], Iterable[_DmlEventHandler[Any]]] = {
248
+ DmlEvent: [_GenericDmlEventHandler(**kwargs)],
249
+ SpanDmlEvent: [_SpanDmlEventHandler(**kwargs)],
250
+ SpanDeleteEvent: [_SpanDeleteEventHandler(**kwargs)],
251
+ SpanAnnotationDmlEvent: [_SpanAnnotationDmlEventHandler(**kwargs)],
252
+ TraceAnnotationDmlEvent: [_TraceAnnotationDmlEventHandler(**kwargs)],
253
+ DocumentAnnotationDmlEvent: [_DocumentAnnotationDmlEventHandler(**kwargs)],
254
+ }
255
+ self._all_handlers = frozenset(chain.from_iterable(self._handlers.values()))
256
+
257
+ async def __aenter__(self) -> None:
258
+ await gather(*(h.start() for h in self._all_handlers))
259
+
260
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
261
+ await gather(*(h.stop() for h in self._all_handlers))
262
+
263
+ def put(self, event: DmlEvent) -> None:
264
+ if not (isinstance(event, DmlEvent) and event):
265
+ return
266
+ for cls in getmro(type(event)):
267
+ if not (issubclass(cls, DmlEvent) and (handlers := self._handlers.get(cls))):
268
+ continue
269
+ for h in handlers:
270
+ h.put(event)
271
+ if cls is DmlEvent:
272
+ break
phoenix/server/types.py CHANGED
@@ -1,10 +1,36 @@
1
- from typing import AsyncContextManager, Callable
1
+ from abc import ABC, abstractmethod
2
+ from asyncio import Task, create_task, sleep
3
+ from collections import defaultdict
4
+ from datetime import datetime, timezone
5
+ from typing import (
6
+ Any,
7
+ AsyncContextManager,
8
+ Callable,
9
+ DefaultDict,
10
+ Generic,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Protocol,
15
+ Type,
16
+ TypeVar,
17
+ )
2
18
 
19
+ from cachetools import LRUCache
3
20
  from sqlalchemy.ext.asyncio import AsyncSession
4
21
 
22
+ from phoenix.db import models
5
23
  from phoenix.db.helpers import SupportedSQLDialect
6
24
 
7
25
 
26
+ class CanSetLastUpdatedAt(Protocol):
27
+ def set(self, table: Type[models.Base], id_: int) -> None: ...
28
+
29
+
30
+ class CanGetLastUpdatedAt(Protocol):
31
+ def get(self, table: Type[models.Base], id_: Optional[int] = None) -> Optional[datetime]: ...
32
+
33
+
8
34
  class DbSessionFactory:
9
35
  def __init__(
10
36
  self,
@@ -16,3 +42,82 @@ class DbSessionFactory:
16
42
 
17
43
  def __call__(self) -> AsyncContextManager[AsyncSession]:
18
44
  return self._db()
45
+
46
+
47
+ _AnyT = TypeVar("_AnyT")
48
+ _ItemT_contra = TypeVar("_ItemT_contra", contravariant=True)
49
+
50
+
51
+ class CanPutItem(Protocol[_ItemT_contra]):
52
+ def put(self, item: _ItemT_contra) -> None: ...
53
+
54
+
55
+ class _Batch(CanPutItem[_AnyT], Protocol[_AnyT]):
56
+ @property
57
+ def empty(self) -> bool: ...
58
+ def clear(self) -> None: ...
59
+ def __iter__(self) -> Iterator[_AnyT]: ...
60
+
61
+
62
+ class _HasBatch(Generic[_ItemT_contra], ABC):
63
+ _batch_factory: Callable[[], _Batch[_ItemT_contra]]
64
+
65
+ def __init__(self) -> None:
66
+ self._batch = self._batch_factory()
67
+
68
+ def put(self, item: _ItemT_contra) -> None:
69
+ self._batch.put(item)
70
+
71
+
72
+ class BatchedCaller(_HasBatch[_AnyT], Generic[_AnyT], ABC):
73
+ def __init__(self, *, sleep_seconds: float = 0.1, **kwargs: Any) -> None:
74
+ assert sleep_seconds > 0
75
+ super().__init__(**kwargs)
76
+ self._running = False
77
+ self._seconds = sleep_seconds
78
+ self._tasks: List[Task[None]] = []
79
+
80
+ async def start(self) -> None:
81
+ self._running = True
82
+ if not self._tasks:
83
+ self._tasks.append(create_task(self._run()))
84
+
85
+ async def stop(self) -> None:
86
+ self._running = False
87
+ for task in reversed(self._tasks):
88
+ if not task.done():
89
+ task.cancel()
90
+ self._tasks.clear()
91
+
92
+ @abstractmethod
93
+ async def __call__(self) -> None: ...
94
+
95
+ async def _run(self) -> None:
96
+ while self._running:
97
+ self._tasks.append(create_task(sleep(self._seconds)))
98
+ await self._tasks[-1]
99
+ self._tasks.pop()
100
+ if self._batch.empty:
101
+ continue
102
+ self._tasks.append(create_task(self()))
103
+ await self._tasks[-1]
104
+ self._tasks.pop()
105
+ self._batch.clear()
106
+
107
+
108
+ class LastUpdatedAt:
109
+ def __init__(self) -> None:
110
+ self._cache: DefaultDict[
111
+ Type[models.Base],
112
+ LRUCache[int, datetime],
113
+ ] = defaultdict(lambda: LRUCache(maxsize=100))
114
+
115
+ def get(self, table: Type[models.Base], id_: Optional[int] = None) -> Optional[datetime]:
116
+ if not (cache := self._cache.get(table)):
117
+ return None
118
+ if id_ is None:
119
+ return max(filter(bool, cache.values()), default=None)
120
+ return cache.get(id_)
121
+
122
+ def set(self, table: Type[models.Base], id_: int) -> None:
123
+ self._cache[table][id_] = datetime.now(timezone.utc)
phoenix/session/client.py CHANGED
@@ -324,7 +324,7 @@ class Client(TraceDataExtractor):
324
324
  Dataset: The dataset object.
325
325
  """
326
326
  response = self._client.get(
327
- urljoin(self._base_url, "/v1/datasets"),
327
+ urljoin(self._base_url, "v1/datasets"),
328
328
  params={"name": name},
329
329
  )
330
330
  response.raise_for_status()
@@ -366,7 +366,7 @@ class Client(TraceDataExtractor):
366
366
  raise ValueError("Dataset id or name must be provided.")
367
367
 
368
368
  response = self._client.get(
369
- urljoin(self._base_url, f"/v1/datasets/{quote(id)}/examples"),
369
+ urljoin(self._base_url, f"v1/datasets/{quote(id)}/examples"),
370
370
  params={"version_id": version_id} if version_id else None,
371
371
  )
372
372
  response.raise_for_status()
phoenix/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "4.19.0"
1
+ __version__ = "4.20.1"