arize-phoenix 4.20.0__py3-none-any.whl → 4.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (33) hide show
  1. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.1.dist-info}/METADATA +2 -1
  2. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.1.dist-info}/RECORD +33 -31
  3. phoenix/db/bulk_inserter.py +24 -98
  4. phoenix/db/insertion/document_annotation.py +13 -0
  5. phoenix/db/insertion/span_annotation.py +13 -0
  6. phoenix/db/insertion/trace_annotation.py +13 -0
  7. phoenix/db/insertion/types.py +34 -28
  8. phoenix/server/api/context.py +8 -6
  9. phoenix/server/api/dataloaders/__init__.py +0 -47
  10. phoenix/server/api/mutations/dataset_mutations.py +9 -3
  11. phoenix/server/api/mutations/experiment_mutations.py +2 -0
  12. phoenix/server/api/mutations/project_mutations.py +5 -5
  13. phoenix/server/api/mutations/span_annotations_mutations.py +10 -2
  14. phoenix/server/api/mutations/trace_annotations_mutations.py +10 -2
  15. phoenix/server/api/queries.py +9 -0
  16. phoenix/server/api/routers/v1/datasets.py +2 -0
  17. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -0
  18. phoenix/server/api/routers/v1/experiment_runs.py +2 -0
  19. phoenix/server/api/routers/v1/experiments.py +2 -0
  20. phoenix/server/api/routers/v1/spans.py +12 -8
  21. phoenix/server/api/routers/v1/traces.py +12 -10
  22. phoenix/server/api/types/Dataset.py +6 -1
  23. phoenix/server/api/types/Experiment.py +6 -1
  24. phoenix/server/api/types/Project.py +4 -1
  25. phoenix/server/api/types/Span.py +2 -2
  26. phoenix/server/app.py +25 -8
  27. phoenix/server/dml_event.py +136 -0
  28. phoenix/server/dml_event_handler.py +272 -0
  29. phoenix/server/types.py +106 -1
  30. phoenix/version.py +1 -1
  31. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.1.dist-info}/WHEEL +0 -0
  32. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.1.dist-info}/licenses/IP_NOTICE +0 -0
  33. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import List, Optional
2
+ from typing import ClassVar, List, Optional, Type
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import select
@@ -23,6 +23,7 @@ from phoenix.server.api.types.Project import Project
23
23
 
24
24
  @strawberry.type
25
25
  class Experiment(Node):
26
+ _table: ClassVar[Type[models.Base]] = models.Experiment
26
27
  cached_sequence_number: Private[Optional[int]] = None
27
28
  id_attr: NodeID[int]
28
29
  name: str
@@ -127,6 +128,10 @@ class Experiment(Node):
127
128
  gradient_end_color=db_project.gradient_end_color,
128
129
  )
129
130
 
131
+ @strawberry.field
132
+ def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
133
+ return info.context.last_updated_at.get(self._table, self.id_attr)
134
+
130
135
 
131
136
  def to_gql_experiment(
132
137
  experiment: models.Experiment,
@@ -2,8 +2,10 @@ import operator
2
2
  from datetime import datetime
3
3
  from typing import (
4
4
  Any,
5
+ ClassVar,
5
6
  List,
6
7
  Optional,
8
+ Type,
7
9
  )
8
10
 
9
11
  import strawberry
@@ -38,6 +40,7 @@ from phoenix.trace.dsl import SpanFilter
38
40
 
39
41
  @strawberry.type
40
42
  class Project(Node):
43
+ _table: ClassVar[Type[models.Base]] = models.Project
41
44
  id_attr: NodeID[int]
42
45
  name: str
43
46
  gradient_start_color: str
@@ -397,7 +400,7 @@ class Project(Node):
397
400
  self,
398
401
  info: Info[Context, None],
399
402
  ) -> Optional[datetime]:
400
- return info.context.streaming_last_updated_at(self.id_attr)
403
+ return info.context.last_updated_at.get(self._table, self.id_attr)
401
404
 
402
405
  @strawberry.field
403
406
  async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
@@ -194,8 +194,8 @@ class Span(Node):
194
194
  ) -> List[SpanAnnotation]:
195
195
  span_id = self.id_attr
196
196
  annotations = await info.context.data_loaders.span_annotations.load(span_id)
197
- sort_key = SpanAnnotationColumn.createdAt.value
198
- sort_descending = True
197
+ sort_key = SpanAnnotationColumn.name.value
198
+ sort_descending = False
199
199
  if sort:
200
200
  sort_key = sort.col.value
201
201
  sort_descending = sort.dir is SortDir.desc
phoenix/server/app.py CHANGED
@@ -2,7 +2,6 @@ import asyncio
2
2
  import contextlib
3
3
  import json
4
4
  import logging
5
- from datetime import datetime
6
5
  from functools import cached_property
7
6
  from pathlib import Path
8
7
  from typing import (
@@ -87,9 +86,16 @@ from phoenix.server.api.dataloaders import (
87
86
  from phoenix.server.api.routers.v1 import REST_API_VERSION
88
87
  from phoenix.server.api.routers.v1 import router as v1_router
89
88
  from phoenix.server.api.schema import schema
89
+ from phoenix.server.dml_event import DmlEvent
90
+ from phoenix.server.dml_event_handler import DmlEventHandler
90
91
  from phoenix.server.grpc_server import GrpcServer
91
92
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
92
- from phoenix.server.types import DbSessionFactory
93
+ from phoenix.server.types import (
94
+ CanGetLastUpdatedAt,
95
+ CanPutItem,
96
+ DbSessionFactory,
97
+ LastUpdatedAt,
98
+ )
93
99
  from phoenix.trace.schemas import Span
94
100
  from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
95
101
 
@@ -220,6 +226,7 @@ def _lifespan(
220
226
  *,
221
227
  dialect: SupportedSQLDialect,
222
228
  bulk_inserter: BulkInserter,
229
+ dml_event_handler: DmlEventHandler,
223
230
  tracer_provider: Optional["TracerProvider"] = None,
224
231
  enable_prometheus: bool = False,
225
232
  clean_ups: Iterable[Callable[[], None]] = (),
@@ -239,8 +246,9 @@ def _lifespan(
239
246
  disabled=read_only,
240
247
  tracer_provider=tracer_provider,
241
248
  enable_prometheus=enable_prometheus,
242
- ):
249
+ ), dml_event_handler:
243
250
  yield {
251
+ "event_queue": dml_event_handler,
244
252
  "enqueue": enqueue,
245
253
  "queue_span_for_bulk_insert": queue_span,
246
254
  "queue_evaluation_for_bulk_insert": queue_evaluation,
@@ -263,9 +271,10 @@ def create_graphql_router(
263
271
  db: DbSessionFactory,
264
272
  model: Model,
265
273
  export_path: Path,
274
+ last_updated_at: CanGetLastUpdatedAt,
266
275
  corpus: Optional[Model] = None,
267
- streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None,
268
276
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
277
+ event_queue: CanPutItem[DmlEvent],
269
278
  read_only: bool = False,
270
279
  ) -> GraphQLRouter: # type: ignore[type-arg]
271
280
  def get_context() -> Context:
@@ -274,7 +283,8 @@ def create_graphql_router(
274
283
  model=model,
275
284
  corpus=corpus,
276
285
  export_path=export_path,
277
- streaming_last_updated_at=streaming_last_updated_at,
286
+ last_updated_at=last_updated_at,
287
+ event_queue=event_queue,
278
288
  data_loaders=DataLoaders(
279
289
  average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
280
290
  dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
@@ -420,11 +430,16 @@ def create_app(
420
430
  cache_for_dataloaders = (
421
431
  CacheForDataLoaders() if db.dialect is SupportedSQLDialect.SQLITE else None
422
432
  )
423
-
433
+ last_updated_at = LastUpdatedAt()
434
+ dml_event_handler = DmlEventHandler(
435
+ db=db,
436
+ cache_for_dataloaders=cache_for_dataloaders,
437
+ last_updated_at=last_updated_at,
438
+ )
424
439
  bulk_inserter = BulkInserter(
425
440
  db,
426
441
  enable_prometheus=enable_prometheus,
427
- cache_for_dataloaders=cache_for_dataloaders,
442
+ event_queue=dml_event_handler,
428
443
  initial_batch_of_spans=initial_batch_of_spans,
429
444
  initial_batch_of_evaluations=initial_batch_of_evaluations,
430
445
  )
@@ -460,7 +475,8 @@ def create_app(
460
475
  model=model,
461
476
  corpus=corpus,
462
477
  export_path=export_path,
463
- streaming_last_updated_at=bulk_inserter.last_updated_at,
478
+ last_updated_at=last_updated_at,
479
+ event_queue=dml_event_handler,
464
480
  cache_for_dataloaders=cache_for_dataloaders,
465
481
  read_only=read_only,
466
482
  )
@@ -477,6 +493,7 @@ def create_app(
477
493
  dialect=db.dialect,
478
494
  read_only=read_only,
479
495
  bulk_inserter=bulk_inserter,
496
+ dml_event_handler=dml_event_handler,
480
497
  tracer_provider=tracer_provider,
481
498
  enable_prometheus=enable_prometheus,
482
499
  clean_ups=clean_ups,
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from dataclasses import dataclass, field
5
+ from typing import ClassVar, Tuple, Type
6
+
7
+ from phoenix.db import models
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class DmlEvent(ABC):
12
+ """
13
+ Event corresponding to a Data Manipulation Language (DML)
14
+ operation, e.g. insertion, update, or deletion.
15
+ """
16
+
17
+ table: ClassVar[Type[models.Base]]
18
+ ids: Tuple[int, ...] = field(default_factory=tuple)
19
+
20
+ def __bool__(self) -> bool:
21
+ return bool(self.ids)
22
+
23
+ def __hash__(self) -> int:
24
+ return id(self)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ProjectDmlEvent(DmlEvent):
29
+ table = models.Project
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class ProjectDeleteEvent(ProjectDmlEvent): ...
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class SpanDmlEvent(ProjectDmlEvent): ...
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class SpanInsertEvent(SpanDmlEvent): ...
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class SpanDeleteEvent(SpanDmlEvent): ...
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class DatasetDmlEvent(DmlEvent):
50
+ table = models.Dataset
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class DatasetInsertEvent(DatasetDmlEvent): ...
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class DatasetDeleteEvent(DatasetDmlEvent): ...
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class ExperimentDmlEvent(DmlEvent):
63
+ table = models.Experiment
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class ExperimentInsertEvent(ExperimentDmlEvent): ...
68
+
69
+
70
+ @dataclass(frozen=True)
71
+ class ExperimentDeleteEvent(ExperimentDmlEvent): ...
72
+
73
+
74
+ @dataclass(frozen=True)
75
+ class ExperimentRunDmlEvent(DmlEvent):
76
+ table = models.ExperimentRun
77
+
78
+
79
+ @dataclass(frozen=True)
80
+ class ExperimentRunInsertEvent(ExperimentRunDmlEvent): ...
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class ExperimentRunDeleteEvent(ExperimentRunDmlEvent): ...
85
+
86
+
87
+ @dataclass(frozen=True)
88
+ class ExperimentRunAnnotationDmlEvent(DmlEvent):
89
+ table = models.ExperimentRunAnnotation
90
+
91
+
92
+ @dataclass(frozen=True)
93
+ class ExperimentRunAnnotationInsertEvent(ExperimentRunAnnotationDmlEvent): ...
94
+
95
+
96
+ @dataclass(frozen=True)
97
+ class ExperimentRunAnnotationDeleteEvent(ExperimentRunAnnotationDmlEvent): ...
98
+
99
+
100
+ @dataclass(frozen=True)
101
+ class SpanAnnotationDmlEvent(DmlEvent):
102
+ table = models.SpanAnnotation
103
+
104
+
105
+ @dataclass(frozen=True)
106
+ class SpanAnnotationInsertEvent(SpanAnnotationDmlEvent): ...
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class SpanAnnotationDeleteEvent(SpanAnnotationDmlEvent): ...
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class TraceAnnotationDmlEvent(DmlEvent):
115
+ table = models.TraceAnnotation
116
+
117
+
118
+ @dataclass(frozen=True)
119
+ class TraceAnnotationInsertEvent(TraceAnnotationDmlEvent): ...
120
+
121
+
122
+ @dataclass(frozen=True)
123
+ class TraceAnnotationDeleteEvent(TraceAnnotationDmlEvent): ...
124
+
125
+
126
+ @dataclass(frozen=True)
127
+ class DocumentAnnotationDmlEvent(DmlEvent):
128
+ table = models.DocumentAnnotation
129
+
130
+
131
+ @dataclass(frozen=True)
132
+ class DocumentAnnotationInsertEvent(DocumentAnnotationDmlEvent): ...
133
+
134
+
135
+ @dataclass(frozen=True)
136
+ class DocumentAnnotationDeleteEvent(DocumentAnnotationDmlEvent): ...
@@ -0,0 +1,272 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from asyncio import gather
5
+ from inspect import getmro
6
+ from itertools import chain
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Generic,
11
+ Iterable,
12
+ Iterator,
13
+ Mapping,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ Type,
18
+ TypedDict,
19
+ TypeVar,
20
+ Union,
21
+ cast,
22
+ )
23
+
24
+ from sqlalchemy import Select, select
25
+ from typing_extensions import TypeAlias, Unpack
26
+
27
+ from phoenix.db.models import (
28
+ Base,
29
+ DocumentAnnotation,
30
+ Project,
31
+ Span,
32
+ SpanAnnotation,
33
+ Trace,
34
+ TraceAnnotation,
35
+ )
36
+ from phoenix.server.api.dataloaders import CacheForDataLoaders
37
+ from phoenix.server.dml_event import (
38
+ DmlEvent,
39
+ DocumentAnnotationDmlEvent,
40
+ SpanAnnotationDmlEvent,
41
+ SpanDeleteEvent,
42
+ SpanDmlEvent,
43
+ TraceAnnotationDmlEvent,
44
+ )
45
+ from phoenix.server.types import (
46
+ BatchedCaller,
47
+ CanSetLastUpdatedAt,
48
+ DbSessionFactory,
49
+ )
50
+
51
+ _DmlEventT = TypeVar("_DmlEventT", bound=DmlEvent)
52
+
53
+
54
+ class _DmlEventQueue(Generic[_DmlEventT]):
55
+ def __init__(self, **kwargs: Any) -> None:
56
+ super().__init__(**kwargs)
57
+ self._events: Set[_DmlEventT] = set()
58
+
59
+ @property
60
+ def empty(self) -> bool:
61
+ return not self._events
62
+
63
+ def put(self, event: _DmlEventT) -> None:
64
+ self._events.add(event)
65
+
66
+ def clear(self) -> None:
67
+ self._events.clear()
68
+
69
+ def __iter__(self) -> Iterator[_DmlEventT]:
70
+ yield from self._events
71
+
72
+
73
+ class _HandlerParams(TypedDict):
74
+ db: DbSessionFactory
75
+ last_updated_at: CanSetLastUpdatedAt
76
+ cache_for_dataloaders: Optional[CacheForDataLoaders]
77
+ sleep_seconds: float
78
+
79
+
80
+ class _HasLastUpdatedAt(ABC):
81
+ def __init__(
82
+ self,
83
+ last_updated_at: CanSetLastUpdatedAt,
84
+ **kwargs: Any,
85
+ ) -> None:
86
+ super().__init__(**kwargs)
87
+ self._last_updated_at = last_updated_at
88
+
89
+
90
+ class _HasCacheForDataLoaders(ABC):
91
+ def __init__(
92
+ self,
93
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
94
+ **kwargs: Any,
95
+ ) -> None:
96
+ super().__init__(**kwargs)
97
+ self._cache_for_dataloaders = cache_for_dataloaders
98
+
99
+
100
+ class _DmlEventHandler(
101
+ _HasLastUpdatedAt,
102
+ _HasCacheForDataLoaders,
103
+ BatchedCaller[_DmlEventT],
104
+ Generic[_DmlEventT],
105
+ ABC,
106
+ ):
107
+ _batch_factory = cast(Callable[[], _DmlEventQueue[_DmlEventT]], _DmlEventQueue)
108
+
109
+ def __init__(self, *, db: DbSessionFactory, **kwargs: Any) -> None:
110
+ super().__init__(**kwargs)
111
+ self._db = db
112
+
113
+ def __hash__(self) -> int:
114
+ return id(self)
115
+
116
+
117
+ class _GenericDmlEventHandler(_DmlEventHandler[DmlEvent]):
118
+ async def __call__(self) -> None:
119
+ for e in self._batch:
120
+ for id_ in e.ids:
121
+ self._update(e.table, id_)
122
+
123
+ def _update(self, table: Type[Base], id_: int) -> None:
124
+ self._last_updated_at.set(table, id_)
125
+
126
+
127
+ class _SpanDmlEventHandler(_DmlEventHandler[SpanDmlEvent]):
128
+ async def __call__(self) -> None:
129
+ if cache := self._cache_for_dataloaders:
130
+ for id_ in set(chain.from_iterable(e.ids for e in self._batch)):
131
+ self._clear(cache, id_)
132
+
133
+ @staticmethod
134
+ def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
135
+ cache.latency_ms_quantile.invalidate(project_id)
136
+ cache.token_count.invalidate(project_id)
137
+ cache.record_count.invalidate(project_id)
138
+ cache.min_start_or_max_end_time.invalidate(project_id)
139
+
140
+
141
+ class _SpanDeleteEventHandler(_SpanDmlEventHandler):
142
+ @staticmethod
143
+ def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
144
+ cache.annotation_summary.invalidate_project(project_id)
145
+ cache.evaluation_summary.invalidate_project(project_id)
146
+ cache.document_evaluation_summary.invalidate_project(project_id)
147
+
148
+
149
+ _AnnotationTable: TypeAlias = Union[
150
+ Type[SpanAnnotation],
151
+ Type[TraceAnnotation],
152
+ Type[DocumentAnnotation],
153
+ ]
154
+
155
+ _AnnotationDmlEventT = TypeVar(
156
+ "_AnnotationDmlEventT",
157
+ SpanAnnotationDmlEvent,
158
+ TraceAnnotationDmlEvent,
159
+ DocumentAnnotationDmlEvent,
160
+ )
161
+
162
+
163
+ class _AnnotationDmlEventHandler(
164
+ _DmlEventHandler[_AnnotationDmlEventT],
165
+ Generic[_AnnotationDmlEventT],
166
+ ABC,
167
+ ):
168
+ _table: _AnnotationTable
169
+ _base_stmt: Union[Select[Tuple[int, str]], Select[Tuple[int]]] = (
170
+ select(Project.id).join_from(Project, Trace).distinct()
171
+ )
172
+
173
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
174
+ super().__init__(**kwargs)
175
+ self._stmt = self._base_stmt
176
+ if self._cache_for_dataloaders:
177
+ self._stmt = self._stmt.add_columns(self._table.name)
178
+
179
+ def _get_stmt(self) -> Union[Select[Tuple[int, str]], Select[Tuple[int]]]:
180
+ ids = set(chain.from_iterable(e.ids for e in self._batch))
181
+ return self._stmt.where(self._table.id.in_(ids))
182
+
183
+ @staticmethod
184
+ @abstractmethod
185
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None: ...
186
+
187
+ async def __call__(self) -> None:
188
+ async with self._db() as session:
189
+ async for row in await session.stream(self._get_stmt()):
190
+ if cache := self._cache_for_dataloaders:
191
+ self._clear(cache, row.id, row.name)
192
+
193
+
194
+ class _SpanAnnotationDmlEventHandler(_AnnotationDmlEventHandler[SpanAnnotationDmlEvent]):
195
+ _table = SpanAnnotation
196
+
197
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
198
+ super().__init__(**kwargs)
199
+ self._stmt = self._stmt.join_from(Trace, Span).join_from(Span, self._table)
200
+
201
+ @staticmethod
202
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
203
+ cache.annotation_summary.invalidate((project_id, name, "span"))
204
+ cache.evaluation_summary.invalidate((project_id, name, "span"))
205
+
206
+
207
+ class _TraceAnnotationDmlEventHandler(_AnnotationDmlEventHandler[TraceAnnotationDmlEvent]):
208
+ _table = TraceAnnotation
209
+
210
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
211
+ super().__init__(**kwargs)
212
+ self._stmt = self._stmt.join_from(Trace, self._table)
213
+
214
+ @staticmethod
215
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
216
+ cache.annotation_summary.invalidate((project_id, name, "trace"))
217
+ cache.evaluation_summary.invalidate((project_id, name, "trace"))
218
+
219
+
220
+ class _DocumentAnnotationDmlEventHandler(_AnnotationDmlEventHandler[DocumentAnnotationDmlEvent]):
221
+ _table = DocumentAnnotation
222
+
223
+ def __init__(self, **kwargs: Unpack[_HandlerParams]) -> None:
224
+ super().__init__(**kwargs)
225
+ self._stmt = self._stmt.join_from(Trace, Span).join_from(Span, self._table)
226
+
227
+ @staticmethod
228
+ def _clear(cache: CacheForDataLoaders, project_id: int, name: str) -> None:
229
+ cache.document_evaluation_summary.invalidate((project_id, name))
230
+
231
+
232
+ class DmlEventHandler:
233
+ def __init__(
234
+ self,
235
+ *,
236
+ db: DbSessionFactory,
237
+ last_updated_at: CanSetLastUpdatedAt,
238
+ cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
239
+ sleep_seconds: float = 0.1,
240
+ ) -> None:
241
+ kwargs = _HandlerParams(
242
+ db=db,
243
+ last_updated_at=last_updated_at,
244
+ cache_for_dataloaders=cache_for_dataloaders,
245
+ sleep_seconds=sleep_seconds,
246
+ )
247
+ self._handlers: Mapping[Type[DmlEvent], Iterable[_DmlEventHandler[Any]]] = {
248
+ DmlEvent: [_GenericDmlEventHandler(**kwargs)],
249
+ SpanDmlEvent: [_SpanDmlEventHandler(**kwargs)],
250
+ SpanDeleteEvent: [_SpanDeleteEventHandler(**kwargs)],
251
+ SpanAnnotationDmlEvent: [_SpanAnnotationDmlEventHandler(**kwargs)],
252
+ TraceAnnotationDmlEvent: [_TraceAnnotationDmlEventHandler(**kwargs)],
253
+ DocumentAnnotationDmlEvent: [_DocumentAnnotationDmlEventHandler(**kwargs)],
254
+ }
255
+ self._all_handlers = frozenset(chain.from_iterable(self._handlers.values()))
256
+
257
+ async def __aenter__(self) -> None:
258
+ await gather(*(h.start() for h in self._all_handlers))
259
+
260
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
261
+ await gather(*(h.stop() for h in self._all_handlers))
262
+
263
+ def put(self, event: DmlEvent) -> None:
264
+ if not (isinstance(event, DmlEvent) and event):
265
+ return
266
+ for cls in getmro(type(event)):
267
+ if not (issubclass(cls, DmlEvent) and (handlers := self._handlers.get(cls))):
268
+ continue
269
+ for h in handlers:
270
+ h.put(event)
271
+ if cls is DmlEvent:
272
+ break
phoenix/server/types.py CHANGED
@@ -1,10 +1,36 @@
1
- from typing import AsyncContextManager, Callable
1
+ from abc import ABC, abstractmethod
2
+ from asyncio import Task, create_task, sleep
3
+ from collections import defaultdict
4
+ from datetime import datetime, timezone
5
+ from typing import (
6
+ Any,
7
+ AsyncContextManager,
8
+ Callable,
9
+ DefaultDict,
10
+ Generic,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Protocol,
15
+ Type,
16
+ TypeVar,
17
+ )
2
18
 
19
+ from cachetools import LRUCache
3
20
  from sqlalchemy.ext.asyncio import AsyncSession
4
21
 
22
+ from phoenix.db import models
5
23
  from phoenix.db.helpers import SupportedSQLDialect
6
24
 
7
25
 
26
+ class CanSetLastUpdatedAt(Protocol):
27
+ def set(self, table: Type[models.Base], id_: int) -> None: ...
28
+
29
+
30
+ class CanGetLastUpdatedAt(Protocol):
31
+ def get(self, table: Type[models.Base], id_: Optional[int] = None) -> Optional[datetime]: ...
32
+
33
+
8
34
  class DbSessionFactory:
9
35
  def __init__(
10
36
  self,
@@ -16,3 +42,82 @@ class DbSessionFactory:
16
42
 
17
43
  def __call__(self) -> AsyncContextManager[AsyncSession]:
18
44
  return self._db()
45
+
46
+
47
+ _AnyT = TypeVar("_AnyT")
48
+ _ItemT_contra = TypeVar("_ItemT_contra", contravariant=True)
49
+
50
+
51
+ class CanPutItem(Protocol[_ItemT_contra]):
52
+ def put(self, item: _ItemT_contra) -> None: ...
53
+
54
+
55
+ class _Batch(CanPutItem[_AnyT], Protocol[_AnyT]):
56
+ @property
57
+ def empty(self) -> bool: ...
58
+ def clear(self) -> None: ...
59
+ def __iter__(self) -> Iterator[_AnyT]: ...
60
+
61
+
62
+ class _HasBatch(Generic[_ItemT_contra], ABC):
63
+ _batch_factory: Callable[[], _Batch[_ItemT_contra]]
64
+
65
+ def __init__(self) -> None:
66
+ self._batch = self._batch_factory()
67
+
68
+ def put(self, item: _ItemT_contra) -> None:
69
+ self._batch.put(item)
70
+
71
+
72
+ class BatchedCaller(_HasBatch[_AnyT], Generic[_AnyT], ABC):
73
+ def __init__(self, *, sleep_seconds: float = 0.1, **kwargs: Any) -> None:
74
+ assert sleep_seconds > 0
75
+ super().__init__(**kwargs)
76
+ self._running = False
77
+ self._seconds = sleep_seconds
78
+ self._tasks: List[Task[None]] = []
79
+
80
+ async def start(self) -> None:
81
+ self._running = True
82
+ if not self._tasks:
83
+ self._tasks.append(create_task(self._run()))
84
+
85
+ async def stop(self) -> None:
86
+ self._running = False
87
+ for task in reversed(self._tasks):
88
+ if not task.done():
89
+ task.cancel()
90
+ self._tasks.clear()
91
+
92
+ @abstractmethod
93
+ async def __call__(self) -> None: ...
94
+
95
+ async def _run(self) -> None:
96
+ while self._running:
97
+ self._tasks.append(create_task(sleep(self._seconds)))
98
+ await self._tasks[-1]
99
+ self._tasks.pop()
100
+ if self._batch.empty:
101
+ continue
102
+ self._tasks.append(create_task(self()))
103
+ await self._tasks[-1]
104
+ self._tasks.pop()
105
+ self._batch.clear()
106
+
107
+
108
+ class LastUpdatedAt:
109
+ def __init__(self) -> None:
110
+ self._cache: DefaultDict[
111
+ Type[models.Base],
112
+ LRUCache[int, datetime],
113
+ ] = defaultdict(lambda: LRUCache(maxsize=100))
114
+
115
+ def get(self, table: Type[models.Base], id_: Optional[int] = None) -> Optional[datetime]:
116
+ if not (cache := self._cache.get(table)):
117
+ return None
118
+ if id_ is None:
119
+ return max(filter(bool, cache.values()), default=None)
120
+ return cache.get(id_)
121
+
122
+ def set(self, table: Type[models.Base], id_: int) -> None:
123
+ self._cache[table][id_] = datetime.now(timezone.utc)
phoenix/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "4.20.0"
1
+ __version__ = "4.20.1"