arize-phoenix 6.1.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (54) hide show
  1. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/METADATA +9 -8
  2. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/RECORD +52 -38
  3. phoenix/config.py +4 -1
  4. phoenix/db/engines.py +1 -1
  5. phoenix/db/insertion/span.py +65 -30
  6. phoenix/db/migrate.py +4 -1
  7. phoenix/db/migrations/data_migration_scripts/__init__.py +0 -0
  8. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  9. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  10. phoenix/db/models.py +27 -0
  11. phoenix/metrics/wrappers.py +7 -1
  12. phoenix/server/api/context.py +15 -2
  13. phoenix/server/api/dataloaders/__init__.py +14 -2
  14. phoenix/server/api/dataloaders/session_io.py +75 -0
  15. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  16. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  17. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  18. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  19. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  20. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  21. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  22. phoenix/server/api/mutations/chat_mutations.py +5 -0
  23. phoenix/server/api/mutations/project_mutations.py +12 -2
  24. phoenix/server/api/queries.py +14 -9
  25. phoenix/server/api/subscriptions.py +6 -0
  26. phoenix/server/api/types/EmbeddingDimension.py +1 -1
  27. phoenix/server/api/types/ExperimentRun.py +3 -4
  28. phoenix/server/api/types/ExperimentRunAnnotation.py +3 -4
  29. phoenix/server/api/types/Project.py +150 -12
  30. phoenix/server/api/types/ProjectSession.py +139 -0
  31. phoenix/server/api/types/Span.py +6 -19
  32. phoenix/server/api/types/SpanIOValue.py +15 -0
  33. phoenix/server/api/types/TokenUsage.py +11 -0
  34. phoenix/server/api/types/Trace.py +59 -2
  35. phoenix/server/app.py +15 -2
  36. phoenix/server/static/.vite/manifest.json +40 -31
  37. phoenix/server/static/assets/{components-CdiZ1Osh.js → components-DKH6AzJw.js} +410 -351
  38. phoenix/server/static/assets/index-DLV87qiO.js +93 -0
  39. phoenix/server/static/assets/{pages-FArMEfgg.js → pages-CVY3Nv4Z.js} +638 -316
  40. phoenix/server/static/assets/vendor-Cb3zlNNd.js +894 -0
  41. phoenix/server/static/assets/{vendor-arizeai-BG6iwyLC.js → vendor-arizeai-Buo4e1A6.js} +2 -2
  42. phoenix/server/static/assets/{vendor-codemirror-BotnVFFX.js → vendor-codemirror-BuAQiUVf.js} +5 -5
  43. phoenix/server/static/assets/{vendor-recharts-Dy5gEFzQ.js → vendor-recharts-Cl9dK5tC.js} +1 -1
  44. phoenix/server/static/assets/{vendor-Bnv1dNRQ.js → vendor-shiki-CazYUixL.js} +5 -898
  45. phoenix/session/client.py +13 -4
  46. phoenix/trace/fixtures.py +8 -0
  47. phoenix/trace/schemas.py +16 -0
  48. phoenix/version.py +1 -1
  49. phoenix/server/api/dataloaders/trace_row_ids.py +0 -33
  50. phoenix/server/static/assets/index-D_sCOjlG.js +0 -101
  51. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/WHEEL +0 -0
  52. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/entry_points.txt +0 -0
  53. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  54. {arize_phoenix-6.1.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,199 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "arize-phoenix[pg]",
4
+ # ]
5
+ # ///
6
+ """
7
+ Populate the `project_sessions` table with data from the traces and spans tables.
8
+
9
+ Environment variables.
10
+
11
+ - `PHOENIX_SQL_DATABASE_URL` must be set to the database connection string.
12
+ - (optional) Postgresql schema can be set via `PHOENIX_SQL_DATABASE_SCHEMA`.
13
+ """
14
+
15
+ import os
16
+ from datetime import datetime
17
+ from time import perf_counter
18
+ from typing import Any, Optional, Union
19
+
20
+ import sqlean
21
+ from openinference.semconv.trace import SpanAttributes
22
+ from sqlalchemy import (
23
+ JSON,
24
+ Engine,
25
+ NullPool,
26
+ create_engine,
27
+ event,
28
+ func,
29
+ insert,
30
+ make_url,
31
+ select,
32
+ update,
33
+ )
34
+ from sqlalchemy.dialects import postgresql
35
+ from sqlalchemy.ext.compiler import compiles
36
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
37
+
38
+ from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA, get_env_database_connection_str
39
+ from phoenix.db.engines import set_postgresql_search_path
40
+
41
+
42
+ class JSONB(JSON):
43
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
44
+ __visit_name__ = "JSONB"
45
+
46
+
47
+ @compiles(JSONB, "sqlite")
48
+ def _(*args: Any, **kwargs: Any) -> str:
49
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
50
+ return "JSONB"
51
+
52
+
53
+ JSON_ = (
54
+ JSON()
55
+ .with_variant(
56
+ postgresql.JSONB(), # type: ignore
57
+ "postgresql",
58
+ )
59
+ .with_variant(
60
+ JSONB(),
61
+ "sqlite",
62
+ )
63
+ )
64
+
65
+
66
+ class Base(DeclarativeBase): ...
67
+
68
+
69
+ class ProjectSession(Base):
70
+ __tablename__ = "project_sessions"
71
+ id: Mapped[int] = mapped_column(primary_key=True)
72
+ session_id: Mapped[str]
73
+ project_id: Mapped[int]
74
+ start_time: Mapped[datetime]
75
+ end_time: Mapped[datetime]
76
+
77
+
78
+ class Trace(Base):
79
+ __tablename__ = "traces"
80
+ id: Mapped[int] = mapped_column(primary_key=True)
81
+ project_session_rowid: Mapped[Union[int, None]]
82
+ project_rowid: Mapped[int]
83
+ start_time: Mapped[datetime]
84
+ end_time: Mapped[datetime]
85
+
86
+
87
+ class Span(Base):
88
+ __tablename__ = "spans"
89
+ id: Mapped[int] = mapped_column(primary_key=True)
90
+ trace_rowid: Mapped[int]
91
+ parent_id: Mapped[Optional[str]]
92
+ attributes: Mapped[dict[str, Any]] = mapped_column(JSON_, nullable=False)
93
+
94
+
95
+ SESSION_ID = SpanAttributes.SESSION_ID.split(".")
96
+ USER_ID = SpanAttributes.USER_ID.split(".")
97
+
98
+
99
+ def populate_project_sessions(
100
+ engine: Engine,
101
+ ) -> None:
102
+ sessions_from_span = (
103
+ select(
104
+ Span.attributes[SESSION_ID].as_string().label("session_id"),
105
+ Trace.project_rowid.label("project_id"),
106
+ Trace.start_time.label("start_time"),
107
+ func.row_number()
108
+ .over(
109
+ partition_by=Span.attributes[SESSION_ID],
110
+ order_by=[Trace.start_time, Trace.id, Span.id],
111
+ )
112
+ .label("rank"),
113
+ func.max(Trace.end_time)
114
+ .over(partition_by=Span.attributes[SESSION_ID])
115
+ .label("end_time"),
116
+ )
117
+ .join_from(Span, Trace, Span.trace_rowid == Trace.id)
118
+ .where(Span.parent_id.is_(None))
119
+ .where(Span.attributes[SESSION_ID].as_string() != "")
120
+ .subquery()
121
+ )
122
+ sessions_for_trace_id = (
123
+ select(
124
+ Span.trace_rowid,
125
+ ProjectSession.id.label("project_session_rowid"),
126
+ )
127
+ .join_from(
128
+ Span,
129
+ ProjectSession,
130
+ Span.attributes[SESSION_ID].as_string() == ProjectSession.session_id,
131
+ )
132
+ .where(Span.parent_id.is_(None))
133
+ .where(Span.attributes[SESSION_ID].as_string() != "")
134
+ .subquery()
135
+ )
136
+ start_time = perf_counter()
137
+ with sessionmaker(engine).begin() as session:
138
+ session.execute(
139
+ insert(ProjectSession).from_select(
140
+ [
141
+ "session_id",
142
+ "project_id",
143
+ "start_time",
144
+ "end_time",
145
+ ],
146
+ select(
147
+ sessions_from_span.c.session_id,
148
+ sessions_from_span.c.project_id,
149
+ sessions_from_span.c.start_time,
150
+ sessions_from_span.c.end_time,
151
+ ).where(sessions_from_span.c.rank == 1),
152
+ )
153
+ )
154
+ session.execute(
155
+ (
156
+ update(Trace)
157
+ .values(project_session_rowid=sessions_for_trace_id.c.project_session_rowid)
158
+ .where(Trace.id == sessions_for_trace_id.c.trace_rowid)
159
+ )
160
+ )
161
+ elapsed_time = perf_counter() - start_time
162
+ print(f"✅ Populated project_sessions in {elapsed_time:.3f} seconds.")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ sql_database_url = make_url(get_env_database_connection_str())
167
+ print(f"Using database URL: {sql_database_url}")
168
+ ans = input("Is that correct? [y]/n: ")
169
+ if ans.lower().startswith("n"):
170
+ url = input("Please enter the correct database URL: ")
171
+ sql_database_url = make_url(url)
172
+ backend = sql_database_url.get_backend_name()
173
+ if backend == "sqlite":
174
+ file = sql_database_url.database
175
+ engine = create_engine(
176
+ url=sql_database_url.set(drivername="sqlite"),
177
+ creator=lambda: sqlean.connect(f"file:///{file}", uri=True),
178
+ poolclass=NullPool,
179
+ echo=True,
180
+ )
181
+ elif backend == "postgresql":
182
+ schema = os.getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA)
183
+ if schema:
184
+ print(f"Using schema: {schema}")
185
+ else:
186
+ print("No PostgreSQL schema set. (This is the default.)")
187
+ ans = input("Is that correct? [y]/n: ")
188
+ if ans.lower().startswith("n"):
189
+ schema = input("Please enter the correct schema: ")
190
+ engine = create_engine(
191
+ url=sql_database_url.set(drivername="postgresql+psycopg"),
192
+ poolclass=NullPool,
193
+ echo=True,
194
+ )
195
+ if schema:
196
+ event.listen(engine, "connect", set_postgresql_search_path(schema))
197
+ else:
198
+ raise ValueError(f"Unknown database backend: {backend}")
199
+ populate_project_sessions(engine)
@@ -0,0 +1,66 @@
1
+ """create project_session table
2
+
3
+ Revision ID: 4ded9e43755f
4
+ Revises: cd164e83824f
5
+ Create Date: 2024-10-08 22:53:24.539786
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ import sqlalchemy as sa
12
+ from alembic import op
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = "4ded9e43755f"
16
+ down_revision: Union[str, None] = "cd164e83824f"
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ op.create_table(
23
+ "project_sessions",
24
+ sa.Column("id", sa.Integer, primary_key=True),
25
+ sa.Column("session_id", sa.String, unique=True, nullable=False),
26
+ sa.Column(
27
+ "project_id",
28
+ sa.Integer,
29
+ sa.ForeignKey("projects.id", ondelete="CASCADE"),
30
+ nullable=False,
31
+ index=True,
32
+ ),
33
+ sa.Column(
34
+ "start_time",
35
+ sa.TIMESTAMP(timezone=True),
36
+ index=True,
37
+ nullable=False,
38
+ ),
39
+ sa.Column(
40
+ "end_time",
41
+ sa.TIMESTAMP(timezone=True),
42
+ index=True,
43
+ nullable=False,
44
+ ),
45
+ )
46
+ with op.batch_alter_table("traces") as batch_op:
47
+ batch_op.add_column(
48
+ sa.Column(
49
+ "project_session_rowid",
50
+ sa.Integer,
51
+ sa.ForeignKey("project_sessions.id", ondelete="CASCADE"),
52
+ nullable=True,
53
+ ),
54
+ )
55
+ op.create_index(
56
+ "ix_traces_project_session_rowid",
57
+ "traces",
58
+ ["project_session_rowid"],
59
+ )
60
+
61
+
62
+ def downgrade() -> None:
63
+ op.drop_index("ix_traces_project_session_rowid")
64
+ with op.batch_alter_table("traces") as batch_op:
65
+ batch_op.drop_column("project_session_rowid")
66
+ op.drop_table("project_sessions")
phoenix/db/models.py CHANGED
@@ -156,14 +156,37 @@ class Project(Base):
156
156
  )
157
157
 
158
158
 
159
+ class ProjectSession(Base):
160
+ __tablename__ = "project_sessions"
161
+ id: Mapped[int] = mapped_column(primary_key=True)
162
+ session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
163
+ project_id: Mapped[int] = mapped_column(
164
+ ForeignKey("projects.id", ondelete="CASCADE"),
165
+ nullable=False,
166
+ index=True,
167
+ )
168
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
169
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
170
+ traces: Mapped[list["Trace"]] = relationship(
171
+ "Trace",
172
+ back_populates="project_session",
173
+ uselist=True,
174
+ )
175
+
176
+
159
177
  class Trace(Base):
160
178
  __tablename__ = "traces"
161
179
  id: Mapped[int] = mapped_column(primary_key=True)
162
180
  project_rowid: Mapped[int] = mapped_column(
163
181
  ForeignKey("projects.id", ondelete="CASCADE"),
182
+ nullable=False,
164
183
  index=True,
165
184
  )
166
185
  trace_id: Mapped[str]
186
+ project_session_rowid: Mapped[Optional[int]] = mapped_column(
187
+ ForeignKey("project_sessions.id", ondelete="CASCADE"),
188
+ index=True,
189
+ )
167
190
  start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
168
191
  end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
169
192
 
@@ -188,6 +211,10 @@ class Trace(Base):
188
211
  cascade="all, delete-orphan",
189
212
  uselist=True,
190
213
  )
214
+ project_session: Mapped[ProjectSession] = relationship(
215
+ "ProjectSession",
216
+ back_populates="traces",
217
+ )
191
218
  experiment_runs: Mapped[list["ExperimentRun"]] = relationship(
192
219
  primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
193
220
  back_populates="trace",
@@ -27,6 +27,8 @@ from sklearn import metrics as sk
27
27
  from sklearn.utils.multiclass import check_classification_targets
28
28
  from wrapt import PartialCallableObjectProxy
29
29
 
30
+ from phoenix.config import SKLEARN_VERSION
31
+
30
32
 
31
33
  class Eval(PartialCallableObjectProxy, ABC): # type: ignore
32
34
  def __call__(
@@ -232,5 +234,9 @@ class SkEval(Enum):
232
234
  r2_score = RegressionEval(sk.r2_score)
233
235
  recall_score = ClassificationEval(sk.recall_score)
234
236
  roc_auc_score = ScoredClassificationEval(sk.roc_auc_score)
235
- root_mean_squared_error = RegressionEval(sk.mean_squared_error, squared=False)
237
+ root_mean_squared_error = (
238
+ RegressionEval(sk.mean_squared_error, squared=False)
239
+ if SKLEARN_VERSION < (1, 6)
240
+ else RegressionEval(sk.root_mean_squared_error)
241
+ )
236
242
  zero_one_loss = ClassificationEval(sk.zero_one_loss)
@@ -31,12 +31,18 @@ from phoenix.server.api.dataloaders import (
31
31
  MinStartOrMaxEndTimeDataLoader,
32
32
  ProjectByNameDataLoader,
33
33
  RecordCountDataLoader,
34
+ SessionIODataLoader,
35
+ SessionNumTracesDataLoader,
36
+ SessionNumTracesWithErrorDataLoader,
37
+ SessionTokenUsagesDataLoader,
38
+ SessionTraceLatencyMsQuantileDataLoader,
34
39
  SpanAnnotationsDataLoader,
35
40
  SpanDatasetExamplesDataLoader,
36
41
  SpanDescendantsDataLoader,
37
42
  SpanProjectsDataLoader,
38
43
  TokenCountDataLoader,
39
- TraceRowIdsDataLoader,
44
+ TraceByTraceIdsDataLoader,
45
+ TraceRootSpansDataLoader,
40
46
  UserRolesDataLoader,
41
47
  UsersDataLoader,
42
48
  )
@@ -68,12 +74,19 @@ class DataLoaders:
68
74
  latency_ms_quantile: LatencyMsQuantileDataLoader
69
75
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
70
76
  record_counts: RecordCountDataLoader
77
+ session_first_inputs: SessionIODataLoader
78
+ session_last_outputs: SessionIODataLoader
79
+ session_num_traces: SessionNumTracesDataLoader
80
+ session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
81
+ session_token_usages: SessionTokenUsagesDataLoader
82
+ session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
71
83
  span_annotations: SpanAnnotationsDataLoader
72
84
  span_dataset_examples: SpanDatasetExamplesDataLoader
73
85
  span_descendants: SpanDescendantsDataLoader
74
86
  span_projects: SpanProjectsDataLoader
75
87
  token_counts: TokenCountDataLoader
76
- trace_row_ids: TraceRowIdsDataLoader
88
+ trace_by_trace_ids: TraceByTraceIdsDataLoader
89
+ trace_root_spans: TraceRootSpansDataLoader
77
90
  project_by_name: ProjectByNameDataLoader
78
91
  users: UsersDataLoader
79
92
  user_roles: UserRolesDataLoader
@@ -19,12 +19,18 @@ from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLo
19
19
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
20
20
  from .project_by_name import ProjectByNameDataLoader
21
21
  from .record_counts import RecordCountCache, RecordCountDataLoader
22
+ from .session_io import SessionIODataLoader
23
+ from .session_num_traces import SessionNumTracesDataLoader
24
+ from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
25
+ from .session_token_usages import SessionTokenUsagesDataLoader
26
+ from .session_trace_latency_ms_quantile import SessionTraceLatencyMsQuantileDataLoader
22
27
  from .span_annotations import SpanAnnotationsDataLoader
23
28
  from .span_dataset_examples import SpanDatasetExamplesDataLoader
24
29
  from .span_descendants import SpanDescendantsDataLoader
25
30
  from .span_projects import SpanProjectsDataLoader
26
31
  from .token_counts import TokenCountCache, TokenCountDataLoader
27
- from .trace_row_ids import TraceRowIdsDataLoader
32
+ from .trace_by_trace_ids import TraceByTraceIdsDataLoader
33
+ from .trace_root_spans import TraceRootSpansDataLoader
28
34
  from .user_roles import UserRolesDataLoader
29
35
  from .users import UsersDataLoader
30
36
 
@@ -45,11 +51,17 @@ __all__ = [
45
51
  "LatencyMsQuantileDataLoader",
46
52
  "MinStartOrMaxEndTimeDataLoader",
47
53
  "RecordCountDataLoader",
54
+ "SessionIODataLoader",
55
+ "SessionNumTracesDataLoader",
56
+ "SessionNumTracesWithErrorDataLoader",
57
+ "SessionTokenUsagesDataLoader",
58
+ "SessionTraceLatencyMsQuantileDataLoader",
48
59
  "SpanDatasetExamplesDataLoader",
49
60
  "SpanDescendantsDataLoader",
50
61
  "SpanProjectsDataLoader",
51
62
  "TokenCountDataLoader",
52
- "TraceRowIdsDataLoader",
63
+ "TraceByTraceIdsDataLoader",
64
+ "TraceRootSpansDataLoader",
53
65
  "ProjectByNameDataLoader",
54
66
  "SpanAnnotationsDataLoader",
55
67
  "UsersDataLoader",
@@ -0,0 +1,75 @@
1
+ from functools import cached_property
2
+ from typing import Literal, Optional, cast
3
+
4
+ from openinference.semconv.trace import SpanAttributes
5
+ from sqlalchemy import Select, func, select
6
+ from strawberry.dataloader import DataLoader
7
+ from typing_extensions import TypeAlias, assert_never
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
11
+ from phoenix.trace.schemas import MimeType, SpanIOValue
12
+
13
+ Key: TypeAlias = int
14
+ Result: TypeAlias = Optional[SpanIOValue]
15
+
16
+ Kind = Literal["first_input", "last_output"]
17
+
18
+
19
+ class SessionIODataLoader(DataLoader[Key, Result]):
20
+ def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
21
+ super().__init__(load_fn=self._load_fn)
22
+ self._db = db
23
+ self._kind = kind
24
+
25
+ @cached_property
26
+ def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
27
+ stmt = (
28
+ select(models.Trace.project_session_rowid.label("id_"))
29
+ .join_from(models.Span, models.Trace)
30
+ .where(models.Span.parent_id.is_(None))
31
+ )
32
+ if self._kind == "first_input":
33
+ stmt = stmt.add_columns(
34
+ models.Span.attributes[INPUT_VALUE].label("value"),
35
+ models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
36
+ func.row_number()
37
+ .over(
38
+ partition_by=models.Trace.project_session_rowid,
39
+ order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
40
+ )
41
+ .label("rank"),
42
+ )
43
+ elif self._kind == "last_output":
44
+ stmt = stmt.add_columns(
45
+ models.Span.attributes[OUTPUT_VALUE].label("value"),
46
+ models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
47
+ func.row_number()
48
+ .over(
49
+ partition_by=models.Trace.project_session_rowid,
50
+ order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
51
+ )
52
+ .label("rank"),
53
+ )
54
+ else:
55
+ assert_never(self._kind)
56
+ return cast(Select[tuple[Optional[int], str, str, int]], stmt)
57
+
58
+ def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
59
+ subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
60
+ return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)
61
+
62
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
63
+ async with self._db() as session:
64
+ result: dict[Key, SpanIOValue] = {
65
+ id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
66
+ async for id_, value, mime_type in await session.stream(self._stmt(*keys))
67
+ if id_ is not None
68
+ }
69
+ return [result.get(key) for key in keys]
70
+
71
+
72
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
73
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
74
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
75
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
@@ -0,0 +1,30 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .group_by(models.Trace.project_session_rowid)
24
+ .where(models.Trace.project_session_rowid.in_(keys))
25
+ )
26
+ async with self._db() as session:
27
+ result: dict[Key, int] = {
28
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
29
+ }
30
+ return [result.get(key, 0) for key in keys]
@@ -0,0 +1,32 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesWithErrorDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .join(models.Span)
24
+ .group_by(models.Trace.project_session_rowid)
25
+ .where(models.Span.cumulative_error_count > 0)
26
+ .where(models.Trace.project_session_rowid.in_(keys))
27
+ )
28
+ async with self._db() as session:
29
+ result: dict[Key, int] = {
30
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
31
+ }
32
+ return [result.get(key, 0) for key in keys]
@@ -0,0 +1,41 @@
1
+ from sqlalchemy import func, select
2
+ from sqlalchemy.sql.functions import coalesce
3
+ from strawberry.dataloader import DataLoader
4
+ from typing_extensions import TypeAlias
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.types import DbSessionFactory
8
+ from phoenix.trace.schemas import TokenUsage
9
+
10
+ Key: TypeAlias = int
11
+ Result: TypeAlias = TokenUsage
12
+
13
+
14
+ class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
15
+ def __init__(self, db: DbSessionFactory) -> None:
16
+ super().__init__(load_fn=self._load_fn)
17
+ self._db = db
18
+
19
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
20
+ stmt = (
21
+ select(
22
+ models.Trace.project_session_rowid.label("id_"),
23
+ func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
24
+ "prompt"
25
+ ),
26
+ func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
27
+ "completion"
28
+ ),
29
+ )
30
+ .join_from(models.Span, models.Trace)
31
+ .where(models.Span.parent_id.is_(None))
32
+ .where(models.Trace.project_session_rowid.in_(keys))
33
+ .group_by(models.Trace.project_session_rowid)
34
+ )
35
+ async with self._db() as session:
36
+ result: dict[Key, TokenUsage] = {
37
+ id_: TokenUsage(prompt=prompt, completion=completion)
38
+ async for id_, prompt, completion in await session.stream(stmt)
39
+ if id_ is not None
40
+ }
41
+ return [result.get(key, TokenUsage()) for key in keys]
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from aioitertools.itertools import groupby
6
+ from sqlalchemy import select
7
+ from strawberry.dataloader import DataLoader
8
+ from typing_extensions import TypeAlias
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
12
+
13
+ SessionId: TypeAlias = int
14
+ Probability: TypeAlias = float
15
+ QuantileValue: TypeAlias = float
16
+
17
+ Key: TypeAlias = tuple[SessionId, Probability]
18
+ Result: TypeAlias = Optional[QuantileValue]
19
+ ResultPosition: TypeAlias = int
20
+
21
+ DEFAULT_VALUE: Result = None
22
+
23
+
24
+ class SessionTraceLatencyMsQuantileDataLoader(DataLoader[Key, Result]):
25
+ def __init__(self, db: DbSessionFactory) -> None:
26
+ super().__init__(load_fn=self._load_fn)
27
+ self._db = db
28
+
29
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
30
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
31
+ argument_position_map: defaultdict[
32
+ SessionId, defaultdict[Probability, list[ResultPosition]]
33
+ ] = defaultdict(lambda: defaultdict(list))
34
+ session_rowids = {session_id for session_id, _ in keys}
35
+ for position, (session_id, probability) in enumerate(keys):
36
+ argument_position_map[session_id][probability].append(position)
37
+ stmt = (
38
+ select(
39
+ models.Trace.project_session_rowid,
40
+ models.Trace.latency_ms,
41
+ )
42
+ .where(models.Trace.project_session_rowid.in_(session_rowids))
43
+ .order_by(models.Trace.project_session_rowid)
44
+ )
45
+ async with self._db() as session:
46
+ data = await session.stream(stmt)
47
+ async for project_session_rowid, group in groupby(
48
+ data, lambda row: row.project_session_rowid
49
+ ):
50
+ session_latencies = [row.latency_ms for row in group]
51
+ for probability, positions in argument_position_map[project_session_rowid].items():
52
+ quantile_value = np.quantile(session_latencies, probability)
53
+ for position in positions:
54
+ results[position] = quantile_value
55
+ return results
@@ -0,0 +1,25 @@
1
+ from typing import List, Optional
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ Key: TypeAlias = str
11
+ Result: TypeAlias = Optional[models.Trace]
12
+
13
+
14
+ class TraceByTraceIdsDataLoader(DataLoader[Key, Result]):
15
+ def __init__(self, db: DbSessionFactory) -> None:
16
+ super().__init__(load_fn=self._load_fn)
17
+ self._db = db
18
+
19
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
20
+ stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
21
+ async with self._db() as session:
22
+ result: dict[Key, models.Trace] = {
23
+ trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
24
+ }
25
+ return [result.get(trace_id) for trace_id in keys]