arize-phoenix 3.25.0__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/METADATA +26 -4
- {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/RECORD +80 -75
- phoenix/__init__.py +9 -5
- phoenix/config.py +109 -53
- phoenix/datetime_utils.py +18 -1
- phoenix/db/README.md +25 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +119 -0
- phoenix/db/bulk_inserter.py +206 -0
- phoenix/db/engines.py +152 -0
- phoenix/db/helpers.py +47 -0
- phoenix/db/insertion/evaluation.py +209 -0
- phoenix/db/insertion/helpers.py +54 -0
- phoenix/db/insertion/span.py +142 -0
- phoenix/db/migrate.py +71 -0
- phoenix/db/migrations/env.py +121 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +371 -0
- phoenix/exceptions.py +5 -1
- phoenix/server/api/context.py +40 -3
- phoenix/server/api/dataloaders/__init__.py +97 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
- phoenix/server/api/dataloaders/document_evaluations.py +37 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
- phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
- phoenix/server/api/dataloaders/record_counts.py +125 -0
- phoenix/server/api/dataloaders/span_descendants.py +64 -0
- phoenix/server/api/dataloaders/span_evaluations.py +37 -0
- phoenix/server/api/dataloaders/token_counts.py +138 -0
- phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
- phoenix/server/api/input_types/SpanSort.py +138 -68
- phoenix/server/api/routers/v1/__init__.py +11 -0
- phoenix/server/api/routers/v1/evaluations.py +275 -0
- phoenix/server/api/routers/v1/spans.py +126 -0
- phoenix/server/api/routers/v1/traces.py +82 -0
- phoenix/server/api/schema.py +112 -48
- phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
- phoenix/server/api/types/Evaluation.py +29 -12
- phoenix/server/api/types/EvaluationSummary.py +29 -44
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +9 -9
- phoenix/server/api/types/Project.py +240 -171
- phoenix/server/api/types/Span.py +87 -131
- phoenix/server/api/types/Trace.py +29 -20
- phoenix/server/api/types/pagination.py +151 -10
- phoenix/server/app.py +263 -35
- phoenix/server/grpc_server.py +93 -0
- phoenix/server/main.py +75 -60
- phoenix/server/openapi/docs.py +218 -0
- phoenix/server/prometheus.py +23 -7
- phoenix/server/static/index.js +662 -643
- phoenix/server/telemetry.py +68 -0
- phoenix/services.py +4 -0
- phoenix/session/client.py +34 -30
- phoenix/session/data_extractor.py +8 -3
- phoenix/session/session.py +176 -155
- phoenix/settings.py +13 -0
- phoenix/trace/attributes.py +349 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +660 -192
- phoenix/trace/dsl/helpers.py +24 -5
- phoenix/trace/dsl/query.py +562 -185
- phoenix/trace/fixtures.py +69 -7
- phoenix/trace/otel.py +33 -199
- phoenix/trace/schemas.py +14 -8
- phoenix/trace/span_evaluations.py +5 -2
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/span_store.py +0 -23
- phoenix/version.py +1 -1
- phoenix/core/project.py +0 -773
- phoenix/core/traces.py +0 -96
- phoenix/datasets/dataset.py +0 -214
- phoenix/datasets/fixtures.py +0 -24
- phoenix/datasets/schema.py +0 -31
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -453
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/routers/evaluation_handler.py +0 -110
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → server/openapi}/__init__.py +0 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
AsyncContextManager,
|
|
6
|
+
AsyncIterator,
|
|
7
|
+
Callable,
|
|
8
|
+
DefaultDict,
|
|
9
|
+
List,
|
|
10
|
+
Literal,
|
|
11
|
+
Mapping,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from cachetools import LFUCache, TTLCache
|
|
18
|
+
from sqlalchemy import (
|
|
19
|
+
ARRAY,
|
|
20
|
+
Float,
|
|
21
|
+
Integer,
|
|
22
|
+
Select,
|
|
23
|
+
SQLColumnExpression,
|
|
24
|
+
Values,
|
|
25
|
+
column,
|
|
26
|
+
func,
|
|
27
|
+
select,
|
|
28
|
+
values,
|
|
29
|
+
)
|
|
30
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
31
|
+
from sqlalchemy.sql.functions import percentile_cont
|
|
32
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
33
|
+
from typing_extensions import TypeAlias, assert_never
|
|
34
|
+
|
|
35
|
+
from phoenix.db import models
|
|
36
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
37
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
38
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
39
|
+
from phoenix.trace.dsl import SpanFilter
|
|
40
|
+
|
|
41
|
+
Kind: TypeAlias = Literal["span", "trace"]
|
|
42
|
+
ProjectRowId: TypeAlias = int
|
|
43
|
+
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
|
|
44
|
+
FilterCondition: TypeAlias = Optional[str]
|
|
45
|
+
Probability: TypeAlias = float
|
|
46
|
+
QuantileValue: TypeAlias = float
|
|
47
|
+
|
|
48
|
+
Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition]
|
|
49
|
+
Param: TypeAlias = Tuple[ProjectRowId, Probability]
|
|
50
|
+
|
|
51
|
+
Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, Probability]
|
|
52
|
+
Result: TypeAlias = Optional[QuantileValue]
|
|
53
|
+
ResultPosition: TypeAlias = int
|
|
54
|
+
DEFAULT_VALUE: Result = None
|
|
55
|
+
|
|
56
|
+
FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
|
|
60
|
+
kind, project_rowid, time_range, filter_condition, probability = key
|
|
61
|
+
interval = (
|
|
62
|
+
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
63
|
+
)
|
|
64
|
+
return (kind, interval, filter_condition), (project_rowid, probability)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
_Section: TypeAlias = ProjectRowId
|
|
68
|
+
_SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind, Probability]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LatencyMsQuantileCache(
|
|
72
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
73
|
+
):
|
|
74
|
+
def __init__(self) -> None:
|
|
75
|
+
super().__init__(
|
|
76
|
+
# TTL=3600 (1-hour) because time intervals are always moving forward, but
|
|
77
|
+
# interval endpoints are rounded down to the hour by the UI, so anything
|
|
78
|
+
# older than an hour most likely won't be a cache-hit anyway.
|
|
79
|
+
main_cache=TTLCache(maxsize=64, ttl=3600),
|
|
80
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2 * 16),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
|
|
84
|
+
(kind, interval, filter_condition), (project_rowid, probability) = _cache_key_fn(key)
|
|
85
|
+
return project_rowid, (interval, filter_condition, kind, probability)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
92
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
super().__init__(
|
|
95
|
+
load_fn=self._load_fn,
|
|
96
|
+
cache_key_fn=_cache_key_fn,
|
|
97
|
+
cache_map=cache_map,
|
|
98
|
+
)
|
|
99
|
+
self._db = db
|
|
100
|
+
|
|
101
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
102
|
+
results: List[Result] = [DEFAULT_VALUE] * len(keys)
|
|
103
|
+
arguments: DefaultDict[
|
|
104
|
+
Segment,
|
|
105
|
+
DefaultDict[Param, List[ResultPosition]],
|
|
106
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
107
|
+
for position, key in enumerate(keys):
|
|
108
|
+
segment, param = _cache_key_fn(key)
|
|
109
|
+
arguments[segment][param].append(position)
|
|
110
|
+
async with self._db() as session:
|
|
111
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
112
|
+
for segment, params in arguments.items():
|
|
113
|
+
async for position, quantile_value in _get_results(
|
|
114
|
+
dialect, session, segment, params
|
|
115
|
+
):
|
|
116
|
+
results[position] = quantile_value
|
|
117
|
+
return results
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def _get_results(
|
|
121
|
+
dialect: SupportedSQLDialect,
|
|
122
|
+
session: AsyncSession,
|
|
123
|
+
segment: Segment,
|
|
124
|
+
params: Mapping[Param, List[ResultPosition]],
|
|
125
|
+
) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
|
|
126
|
+
kind, (start_time, end_time), filter_condition = segment
|
|
127
|
+
stmt = select(models.Trace.project_rowid)
|
|
128
|
+
if kind == "trace":
|
|
129
|
+
latency_column = cast(FloatCol, models.Trace.latency_ms)
|
|
130
|
+
time_column = models.Trace.start_time
|
|
131
|
+
elif kind == "span":
|
|
132
|
+
latency_column = cast(FloatCol, models.Span.latency_ms)
|
|
133
|
+
time_column = models.Span.start_time
|
|
134
|
+
stmt = stmt.join(models.Span)
|
|
135
|
+
if filter_condition:
|
|
136
|
+
sf = SpanFilter(filter_condition)
|
|
137
|
+
stmt = sf(stmt)
|
|
138
|
+
else:
|
|
139
|
+
assert_never(kind)
|
|
140
|
+
if start_time:
|
|
141
|
+
stmt = stmt.where(start_time <= time_column)
|
|
142
|
+
if end_time:
|
|
143
|
+
stmt = stmt.where(time_column < end_time)
|
|
144
|
+
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
145
|
+
results = _get_results_postgresql(session, stmt, latency_column, params)
|
|
146
|
+
elif dialect is SupportedSQLDialect.SQLITE:
|
|
147
|
+
results = _get_results_sqlite(session, stmt, latency_column, params)
|
|
148
|
+
else:
|
|
149
|
+
assert_never(dialect)
|
|
150
|
+
async for position, quantile_value in results:
|
|
151
|
+
yield position, quantile_value
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _get_results_sqlite(
|
|
155
|
+
session: AsyncSession,
|
|
156
|
+
base_stmt: Select[Any],
|
|
157
|
+
latency_column: FloatCol,
|
|
158
|
+
params: Mapping[Param, List[ResultPosition]],
|
|
159
|
+
) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
|
|
160
|
+
projects_per_prob: DefaultDict[Probability, List[ProjectRowId]] = defaultdict(list)
|
|
161
|
+
for project_rowid, probability in params.keys():
|
|
162
|
+
projects_per_prob[probability].append(project_rowid)
|
|
163
|
+
pid = models.Trace.project_rowid
|
|
164
|
+
for probability, project_rowids in projects_per_prob.items():
|
|
165
|
+
pctl: FloatCol = func.percentile(latency_column, probability * 100)
|
|
166
|
+
stmt = base_stmt.add_columns(pctl)
|
|
167
|
+
stmt = stmt.where(pid.in_(project_rowids))
|
|
168
|
+
stmt = stmt.group_by(pid)
|
|
169
|
+
data = await session.stream(stmt)
|
|
170
|
+
async for project_rowid, quantile_value in data:
|
|
171
|
+
for position in params[(project_rowid, probability)]:
|
|
172
|
+
yield position, quantile_value
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def _get_results_postgresql(
|
|
176
|
+
session: AsyncSession,
|
|
177
|
+
base_stmt: Select[Any],
|
|
178
|
+
latency_column: FloatCol,
|
|
179
|
+
params: Mapping[Param, List[ResultPosition]],
|
|
180
|
+
) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]:
|
|
181
|
+
probs_per_project: DefaultDict[ProjectRowId, List[Probability]] = defaultdict(list)
|
|
182
|
+
for project_rowid, probability in params.keys():
|
|
183
|
+
probs_per_project[project_rowid].append(probability)
|
|
184
|
+
pp: Values = values(
|
|
185
|
+
column("project_rowid", Integer),
|
|
186
|
+
column("probabilities", ARRAY(Float[float])),
|
|
187
|
+
name="project_probabilities",
|
|
188
|
+
).data(probs_per_project.items()) # type: ignore
|
|
189
|
+
pid = models.Trace.project_rowid
|
|
190
|
+
pctl: FloatCol = percentile_cont(pp.c.probabilities).within_group(latency_column)
|
|
191
|
+
stmt = base_stmt.add_columns(pp.c.probabilities, pctl)
|
|
192
|
+
stmt = stmt.join(pp, pid == pp.c.project_rowid)
|
|
193
|
+
stmt = stmt.group_by(pid, pp.c.probabilities)
|
|
194
|
+
data = await session.stream(stmt)
|
|
195
|
+
async for project_rowid, probabilities, quantile_values in data:
|
|
196
|
+
for probability, quantile_value in zip(probabilities, quantile_values):
|
|
197
|
+
for position in params[(project_rowid, probability)]:
|
|
198
|
+
yield position, quantile_value
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import (
|
|
4
|
+
AsyncContextManager,
|
|
5
|
+
Callable,
|
|
6
|
+
DefaultDict,
|
|
7
|
+
List,
|
|
8
|
+
Literal,
|
|
9
|
+
Optional,
|
|
10
|
+
Tuple,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from cachetools import LFUCache
|
|
14
|
+
from sqlalchemy import func, select
|
|
15
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
16
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
17
|
+
from typing_extensions import TypeAlias, assert_never
|
|
18
|
+
|
|
19
|
+
from phoenix.db import models
|
|
20
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
21
|
+
|
|
22
|
+
Kind: TypeAlias = Literal["start", "end"]
|
|
23
|
+
ProjectRowId: TypeAlias = int
|
|
24
|
+
|
|
25
|
+
Segment: TypeAlias = ProjectRowId
|
|
26
|
+
Param: TypeAlias = Kind
|
|
27
|
+
|
|
28
|
+
Key: TypeAlias = Tuple[ProjectRowId, Kind]
|
|
29
|
+
Result: TypeAlias = Optional[datetime]
|
|
30
|
+
ResultPosition: TypeAlias = int
|
|
31
|
+
DEFAULT_VALUE: Result = None
|
|
32
|
+
|
|
33
|
+
_Section = ProjectRowId
|
|
34
|
+
_SubKey = Kind
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MinStartOrMaxEndTimeCache(
|
|
38
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
39
|
+
):
|
|
40
|
+
def __init__(self) -> None:
|
|
41
|
+
super().__init__(
|
|
42
|
+
main_cache=LFUCache(maxsize=64),
|
|
43
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
|
|
47
|
+
return key
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]):
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
54
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
super().__init__(
|
|
57
|
+
load_fn=self._load_fn,
|
|
58
|
+
cache_map=cache_map,
|
|
59
|
+
)
|
|
60
|
+
self._db = db
|
|
61
|
+
|
|
62
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
63
|
+
results: List[Result] = [DEFAULT_VALUE] * len(keys)
|
|
64
|
+
arguments: DefaultDict[
|
|
65
|
+
Segment,
|
|
66
|
+
DefaultDict[Param, List[ResultPosition]],
|
|
67
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
68
|
+
for position, key in enumerate(keys):
|
|
69
|
+
segment, param = key
|
|
70
|
+
arguments[segment][param].append(position)
|
|
71
|
+
pid = models.Trace.project_rowid
|
|
72
|
+
stmt = (
|
|
73
|
+
select(
|
|
74
|
+
pid,
|
|
75
|
+
func.min(models.Trace.start_time).label("min_start"),
|
|
76
|
+
func.max(models.Trace.end_time).label("max_end"),
|
|
77
|
+
)
|
|
78
|
+
.where(pid.in_(arguments.keys()))
|
|
79
|
+
.group_by(pid)
|
|
80
|
+
)
|
|
81
|
+
async with self._db() as session:
|
|
82
|
+
data = await session.stream(stmt)
|
|
83
|
+
async for project_rowid, min_start, max_end in data:
|
|
84
|
+
for kind, positions in arguments[project_rowid].items():
|
|
85
|
+
if kind == "start":
|
|
86
|
+
for position in positions:
|
|
87
|
+
results[position] = min_start
|
|
88
|
+
elif kind == "end":
|
|
89
|
+
for position in positions:
|
|
90
|
+
results[position] = max_end
|
|
91
|
+
else:
|
|
92
|
+
assert_never(kind)
|
|
93
|
+
return results
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
AsyncContextManager,
|
|
6
|
+
Callable,
|
|
7
|
+
DefaultDict,
|
|
8
|
+
List,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
Tuple,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from cachetools import LFUCache, TTLCache
|
|
15
|
+
from sqlalchemy import Select, func, select
|
|
16
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
17
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
18
|
+
from typing_extensions import TypeAlias, assert_never
|
|
19
|
+
|
|
20
|
+
from phoenix.db import models
|
|
21
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
22
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
23
|
+
from phoenix.trace.dsl import SpanFilter
|
|
24
|
+
|
|
25
|
+
Kind: TypeAlias = Literal["span", "trace"]
|
|
26
|
+
ProjectRowId: TypeAlias = int
|
|
27
|
+
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
|
|
28
|
+
FilterCondition: TypeAlias = Optional[str]
|
|
29
|
+
SpanCount: TypeAlias = int
|
|
30
|
+
|
|
31
|
+
Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition]
|
|
32
|
+
Param: TypeAlias = ProjectRowId
|
|
33
|
+
|
|
34
|
+
Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
|
|
35
|
+
Result: TypeAlias = SpanCount
|
|
36
|
+
ResultPosition: TypeAlias = int
|
|
37
|
+
DEFAULT_VALUE: Result = 0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
|
|
41
|
+
kind, project_rowid, time_range, filter_condition = key
|
|
42
|
+
interval = (
|
|
43
|
+
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
44
|
+
)
|
|
45
|
+
return (kind, interval, filter_condition), project_rowid
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
_Section: TypeAlias = ProjectRowId
|
|
49
|
+
_SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class RecordCountCache(
|
|
53
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
54
|
+
):
|
|
55
|
+
def __init__(self) -> None:
|
|
56
|
+
super().__init__(
|
|
57
|
+
# TTL=3600 (1-hour) because time intervals are always moving forward, but
|
|
58
|
+
# interval endpoints are rounded down to the hour by the UI, so anything
|
|
59
|
+
# older than an hour most likely won't be a cache-hit anyway.
|
|
60
|
+
main_cache=TTLCache(maxsize=64, ttl=3600),
|
|
61
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
|
|
65
|
+
(kind, interval, filter_condition), project_rowid = _cache_key_fn(key)
|
|
66
|
+
return project_rowid, (interval, filter_condition, kind)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RecordCountDataLoader(DataLoader[Key, Result]):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
73
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
super().__init__(
|
|
76
|
+
load_fn=self._load_fn,
|
|
77
|
+
cache_key_fn=_cache_key_fn,
|
|
78
|
+
cache_map=cache_map,
|
|
79
|
+
)
|
|
80
|
+
self._db = db
|
|
81
|
+
|
|
82
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
83
|
+
results: List[Result] = [DEFAULT_VALUE] * len(keys)
|
|
84
|
+
arguments: DefaultDict[
|
|
85
|
+
Segment,
|
|
86
|
+
DefaultDict[Param, List[ResultPosition]],
|
|
87
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
88
|
+
for position, key in enumerate(keys):
|
|
89
|
+
segment, param = _cache_key_fn(key)
|
|
90
|
+
arguments[segment][param].append(position)
|
|
91
|
+
async with self._db() as session:
|
|
92
|
+
for segment, params in arguments.items():
|
|
93
|
+
stmt = _get_stmt(segment, *params.keys())
|
|
94
|
+
data = await session.stream(stmt)
|
|
95
|
+
async for project_rowid, count in data:
|
|
96
|
+
for position in params[project_rowid]:
|
|
97
|
+
results[position] = count
|
|
98
|
+
return results
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _get_stmt(
|
|
102
|
+
segment: Segment,
|
|
103
|
+
*project_rowids: Param,
|
|
104
|
+
) -> Select[Any]:
|
|
105
|
+
kind, (start_time, end_time), filter_condition = segment
|
|
106
|
+
pid = models.Trace.project_rowid
|
|
107
|
+
stmt = select(pid)
|
|
108
|
+
if kind == "span":
|
|
109
|
+
time_column = models.Span.start_time
|
|
110
|
+
stmt = stmt.join(models.Span)
|
|
111
|
+
if filter_condition:
|
|
112
|
+
sf = SpanFilter(filter_condition)
|
|
113
|
+
stmt = sf(stmt)
|
|
114
|
+
elif kind == "trace":
|
|
115
|
+
time_column = models.Trace.start_time
|
|
116
|
+
else:
|
|
117
|
+
assert_never(kind)
|
|
118
|
+
stmt = stmt.add_columns(func.count().label("count"))
|
|
119
|
+
stmt = stmt.where(pid.in_(project_rowids))
|
|
120
|
+
stmt = stmt.group_by(pid)
|
|
121
|
+
if start_time:
|
|
122
|
+
stmt = stmt.where(start_time <= time_column)
|
|
123
|
+
if end_time:
|
|
124
|
+
stmt = stmt.where(time_column < end_time)
|
|
125
|
+
return stmt
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from random import randint
|
|
2
|
+
from typing import (
|
|
3
|
+
AsyncContextManager,
|
|
4
|
+
Callable,
|
|
5
|
+
Dict,
|
|
6
|
+
List,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from aioitertools.itertools import groupby
|
|
10
|
+
from sqlalchemy import select
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
|
+
from sqlalchemy.orm import contains_eager
|
|
13
|
+
from strawberry.dataloader import DataLoader
|
|
14
|
+
from typing_extensions import TypeAlias
|
|
15
|
+
|
|
16
|
+
from phoenix.db import models
|
|
17
|
+
|
|
18
|
+
SpanId: TypeAlias = str
|
|
19
|
+
|
|
20
|
+
Key: TypeAlias = SpanId
|
|
21
|
+
Result: TypeAlias = List[models.Span]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SpanDescendantsDataLoader(DataLoader[Key, Result]):
|
|
25
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
26
|
+
super().__init__(load_fn=self._load_fn)
|
|
27
|
+
self._db = db
|
|
28
|
+
|
|
29
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
30
|
+
root_ids = set(keys)
|
|
31
|
+
root_id_label = f"root_id_{randint(0, 10**6):06}"
|
|
32
|
+
descendant_ids = (
|
|
33
|
+
select(
|
|
34
|
+
models.Span.id,
|
|
35
|
+
models.Span.span_id,
|
|
36
|
+
models.Span.parent_id.label(root_id_label),
|
|
37
|
+
)
|
|
38
|
+
.where(models.Span.parent_id.in_(root_ids))
|
|
39
|
+
.cte(recursive=True)
|
|
40
|
+
)
|
|
41
|
+
parent_ids = descendant_ids.alias()
|
|
42
|
+
descendant_ids = descendant_ids.union_all(
|
|
43
|
+
select(
|
|
44
|
+
models.Span.id,
|
|
45
|
+
models.Span.span_id,
|
|
46
|
+
parent_ids.c[root_id_label],
|
|
47
|
+
).join(
|
|
48
|
+
parent_ids,
|
|
49
|
+
models.Span.parent_id == parent_ids.c.span_id,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
stmt = (
|
|
53
|
+
select(descendant_ids.c[root_id_label], models.Span)
|
|
54
|
+
.join(descendant_ids, models.Span.id == descendant_ids.c.id)
|
|
55
|
+
.join(models.Trace)
|
|
56
|
+
.options(contains_eager(models.Span.trace))
|
|
57
|
+
.order_by(descendant_ids.c[root_id_label])
|
|
58
|
+
)
|
|
59
|
+
results: Dict[SpanId, Result] = {key: [] for key in keys}
|
|
60
|
+
async with self._db() as session:
|
|
61
|
+
data = await session.stream(stmt)
|
|
62
|
+
async for root_id, group in groupby(data, key=lambda d: d[0]):
|
|
63
|
+
results[root_id].extend(span for _, span in group)
|
|
64
|
+
return [results[key].copy() for key in keys]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import (
|
|
3
|
+
AsyncContextManager,
|
|
4
|
+
Callable,
|
|
5
|
+
DefaultDict,
|
|
6
|
+
List,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import select
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
|
+
from strawberry.dataloader import DataLoader
|
|
12
|
+
from typing_extensions import TypeAlias
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.types.Evaluation import SpanEvaluation
|
|
16
|
+
|
|
17
|
+
Key: TypeAlias = int
|
|
18
|
+
Result: TypeAlias = List[SpanEvaluation]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SpanEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
22
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
23
|
+
super().__init__(load_fn=self._load_fn)
|
|
24
|
+
self._db = db
|
|
25
|
+
|
|
26
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
27
|
+
span_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list)
|
|
28
|
+
msa = models.SpanAnnotation
|
|
29
|
+
async with self._db() as session:
|
|
30
|
+
data = await session.stream_scalars(
|
|
31
|
+
select(msa).where(msa.span_rowid.in_(keys)).where(msa.annotator_kind == "LLM")
|
|
32
|
+
)
|
|
33
|
+
async for span_evaluation in data:
|
|
34
|
+
span_evaluations_by_id[span_evaluation.span_rowid].append(
|
|
35
|
+
SpanEvaluation.from_sql_span_annotation(span_evaluation)
|
|
36
|
+
)
|
|
37
|
+
return [span_evaluations_by_id[key] for key in keys]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
AsyncContextManager,
|
|
6
|
+
Callable,
|
|
7
|
+
DefaultDict,
|
|
8
|
+
List,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
Tuple,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from cachetools import LFUCache, TTLCache
|
|
15
|
+
from openinference.semconv.trace import SpanAttributes
|
|
16
|
+
from sqlalchemy import Select, func, select
|
|
17
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
18
|
+
from sqlalchemy.sql.functions import coalesce
|
|
19
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
20
|
+
from typing_extensions import TypeAlias
|
|
21
|
+
|
|
22
|
+
from phoenix.db import models
|
|
23
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
24
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
25
|
+
from phoenix.trace.dsl import SpanFilter
|
|
26
|
+
|
|
27
|
+
Kind: TypeAlias = Literal["prompt", "completion", "total"]
|
|
28
|
+
ProjectRowId: TypeAlias = int
|
|
29
|
+
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
|
|
30
|
+
FilterCondition: TypeAlias = Optional[str]
|
|
31
|
+
TokenCount: TypeAlias = int
|
|
32
|
+
|
|
33
|
+
Segment: TypeAlias = Tuple[TimeInterval, FilterCondition]
|
|
34
|
+
Param: TypeAlias = Tuple[ProjectRowId, Kind]
|
|
35
|
+
|
|
36
|
+
Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
|
|
37
|
+
Result: TypeAlias = TokenCount
|
|
38
|
+
ResultPosition: TypeAlias = int
|
|
39
|
+
DEFAULT_VALUE: Result = 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
|
|
43
|
+
kind, project_rowid, time_range, filter_condition = key
|
|
44
|
+
interval = (
|
|
45
|
+
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
46
|
+
)
|
|
47
|
+
return (interval, filter_condition), (project_rowid, kind)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
_Section: TypeAlias = ProjectRowId
|
|
51
|
+
_SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition, Kind]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TokenCountCache(
|
|
55
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
56
|
+
):
|
|
57
|
+
def __init__(self) -> None:
|
|
58
|
+
super().__init__(
|
|
59
|
+
# TTL=3600 (1-hour) because time intervals are always moving forward, but
|
|
60
|
+
# interval endpoints are rounded down to the hour by the UI, so anything
|
|
61
|
+
# older than an hour most likely won't be a cache-hit anyway.
|
|
62
|
+
main_cache=TTLCache(maxsize=64, ttl=3600),
|
|
63
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
|
|
67
|
+
(interval, filter_condition), (project_rowid, kind) = _cache_key_fn(key)
|
|
68
|
+
return project_rowid, (interval, filter_condition, kind)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class TokenCountDataLoader(DataLoader[Key, Result]):
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
75
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
super().__init__(
|
|
78
|
+
load_fn=self._load_fn,
|
|
79
|
+
cache_key_fn=_cache_key_fn,
|
|
80
|
+
cache_map=cache_map,
|
|
81
|
+
)
|
|
82
|
+
self._db = db
|
|
83
|
+
|
|
84
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
85
|
+
results: List[Result] = [DEFAULT_VALUE] * len(keys)
|
|
86
|
+
arguments: DefaultDict[
|
|
87
|
+
Segment,
|
|
88
|
+
DefaultDict[Param, List[ResultPosition]],
|
|
89
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
90
|
+
for position, key in enumerate(keys):
|
|
91
|
+
segment, param = _cache_key_fn(key)
|
|
92
|
+
arguments[segment][param].append(position)
|
|
93
|
+
async with self._db() as session:
|
|
94
|
+
for segment, params in arguments.items():
|
|
95
|
+
stmt = _get_stmt(segment, *params.keys())
|
|
96
|
+
data = await session.stream(stmt)
|
|
97
|
+
async for project_rowid, prompt, completion, total in data:
|
|
98
|
+
for position in params[(project_rowid, "prompt")]:
|
|
99
|
+
results[position] = prompt
|
|
100
|
+
for position in params[(project_rowid, "completion")]:
|
|
101
|
+
results[position] = completion
|
|
102
|
+
for position in params[(project_rowid, "total")]:
|
|
103
|
+
results[position] = total
|
|
104
|
+
return results
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _get_stmt(
|
|
108
|
+
segment: Segment,
|
|
109
|
+
*params: Param,
|
|
110
|
+
) -> Select[Any]:
|
|
111
|
+
(start_time, end_time), filter_condition = segment
|
|
112
|
+
prompt = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_PROMPT].as_float())
|
|
113
|
+
completion = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_COMPLETION].as_float())
|
|
114
|
+
total = coalesce(prompt, 0) + coalesce(completion, 0)
|
|
115
|
+
pid = models.Trace.project_rowid
|
|
116
|
+
stmt: Select[Any] = (
|
|
117
|
+
select(
|
|
118
|
+
pid,
|
|
119
|
+
prompt.label("prompt"),
|
|
120
|
+
completion.label("completion"),
|
|
121
|
+
total.label("total"),
|
|
122
|
+
)
|
|
123
|
+
.join_from(models.Trace, models.Span)
|
|
124
|
+
.group_by(pid)
|
|
125
|
+
)
|
|
126
|
+
if start_time:
|
|
127
|
+
stmt = stmt.where(start_time <= models.Span.start_time)
|
|
128
|
+
if end_time:
|
|
129
|
+
stmt = stmt.where(models.Span.start_time < end_time)
|
|
130
|
+
if filter_condition:
|
|
131
|
+
sf = SpanFilter(filter_condition)
|
|
132
|
+
stmt = sf(stmt)
|
|
133
|
+
stmt = stmt.where(pid.in_([rowid for rowid, _ in params]))
|
|
134
|
+
return stmt
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
_LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".")
|
|
138
|
+
_LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import (
|
|
3
|
+
AsyncContextManager,
|
|
4
|
+
Callable,
|
|
5
|
+
DefaultDict,
|
|
6
|
+
List,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import select
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
|
+
from strawberry.dataloader import DataLoader
|
|
12
|
+
from typing_extensions import TypeAlias
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.types.Evaluation import TraceEvaluation
|
|
16
|
+
|
|
17
|
+
Key: TypeAlias = int
|
|
18
|
+
Result: TypeAlias = List[TraceEvaluation]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TraceEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
22
|
+
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
|
|
23
|
+
super().__init__(load_fn=self._load_fn)
|
|
24
|
+
self._db = db
|
|
25
|
+
|
|
26
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
27
|
+
trace_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list)
|
|
28
|
+
mta = models.TraceAnnotation
|
|
29
|
+
async with self._db() as session:
|
|
30
|
+
data = await session.stream_scalars(
|
|
31
|
+
select(mta).where(mta.trace_rowid.in_(keys)).where(mta.annotator_kind == "LLM")
|
|
32
|
+
)
|
|
33
|
+
async for trace_evaluation in data:
|
|
34
|
+
trace_evaluations_by_id[trace_evaluation.trace_rowid].append(
|
|
35
|
+
TraceEvaluation.from_sql_trace_annotation(trace_evaluation)
|
|
36
|
+
)
|
|
37
|
+
return [trace_evaluations_by_id[key] for key in keys]
|