arize-phoenix 3.25.0__py3-none-any.whl → 4.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (113) hide show
  1. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/METADATA +26 -4
  2. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/RECORD +80 -75
  3. phoenix/__init__.py +9 -5
  4. phoenix/config.py +109 -53
  5. phoenix/datetime_utils.py +18 -1
  6. phoenix/db/README.md +25 -0
  7. phoenix/db/__init__.py +4 -0
  8. phoenix/db/alembic.ini +119 -0
  9. phoenix/db/bulk_inserter.py +206 -0
  10. phoenix/db/engines.py +152 -0
  11. phoenix/db/helpers.py +47 -0
  12. phoenix/db/insertion/evaluation.py +209 -0
  13. phoenix/db/insertion/helpers.py +51 -0
  14. phoenix/db/insertion/span.py +142 -0
  15. phoenix/db/migrate.py +71 -0
  16. phoenix/db/migrations/env.py +121 -0
  17. phoenix/db/migrations/script.py.mako +26 -0
  18. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  19. phoenix/db/models.py +371 -0
  20. phoenix/exceptions.py +5 -1
  21. phoenix/server/api/context.py +40 -3
  22. phoenix/server/api/dataloaders/__init__.py +97 -0
  23. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  24. phoenix/server/api/dataloaders/cache/two_tier_cache.py +67 -0
  25. phoenix/server/api/dataloaders/document_evaluation_summaries.py +152 -0
  26. phoenix/server/api/dataloaders/document_evaluations.py +37 -0
  27. phoenix/server/api/dataloaders/document_retrieval_metrics.py +98 -0
  28. phoenix/server/api/dataloaders/evaluation_summaries.py +151 -0
  29. phoenix/server/api/dataloaders/latency_ms_quantile.py +198 -0
  30. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +93 -0
  31. phoenix/server/api/dataloaders/record_counts.py +125 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +64 -0
  33. phoenix/server/api/dataloaders/span_evaluations.py +37 -0
  34. phoenix/server/api/dataloaders/token_counts.py +138 -0
  35. phoenix/server/api/dataloaders/trace_evaluations.py +37 -0
  36. phoenix/server/api/input_types/SpanSort.py +138 -68
  37. phoenix/server/api/routers/v1/__init__.py +11 -0
  38. phoenix/server/api/routers/v1/evaluations.py +275 -0
  39. phoenix/server/api/routers/v1/spans.py +126 -0
  40. phoenix/server/api/routers/v1/traces.py +82 -0
  41. phoenix/server/api/schema.py +112 -48
  42. phoenix/server/api/types/DocumentEvaluationSummary.py +1 -1
  43. phoenix/server/api/types/Evaluation.py +29 -12
  44. phoenix/server/api/types/EvaluationSummary.py +29 -44
  45. phoenix/server/api/types/MimeType.py +2 -2
  46. phoenix/server/api/types/Model.py +9 -9
  47. phoenix/server/api/types/Project.py +240 -171
  48. phoenix/server/api/types/Span.py +87 -131
  49. phoenix/server/api/types/Trace.py +29 -20
  50. phoenix/server/api/types/pagination.py +151 -10
  51. phoenix/server/app.py +263 -35
  52. phoenix/server/grpc_server.py +93 -0
  53. phoenix/server/main.py +75 -60
  54. phoenix/server/openapi/docs.py +218 -0
  55. phoenix/server/prometheus.py +23 -7
  56. phoenix/server/static/index.js +662 -643
  57. phoenix/server/telemetry.py +68 -0
  58. phoenix/services.py +4 -0
  59. phoenix/session/client.py +34 -30
  60. phoenix/session/data_extractor.py +8 -3
  61. phoenix/session/session.py +176 -155
  62. phoenix/settings.py +13 -0
  63. phoenix/trace/attributes.py +349 -0
  64. phoenix/trace/dsl/README.md +116 -0
  65. phoenix/trace/dsl/filter.py +660 -192
  66. phoenix/trace/dsl/helpers.py +24 -5
  67. phoenix/trace/dsl/query.py +562 -185
  68. phoenix/trace/fixtures.py +69 -7
  69. phoenix/trace/otel.py +44 -200
  70. phoenix/trace/schemas.py +14 -8
  71. phoenix/trace/span_evaluations.py +5 -2
  72. phoenix/utilities/__init__.py +0 -26
  73. phoenix/utilities/span_store.py +0 -23
  74. phoenix/version.py +1 -1
  75. phoenix/core/project.py +0 -773
  76. phoenix/core/traces.py +0 -96
  77. phoenix/datasets/dataset.py +0 -214
  78. phoenix/datasets/fixtures.py +0 -24
  79. phoenix/datasets/schema.py +0 -31
  80. phoenix/experimental/evals/__init__.py +0 -73
  81. phoenix/experimental/evals/evaluators.py +0 -413
  82. phoenix/experimental/evals/functions/__init__.py +0 -4
  83. phoenix/experimental/evals/functions/classify.py +0 -453
  84. phoenix/experimental/evals/functions/executor.py +0 -353
  85. phoenix/experimental/evals/functions/generate.py +0 -138
  86. phoenix/experimental/evals/functions/processing.py +0 -76
  87. phoenix/experimental/evals/models/__init__.py +0 -14
  88. phoenix/experimental/evals/models/anthropic.py +0 -175
  89. phoenix/experimental/evals/models/base.py +0 -170
  90. phoenix/experimental/evals/models/bedrock.py +0 -221
  91. phoenix/experimental/evals/models/litellm.py +0 -134
  92. phoenix/experimental/evals/models/openai.py +0 -453
  93. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  94. phoenix/experimental/evals/models/vertex.py +0 -173
  95. phoenix/experimental/evals/models/vertexai.py +0 -186
  96. phoenix/experimental/evals/retrievals.py +0 -96
  97. phoenix/experimental/evals/templates/__init__.py +0 -50
  98. phoenix/experimental/evals/templates/default_templates.py +0 -472
  99. phoenix/experimental/evals/templates/template.py +0 -195
  100. phoenix/experimental/evals/utils/__init__.py +0 -172
  101. phoenix/experimental/evals/utils/threads.py +0 -27
  102. phoenix/server/api/routers/evaluation_handler.py +0 -110
  103. phoenix/server/api/routers/span_handler.py +0 -70
  104. phoenix/server/api/routers/trace_handler.py +0 -60
  105. phoenix/storage/span_store/__init__.py +0 -23
  106. phoenix/storage/span_store/text_file.py +0 -85
  107. phoenix/trace/dsl/missing.py +0 -60
  108. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/WHEEL +0 -0
  109. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  110. {arize_phoenix-3.25.0.dist-info → arize_phoenix-4.0.1.dist-info}/licenses/LICENSE +0 -0
  111. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  112. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  113. /phoenix/{storage → server/openapi}/__init__.py +0 -0
@@ -1,20 +1,57 @@
1
1
  from dataclasses import dataclass
2
+ from datetime import datetime
2
3
  from pathlib import Path
3
- from typing import Optional, Union
4
+ from typing import AsyncContextManager, Callable, Optional, Union
4
5
 
6
+ from sqlalchemy.ext.asyncio import AsyncSession
5
7
  from starlette.requests import Request
6
8
  from starlette.responses import Response
7
9
  from starlette.websockets import WebSocket
10
+ from typing_extensions import TypeAlias
8
11
 
9
12
  from phoenix.core.model_schema import Model
10
- from phoenix.core.traces import Traces
13
+ from phoenix.server.api.dataloaders import (
14
+ CacheForDataLoaders,
15
+ DocumentEvaluationsDataLoader,
16
+ DocumentEvaluationSummaryDataLoader,
17
+ DocumentRetrievalMetricsDataLoader,
18
+ EvaluationSummaryDataLoader,
19
+ LatencyMsQuantileDataLoader,
20
+ MinStartOrMaxEndTimeDataLoader,
21
+ RecordCountDataLoader,
22
+ SpanDescendantsDataLoader,
23
+ SpanEvaluationsDataLoader,
24
+ TokenCountDataLoader,
25
+ TraceEvaluationsDataLoader,
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class DataLoaders:
31
+ document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
32
+ document_evaluations: DocumentEvaluationsDataLoader
33
+ document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
34
+ evaluation_summaries: EvaluationSummaryDataLoader
35
+ latency_ms_quantile: LatencyMsQuantileDataLoader
36
+ min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
37
+ record_counts: RecordCountDataLoader
38
+ span_descendants: SpanDescendantsDataLoader
39
+ span_evaluations: SpanEvaluationsDataLoader
40
+ token_counts: TokenCountDataLoader
41
+ trace_evaluations: TraceEvaluationsDataLoader
42
+
43
+
44
+ ProjectRowId: TypeAlias = int
11
45
 
12
46
 
13
47
  @dataclass
14
48
  class Context:
15
49
  request: Union[Request, WebSocket]
16
50
  response: Optional[Response]
51
+ db: Callable[[], AsyncContextManager[AsyncSession]]
52
+ data_loaders: DataLoaders
53
+ cache_for_dataloaders: Optional[CacheForDataLoaders]
17
54
  model: Model
18
55
  export_path: Path
19
56
  corpus: Optional[Model] = None
20
- traces: Optional[Traces] = None
57
+ streaming_last_updated_at: Callable[[ProjectRowId], Optional[datetime]] = lambda _: None
@@ -0,0 +1,97 @@
1
+ from dataclasses import dataclass, field
2
+ from functools import singledispatchmethod
3
+
4
+ from phoenix.db.insertion.evaluation import (
5
+ DocumentEvaluationInsertionEvent,
6
+ SpanEvaluationInsertionEvent,
7
+ TraceEvaluationInsertionEvent,
8
+ )
9
+ from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
10
+
11
+ from .document_evaluation_summaries import (
12
+ DocumentEvaluationSummaryCache,
13
+ DocumentEvaluationSummaryDataLoader,
14
+ )
15
+ from .document_evaluations import DocumentEvaluationsDataLoader
16
+ from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
17
+ from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
18
+ from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
19
+ from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
20
+ from .record_counts import RecordCountCache, RecordCountDataLoader
21
+ from .span_descendants import SpanDescendantsDataLoader
22
+ from .span_evaluations import SpanEvaluationsDataLoader
23
+ from .token_counts import TokenCountCache, TokenCountDataLoader
24
+ from .trace_evaluations import TraceEvaluationsDataLoader
25
+
26
+ __all__ = [
27
+ "CacheForDataLoaders",
28
+ "DocumentEvaluationSummaryDataLoader",
29
+ "DocumentEvaluationsDataLoader",
30
+ "DocumentRetrievalMetricsDataLoader",
31
+ "EvaluationSummaryDataLoader",
32
+ "LatencyMsQuantileDataLoader",
33
+ "MinStartOrMaxEndTimeDataLoader",
34
+ "RecordCountDataLoader",
35
+ "SpanDescendantsDataLoader",
36
+ "SpanEvaluationsDataLoader",
37
+ "TokenCountDataLoader",
38
+ "TraceEvaluationsDataLoader",
39
+ ]
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class CacheForDataLoaders:
44
+ document_evaluation_summary: DocumentEvaluationSummaryCache = field(
45
+ default_factory=DocumentEvaluationSummaryCache,
46
+ )
47
+ evaluation_summary: EvaluationSummaryCache = field(
48
+ default_factory=EvaluationSummaryCache,
49
+ )
50
+ latency_ms_quantile: LatencyMsQuantileCache = field(
51
+ default_factory=LatencyMsQuantileCache,
52
+ )
53
+ min_start_or_max_end_time: MinStartOrMaxEndTimeCache = field(
54
+ default_factory=MinStartOrMaxEndTimeCache,
55
+ )
56
+ record_count: RecordCountCache = field(
57
+ default_factory=RecordCountCache,
58
+ )
59
+ token_count: TokenCountCache = field(
60
+ default_factory=TokenCountCache,
61
+ )
62
+
63
+ def _update_spans(self, project_rowid: int) -> None:
64
+ self.latency_ms_quantile.invalidate(project_rowid)
65
+ self.token_count.invalidate(project_rowid)
66
+ self.record_count.invalidate(project_rowid)
67
+ self.min_start_or_max_end_time.invalidate(project_rowid)
68
+
69
+ def _clear_spans(self, project_rowid: int) -> None:
70
+ self._update_spans(project_rowid)
71
+ self.evaluation_summary.invalidate_project(project_rowid)
72
+ self.document_evaluation_summary.invalidate_project(project_rowid)
73
+
74
+ @singledispatchmethod
75
+ def invalidate(self, event: SpanInsertionEvent) -> None:
76
+ project_rowid, *_ = event
77
+ self._update_spans(project_rowid)
78
+
79
+ @invalidate.register
80
+ def _(self, event: ClearProjectSpansEvent) -> None:
81
+ project_rowid, *_ = event
82
+ self._clear_spans(project_rowid)
83
+
84
+ @invalidate.register
85
+ def _(self, event: DocumentEvaluationInsertionEvent) -> None:
86
+ project_rowid, evaluation_name = event
87
+ self.document_evaluation_summary.invalidate((project_rowid, evaluation_name))
88
+
89
+ @invalidate.register
90
+ def _(self, event: SpanEvaluationInsertionEvent) -> None:
91
+ project_rowid, evaluation_name = event
92
+ self.evaluation_summary.invalidate((project_rowid, evaluation_name, "span"))
93
+
94
+ @invalidate.register
95
+ def _(self, event: TraceEvaluationInsertionEvent) -> None:
96
+ project_rowid, evaluation_name = event
97
+ self.evaluation_summary.invalidate((project_rowid, evaluation_name, "trace"))
@@ -0,0 +1,3 @@
1
+ from phoenix.server.api.dataloaders.cache.two_tier_cache import TwoTierCache
2
+
3
+ __all__ = ("TwoTierCache",)
@@ -0,0 +1,67 @@
1
+ """
2
+ The primary intent of a two-tier system is to make cache invalidation more efficient,
3
+ because the cache keys are typically tuples such as (project_id, time_interval, ...),
4
+ but we need to invalidate subsets of keys, e.g. all those associated with a
5
+ specific project, very frequently (i.e. essentially at each span insertion). In a
6
+ single-tier system we would need to check all the keys to see if they are in the
7
+ subset that we want to invalidate.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from asyncio import Future
12
+ from typing import Any, Callable, Generic, Optional, Tuple, TypeVar
13
+
14
+ from cachetools import Cache
15
+ from strawberry.dataloader import AbstractCache
16
+
17
+ _Key = TypeVar("_Key")
18
+ _Result = TypeVar("_Result")
19
+
20
+ _Section = TypeVar("_Section")
21
+ _SubKey = TypeVar("_SubKey")
22
+
23
+
24
+ class TwoTierCache(
25
+ AbstractCache[_Key, _Result],
26
+ Generic[_Key, _Result, _Section, _SubKey],
27
+ ABC,
28
+ ):
29
+ def __init__(
30
+ self,
31
+ main_cache: "Cache[_Section, Cache[_SubKey, Future[_Result]]]",
32
+ sub_cache_factory: Callable[[], "Cache[_SubKey, Future[_Result]]"],
33
+ *args: Any,
34
+ **kwargs: Any,
35
+ ) -> None:
36
+ super().__init__(*args, **kwargs)
37
+ self._cache = main_cache
38
+ self._sub_cache_factory = sub_cache_factory
39
+
40
+ @abstractmethod
41
+ def _cache_key(self, key: _Key) -> Tuple[_Section, _SubKey]: ...
42
+
43
+ def invalidate(self, section: _Section) -> None:
44
+ if sub_cache := self._cache.get(section):
45
+ sub_cache.clear()
46
+
47
+ def get(self, key: _Key) -> Optional["Future[_Result]"]:
48
+ section, sub_key = self._cache_key(key)
49
+ if not (sub_cache := self._cache.get(section)):
50
+ return None
51
+ return sub_cache.get(sub_key)
52
+
53
+ def set(self, key: _Key, value: "Future[_Result]") -> None:
54
+ section, sub_key = self._cache_key(key)
55
+ if (sub_cache := self._cache.get(section)) is None:
56
+ self._cache[section] = sub_cache = self._sub_cache_factory()
57
+ sub_cache[sub_key] = value
58
+
59
+ def delete(self, key: _Key) -> None:
60
+ section, sub_key = self._cache_key(key)
61
+ if sub_cache := self._cache.get(section):
62
+ del sub_cache[sub_key]
63
+ if not sub_cache:
64
+ del self._cache[section]
65
+
66
+ def clear(self) -> None:
67
+ self._cache.clear()
@@ -0,0 +1,152 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ Any,
5
+ AsyncContextManager,
6
+ Callable,
7
+ DefaultDict,
8
+ List,
9
+ Optional,
10
+ Tuple,
11
+ )
12
+
13
+ import numpy as np
14
+ from aioitertools.itertools import groupby
15
+ from cachetools import LFUCache, TTLCache
16
+ from sqlalchemy import Select, select
17
+ from sqlalchemy.ext.asyncio import AsyncSession
18
+ from strawberry.dataloader import AbstractCache, DataLoader
19
+ from typing_extensions import TypeAlias
20
+
21
+ from phoenix.db import models
22
+ from phoenix.db.helpers import SupportedSQLDialect, num_docs_col
23
+ from phoenix.metrics.retrieval_metrics import RetrievalMetrics
24
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
25
+ from phoenix.server.api.input_types.TimeRange import TimeRange
26
+ from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
27
+ from phoenix.trace.dsl import SpanFilter
28
+
29
+ ProjectRowId: TypeAlias = int
30
+ TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
31
+ FilterCondition: TypeAlias = Optional[str]
32
+ EvalName: TypeAlias = str
33
+
34
+ Segment: TypeAlias = Tuple[ProjectRowId, TimeInterval, FilterCondition]
35
+ Param: TypeAlias = EvalName
36
+
37
+ Key: TypeAlias = Tuple[ProjectRowId, Optional[TimeRange], FilterCondition, EvalName]
38
+ Result: TypeAlias = Optional[DocumentEvaluationSummary]
39
+ ResultPosition: TypeAlias = int
40
+ DEFAULT_VALUE: Result = None
41
+
42
+
43
+ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
44
+ project_rowid, time_range, filter_condition, eval_name = key
45
+ interval = (
46
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
47
+ )
48
+ return (project_rowid, interval, filter_condition), eval_name
49
+
50
+
51
+ _Section: TypeAlias = Tuple[ProjectRowId, EvalName]
52
+ _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition]
53
+
54
+
55
+ class DocumentEvaluationSummaryCache(
56
+ TwoTierCache[Key, Result, _Section, _SubKey],
57
+ ):
58
+ def __init__(self) -> None:
59
+ super().__init__(
60
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
61
+ # interval endpoints are rounded down to the hour by the UI, so anything
62
+ # older than an hour most likely won't be a cache-hit anyway.
63
+ main_cache=TTLCache(maxsize=64 * 32, ttl=3600),
64
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2),
65
+ )
66
+
67
+ def invalidate_project(self, project_rowid: ProjectRowId) -> None:
68
+ for section in self._cache.keys():
69
+ if section[0] == project_rowid:
70
+ del self._cache[section]
71
+
72
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
73
+ (project_rowid, interval, filter_condition), eval_name = _cache_key_fn(key)
74
+ return (project_rowid, eval_name), (interval, filter_condition)
75
+
76
+
77
+ class DocumentEvaluationSummaryDataLoader(DataLoader[Key, Result]):
78
+ def __init__(
79
+ self,
80
+ db: Callable[[], AsyncContextManager[AsyncSession]],
81
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
82
+ ) -> None:
83
+ super().__init__(
84
+ load_fn=self._load_fn,
85
+ cache_key_fn=_cache_key_fn,
86
+ cache_map=cache_map,
87
+ )
88
+ self._db = db
89
+
90
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
91
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
92
+ arguments: DefaultDict[
93
+ Segment,
94
+ DefaultDict[Param, List[ResultPosition]],
95
+ ] = defaultdict(lambda: defaultdict(list))
96
+ for position, key in enumerate(keys):
97
+ segment, param = _cache_key_fn(key)
98
+ arguments[segment][param].append(position)
99
+ for segment, params in arguments.items():
100
+ async with self._db() as session:
101
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
102
+ stmt = _get_stmt(dialect, segment, *params.keys())
103
+ data = await session.stream(stmt)
104
+ async for eval_name, group in groupby(data, lambda d: d.name):
105
+ metrics_collection = []
106
+ async for (_, num_docs), subgroup in groupby(
107
+ group, lambda g: (g.id, g.num_docs)
108
+ ):
109
+ scores = [np.nan] * num_docs
110
+ for row in subgroup:
111
+ scores[row.document_position] = row.score
112
+ metrics_collection.append(RetrievalMetrics(scores))
113
+ summary = DocumentEvaluationSummary(
114
+ evaluation_name=eval_name,
115
+ metrics_collection=metrics_collection,
116
+ )
117
+ for position in params[eval_name]:
118
+ results[position] = summary
119
+ return results
120
+
121
+
122
+ def _get_stmt(
123
+ dialect: SupportedSQLDialect,
124
+ segment: Segment,
125
+ *eval_names: Param,
126
+ ) -> Select[Any]:
127
+ project_rowid, (start_time, end_time), filter_condition = segment
128
+ mda = models.DocumentAnnotation
129
+ stmt = (
130
+ select(
131
+ mda.name,
132
+ models.Span.id,
133
+ num_docs_col(dialect),
134
+ mda.score,
135
+ mda.document_position,
136
+ )
137
+ .join(models.Trace)
138
+ .where(models.Trace.project_rowid == project_rowid)
139
+ .join(mda)
140
+ .where(mda.name.in_(eval_names))
141
+ .where(mda.annotator_kind == "LLM")
142
+ .where(mda.score.is_not(None))
143
+ .order_by(mda.name, models.Span.id)
144
+ )
145
+ if start_time:
146
+ stmt = stmt.where(start_time <= models.Span.start_time)
147
+ if end_time:
148
+ stmt = stmt.where(models.Span.start_time < end_time)
149
+ if filter_condition:
150
+ span_filter = SpanFilter(condition=filter_condition)
151
+ stmt = span_filter(stmt)
152
+ return stmt
@@ -0,0 +1,37 @@
1
+ from collections import defaultdict
2
+ from typing import (
3
+ AsyncContextManager,
4
+ Callable,
5
+ DefaultDict,
6
+ List,
7
+ )
8
+
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from strawberry.dataloader import DataLoader
12
+ from typing_extensions import TypeAlias
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.types.Evaluation import DocumentEvaluation
16
+
17
+ Key: TypeAlias = int
18
+ Result: TypeAlias = List[DocumentEvaluation]
19
+
20
+
21
+ class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
22
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
23
+ super().__init__(load_fn=self._load_fn)
24
+ self._db = db
25
+
26
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
27
+ document_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list)
28
+ mda = models.DocumentAnnotation
29
+ async with self._db() as session:
30
+ data = await session.stream_scalars(
31
+ select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM")
32
+ )
33
+ async for document_evaluation in data:
34
+ document_evaluations_by_id[document_evaluation.span_rowid].append(
35
+ DocumentEvaluation.from_sql_document_annotation(document_evaluation)
36
+ )
37
+ return [document_evaluations_by_id[key] for key in keys]
@@ -0,0 +1,98 @@
1
+ from collections import defaultdict
2
+ from typing import (
3
+ AsyncContextManager,
4
+ Callable,
5
+ DefaultDict,
6
+ Dict,
7
+ List,
8
+ Optional,
9
+ Set,
10
+ Tuple,
11
+ )
12
+
13
+ import numpy as np
14
+ from aioitertools.itertools import groupby
15
+ from sqlalchemy import select
16
+ from sqlalchemy.ext.asyncio import AsyncSession
17
+ from strawberry.dataloader import DataLoader
18
+ from typing_extensions import TypeAlias
19
+
20
+ from phoenix.db import models
21
+ from phoenix.metrics.retrieval_metrics import RetrievalMetrics
22
+ from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
23
+
24
+ RowId: TypeAlias = int
25
+ NumDocs: TypeAlias = int
26
+ EvalName: TypeAlias = Optional[str]
27
+
28
+ Key: TypeAlias = Tuple[RowId, EvalName, NumDocs]
29
+ Result: TypeAlias = List[DocumentRetrievalMetrics]
30
+
31
+
32
+ class DocumentRetrievalMetricsDataLoader(DataLoader[Key, Result]):
33
+ def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
34
+ super().__init__(load_fn=self._load_fn)
35
+ self._db = db
36
+
37
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
38
+ mda = models.DocumentAnnotation
39
+ stmt = (
40
+ select(
41
+ mda.span_rowid,
42
+ mda.name,
43
+ mda.score,
44
+ mda.document_position,
45
+ )
46
+ .where(mda.score != None) # noqa: E711
47
+ .where(mda.annotator_kind == "LLM")
48
+ .where(mda.document_position >= 0)
49
+ .order_by(mda.span_rowid, mda.name)
50
+ )
51
+ # Using CTE with VALUES clause is possible in SQLite, but not in
52
+ # SQLAlchemy v2.0.29, hence the workaround below with over-fetching.
53
+ # We could use CTE with VALUES for postgresql, but for now we'll keep
54
+ # it simple and just use one approach for all backends.
55
+ all_row_ids = {row_id for row_id, _, _ in keys}
56
+ stmt = stmt.where(mda.span_rowid.in_(all_row_ids))
57
+ all_eval_names = {eval_name for _, eval_name, _ in keys}
58
+ if None not in all_eval_names:
59
+ stmt = stmt.where(mda.name.in_(all_eval_names))
60
+ max_position = max(num_docs for _, _, num_docs in keys)
61
+ stmt = stmt.where(mda.document_position < max_position)
62
+ results: Dict[Key, Result] = {key: [] for key in keys}
63
+ requested_num_docs: DefaultDict[Tuple[RowId, EvalName], Set[NumDocs]] = defaultdict(set)
64
+ for row_id, eval_name, num_docs in results.keys():
65
+ requested_num_docs[(row_id, eval_name)].add(num_docs)
66
+ async with self._db() as session:
67
+ data = await session.stream(stmt)
68
+ async for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)):
69
+ # We need to fulfill two types of potential requests: 1. when it
70
+ # specifies an evaluation name, and 2. when it doesn't care about
71
+ # the evaluation name by specifying None.
72
+ max_requested_num_docs = max(
73
+ (
74
+ num_docs
75
+ for eval_name in (name, None)
76
+ for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ())
77
+ ),
78
+ default=0,
79
+ )
80
+ if max_requested_num_docs <= 0:
81
+ # We have over-fetched. Skip this group.
82
+ continue
83
+ scores = [np.nan] * max_requested_num_docs
84
+ for row in group:
85
+ # Length check is necessary due to over-fetching.
86
+ if row.document_position < len(scores):
87
+ scores[row.document_position] = row.score
88
+ for eval_name in (name, None):
89
+ for num_docs in requested_num_docs.get((span_rowid, eval_name)) or ():
90
+ metrics = RetrievalMetrics(scores[:num_docs])
91
+ doc_metrics = DocumentRetrievalMetrics(
92
+ evaluation_name=name, metrics=metrics
93
+ )
94
+ key = (span_rowid, eval_name, num_docs)
95
+ results[key].append(doc_metrics)
96
+ # Make sure to copy the result, so we don't return the same list
97
+ # object to two different requesters.
98
+ return [results[key].copy() for key in keys]
@@ -0,0 +1,151 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import (
4
+ Any,
5
+ AsyncContextManager,
6
+ Callable,
7
+ DefaultDict,
8
+ List,
9
+ Literal,
10
+ Optional,
11
+ Tuple,
12
+ )
13
+
14
+ import pandas as pd
15
+ from aioitertools.itertools import groupby
16
+ from cachetools import LFUCache, TTLCache
17
+ from sqlalchemy import Select, func, or_, select
18
+ from sqlalchemy.ext.asyncio import AsyncSession
19
+ from strawberry.dataloader import AbstractCache, DataLoader
20
+ from typing_extensions import TypeAlias, assert_never
21
+
22
+ from phoenix.db import models
23
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
24
+ from phoenix.server.api.input_types.TimeRange import TimeRange
25
+ from phoenix.server.api.types.EvaluationSummary import EvaluationSummary
26
+ from phoenix.trace.dsl import SpanFilter
27
+
28
+ Kind: TypeAlias = Literal["span", "trace"]
29
+ ProjectRowId: TypeAlias = int
30
+ TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
31
+ FilterCondition: TypeAlias = Optional[str]
32
+ EvalName: TypeAlias = str
33
+
34
+ Segment: TypeAlias = Tuple[Kind, ProjectRowId, TimeInterval, FilterCondition]
35
+ Param: TypeAlias = EvalName
36
+
37
+ Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, EvalName]
38
+ Result: TypeAlias = Optional[EvaluationSummary]
39
+ ResultPosition: TypeAlias = int
40
+ DEFAULT_VALUE: Result = None
41
+
42
+
43
+ def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
44
+ kind, project_rowid, time_range, filter_condition, eval_name = key
45
+ interval = (
46
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
47
+ )
48
+ return (kind, project_rowid, interval, filter_condition), eval_name
49
+
50
+
51
+ _Section: TypeAlias = Tuple[ProjectRowId, EvalName, Kind]
52
+ _SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition]
53
+
54
+
55
+ class EvaluationSummaryCache(
56
+ TwoTierCache[Key, Result, _Section, _SubKey],
57
+ ):
58
+ def __init__(self) -> None:
59
+ super().__init__(
60
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
61
+ # interval endpoints are rounded down to the hour by the UI, so anything
62
+ # older than an hour most likely won't be a cache-hit anyway.
63
+ main_cache=TTLCache(maxsize=64 * 32 * 2, ttl=3600),
64
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2),
65
+ )
66
+
67
+ def invalidate_project(self, project_rowid: ProjectRowId) -> None:
68
+ for section in self._cache.keys():
69
+ if section[0] == project_rowid:
70
+ del self._cache[section]
71
+
72
+ def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
73
+ (kind, project_rowid, interval, filter_condition), eval_name = _cache_key_fn(key)
74
+ return (project_rowid, eval_name, kind), (interval, filter_condition)
75
+
76
+
77
+ class EvaluationSummaryDataLoader(DataLoader[Key, Result]):
78
+ def __init__(
79
+ self,
80
+ db: Callable[[], AsyncContextManager[AsyncSession]],
81
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
82
+ ) -> None:
83
+ super().__init__(
84
+ load_fn=self._load_fn,
85
+ cache_key_fn=_cache_key_fn,
86
+ cache_map=cache_map,
87
+ )
88
+ self._db = db
89
+
90
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
91
+ results: List[Result] = [DEFAULT_VALUE] * len(keys)
92
+ arguments: DefaultDict[
93
+ Segment,
94
+ DefaultDict[Param, List[ResultPosition]],
95
+ ] = defaultdict(lambda: defaultdict(list))
96
+ for position, key in enumerate(keys):
97
+ segment, param = _cache_key_fn(key)
98
+ arguments[segment][param].append(position)
99
+ for segment, params in arguments.items():
100
+ stmt = _get_stmt(segment, *params.keys())
101
+ async with self._db() as session:
102
+ data = await session.stream(stmt)
103
+ async for eval_name, group in groupby(data, lambda row: row.name):
104
+ summary = EvaluationSummary(pd.DataFrame(group))
105
+ for position in params[eval_name]:
106
+ results[position] = summary
107
+ return results
108
+
109
+
110
+ def _get_stmt(
111
+ segment: Segment,
112
+ *eval_names: Param,
113
+ ) -> Select[Any]:
114
+ kind, project_rowid, (start_time, end_time), filter_condition = segment
115
+ stmt = select()
116
+ if kind == "span":
117
+ msa = models.SpanAnnotation
118
+ name_column, label_column, score_column = msa.name, msa.label, msa.score
119
+ annotator_kind_column = msa.annotator_kind
120
+ time_column = models.Span.start_time
121
+ stmt = stmt.join(models.Span).join_from(models.Span, models.Trace)
122
+ if filter_condition:
123
+ sf = SpanFilter(filter_condition)
124
+ stmt = sf(stmt)
125
+ elif kind == "trace":
126
+ mta = models.TraceAnnotation
127
+ name_column, label_column, score_column = mta.name, mta.label, mta.score
128
+ annotator_kind_column = mta.annotator_kind
129
+ time_column = models.Trace.start_time
130
+ stmt = stmt.join(models.Trace)
131
+ else:
132
+ assert_never(kind)
133
+ stmt = stmt.add_columns(
134
+ name_column,
135
+ label_column,
136
+ func.count().label("record_count"),
137
+ func.count(label_column).label("label_count"),
138
+ func.count(score_column).label("score_count"),
139
+ func.sum(score_column).label("score_sum"),
140
+ )
141
+ stmt = stmt.group_by(name_column, label_column)
142
+ stmt = stmt.order_by(name_column, label_column)
143
+ stmt = stmt.where(models.Trace.project_rowid == project_rowid)
144
+ stmt = stmt.where(annotator_kind_column == "LLM")
145
+ stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
146
+ stmt = stmt.where(name_column.in_(eval_names))
147
+ if start_time:
148
+ stmt = stmt.where(start_time <= time_column)
149
+ if end_time:
150
+ stmt = stmt.where(time_column < end_time)
151
+ return stmt