arize-phoenix 4.20.0__py3-none-any.whl → 4.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (46) hide show
  1. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.2.dist-info}/METADATA +2 -1
  2. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.2.dist-info}/RECORD +45 -43
  3. phoenix/db/bulk_inserter.py +24 -98
  4. phoenix/db/insertion/document_annotation.py +13 -0
  5. phoenix/db/insertion/span.py +9 -0
  6. phoenix/db/insertion/span_annotation.py +13 -0
  7. phoenix/db/insertion/trace_annotation.py +13 -0
  8. phoenix/db/insertion/types.py +34 -28
  9. phoenix/db/migrations/versions/10460e46d750_datasets.py +28 -2
  10. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +134 -0
  11. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  12. phoenix/db/models.py +9 -1
  13. phoenix/server/api/context.py +8 -6
  14. phoenix/server/api/dataloaders/__init__.py +0 -47
  15. phoenix/server/api/dataloaders/token_counts.py +2 -7
  16. phoenix/server/api/input_types/SpanSort.py +3 -8
  17. phoenix/server/api/mutations/dataset_mutations.py +9 -3
  18. phoenix/server/api/mutations/experiment_mutations.py +2 -0
  19. phoenix/server/api/mutations/project_mutations.py +5 -5
  20. phoenix/server/api/mutations/span_annotations_mutations.py +10 -2
  21. phoenix/server/api/mutations/trace_annotations_mutations.py +10 -2
  22. phoenix/server/api/queries.py +9 -0
  23. phoenix/server/api/routers/v1/datasets.py +2 -0
  24. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -0
  25. phoenix/server/api/routers/v1/experiment_runs.py +2 -0
  26. phoenix/server/api/routers/v1/experiments.py +2 -0
  27. phoenix/server/api/routers/v1/spans.py +12 -8
  28. phoenix/server/api/routers/v1/traces.py +12 -10
  29. phoenix/server/api/types/Dataset.py +6 -1
  30. phoenix/server/api/types/Experiment.py +6 -1
  31. phoenix/server/api/types/Project.py +4 -1
  32. phoenix/server/api/types/Span.py +5 -17
  33. phoenix/server/app.py +25 -8
  34. phoenix/server/dml_event.py +136 -0
  35. phoenix/server/dml_event_handler.py +272 -0
  36. phoenix/server/static/.vite/manifest.json +14 -14
  37. phoenix/server/static/assets/{components-CAummAJx.js → components-BSw2e1Zr.js} +108 -100
  38. phoenix/server/static/assets/{index-Cg5hdf3g.js → index-BYUFcdtx.js} +1 -1
  39. phoenix/server/static/assets/{pages-BU__X1UX.js → pages-p_fuED5k.js} +251 -237
  40. phoenix/server/static/assets/{vendor-arizeai-CkyzG9Wl.js → vendor-arizeai-CIETbKDq.js} +28 -28
  41. phoenix/server/types.py +106 -1
  42. phoenix/version.py +1 -1
  43. phoenix/db/migrations/types.py +0 -29
  44. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.2.dist-info}/WHEEL +0 -0
  45. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.2.dist-info}/licenses/IP_NOTICE +0 -0
  46. {arize_phoenix-4.20.0.dist-info → arize_phoenix-4.20.2.dist-info}/licenses/LICENSE +0 -0
@@ -21,11 +21,12 @@ from typing import (
21
21
  )
22
22
 
23
23
  from sqlalchemy.ext.asyncio import AsyncSession
24
- from sqlalchemy.sql.dml import ReturningInsert
24
+ from sqlalchemy.sql.dml import Insert
25
25
 
26
26
  from phoenix.db import models
27
27
  from phoenix.db.insertion.constants import DEFAULT_RETRY_ALLOWANCE, DEFAULT_RETRY_DELAY_SEC
28
- from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
28
+ from phoenix.db.insertion.helpers import insert_on_conflict
29
+ from phoenix.server.dml_event import DmlEvent
29
30
  from phoenix.server.types import DbSessionFactory
30
31
 
31
32
  logger = logging.getLogger("__name__")
@@ -40,6 +41,7 @@ _AnyT = TypeVar("_AnyT")
40
41
  _PrecursorT = TypeVar("_PrecursorT")
41
42
  _InsertableT = TypeVar("_InsertableT", bound=Insertable)
42
43
  _RowT = TypeVar("_RowT", bound=models.Base)
44
+ _DmlEventT = TypeVar("_DmlEventT", bound=DmlEvent)
43
45
 
44
46
 
45
47
  @dataclass(frozen=True)
@@ -56,7 +58,7 @@ class Postponed(Received[_AnyT]):
56
58
  retries_left: int = field(default=DEFAULT_RETRY_ALLOWANCE)
57
59
 
58
60
 
59
- class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT]):
61
+ class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT, _DmlEventT]):
60
62
  table: Type[_RowT]
61
63
  unique_by: Sequence[str]
62
64
 
@@ -97,59 +99,63 @@ class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT]):
97
99
  List[Received[_PrecursorT]],
98
100
  ]: ...
99
101
 
100
- async def insert(self) -> Tuple[Type[_RowT], List[int]]:
102
+ async def insert(self) -> Optional[List[_DmlEventT]]:
101
103
  if not self._queue:
102
- return self.table, []
103
- parcels = self._queue
104
- self._queue = []
105
- inserted_ids: List[int] = []
104
+ return None
105
+ self._queue, parcels = [], self._queue
106
+ events: List[_DmlEventT] = []
106
107
  async with self._db() as session:
107
108
  to_insert, to_postpone, _ = await self._partition(session, *parcels)
108
109
  if to_insert:
109
- inserted_ids, to_retry, _ = await self._insert(session, *to_insert)
110
- to_postpone.extend(to_retry)
110
+ events, to_retry, _ = await self._insert(session, *to_insert)
111
+ if to_retry:
112
+ to_postpone.extend(to_retry)
111
113
  if to_postpone:
112
114
  loop = asyncio.get_running_loop()
113
115
  loop.call_later(self._retry_delay_sec, self._queue.extend, to_postpone)
114
- return self.table, inserted_ids
116
+ return events
115
117
 
116
- def _stmt(self, *records: Mapping[str, Any]) -> ReturningInsert[Tuple[int]]:
117
- pk = next(c for c in self.table.__table__.c if c.primary_key)
118
+ def _insert_on_conflict(self, *records: Mapping[str, Any]) -> Insert:
118
119
  return insert_on_conflict(
119
120
  *records,
120
121
  table=self.table,
121
122
  unique_by=self.unique_by,
122
123
  dialect=self._db.dialect,
123
- ).returning(pk)
124
+ )
125
+
126
+ @abstractmethod
127
+ async def _events(
128
+ self,
129
+ session: AsyncSession,
130
+ *insertions: _InsertableT,
131
+ ) -> List[_DmlEventT]: ...
124
132
 
125
133
  async def _insert(
126
134
  self,
127
135
  session: AsyncSession,
128
- *insertions: Received[_InsertableT],
129
- ) -> Tuple[List[int], List[Postponed[_PrecursorT]], List[Received[_InsertableT]]]:
130
- records = [dict(as_kv(ins.item.row)) for ins in insertions]
131
- inserted_ids: List[int] = []
136
+ *parcels: Received[_InsertableT],
137
+ ) -> Tuple[
138
+ List[_DmlEventT],
139
+ List[Postponed[_PrecursorT]],
140
+ List[Received[_InsertableT]],
141
+ ]:
132
142
  to_retry: List[Postponed[_PrecursorT]] = []
133
143
  failures: List[Received[_InsertableT]] = []
134
- stmt = self._stmt(*records)
144
+ events: List[_DmlEventT] = []
135
145
  try:
136
146
  async with session.begin_nested():
137
- ids = [id_ async for id_ in await session.stream_scalars(stmt)]
138
- inserted_ids.extend(ids)
147
+ events.extend(await self._events(session, *(p.item for p in parcels)))
139
148
  except BaseException:
140
149
  logger.exception(
141
150
  f"Failed to bulk insert for {self.table.__name__}. "
142
- f"Will try to insert ({len(records)} records) individually instead."
151
+ f"Will try to insert ({len(parcels)} records) individually instead."
143
152
  )
144
- for i, record in enumerate(records):
145
- stmt = self._stmt(record)
153
+ for p in parcels:
146
154
  try:
147
155
  async with session.begin_nested():
148
- ids = [id_ async for id_ in await session.stream_scalars(stmt)]
149
- inserted_ids.extend(ids)
156
+ events.extend(await self._events(session, p.item))
150
157
  except BaseException:
151
158
  logger.exception(f"Failed to insert for {self.table.__name__}.")
152
- p = insertions[i]
153
159
  if isinstance(p, Postponed) and p.retries_left == 1:
154
160
  failures.append(p)
155
161
  else:
@@ -162,7 +168,7 @@ class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT]):
162
168
  else self._retry_allowance,
163
169
  )
164
170
  )
165
- return inserted_ids, to_retry, failures
171
+ return events, to_retry, failures
166
172
 
167
173
 
168
174
  class Precursors(ABC):
@@ -6,11 +6,37 @@ Create Date: 2024-05-10 11:24:23.985834
6
6
 
7
7
  """
8
8
 
9
- from typing import Sequence, Union
9
+ from typing import Any, Sequence, Union
10
10
 
11
11
  import sqlalchemy as sa
12
12
  from alembic import op
13
- from phoenix.db.migrations.types import JSON_
13
+ from sqlalchemy import JSON
14
+ from sqlalchemy.dialects import postgresql
15
+ from sqlalchemy.ext.compiler import compiles
16
+
17
+
18
+ class JSONB(JSON):
19
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
20
+ __visit_name__ = "JSONB"
21
+
22
+
23
+ @compiles(JSONB, "sqlite") # type: ignore
24
+ def _(*args: Any, **kwargs: Any) -> str:
25
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
26
+ return "JSONB"
27
+
28
+
29
+ JSON_ = (
30
+ JSON()
31
+ .with_variant(
32
+ postgresql.JSONB(), # type: ignore
33
+ "postgresql",
34
+ )
35
+ .with_variant(
36
+ JSONB(),
37
+ "sqlite",
38
+ )
39
+ )
14
40
 
15
41
  # revision identifiers, used by Alembic.
16
42
  revision: str = "10460e46d750"
@@ -0,0 +1,134 @@
1
+ """add token columns to spans table
2
+
3
+ Revision ID: 3be8647b87d8
4
+ Revises: 10460e46d750
5
+ Create Date: 2024-08-03 22:11:28.733133
6
+
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+ from openinference.semconv.trace import SpanAttributes
14
+ from sqlalchemy import (
15
+ JSON,
16
+ Dialect,
17
+ MetaData,
18
+ TypeDecorator,
19
+ update,
20
+ )
21
+ from sqlalchemy.dialects import postgresql
22
+ from sqlalchemy.ext.asyncio.engine import AsyncConnection
23
+ from sqlalchemy.ext.compiler import compiles
24
+ from sqlalchemy.orm import (
25
+ DeclarativeBase,
26
+ Mapped,
27
+ mapped_column,
28
+ )
29
+
30
+
31
+ class JSONB(JSON):
32
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
33
+ __visit_name__ = "JSONB"
34
+
35
+
36
+ @compiles(JSONB, "sqlite") # type: ignore
37
+ def _(*args: Any, **kwargs: Any) -> str:
38
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
39
+ return "JSONB"
40
+
41
+
42
+ JSON_ = (
43
+ JSON()
44
+ .with_variant(
45
+ postgresql.JSONB(), # type: ignore
46
+ "postgresql",
47
+ )
48
+ .with_variant(
49
+ JSONB(),
50
+ "sqlite",
51
+ )
52
+ )
53
+
54
+
55
+ class JsonDict(TypeDecorator[Dict[str, Any]]):
56
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
57
+ cache_ok = True
58
+ impl = JSON_
59
+
60
+ def process_bind_param(self, value: Optional[Dict[str, Any]], _: Dialect) -> Dict[str, Any]:
61
+ return value if isinstance(value, dict) else {}
62
+
63
+
64
+ class JsonList(TypeDecorator[List[Any]]):
65
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
66
+ cache_ok = True
67
+ impl = JSON_
68
+
69
+ def process_bind_param(self, value: Optional[List[Any]], _: Dialect) -> List[Any]:
70
+ return value if isinstance(value, list) else []
71
+
72
+
73
+ class ExperimentRunOutput(TypedDict, total=False):
74
+ task_output: Any
75
+
76
+
77
+ class Base(DeclarativeBase):
78
+ # Enforce best practices for naming constraints
79
+ # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate
80
+ metadata = MetaData(
81
+ naming_convention={
82
+ "ix": "ix_%(table_name)s_%(column_0_N_name)s",
83
+ "uq": "uq_%(table_name)s_%(column_0_N_name)s",
84
+ "ck": "ck_%(table_name)s_`%(constraint_name)s`",
85
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
86
+ "pk": "pk_%(table_name)s",
87
+ }
88
+ )
89
+ type_annotation_map = {
90
+ Dict[str, Any]: JsonDict,
91
+ List[Dict[str, Any]]: JsonList,
92
+ ExperimentRunOutput: JsonDict,
93
+ }
94
+
95
+
96
+ class Span(Base):
97
+ __tablename__ = "spans"
98
+ id: Mapped[int] = mapped_column(primary_key=True)
99
+ attributes: Mapped[Dict[str, Any]]
100
+ llm_token_count_prompt: Mapped[Optional[int]]
101
+ llm_token_count_completion: Mapped[Optional[int]]
102
+
103
+
104
+ # revision identifiers, used by Alembic.
105
+ revision: str = "3be8647b87d8"
106
+ down_revision: Union[str, None] = "10460e46d750"
107
+ branch_labels: Union[str, Sequence[str], None] = None
108
+ depends_on: Union[str, Sequence[str], None] = None
109
+
110
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".")
111
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".")
112
+
113
+
114
+ async def get_token_counts_from_attributes(connection: AsyncConnection) -> None:
115
+ """
116
+ Gets token counts from attributes if present.
117
+ """
118
+ await connection.execute(
119
+ update(Span).values(
120
+ llm_token_count_prompt=Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float(),
121
+ llm_token_count_completion=Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float(),
122
+ )
123
+ )
124
+
125
+
126
+ def upgrade() -> None:
127
+ op.add_column("spans", sa.Column("llm_token_count_prompt", sa.Integer, nullable=True))
128
+ op.add_column("spans", sa.Column("llm_token_count_completion", sa.Integer, nullable=True))
129
+ op.run_async(get_token_counts_from_attributes)
130
+
131
+
132
+ def downgrade() -> None:
133
+ op.drop_column("spans", "llm_token_count_completion")
134
+ op.drop_column("spans", "llm_token_count_prompt")
@@ -6,11 +6,37 @@ Create Date: 2024-04-03 19:41:48.871555
6
6
 
7
7
  """
8
8
 
9
- from typing import Sequence, Union
9
+ from typing import Any, Sequence, Union
10
10
 
11
11
  import sqlalchemy as sa
12
12
  from alembic import op
13
- from phoenix.db.migrations.types import JSON_
13
+ from sqlalchemy import JSON
14
+ from sqlalchemy.dialects import postgresql
15
+ from sqlalchemy.ext.compiler import compiles
16
+
17
+
18
+ class JSONB(JSON):
19
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
20
+ __visit_name__ = "JSONB"
21
+
22
+
23
+ @compiles(JSONB, "sqlite") # type: ignore
24
+ def _(*args: Any, **kwargs: Any) -> str:
25
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
26
+ return "JSONB"
27
+
28
+
29
+ JSON_ = (
30
+ JSON()
31
+ .with_variant(
32
+ postgresql.JSONB(), # type: ignore
33
+ "postgresql",
34
+ )
35
+ .with_variant(
36
+ JSONB(),
37
+ "sqlite",
38
+ )
39
+ )
14
40
 
15
41
  # revision identifiers, used by Alembic.
16
42
  revision: str = "cf03bd6bae1d"
phoenix/db/models.py CHANGED
@@ -197,7 +197,7 @@ class Span(Base):
197
197
  ForeignKey("traces.id", ondelete="CASCADE"),
198
198
  index=True,
199
199
  )
200
- span_id: Mapped[str]
200
+ span_id: Mapped[str] = mapped_column(index=True)
201
201
  parent_id: Mapped[Optional[str]] = mapped_column(index=True)
202
202
  name: Mapped[str]
203
203
  span_kind: Mapped[str]
@@ -214,6 +214,8 @@ class Span(Base):
214
214
  cumulative_error_count: Mapped[int]
215
215
  cumulative_llm_token_count_prompt: Mapped[int]
216
216
  cumulative_llm_token_count_completion: Mapped[int]
217
+ llm_token_count_prompt: Mapped[Optional[int]]
218
+ llm_token_count_completion: Mapped[Optional[int]]
217
219
 
218
220
  @hybrid_property
219
221
  def latency_ms(self) -> float:
@@ -230,6 +232,12 @@ class Span(Base):
230
232
  def cumulative_llm_token_count_total(self) -> int:
231
233
  return self.cumulative_llm_token_count_prompt + self.cumulative_llm_token_count_completion
232
234
 
235
+ @hybrid_property
236
+ def llm_token_count_total(self) -> Optional[int]:
237
+ if self.llm_token_count_prompt is None and self.llm_token_count_completion is None:
238
+ return None
239
+ return (self.llm_token_count_prompt or 0) + (self.llm_token_count_completion or 0)
240
+
233
241
  trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
234
242
  document_annotations: Mapped[List["DocumentAnnotation"]] = relationship(back_populates="span")
235
243
  dataset_examples: Mapped[List["DatasetExample"]] = relationship(back_populates="span")
@@ -1,10 +1,8 @@
1
1
  from dataclasses import dataclass
2
- from datetime import datetime
3
2
  from pathlib import Path
4
- from typing import Callable, Optional
3
+ from typing import Any, Optional
5
4
 
6
5
  from strawberry.fastapi import BaseContext
7
- from typing_extensions import TypeAlias
8
6
 
9
7
  from phoenix.core.model_schema import Model
10
8
  from phoenix.server.api.dataloaders import (
@@ -34,7 +32,8 @@ from phoenix.server.api.dataloaders import (
34
32
  TraceEvaluationsDataLoader,
35
33
  TraceRowIdsDataLoader,
36
34
  )
37
- from phoenix.server.types import DbSessionFactory
35
+ from phoenix.server.dml_event import DmlEvent
36
+ from phoenix.server.types import CanGetLastUpdatedAt, CanPutItem, DbSessionFactory
38
37
 
39
38
 
40
39
  @dataclass
@@ -65,7 +64,9 @@ class DataLoaders:
65
64
  project_by_name: ProjectByNameDataLoader
66
65
 
67
66
 
68
- ProjectRowId: TypeAlias = int
67
+ class _NoOp:
68
+ def get(self, *args: Any, **kwargs: Any) -> Any: ...
69
+ def put(self, *args: Any, **kwargs: Any) -> Any: ...
69
70
 
70
71
 
71
72
  @dataclass
@@ -75,6 +76,7 @@ class Context(BaseContext):
75
76
  cache_for_dataloaders: Optional[CacheForDataLoaders]
76
77
  model: Model
77
78
  export_path: Path
79
+ last_updated_at: CanGetLastUpdatedAt = _NoOp()
80
+ event_queue: CanPutItem[DmlEvent] = _NoOp()
78
81
  corpus: Optional[Model] = None
79
- streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None
80
82
  read_only: bool = False
@@ -1,12 +1,4 @@
1
1
  from dataclasses import dataclass, field
2
- from functools import singledispatchmethod
3
-
4
- from phoenix.db.insertion.evaluation import (
5
- DocumentEvaluationInsertionEvent,
6
- SpanEvaluationInsertionEvent,
7
- TraceEvaluationInsertionEvent,
8
- )
9
- from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
10
2
 
11
3
  from .annotation_summaries import AnnotationSummaryCache, AnnotationSummaryDataLoader
12
4
  from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
@@ -88,42 +80,3 @@ class CacheForDataLoaders:
88
80
  token_count: TokenCountCache = field(
89
81
  default_factory=TokenCountCache,
90
82
  )
91
-
92
- def _update_spans(self, project_rowid: int) -> None:
93
- self.latency_ms_quantile.invalidate(project_rowid)
94
- self.token_count.invalidate(project_rowid)
95
- self.record_count.invalidate(project_rowid)
96
- self.min_start_or_max_end_time.invalidate(project_rowid)
97
-
98
- def _clear_spans(self, project_rowid: int) -> None:
99
- self._update_spans(project_rowid)
100
- self.annotation_summary.invalidate_project(project_rowid)
101
- self.evaluation_summary.invalidate_project(project_rowid)
102
- self.document_evaluation_summary.invalidate_project(project_rowid)
103
-
104
- @singledispatchmethod
105
- def invalidate(self, event: SpanInsertionEvent) -> None:
106
- project_rowid, *_ = event
107
- self._update_spans(project_rowid)
108
-
109
- @invalidate.register
110
- def _(self, event: ClearProjectSpansEvent) -> None:
111
- project_rowid, *_ = event
112
- self._clear_spans(project_rowid)
113
-
114
- @invalidate.register
115
- def _(self, event: DocumentEvaluationInsertionEvent) -> None:
116
- project_rowid, evaluation_name = event
117
- self.document_evaluation_summary.invalidate((project_rowid, evaluation_name))
118
-
119
- @invalidate.register
120
- def _(self, event: SpanEvaluationInsertionEvent) -> None:
121
- project_rowid, evaluation_name = event
122
- self.annotation_summary.invalidate((project_rowid, evaluation_name, "span"))
123
- self.evaluation_summary.invalidate((project_rowid, evaluation_name, "span"))
124
-
125
- @invalidate.register
126
- def _(self, event: TraceEvaluationInsertionEvent) -> None:
127
- project_rowid, evaluation_name = event
128
- self.annotation_summary.invalidate((project_rowid, evaluation_name, "trace"))
129
- self.evaluation_summary.invalidate((project_rowid, evaluation_name, "trace"))
@@ -10,7 +10,6 @@ from typing import (
10
10
  )
11
11
 
12
12
  from cachetools import LFUCache, TTLCache
13
- from openinference.semconv.trace import SpanAttributes
14
13
  from sqlalchemy import Select, func, select
15
14
  from sqlalchemy.sql.functions import coalesce
16
15
  from strawberry.dataloader import AbstractCache, DataLoader
@@ -107,8 +106,8 @@ def _get_stmt(
107
106
  *params: Param,
108
107
  ) -> Select[Any]:
109
108
  (start_time, end_time), filter_condition = segment
110
- prompt = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_PROMPT].as_float())
111
- completion = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_COMPLETION].as_float())
109
+ prompt = func.sum(models.Span.llm_token_count_prompt)
110
+ completion = func.sum(models.Span.llm_token_count_completion)
112
111
  total = coalesce(prompt, 0) + coalesce(completion, 0)
113
112
  pid = models.Trace.project_rowid
114
113
  stmt: Select[Any] = (
@@ -130,7 +129,3 @@ def _get_stmt(
130
129
  stmt = sf(stmt)
131
130
  stmt = stmt.where(pid.in_([rowid for rowid, _ in params]))
132
131
  return stmt
133
-
134
-
135
- _LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".")
136
- _LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".")
@@ -3,7 +3,6 @@ from enum import Enum, auto
3
3
  from typing import Any, Optional, Protocol
4
4
 
5
5
  import strawberry
6
- from openinference.semconv.trace import SpanAttributes
7
6
  from sqlalchemy import and_, desc, nulls_last
8
7
  from sqlalchemy.orm import InstrumentedAttribute
9
8
  from sqlalchemy.sql.expression import Select
@@ -16,10 +15,6 @@ from phoenix.server.api.types.pagination import CursorSortColumnDataType
16
15
  from phoenix.server.api.types.SortDir import SortDir
17
16
  from phoenix.trace.schemas import SpanID
18
17
 
19
- LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".")
20
- LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".")
21
- LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL.split(".")
22
-
23
18
 
24
19
  @strawberry.enum
25
20
  class SpanColumn(Enum):
@@ -47,11 +42,11 @@ class SpanColumn(Enum):
47
42
  elif self is SpanColumn.latencyMs:
48
43
  expr = models.Span.latency_ms
49
44
  elif self is SpanColumn.tokenCountTotal:
50
- expr = models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_float()
45
+ expr = models.Span.llm_token_count_total
51
46
  elif self is SpanColumn.tokenCountPrompt:
52
- expr = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float()
47
+ expr = models.Span.llm_token_count_prompt
53
48
  elif self is SpanColumn.tokenCountCompletion:
54
- expr = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float()
49
+ expr = models.Span.llm_token_count_completion
55
50
  elif self is SpanColumn.cumulativeTokenCountTotal:
56
51
  expr = (
57
52
  models.Span.cumulative_llm_token_count_prompt
@@ -33,6 +33,7 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
33
33
  from phoenix.server.api.types.node import from_global_id_with_expected_type
34
34
  from phoenix.server.api.types.Span import Span
35
35
  from phoenix.server.api.utils import delete_projects, delete_traces
36
+ from phoenix.server.dml_event import DatasetDeleteEvent, DatasetInsertEvent
36
37
 
37
38
 
38
39
  @strawberry.type
@@ -62,6 +63,7 @@ class DatasetMutationMixin:
62
63
  .returning(models.Dataset)
63
64
  )
64
65
  assert dataset is not None
66
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
65
67
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
66
68
 
67
69
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -90,6 +92,7 @@ class DatasetMutationMixin:
90
92
  .values(**patch)
91
93
  )
92
94
  assert dataset is not None
95
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
93
96
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
94
97
 
95
98
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -218,6 +221,7 @@ class DatasetMutationMixin:
218
221
  for dataset_example_rowid, span in zip(dataset_example_rowids, spans)
219
222
  ],
220
223
  )
224
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
221
225
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
222
226
 
223
227
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -303,6 +307,7 @@ class DatasetMutationMixin:
303
307
  )
304
308
  ],
305
309
  )
310
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
306
311
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
307
312
 
308
313
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -333,6 +338,7 @@ class DatasetMutationMixin:
333
338
  delete_traces(info.context.db, *eval_trace_ids),
334
339
  return_exceptions=True,
335
340
  )
341
+ info.context.event_queue.put(DatasetDeleteEvent((dataset.id,)))
336
342
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
337
343
 
338
344
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -424,7 +430,7 @@ class DatasetMutationMixin:
424
430
  for revision, patch, example_id in zip(revisions, patches, example_ids)
425
431
  ],
426
432
  )
427
-
433
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
428
434
  return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
429
435
 
430
436
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -507,8 +513,8 @@ class DatasetMutationMixin:
507
513
  for dataset_example_rowid in example_db_ids
508
514
  ],
509
515
  )
510
-
511
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
516
+ info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
517
+ return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
512
518
 
513
519
 
514
520
  def _span_attribute(semconv: str) -> Any:
@@ -14,6 +14,7 @@ from phoenix.server.api.mutations.auth import IsAuthenticated
14
14
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
15
15
  from phoenix.server.api.types.node import from_global_id_with_expected_type
16
16
  from phoenix.server.api.utils import delete_projects, delete_traces
17
+ from phoenix.server.dml_event import ExperimentDeleteEvent
17
18
 
18
19
 
19
20
  @strawberry.type
@@ -66,6 +67,7 @@ class ExperimentMutationMixin:
66
67
  delete_traces(info.context.db, *eval_trace_ids),
67
68
  return_exceptions=True,
68
69
  )
70
+ info.context.event_queue.put(ExperimentDeleteEvent(tuple(experiments.keys())))
69
71
  return ExperimentMutationPayload(
70
72
  experiments=[
71
73
  to_gql_experiment(experiments[experiment_id]) for experiment_id in experiment_ids
@@ -6,23 +6,23 @@ from strawberry.types import Info
6
6
 
7
7
  from phoenix.config import DEFAULT_PROJECT_NAME
8
8
  from phoenix.db import models
9
- from phoenix.db.insertion.span import ClearProjectSpansEvent
10
9
  from phoenix.server.api.context import Context
11
10
  from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput
12
11
  from phoenix.server.api.mutations.auth import IsAuthenticated
13
12
  from phoenix.server.api.queries import Query
14
13
  from phoenix.server.api.types.node import from_global_id_with_expected_type
14
+ from phoenix.server.dml_event import ProjectDeleteEvent, SpanDeleteEvent
15
15
 
16
16
 
17
17
  @strawberry.type
18
18
  class ProjectMutationMixin:
19
19
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
20
20
  async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
21
- node_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
21
+ project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
22
22
  async with info.context.db() as session:
23
23
  project = await session.scalar(
24
24
  select(models.Project)
25
- .where(models.Project.id == node_id)
25
+ .where(models.Project.id == project_id)
26
26
  .options(load_only(models.Project.name))
27
27
  )
28
28
  if project is None:
@@ -30,6 +30,7 @@ class ProjectMutationMixin:
30
30
  if project.name == DEFAULT_PROJECT_NAME:
31
31
  raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
32
32
  await session.delete(project)
33
+ info.context.event_queue.put(ProjectDeleteEvent((project_id,)))
33
34
  return Query()
34
35
 
35
36
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
@@ -42,6 +43,5 @@ class ProjectMutationMixin:
42
43
  delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
43
44
  async with info.context.db() as session:
44
45
  await session.execute(delete_statement)
45
- if cache := info.context.cache_for_dataloaders:
46
- cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
46
+ info.context.event_queue.put(SpanDeleteEvent((project_id,)))
47
47
  return Query()