arize-phoenix 3.25.0__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +54 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +33 -199
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -0,0 +1,198 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ Any,
5
+ AsyncContextManager,
6
+ AsyncIterator,
7
+ Callable,
8
+ DefaultDict,
9
+ List,
10
+ Literal,
11
+ Mapping,
12
+ Optional,
13
+ Tuple,
14
+ cast,
15
+ )
16
+
17
+ from cachetools import LFUCache, TTLCache
18
+ from sqlalchemy import (
19
+ ARRAY,
20
+ Float,
21
+ Integer,
22
+ Select,
23
+ SQLColumnExpression,
24
+ Values,
25
+ column,
26
+ func,
27
+ select,
28
+ values,
29
+ )
30
+ from sqlalchemy.ext.asyncio import AsyncSession
31
+ from sqlalchemy.sql.functions import percentile_cont
32
+ from strawberry.dataloader import AbstractCache, DataLoader
33
+ from typing_extensions import TypeAlias, assert_never
34
+
35
+ from phoenix.db import models
36
+ from phoenix.db.helpers import SupportedSQLDialect
37
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
38
+ from phoenix.server.api.input_types.TimeRange import TimeRange
39
+ from phoenix.trace.dsl import SpanFilter
40
+
41
+ Kind: TypeAlias = Literal["span", "trace"]
42
+ ProjectRowId: TypeAlias = int
43
+ TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
44
+ FilterCondition: TypeAlias = Optional[str]
45
+ Probability: TypeAlias = float
46
+ QuantileValue: TypeAlias = float
47
+
48
+ Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition]
49
+ Param: TypeAlias = Tuple[ProjectRowId, Probability]
50
+
51
+ Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, Probability]
52
+ Result: TypeAlias = Optional[QuantileValue]
53
+ ResultPosition: TypeAlias = int
54
+ DEFAULT_VALUE: Result = None
55
+
56
+ FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
57
+
58
+
59
+ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
60
+ kind, project_rowid, time_range, filter_condition, probability = key
61
+ interval = (
62
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
63
+ )
64
+ return (kind, interval, filter_condition), (project_rowid, probability)
65
+
66
+
67
+ _Section: TypeAlias = ProjectRowId
68
+ _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind, Probability]
69
+
70
+
71
+ class LatencyMsQuantileCache(
72
+ TwoTierCache[Key, Result, _Section, _SubKey],
73
+ ):
74
+ def __init__(self) -> None:
75
+ super().__init__(
76
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
77
+ # interval endpoints are rounded down to the hour by the UI, so anything
78
+ # older than an hour most likely won't be a cache-hit anyway.
79
+ main_cache=TTLCache(maxsize=64, ttl=3600),
80
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2 * 16),
81
+ )
82
+
83
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
84
+ (kind, interval, filter_condition), (project_rowid, probability) = _cache_key_fn(key)
85
+ return project_rowid, (interval, filter_condition, kind, probability)
86
+
87
+
88
+ class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
89
+ def __init__(
90
+ self,
91
+ db: Callable[[], AsyncContextManager[AsyncSession]],
92
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
93
+ ) -> None:
94
+ super().__init__(
95
+ load_fn=self._load_fn,
96
+ cache_key_fn=_cache_key_fn,
97
+ cache_map=cache_map,
98
+ )
99
+ self._db = db
100
+
101
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
102
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
103
+ arguments: DefaultDict[
104
+ Segment,
105
+ DefaultDict[Param, List[ResultPosition]],
106
+ ] = defaultdict(lambda: defaultdict(list))
107
+ for position, key in enumerate(keys):
108
+ segment, param = _cache_key_fn(key)
109
+ arguments[segment][param].append(position)
110
+ async with self._db() as session:
111
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
112
+ for segment, params in arguments.items():
113
+ async for position, quantile_value in _get_results(
114
+ dialect, session, segment, params
115
+ ):
116
+ results[position] = quantile_value
117
+ return results
118
+
119
+
120
+ async def _get_results(
121
+ dialect: SupportedSQLDialect,
122
+ session: AsyncSession,
123
+ segment: Segment,
124
+ params: Mapping[Param, List[ResultPosition]],
125
+ ) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
126
+ kind, (start_time, end_time), filter_condition = segment
127
+ stmt = select(models.Trace.project_rowid)
128
+ if kind == "trace":
129
+ latency_column = cast(FloatCol, models.Trace.latency_ms)
130
+ time_column = models.Trace.start_time
131
+ elif kind == "span":
132
+ latency_column = cast(FloatCol, models.Span.latency_ms)
133
+ time_column = models.Span.start_time
134
+ stmt = stmt.join(models.Span)
135
+ if filter_condition:
136
+ sf = SpanFilter(filter_condition)
137
+ stmt = sf(stmt)
138
+ else:
139
+ assert_never(kind)
140
+ if start_time:
141
+ stmt = stmt.where(start_time <= time_column)
142
+ if end_time:
143
+ stmt = stmt.where(time_column < end_time)
144
+ if dialect is SupportedSQLDialect.POSTGRESQL:
145
+ results = _get_results_postgresql(session, stmt, latency_column, params)
146
+ elif dialect is SupportedSQLDialect.SQLITE:
147
+ results = _get_results_sqlite(session, stmt, latency_column, params)
148
+ else:
149
+ assert_never(dialect)
150
+ async for position, quantile_value in results:
151
+ yield position, quantile_value
152
+
153
+
154
+ async def _get_results_sqlite(
155
+ session: AsyncSession,
156
+ base_stmt: Select[Any],
157
+ latency_column: FloatCol,
158
+ params: Mapping[Param, List[ResultPosition]],
159
+ ) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
160
+ projects_per_prob: DefaultDict[Probability, List[ProjectRowId]] = defaultdict(list)
161
+ for project_rowid, probability in params.keys():
162
+ projects_per_prob[probability].append(project_rowid)
163
+ pid = models.Trace.project_rowid
164
+ for probability, project_rowids in projects_per_prob.items():
165
+ pctl: FloatCol = func.percentile(latency_column, probability * 100)
166
+ stmt = base_stmt.add_columns(pctl)
167
+ stmt = stmt.where(pid.in_(project_rowids))
168
+ stmt = stmt.group_by(pid)
169
+ data = await session.stream(stmt)
170
+ async for project_rowid, quantile_value in data:
171
+ for position in params[(project_rowid, probability)]:
172
+ yield position, quantile_value
173
+
174
+
175
+ async def _get_results_postgresql(
176
+ session: AsyncSession,
177
+ base_stmt: Select[Any],
178
+ latency_column: FloatCol,
179
+ params: Mapping[Param, List[ResultPosition]],
180
+ ) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
181
+ probs_per_project: DefaultDict[ProjectRowId, List[Probability]] = defaultdict(list)
182
+ for project_rowid, probability in params.keys():
183
+ probs_per_project[project_rowid].append(probability)
184
+ pp: Values = values(
185
+ column("project_rowid", Integer),
186
+ column("probabilities", ARRAY(Float[float])),
187
+ name="project_probabilities",
188
+ ).data(probs_per_project.items()) # type: ignore
189
+ pid = models.Trace.project_rowid
190
+ pctl: FloatCol = percentile_cont(pp.c.probabilities).within_group(latency_column)
191
+ stmt = base_stmt.add_columns(pp.c.probabilities, pctl)
192
+ stmt = stmt.join(pp, pid == pp.c.project_rowid)
193
+ stmt = stmt.group_by(pid, pp.c.probabilities)
194
+ data = await session.stream(stmt)
195
+ async for project_rowid, probabilities, quantile_values in data:
196
+ for probability, quantile_value in zip(probabilities, quantile_values):
197
+ for position in params[(project_rowid, probability)]:
198
+ yield position, quantile_value
@@ -0,0 +1,93 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ AsyncContextManager,
5
+ Callable,
6
+ DefaultDict,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ Tuple,
11
+ )
12
+
13
+ from cachetools import LFUCache
14
+ from sqlalchemy import func, select
15
+ from sqlalchemy.ext.asyncio import AsyncSession
16
+ from strawberry.dataloader import AbstractCache, DataLoader
17
+ from typing_extensions import TypeAlias, assert_never
18
+
19
+ from phoenix.db import models
20
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
21
+
22
+ Kind: TypeAlias = Literal["start", "end"]
23
+ ProjectRowId: TypeAlias = int
24
+
25
+ Segment: TypeAlias = ProjectRowId
26
+ Param: TypeAlias = Kind
27
+
28
+ Key: TypeAlias = Tuple[ProjectRowId, Kind]
29
+ Result: TypeAlias = Optional[datetime]
30
+ ResultPosition: TypeAlias = int
31
+ DEFAULT_VALUE: Result = None
32
+
33
+ _Section = ProjectRowId
34
+ _SubKey = Kind
35
+
36
+
37
+ class MinStartOrMaxEndTimeCache(
38
+ TwoTierCache[Key, Result, _Section, _SubKey],
39
+ ):
40
+ def __init__(self) -> None:
41
+ super().__init__(
42
+ main_cache=LFUCache(maxsize=64),
43
+ sub_cache_factory=lambda: LFUCache(maxsize=2),
44
+ )
45
+
46
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
47
+ return key
48
+
49
+
50
+ class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]):
51
+ def __init__(
52
+ self,
53
+ db: Callable[[], AsyncContextManager[AsyncSession]],
54
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
55
+ ) -> None:
56
+ super().__init__(
57
+ load_fn=self._load_fn,
58
+ cache_map=cache_map,
59
+ )
60
+ self._db = db
61
+
62
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
63
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
64
+ arguments: DefaultDict[
65
+ Segment,
66
+ DefaultDict[Param, List[ResultPosition]],
67
+ ] = defaultdict(lambda: defaultdict(list))
68
+ for position, key in enumerate(keys):
69
+ segment, param = key
70
+ arguments[segment][param].append(position)
71
+ pid = models.Trace.project_rowid
72
+ stmt = (
73
+ select(
74
+ pid,
75
+ func.min(models.Trace.start_time).label("min_start"),
76
+ func.max(models.Trace.end_time).label("max_end"),
77
+ )
78
+ .where(pid.in_(arguments.keys()))
79
+ .group_by(pid)
80
+ )
81
+ async with self._db() as session:
82
+ data = await session.stream(stmt)
83
+ async for project_rowid, min_start, max_end in data:
84
+ for kind, positions in arguments[project_rowid].items():
85
+ if kind == "start":
86
+ for position in positions:
87
+ results[position] = min_start
88
+ elif kind == "end":
89
+ for position in positions:
90
+ results[position] = max_end
91
+ else:
92
+ assert_never(kind)
93
+ return results
@@ -0,0 +1,125 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ Any,
5
+ AsyncContextManager,
6
+ Callable,
7
+ DefaultDict,
8
+ List,
9
+ Literal,
10
+ Optional,
11
+ Tuple,
12
+ )
13
+
14
+ from cachetools import LFUCache, TTLCache
15
+ from sqlalchemy import Select, func, select
16
+ from sqlalchemy.ext.asyncio import AsyncSession
17
+ from strawberry.dataloader import AbstractCache, DataLoader
18
+ from typing_extensions import TypeAlias, assert_never
19
+
20
+ from phoenix.db import models
21
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
22
+ from phoenix.server.api.input_types.TimeRange import TimeRange
23
+ from phoenix.trace.dsl import SpanFilter
24
+
25
+ Kind: TypeAlias = Literal["span", "trace"]
26
+ ProjectRowId: TypeAlias = int
27
+ TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
28
+ FilterCondition: TypeAlias = Optional[str]
29
+ SpanCount: TypeAlias = int
30
+
31
+ Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition]
32
+ Param: TypeAlias = ProjectRowId
33
+
34
+ Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
35
+ Result: TypeAlias = SpanCount
36
+ ResultPosition: TypeAlias = int
37
+ DEFAULT_VALUE: Result = 0
38
+
39
+
40
+ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
41
+ kind, project_rowid, time_range, filter_condition = key
42
+ interval = (
43
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
44
+ )
45
+ return (kind, interval, filter_condition), project_rowid
46
+
47
+
48
+ _Section: TypeAlias = ProjectRowId
49
+ _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
50
+
51
+
52
+ class RecordCountCache(
53
+ TwoTierCache[Key, Result, _Section, _SubKey],
54
+ ):
55
+ def __init__(self) -> None:
56
+ super().__init__(
57
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
58
+ # interval endpoints are rounded down to the hour by the UI, so anything
59
+ # older than an hour most likely won't be a cache-hit anyway.
60
+ main_cache=TTLCache(maxsize=64, ttl=3600),
61
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2),
62
+ )
63
+
64
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
65
+ (kind, interval, filter_condition), project_rowid = _cache_key_fn(key)
66
+ return project_rowid, (interval, filter_condition, kind)
67
+
68
+
69
+ class RecordCountDataLoader(DataLoader[Key, Result]):
70
+ def __init__(
71
+ self,
72
+ db: Callable[[], AsyncContextManager[AsyncSession]],
73
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
74
+ ) -> None:
75
+ super().__init__(
76
+ load_fn=self._load_fn,
77
+ cache_key_fn=_cache_key_fn,
78
+ cache_map=cache_map,
79
+ )
80
+ self._db = db
81
+
82
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
83
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
84
+ arguments: DefaultDict[
85
+ Segment,
86
+ DefaultDict[Param, List[ResultPosition]],
87
+ ] = defaultdict(lambda: defaultdict(list))
88
+ for position, key in enumerate(keys):
89
+ segment, param = _cache_key_fn(key)
90
+ arguments[segment][param].append(position)
91
+ async with self._db() as session:
92
+ for segment, params in arguments.items():
93
+ stmt = _get_stmt(segment, *params.keys())
94
+ data = await session.stream(stmt)
95
+ async for project_rowid, count in data:
96
+ for position in params[project_rowid]:
97
+ results[position] = count
98
+ return results
99
+
100
+
101
+ def _get_stmt(
102
+ segment: Segment,
103
+ *project_rowids: Param,
104
+ ) -> Select[Any]:
105
+ kind, (start_time, end_time), filter_condition = segment
106
+ pid = models.Trace.project_rowid
107
+ stmt = select(pid)
108
+ if kind == "span":
109
+ time_column = models.Span.start_time
110
+ stmt = stmt.join(models.Span)
111
+ if filter_condition:
112
+ sf = SpanFilter(filter_condition)
113
+ stmt = sf(stmt)
114
+ elif kind == "trace":
115
+ time_column = models.Trace.start_time
116
+ else:
117
+ assert_never(kind)
118
+ stmt = stmt.add_columns(func.count().label("count"))
119
+ stmt = stmt.where(pid.in_(project_rowids))
120
+ stmt = stmt.group_by(pid)
121
+ if start_time:
122
+ stmt = stmt.where(start_time <= time_column)
123
+ if end_time:
124
+ stmt = stmt.where(time_column < end_time)
125
+ return stmt
@@ -0,0 +1,64 @@
1
+ from random import randint
2
+ from typing import (
3
+ AsyncContextManager,
4
+ Callable,
5
+ Dict,
6
+ List,
7
+ )
8
+
9
+ from aioitertools.itertools import groupby
10
+ from sqlalchemy import select
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+ from sqlalchemy.orm import contains_eager
13
+ from strawberry.dataloader import DataLoader
14
+ from typing_extensions import TypeAlias
15
+
16
+ from phoenix.db import models
17
+
18
+ SpanId: TypeAlias = str
19
+
20
+ Key: TypeAlias = SpanId
21
+ Result: TypeAlias = List[models.Span]
22
+
23
+
24
+ class SpanDescendantsDataLoader(DataLoader[Key, Result]):
25
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
26
+ super().__init__(load_fn=self._load_fn)
27
+ self._db = db
28
+
29
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
30
+ root_ids = set(keys)
31
+ root_id_label = f"root_id_{randint(0, 10**6):06}"
32
+ descendant_ids = (
33
+ select(
34
+ models.Span.id,
35
+ models.Span.span_id,
36
+ models.Span.parent_id.label(root_id_label),
37
+ )
38
+ .where(models.Span.parent_id.in_(root_ids))
39
+ .cte(recursive=True)
40
+ )
41
+ parent_ids = descendant_ids.alias()
42
+ descendant_ids = descendant_ids.union_all(
43
+ select(
44
+ models.Span.id,
45
+ models.Span.span_id,
46
+ parent_ids.c[root_id_label],
47
+ ).join(
48
+ parent_ids,
49
+ models.Span.parent_id == parent_ids.c.span_id,
50
+ )
51
+ )
52
+ stmt = (
53
+ select(descendant_ids.c[root_id_label], models.Span)
54
+ .join(descendant_ids, models.Span.id == descendant_ids.c.id)
55
+ .join(models.Trace)
56
+ .options(contains_eager(models.Span.trace))
57
+ .order_by(descendant_ids.c[root_id_label])
58
+ )
59
+ results: Dict[SpanId, Result] = {key: [] for key in keys}
60
+ async with self._db() as session:
61
+ data = await session.stream(stmt)
62
+ async for root_id, group in groupby(data, key=lambda d: d[0]):
63
+ results[root_id].extend(span for _, span in group)
64
+ return [results[key].copy() for key in keys]
@@ -0,0 +1,37 @@
1
+ from collections import defaultdict
2
+ from typing import (
3
+ AsyncContextManager,
4
+ Callable,
5
+ DefaultDict,
6
+ List,
7
+ )
8
+
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from strawberry.dataloader import DataLoader
12
+ from typing_extensions import TypeAlias
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.types.Evaluation import SpanEvaluation
16
+
17
+ Key: TypeAlias = int
18
+ Result: TypeAlias = List[SpanEvaluation]
19
+
20
+
21
+ class SpanEvaluationsDataLoader(DataLoader[Key, Result]):
22
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ super().__init__(load_fn=self._load_fn)
24
+ self._db = db
25
+
26
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
+ span_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list)
28
+ msa = models.SpanAnnotation
29
+ async with self._db() as session:
30
+ data = await session.stream_scalars(
31
+ select(msa).where(msa.span_rowid.in_(keys)).where(msa.annotator_kind == "LLM")
32
+ )
33
+ async for span_evaluation in data:
34
+ span_evaluations_by_id[span_evaluation.span_rowid].append(
35
+ SpanEvaluation.from_sql_span_annotation(span_evaluation)
36
+ )
37
+ return [span_evaluations_by_id[key] for key in keys]
@@ -0,0 +1,138 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ Any,
5
+ AsyncContextManager,
6
+ Callable,
7
+ DefaultDict,
8
+ List,
9
+ Literal,
10
+ Optional,
11
+ Tuple,
12
+ )
13
+
14
+ from cachetools import LFUCache, TTLCache
15
+ from openinference.semconv.trace import SpanAttributes
16
+ from sqlalchemy import Select, func, select
17
+ from sqlalchemy.ext.asyncio import AsyncSession
18
+ from sqlalchemy.sql.functions import coalesce
19
+ from strawberry.dataloader import AbstractCache, DataLoader
20
+ from typing_extensions import TypeAlias
21
+
22
+ from phoenix.db import models
23
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
24
+ from phoenix.server.api.input_types.TimeRange import TimeRange
25
+ from phoenix.trace.dsl import SpanFilter
26
+
27
+ Kind: TypeAlias = Literal["prompt", "completion", "total"]
28
+ ProjectRowId: TypeAlias = int
29
+ TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
30
+ FilterCondition: TypeAlias = Optional[str]
31
+ TokenCount: TypeAlias = int
32
+
33
+ Segment: TypeAlias = Tuple[TimeInterval, FilterCondition]
34
+ Param: TypeAlias = Tuple[ProjectRowId, Kind]
35
+
36
+ Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
37
+ Result: TypeAlias = TokenCount
38
+ ResultPosition: TypeAlias = int
39
+ DEFAULT_VALUE: Result = 0
40
+
41
+
42
+ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
43
+ kind, project_rowid, time_range, filter_condition = key
44
+ interval = (
45
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
46
+ )
47
+ return (interval, filter_condition), (project_rowid, kind)
48
+
49
+
50
+ _Section: TypeAlias = ProjectRowId
51
+ _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
52
+
53
+
54
+ class TokenCountCache(
55
+ TwoTierCache[Key, Result, _Section, _SubKey],
56
+ ):
57
+ def __init__(self) -> None:
58
+ super().__init__(
59
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
60
+ # interval endpoints are rounded down to the hour by the UI, so anything
61
+ # older than an hour most likely won't be a cache-hit anyway.
62
+ main_cache=TTLCache(maxsize=64, ttl=3600),
63
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
64
+ )
65
+
66
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
67
+ (interval, filter_condition), (project_rowid, kind) = _cache_key_fn(key)
68
+ return project_rowid, (interval, filter_condition, kind)
69
+
70
+
71
+ class TokenCountDataLoader(DataLoader[Key, Result]):
72
+ def __init__(
73
+ self,
74
+ db: Callable[[], AsyncContextManager[AsyncSession]],
75
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
76
+ ) -> None:
77
+ super().__init__(
78
+ load_fn=self._load_fn,
79
+ cache_key_fn=_cache_key_fn,
80
+ cache_map=cache_map,
81
+ )
82
+ self._db = db
83
+
84
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
85
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
86
+ arguments: DefaultDict[
87
+ Segment,
88
+ DefaultDict[Param, List[ResultPosition]],
89
+ ] = defaultdict(lambda: defaultdict(list))
90
+ for position, key in enumerate(keys):
91
+ segment, param = _cache_key_fn(key)
92
+ arguments[segment][param].append(position)
93
+ async with self._db() as session:
94
+ for segment, params in arguments.items():
95
+ stmt = _get_stmt(segment, *params.keys())
96
+ data = await session.stream(stmt)
97
+ async for project_rowid, prompt, completion, total in data:
98
+ for position in params[(project_rowid, "prompt")]:
99
+ results[position] = prompt
100
+ for position in params[(project_rowid, "completion")]:
101
+ results[position] = completion
102
+ for position in params[(project_rowid, "total")]:
103
+ results[position] = total
104
+ return results
105
+
106
+
107
+ def _get_stmt(
108
+ segment: Segment,
109
+ *params: Param,
110
+ ) -> Select[Any]:
111
+ (start_time, end_time), filter_condition = segment
112
+ prompt = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_PROMPT].as_float())
113
+ completion = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_COMPLETION].as_float())
114
+ total = coalesce(prompt, 0) + coalesce(completion, 0)
115
+ pid = models.Trace.project_rowid
116
+ stmt: Select[Any] = (
117
+ select(
118
+ pid,
119
+ prompt.label("prompt"),
120
+ completion.label("completion"),
121
+ total.label("total"),
122
+ )
123
+ .join_from(models.Trace, models.Span)
124
+ .group_by(pid)
125
+ )
126
+ if start_time:
127
+ stmt = stmt.where(start_time <= models.Span.start_time)
128
+ if end_time:
129
+ stmt = stmt.where(models.Span.start_time < end_time)
130
+ if filter_condition:
131
+ sf = SpanFilter(filter_condition)
132
+ stmt = sf(stmt)
133
+ stmt = stmt.where(pid.in_([rowid for rowid, _ in params]))
134
+ return stmt
135
+
136
+
137
+ _LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".")
138
+ _LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".")
@@ -0,0 +1,37 @@
1
+ from collections import defaultdict
2
+ from typing import (
3
+ AsyncContextManager,
4
+ Callable,
5
+ DefaultDict,
6
+ List,
7
+ )
8
+
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from strawberry.dataloader import DataLoader
12
+ from typing_extensions import TypeAlias
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.types.Evaluation import TraceEvaluation
16
+
17
+ Key: TypeAlias = int
18
+ Result: TypeAlias = List[TraceEvaluation]
19
+
20
+
21
+ class TraceEvaluationsDataLoader(DataLoader[Key, Result]):
22
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ super().__init__(load_fn=self._load_fn)
24
+ self._db = db
25
+
26
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
+ trace_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list)
28
+ mta = models.TraceAnnotation
29
+ async with self._db() as session:
30
+ data = await session.stream_scalars(
31
+ select(mta).where(mta.trace_rowid.in_(keys)).where(mta.annotator_kind == "LLM")
32
+ )
33
+ async for trace_evaluation in data:
34
+ trace_evaluations_by_id[trace_evaluation.trace_rowid].append(
35
+ TraceEvaluation.from_sql_trace_annotation(trace_evaluation)
36
+ )
37
+ return [trace_evaluations_by_id[key] for key in keys]