ursa-ai 0.4.2__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,319 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import asdict, dataclass
6
+ from decimal import ROUND_HALF_UP, Decimal, getcontext
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ getcontext().prec = 28 # robust money math
10
+
11
+
12
+ # ---------- Model pricing schema ----------
13
+
14
+
15
+ @dataclass
16
+ class ModelPricing:
17
+ # Prices are USD per 1,000 tokens
18
+ input_per_1k: Decimal
19
+ output_per_1k: Decimal
20
+ reasoning_per_1k: Optional[Decimal] = (
21
+ None # None --> charge 0 for reasoning tokens
22
+ )
23
+ cached_input_multiplier: Decimal = Decimal(
24
+ "1"
25
+ ) # e.g., 0.25 if your provider discounts cached prompt tokens
26
+
27
+ def price_tokens(self, usage: Dict[str, Any]) -> Dict[str, Decimal]:
28
+ """Compute cost components from a usage dict with keys like input_tokens, output_tokens, reasoning_tokens, cached_tokens."""
29
+
30
+ def _to_dec(x) -> Decimal:
31
+ if x is None:
32
+ return Decimal("0")
33
+ try:
34
+ return Decimal(str(x))
35
+ except Exception:
36
+ return Decimal("0")
37
+
38
+ in_t = _to_dec(usage.get("input_tokens", usage.get("prompt_tokens", 0)))
39
+ out_t = _to_dec(
40
+ usage.get("output_tokens", usage.get("completion_tokens", 0))
41
+ )
42
+ # total_t = _to_dec(usage.get("total_tokens", (in_t + out_t)))
43
+ cached_t = _to_dec(usage.get("cached_tokens", 0))
44
+ reasoning_t = _to_dec(usage.get("reasoning_tokens", 0))
45
+
46
+ eff_in = (in_t - cached_t) if in_t > cached_t else Decimal("0")
47
+ cached_eff_in = cached_t
48
+
49
+ input_cost = (eff_in / Decimal(1000)) * self.input_per_1k
50
+ cached_input_cost = (
51
+ (cached_eff_in / Decimal(1000))
52
+ * self.input_per_1k
53
+ * self.cached_input_multiplier
54
+ )
55
+ output_cost = (out_t / Decimal(1000)) * self.output_per_1k
56
+ reasoning_cost = Decimal("0")
57
+ if self.reasoning_per_1k is not None and reasoning_t > 0:
58
+ reasoning_cost = (
59
+ reasoning_t / Decimal(1000)
60
+ ) * self.reasoning_per_1k
61
+
62
+ total_cost = (
63
+ input_cost + cached_input_cost + output_cost + reasoning_cost
64
+ )
65
+ return {
66
+ "input_cost": input_cost,
67
+ "cached_input_cost": cached_input_cost,
68
+ "output_cost": output_cost,
69
+ "reasoning_cost": reasoning_cost,
70
+ "total_cost": total_cost,
71
+ }
72
+
73
+
74
+ # ---------- Registry & resolution ----------
75
+
76
+
77
+ def _dec(x: str | float | int) -> Decimal:
78
+ try:
79
+ return Decimal(str(x))
80
+ except Exception:
81
+ return Decimal("0")
82
+
83
+
84
+ # DEFAULTS: keep $0.00 so you don’t accidentally attribute costs.
85
+ # Fill in values for your org as needed (USD per 1K tokens).
86
+ DEFAULT_REGISTRY: Dict[str, ModelPricing] = {
87
+ # Examples — edit to match your negotiated prices:
88
+ # "openai/gpt-4o": ModelPricing(_dec("5.00"), _dec("15.00")),
89
+ # "openai/o3-mini": ModelPricing(_dec("2.50"), _dec("10.00"), reasoning_per_1k=_dec("5.00")),
90
+ "openai/o3": ModelPricing(
91
+ _dec("0.00"), _dec("0.00"), reasoning_per_1k=_dec("0.00")
92
+ ),
93
+ "local/*": ModelPricing(_dec("0.00"), _dec("0.00")),
94
+ }
95
+
96
+
97
+ def normalize_model_name(name: str) -> str:
98
+ return (name or "").strip().lower()
99
+
100
+
101
+ def resolve_model_name(event: Dict[str, Any]) -> str:
102
+ m = (
103
+ ((event.get("metadata") or {}).get("model"))
104
+ or ((event.get("metadata") or {}).get("ls_model_name"))
105
+ or (event.get("name") or "").replace("llm:", "")
106
+ )
107
+ return normalize_model_name(m)
108
+
109
+
110
+ def find_pricing(
111
+ model: str, registry: Dict[str, ModelPricing]
112
+ ) -> Optional[ModelPricing]:
113
+ if model in registry:
114
+ return registry[model]
115
+ # simple wildcard support like "local/*"
116
+ for key, mp in registry.items():
117
+ if key.endswith("/*") and model.startswith(key[:-2]):
118
+ return mp
119
+ # try provider/model normalization like "openai-o3" → "openai/o3"
120
+ model2 = model.replace("-", "/")
121
+ if model2 in registry:
122
+ return registry[model2]
123
+ return None
124
+
125
+
126
+ def default_registry_path() -> str:
127
+ """Pricing file shipped with this module (pricing.json next to pricing.py)."""
128
+ return os.path.join(os.path.dirname(__file__), "pricing.json")
129
+
130
+
131
+ def load_registry(
132
+ path: Optional[str] = None,
133
+ overrides: Optional[Dict[str, Any]] = None,
134
+ use_default_if_missing: bool = True,
135
+ ) -> Dict[str, ModelPricing]:
136
+ """
137
+ Load pricing registry from:
138
+ 1) explicit `path` (if provided), else
139
+ 2) $URSA_PRICING_JSON (if set), else
140
+ 3) pricing.json next to pricing.py (if present, and use_default_if_missing)
141
+ 4) fall back to DEFAULT_REGISTRY
142
+ """
143
+ reg: Dict[str, ModelPricing] = dict(DEFAULT_REGISTRY)
144
+
145
+ # 1) explicit path from caller wins
146
+ candidate = path
147
+
148
+ # 2) else env var
149
+ if not candidate:
150
+ env_path = os.environ.get("URSA_PRICING_JSON")
151
+ if env_path:
152
+ candidate = env_path
153
+
154
+ # 3) else module-local pricing.json
155
+ if not candidate and use_default_if_missing:
156
+ local_path = default_registry_path()
157
+ if os.path.exists(local_path):
158
+ candidate = local_path
159
+
160
+ # Load if we have a candidate
161
+ if candidate and os.path.exists(candidate):
162
+ with open(candidate, "r", encoding="utf-8") as f:
163
+ data = json.load(f)
164
+ for k, v in (data or {}).items():
165
+ # Ignore non-model notes like "_note"
166
+ if not isinstance(v, dict) or (
167
+ "input_per_1k" not in v and "output_per_1k" not in v
168
+ ):
169
+ continue
170
+ reg[normalize_model_name(k)] = ModelPricing(
171
+ _dec(v.get("input_per_1k", 0)),
172
+ _dec(v.get("output_per_1k", 0)),
173
+ _dec(v["reasoning_per_1k"])
174
+ if v.get("reasoning_per_1k") is not None
175
+ else None,
176
+ _dec(v.get("cached_input_multiplier", 1)),
177
+ )
178
+
179
+ # Apply programmatic overrides last
180
+ if overrides:
181
+ for k, v in overrides.items():
182
+ reg[normalize_model_name(k)] = ModelPricing(
183
+ _dec(v.get("input_per_1k", 0)),
184
+ _dec(v.get("output_per_1k", 0)),
185
+ _dec(v["reasoning_per_1k"])
186
+ if v.get("reasoning_per_1k") is not None
187
+ else None,
188
+ _dec(v.get("cached_input_multiplier", 1)),
189
+ )
190
+
191
+ return reg
192
+
193
+
194
+ # ---------- Core pricing application ----------
195
+
196
+
197
+ def _has_provider_cost(roll: Dict[str, Any]) -> bool:
198
+ # Treat nonzero provider totals as authoritative
199
+ try:
200
+ return any([
201
+ float(roll.get("total_cost", 0) or 0) > 0,
202
+ float(roll.get("input_cost", 0) or 0) > 0,
203
+ float(roll.get("output_cost", 0) or 0) > 0,
204
+ ])
205
+ except Exception:
206
+ return False
207
+
208
+
209
+ def _round_money(x: Decimal) -> float:
210
+ return float(x.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP))
211
+
212
+
213
+ def price_event(
214
+ event: Dict[str, Any],
215
+ registry: Dict[str, ModelPricing],
216
+ overwrite: bool = False,
217
+ ) -> Tuple[Dict[str, Any], Optional[Decimal], str]:
218
+ """
219
+ Returns (event, total_cost_decimal_or_None, cost_source)
220
+ cost_source ∈ {"provider", "computed", "no_usage", "no_pricing"}
221
+ """
222
+ metrics = event.get("metrics") or {}
223
+ roll = metrics.get("usage_rollup") or {}
224
+ if not roll:
225
+ return (event, None, "no_usage")
226
+
227
+ if _has_provider_cost(roll) and not overwrite:
228
+ # Respect provider-reported cost
229
+ return (event, Decimal(str(roll.get("total_cost", 0) or 0)), "provider")
230
+
231
+ model = resolve_model_name(event)
232
+ mp = find_pricing(model, registry)
233
+ if not mp:
234
+ return (event, None, "no_pricing")
235
+
236
+ # Compute costs from tokens
237
+ costs = mp.price_tokens(roll)
238
+
239
+ # Populate rollup fields (only fill or overwrite if asked)
240
+ roll = dict(roll) # copy to avoid mutating caller unexpectedly
241
+ for key in ("input_cost", "output_cost", "total_cost"):
242
+ if overwrite or not roll.get(key):
243
+ roll[key] = _round_money(costs[key])
244
+ # Optional: attach granular breakdown so you can inspect later
245
+ metrics["cost_details"] = {
246
+ "source": "computed",
247
+ "model_resolved": model,
248
+ "pricing_used": asdict(mp),
249
+ "components_usd": {k: _round_money(v) for k, v in costs.items()},
250
+ }
251
+ metrics["cost_source"] = "computed"
252
+ event["metrics"] = metrics
253
+ event["metrics"]["usage_rollup"] = roll
254
+ return (event, costs["total_cost"], "computed")
255
+
256
+
257
+ def price_payload(
258
+ payload: Dict[str, Any],
259
+ registry: Optional[Dict[str, ModelPricing]] = None,
260
+ overwrite: bool = False,
261
+ ) -> Dict[str, Any]:
262
+ """
263
+ Enriches payload in-place with computed costs where missing.
264
+ Adds a `costs` block with totals and by-model aggregation.
265
+ """
266
+ reg = registry or load_registry()
267
+ llm_events = payload.get("llm_events") or []
268
+ total = Decimal("0")
269
+ by_model: Dict[str, Decimal] = {}
270
+ sources = {"provider": 0, "computed": 0, "no_usage": 0, "no_pricing": 0}
271
+
272
+ for ev in llm_events:
273
+ ev2, cost_dec, src = price_event(ev, reg, overwrite=overwrite)
274
+ sources[src] = sources.get(src, 0) + 1
275
+ model = resolve_model_name(ev2)
276
+ if cost_dec is not None:
277
+ total += cost_dec
278
+ by_model[model] = by_model.get(model, Decimal("0")) + cost_dec
279
+
280
+ payload.setdefault("costs", {})
281
+ payload["costs"]["total_usd"] = _round_money(total)
282
+ payload["costs"]["by_model_usd"] = {
283
+ k: _round_money(v) for k, v in by_model.items()
284
+ }
285
+ payload["costs"]["event_sources"] = sources
286
+ payload["costs"]["registry_note"] = (
287
+ "Edit pricing via DEFAULT_REGISTRY, pricing.json, or overrides."
288
+ )
289
+
290
+ return payload
291
+
292
+
293
+ # ---------- Convenience file I/O ----------
294
+
295
+
296
+ def price_file(
297
+ in_path: str,
298
+ out_path: Optional[str] = None,
299
+ registry_path: Optional[str] = None,
300
+ overwrite: bool = False,
301
+ ) -> str:
302
+ """
303
+ Reads a metrics JSON file (from timing.py), enriches with costs, writes result.
304
+ If out_path is None, writes alongside input as '<name>.priced.json'.
305
+ Returns output path.
306
+ """
307
+ with open(in_path, "r", encoding="utf-8") as f:
308
+ payload = json.load(f)
309
+
310
+ reg = load_registry(path=registry_path)
311
+ payload = price_payload(payload, registry=reg, overwrite=overwrite)
312
+
313
+ if not out_path:
314
+ base, ext = os.path.splitext(in_path)
315
+ out_path = f"{base}.priced.json"
316
+
317
+ with open(out_path, "w", encoding="utf-8") as f:
318
+ json.dump(payload, f, ensure_ascii=False, indent=2)
319
+ return out_path