arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/model_provider.py +1 -0
  12. phoenix/db/types/token_price_customization.py +29 -0
  13. phoenix/server/api/context.py +38 -4
  14. phoenix/server/api/dataloaders/__init__.py +41 -5
  15. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  16. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  20. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  21. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  27. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  28. phoenix/server/api/dataloaders/span_costs.py +35 -0
  29. phoenix/server/api/dataloaders/types.py +29 -0
  30. phoenix/server/api/helpers/playground_clients.py +562 -12
  31. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  32. phoenix/server/api/helpers/prompts/models.py +67 -0
  33. phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
  34. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  35. phoenix/server/api/input_types/SpanSort.py +17 -0
  36. phoenix/server/api/mutations/__init__.py +2 -0
  37. phoenix/server/api/mutations/chat_mutations.py +17 -0
  38. phoenix/server/api/mutations/model_mutations.py +208 -0
  39. phoenix/server/api/queries.py +82 -41
  40. phoenix/server/api/routers/v1/traces.py +11 -4
  41. phoenix/server/api/subscriptions.py +36 -2
  42. phoenix/server/api/types/CostBreakdown.py +15 -0
  43. phoenix/server/api/types/Experiment.py +59 -1
  44. phoenix/server/api/types/ExperimentRun.py +58 -4
  45. phoenix/server/api/types/GenerativeModel.py +143 -2
  46. phoenix/server/api/types/GenerativeProvider.py +33 -20
  47. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  48. phoenix/server/api/types/ModelInterface.py +11 -0
  49. phoenix/server/api/types/PlaygroundModel.py +10 -0
  50. phoenix/server/api/types/Project.py +42 -0
  51. phoenix/server/api/types/ProjectSession.py +44 -0
  52. phoenix/server/api/types/Span.py +137 -0
  53. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  54. phoenix/server/api/types/SpanCostSummary.py +10 -0
  55. phoenix/server/api/types/TokenPrice.py +16 -0
  56. phoenix/server/api/types/TokenUsage.py +3 -3
  57. phoenix/server/api/types/Trace.py +41 -0
  58. phoenix/server/app.py +59 -0
  59. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  60. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  61. phoenix/server/cost_tracking/helpers.py +68 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  63. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  64. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  65. phoenix/server/daemons/__init__.py +0 -0
  66. phoenix/server/daemons/generative_model_store.py +51 -0
  67. phoenix/server/daemons/span_cost_calculator.py +103 -0
  68. phoenix/server/dml_event_handler.py +1 -0
  69. phoenix/server/static/.vite/manifest.json +36 -36
  70. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  71. phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
  72. phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
  73. phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
  74. phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
  75. phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  76. phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
  77. phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  80. phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
  81. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,397 @@
1
+ """
2
+ Regex specificity scorer based on heuristics intended for tie-breaking.
3
+
4
+ This module provides functionality to score regex patterns based on their specificity.
5
+ More specific patterns (like exact matches with anchors) receive higher scores,
6
+ while more general patterns (like wildcards and quantifiers) receive lower scores.
7
+
8
+ Scoring Weights:
9
+ - Full anchors (^pattern$): +10000 points
10
+ - Partial anchors (^pattern or pattern$): +5000 points
11
+ - Literal characters: +1000 points each
12
+ - Escaped characters (\\. \\+ etc): +950 points each
13
+ - Character classes [abc]: +500 points
14
+ - Shorthand classes (\\d \\w \\s): +400 points
15
+ - Negated classes [^abc]: +300 points
16
+ - Negated shorthand (\\D \\W \\S): +250 points
17
+ - Exact quantifiers {n}: -50 points
18
+ - Range quantifiers {n,m}: -100 points
19
+ - Wildcards (.): -200 points
20
+ - Optional (?): -100 points
21
+ - Multiple (+ *): -150 points
22
+ - Alternation (|): -300 points
23
+
24
+ Examples:
25
+ >>> score("^abc$") # Exact match: 12002
26
+ >>> score("abc") # Literal: 3002
27
+ >>> score(".*") # Wildcard: -198
28
+ >>> score("[a-z]+") # Class + multiple: 350
29
+ >>> score("\\d{3}") # Shorthand + exact quantifier: 350
30
+ """
31
+
32
+ import re
33
+ from typing import Union
34
+
35
+ from typing_extensions import assert_never
36
+
37
+ # Scoring weights for different regex pattern elements
38
+ FULL_ANCHOR = 10000 # ^pattern$ - highest specificity
39
+ PARTIAL_ANCHOR = 5000 # ^pattern or pattern$ - high specificity
40
+ LITERAL = 1000 # exact characters - good specificity
41
+ ESCAPED = 950 # \. \+ etc - slightly less than literal
42
+ CHAR_CLASS = 500 # [abc] [0-9] - moderate specificity
43
+ SHORTHAND = 400 # \d \w \s - moderate specificity
44
+ NEGATED_CLASS = 300 # [^abc] - lower specificity
45
+ NEGATED_SHORTHAND = 250 # \D \W \S - lower specificity
46
+ QUANTIFIER_EXACT = -50 # {n} - reduces specificity
47
+ QUANTIFIER_RANGE = -100 # {n,m} {n,} - reduces specificity more
48
+ WILDCARD = -200 # . - significantly reduces specificity
49
+ OPTIONAL = -100 # ? - reduces specificity
50
+ MULTIPLE = -150 # + * - reduces specificity
51
+ ALTERNATION = -300 # | - significantly reduces specificity
52
+
53
+ # Character sets for classification
54
+ POSITIVE_SHORTHANDS = "dws" # \d \w \s - digit, word, space
55
+ NEGATIVE_SHORTHANDS = "DWS" # \D \W \S - non-digit, non-word, non-space
56
+ META_CHARS = "()^$" # Regex metacharacters that don't affect scoring
57
+
58
+
59
+ def score(regex: Union[str, re.Pattern[str]]) -> int:
60
+ """
61
+ Score a regex pattern for specificity.
62
+
63
+ Calculates a specificity score for a regex pattern where higher scores
64
+ indicate more specific patterns. The scoring considers:
65
+
66
+ - Anchors (^ and $) - significantly increase specificity
67
+ - Character types (literal, escaped, classes) - moderate impact
68
+ - Quantifiers and wildcards - reduce specificity
69
+ - Pattern length - slight bonus for longer patterns
70
+
71
+ Args:
72
+ regex: The regex pattern string to score. Must be a valid regex.
73
+
74
+ Returns:
75
+ An integer score where:
76
+ - Positive scores indicate specific patterns
77
+ - Higher scores indicate more specific patterns
78
+ - Negative scores indicate very general patterns
79
+ - Minimum score is 1 (for empty patterns)
80
+
81
+ Raises:
82
+ ValueError: If the pattern is not a valid regex or is None.
83
+
84
+ Examples:
85
+ >>> score("^abc$")
86
+ 12002
87
+ >>> score("abc")
88
+ 3002
89
+ >>> score(".*")
90
+ -198
91
+ >>> score("")
92
+ 1
93
+ >>> score("[a-z]+")
94
+ 350
95
+ >>> score("\\d{3}")
96
+ 350
97
+
98
+ Note:
99
+ The scoring algorithm is designed for cost tracking scenarios
100
+ where more specific patterns should be prioritized over general ones.
101
+ """
102
+ if isinstance(regex, str):
103
+ pattern = regex
104
+ try:
105
+ re.compile(pattern) # Validate regex
106
+ except re.error as e:
107
+ raise ValueError(f"Invalid regex pattern: {pattern}") from e
108
+ elif isinstance(regex.pattern, str):
109
+ pattern = regex.pattern
110
+ elif isinstance(regex.pattern, bytes):
111
+ pattern = regex.pattern.decode("utf-8")
112
+ else:
113
+ assert_never(regex.pattern)
114
+
115
+ score_value = 0
116
+
117
+ # Score anchors - most significant factor
118
+ has_start_anchor = _has_start_anchor(pattern)
119
+ has_end_anchor = pattern.endswith("$")
120
+
121
+ if has_start_anchor and has_end_anchor:
122
+ score_value += FULL_ANCHOR
123
+ elif has_start_anchor or has_end_anchor:
124
+ score_value += PARTIAL_ANCHOR
125
+
126
+ # Score pattern content
127
+ content = _strip_anchors(pattern)
128
+ score_value += _score_content(content)
129
+
130
+ # Length bonus for tie-breaking (longer patterns slightly preferred)
131
+ score_value += len(pattern) * 2
132
+
133
+ return max(score_value, 1)
134
+
135
+
136
+ def _has_start_anchor(pattern: str) -> bool:
137
+ """
138
+ Check if pattern has a start anchor (after all leading inline flags).
139
+ Handles multiple inline flags robustly.
140
+ """
141
+ i = 0
142
+ # Skip all leading inline flags
143
+ while pattern.startswith("(?", i):
144
+ close = pattern.find(")", i)
145
+ if close == -1:
146
+ break
147
+ i = close + 1
148
+ # After all flags, check for ^
149
+ return i < len(pattern) and pattern[i] == "^"
150
+
151
+
152
+ def _strip_anchors(pattern: str) -> str:
153
+ """
154
+ Remove all leading inline flags and anchors from pattern for content analysis.
155
+ Handles multiple inline flags robustly.
156
+ """
157
+ i = 0
158
+ # Remove all leading inline flags
159
+ while pattern.startswith("(?", i):
160
+ close = pattern.find(")", i)
161
+ if close == -1:
162
+ break
163
+ i = close + 1
164
+ # Remove start anchor
165
+ if i < len(pattern) and pattern[i] == "^":
166
+ i += 1
167
+ content = pattern[i:]
168
+ # Remove end anchor
169
+ if content.endswith("$"):
170
+ content = content[:-1]
171
+ return content
172
+
173
+
174
+ def _score_content(content: str) -> int:
175
+ r"""
176
+ Score the content of a pattern by analyzing each character.
177
+
178
+ Iterates through the pattern content and scores each element:
179
+ - Escape sequences (\d, \., etc.)
180
+ - Character classes ([abc], [^abc])
181
+ - Quantifiers ({n}, {n,m})
182
+ - Special characters (., ?, +, *, |)
183
+ - Literal characters
184
+
185
+ Args:
186
+ content: Pattern content without anchors
187
+
188
+ Returns:
189
+ Cumulative score for all pattern elements
190
+ """
191
+ score_value = 0
192
+ i = 0
193
+
194
+ while i < len(content):
195
+ char = content[i]
196
+
197
+ if char == "\\" and i + 1 < len(content):
198
+ # Handle escape sequences
199
+ score_value += _score_escape(content[i + 1])
200
+ i += 2
201
+ elif char == "[":
202
+ # Handle character classes
203
+ bracket_score, new_pos = _score_bracket(content, i)
204
+ score_value += bracket_score
205
+ i = new_pos
206
+ elif char == "{":
207
+ # Handle quantifiers
208
+ quantifier_score, new_pos = _score_quantifier(content, i)
209
+ score_value += quantifier_score
210
+ i = new_pos
211
+ else:
212
+ # Handle single characters
213
+ score_value += _score_char(char)
214
+ i += 1
215
+
216
+ return score_value
217
+
218
+
219
+ def _score_escape(char: str) -> int:
220
+ r"""
221
+ Score an escape sequence.
222
+
223
+ Args:
224
+ char: The character following the backslash
225
+
226
+ Returns:
227
+ Score for the escape sequence:
228
+ - \d, \w, \s: +400 (shorthand classes)
229
+ - \D, \W, \S: +250 (negated shorthand)
230
+ - \., \+, etc: +950 (escaped literals)
231
+ """
232
+ if char in POSITIVE_SHORTHANDS:
233
+ return SHORTHAND
234
+ elif char in NEGATIVE_SHORTHANDS:
235
+ return NEGATED_SHORTHAND
236
+ else:
237
+ return ESCAPED
238
+
239
+
240
+ def _score_bracket(content: str, start: int) -> tuple[int, int]:
241
+ """
242
+ Score a character class and find its end position.
243
+
244
+ Args:
245
+ content: Pattern content
246
+ start: Starting position of the opening bracket
247
+
248
+ Returns:
249
+ Tuple of (score, next_position):
250
+ - score: +500 for [abc], +300 for [^abc]
251
+ - next_position: Position after the closing bracket
252
+ """
253
+ end = _find_bracket_end(content, start)
254
+ if end == -1:
255
+ # Malformed bracket, treat as literal
256
+ return LITERAL, start + 1
257
+
258
+ class_content = content[start + 1 : end]
259
+ score_value = NEGATED_CLASS if class_content.startswith("^") else CHAR_CLASS
260
+
261
+ return score_value, end + 1
262
+
263
+
264
+ def _score_quantifier(content: str, start: int) -> tuple[int, int]:
265
+ """
266
+ Score a quantifier and find its end position.
267
+
268
+ Args:
269
+ content: Pattern content
270
+ start: Starting position of the opening brace
271
+
272
+ Returns:
273
+ Tuple of (score, next_position):
274
+ - score: -50 for {n}, -100 for {n,m} or {n,}
275
+ - next_position: Position after the closing brace
276
+ """
277
+ end = content.find("}", start)
278
+ if end == -1:
279
+ # Malformed quantifier, treat as literal
280
+ return LITERAL, start + 1
281
+
282
+ quantifier = content[start : end + 1]
283
+
284
+ # Validate quantifier syntax
285
+ if not _is_valid_quantifier(quantifier):
286
+ return LITERAL, start + 1
287
+
288
+ has_comma = "," in quantifier
289
+
290
+ score_value = QUANTIFIER_RANGE if has_comma else QUANTIFIER_EXACT
291
+
292
+ return score_value, end + 1
293
+
294
+
295
+ def _is_valid_quantifier(quantifier: str) -> bool:
296
+ """
297
+ Check if a quantifier has valid syntax.
298
+
299
+ Args:
300
+ quantifier: Quantifier string like "{n}", "{n,m}", "{n,}"
301
+
302
+ Returns:
303
+ True if quantifier syntax is valid
304
+ """
305
+ if not quantifier.startswith("{") or not quantifier.endswith("}"):
306
+ return False
307
+
308
+ # Extract content between braces
309
+ content = quantifier[1:-1]
310
+
311
+ if "," in content:
312
+ # Range quantifier: {n,m} or {n,}
313
+ parts = content.split(",")
314
+ if len(parts) != 2:
315
+ return False
316
+
317
+ min_part, max_part = parts
318
+
319
+ # Check minimum part
320
+ if not min_part.isdigit():
321
+ return False
322
+
323
+ # Check maximum part (can be empty for {n,})
324
+ if max_part and not max_part.isdigit():
325
+ return False
326
+
327
+ # Validate range
328
+ if max_part:
329
+ min_val = int(min_part)
330
+ max_val = int(max_part)
331
+ if min_val > max_val:
332
+ return False
333
+ else:
334
+ # Exact quantifier: {n}
335
+ if not content.isdigit():
336
+ return False
337
+
338
+ return True
339
+
340
+
341
+ def _score_char(char: str) -> int:
342
+ """
343
+ Score a single character.
344
+
345
+ Args:
346
+ char: Single character to score
347
+
348
+ Returns:
349
+ Score for the character:
350
+ - .: -200 (wildcard)
351
+ - ?: -100 (optional)
352
+ - |: -300 (alternation)
353
+ - +, *: -150 (multiple)
354
+ - (, ), ^, $: 0 (metacharacters)
355
+ - Other: +1000 (literal)
356
+ """
357
+ char_scores = {
358
+ ".": WILDCARD,
359
+ "?": OPTIONAL,
360
+ "|": ALTERNATION,
361
+ }
362
+
363
+ if char in char_scores:
364
+ return char_scores[char]
365
+ elif char in "+*":
366
+ return MULTIPLE
367
+ elif char in META_CHARS:
368
+ return 0 # Metacharacters don't affect scoring
369
+ else:
370
+ return LITERAL
371
+
372
+
373
+ def _find_bracket_end(pattern: str, start: int) -> int:
374
+ r"""
375
+ Find the end of a character class, handling escaped brackets.
376
+
377
+ Args:
378
+ pattern: Pattern string
379
+ start: Position of opening bracket
380
+
381
+ Returns:
382
+ Position of closing bracket, or -1 if not found
383
+
384
+ Note:
385
+ Handles escaped closing brackets like [a\]b] correctly.
386
+ """
387
+ for i in range(start + 1, len(pattern)):
388
+ if pattern[i] == "]":
389
+ # Count backslashes to check if this ] is escaped
390
+ backslashes = 0
391
+ j = i - 1
392
+ while j >= 0 and pattern[j] == "\\":
393
+ backslashes += 1
394
+ j -= 1
395
+ if backslashes % 2 == 0: # Not escaped
396
+ return i
397
+ return -1
@@ -0,0 +1,57 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Mapping, Optional
3
+
4
+ from typing_extensions import override
5
+
6
+ from phoenix.db.types.token_price_customization import (
7
+ ThresholdBasedTokenPriceCustomization,
8
+ TokenPriceCustomization,
9
+ )
10
+ from phoenix.trace.attributes import get_attribute_value
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TokenCostCalculator:
15
+ base_rate: float
16
+
17
+ def calculate_cost(
18
+ self,
19
+ attributes: Mapping[str, Any],
20
+ tokens: int,
21
+ ) -> float:
22
+ return tokens * self.base_rate
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class ThresholdBasedTokenCostCalculator(TokenCostCalculator):
27
+ key: str
28
+ threshold: float
29
+ new_rate: float
30
+
31
+ @override
32
+ def calculate_cost(
33
+ self,
34
+ attributes: Mapping[str, Any],
35
+ tokens: float,
36
+ ) -> float:
37
+ if not (v := get_attribute_value(attributes, self.key)):
38
+ return tokens * self.base_rate
39
+ if v > self.threshold:
40
+ return tokens * self.new_rate
41
+ return tokens * self.base_rate
42
+
43
+
44
+ def create_token_cost_calculator(
45
+ base_rate: float,
46
+ customization: Optional[TokenPriceCustomization] = None,
47
+ ) -> TokenCostCalculator:
48
+ if not customization:
49
+ return TokenCostCalculator(base_rate=base_rate)
50
+ if isinstance(customization, ThresholdBasedTokenPriceCustomization):
51
+ return ThresholdBasedTokenCostCalculator(
52
+ base_rate=base_rate,
53
+ key=customization.key,
54
+ threshold=customization.threshold,
55
+ new_rate=customization.new_rate,
56
+ )
57
+ return TokenCostCalculator(base_rate=base_rate)
File without changes
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from asyncio import sleep
5
+ from datetime import datetime
6
+ from typing import Any, Mapping, Optional
7
+
8
+ import sqlalchemy as sa
9
+ from sqlalchemy.orm import joinedload
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.cost_tracking.cost_model_lookup import CostModelLookup
13
+ from phoenix.server.types import DaemonTask, DbSessionFactory
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class GenerativeModelStore(DaemonTask):
19
+ def __init__(
20
+ self,
21
+ db: DbSessionFactory,
22
+ ) -> None:
23
+ super().__init__()
24
+ self._db = db
25
+ self._lookup = CostModelLookup()
26
+
27
+ def find_model(
28
+ self,
29
+ start_time: datetime,
30
+ attributes: Mapping[str, Any],
31
+ ) -> Optional[models.GenerativeModel]:
32
+ return self._lookup.find_model(start_time, attributes)
33
+
34
+ async def _run(self) -> None:
35
+ while self._running:
36
+ try:
37
+ await self._fetch_models()
38
+ except Exception:
39
+ logger.exception("Failed to refresh generative models")
40
+ await sleep(5) # Refresh every 5 seconds
41
+
42
+ async def _fetch_models(self) -> None:
43
+ stmt = (
44
+ sa.select(models.GenerativeModel)
45
+ .where(models.GenerativeModel.deleted_at.is_(None))
46
+ .options(joinedload(models.GenerativeModel.token_prices))
47
+ .order_by(models.GenerativeModel.name)
48
+ )
49
+ async with self._db() as session:
50
+ result = await session.scalars(stmt)
51
+ self._lookup = CostModelLookup(result.unique())
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from asyncio import sleep
5
+ from datetime import datetime
6
+ from typing import Any, Mapping, NamedTuple, Optional
7
+
8
+ from sqlalchemy import inspect
9
+ from typing_extensions import TypeAlias
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.cost_tracking.cost_details_calculator import SpanCostDetailsCalculator
13
+ from phoenix.server.daemons.generative_model_store import GenerativeModelStore
14
+ from phoenix.server.types import DaemonTask, DbSessionFactory
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _GenerativeModelId: TypeAlias = int
19
+
20
+
21
+ class SpanCostCalculatorQueueItem(NamedTuple):
22
+ span_rowid: int
23
+ trace_rowid: int
24
+ attributes: Mapping[str, Any]
25
+ span_start_time: datetime
26
+
27
+
28
+ class SpanCostCalculator(DaemonTask):
29
+ _SLEEP_INTERVAL = 5 # seconds
30
+
31
+ def __init__(
32
+ self,
33
+ db: DbSessionFactory,
34
+ model_store: GenerativeModelStore,
35
+ ) -> None:
36
+ super().__init__()
37
+ self._db = db
38
+ self._model_store = model_store
39
+ self._queue: list[SpanCostCalculatorQueueItem] = []
40
+
41
+ async def _run(self) -> None:
42
+ while self._running:
43
+ try:
44
+ await self._insert_costs()
45
+ except Exception as e:
46
+ logger.exception(f"Failed to insert costs: {e}")
47
+ await sleep(self._SLEEP_INTERVAL)
48
+
49
+ async def _insert_costs(self) -> None:
50
+ if not self._queue:
51
+ return
52
+ costs: list[models.SpanCost] = []
53
+ for item in self._queue:
54
+ try:
55
+ cost = self.calculate_cost(item.span_start_time, item.attributes)
56
+ except Exception as e:
57
+ logger.exception(f"Failed to calculate cost for span {item.span_rowid}: {e}")
58
+ continue
59
+ if not cost:
60
+ continue
61
+ cost.span_rowid = item.span_rowid
62
+ cost.trace_rowid = item.trace_rowid
63
+ costs.append(cost)
64
+ try:
65
+ async with self._db() as session:
66
+ session.add_all(costs)
67
+ except Exception as e:
68
+ logger.exception(f"Failed to insert costs: {e}")
69
+ finally:
70
+ # Clear the queue after processing
71
+ self._queue.clear()
72
+
73
+ def put_nowait(self, item: SpanCostCalculatorQueueItem) -> None:
74
+ self._queue.append(item)
75
+
76
+ def calculate_cost(
77
+ self,
78
+ start_time: datetime,
79
+ attributes: Mapping[str, Any],
80
+ ) -> Optional[models.SpanCost]:
81
+ if not attributes:
82
+ return None
83
+ cost_model = self._model_store.find_model(
84
+ start_time=start_time,
85
+ attributes=attributes,
86
+ )
87
+ if not cost_model:
88
+ return None
89
+ if not isinstance(inspect(cost_model).attrs.token_prices.loaded_value, list):
90
+ return None
91
+
92
+ calculator = SpanCostDetailsCalculator(cost_model.token_prices)
93
+ details = calculator.calculate_details(attributes)
94
+ if not details:
95
+ return None
96
+
97
+ cost = models.SpanCost(
98
+ model_id=cost_model.id,
99
+ span_start_time=start_time,
100
+ )
101
+ for detail in details:
102
+ cost.append_detail(detail)
103
+ return cost
@@ -120,6 +120,7 @@ class _SpanDmlEventHandler(_DmlEventHandler[SpanDmlEvent]):
120
120
  def _clear(cache: CacheForDataLoaders, project_id: int) -> None:
121
121
  cache.latency_ms_quantile.invalidate(project_id)
122
122
  cache.token_count.invalidate(project_id)
123
+ cache.token_cost.invalidate(project_id)
123
124
  cache.record_count.invalidate(project_id)
124
125
  cache.min_start_or_max_end_time.invalidate(project_id)
125
126