ursa-ai 0.4.2__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +2 -0
- ursa/agents/arxiv_agent.py +88 -99
- ursa/agents/base.py +369 -2
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +403 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +303 -0
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0rc1.dist-info/RECORD +39 -0
- ursa_ai-0.6.0rc1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.4.2.dist-info/RECORD +0 -27
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
from decimal import ROUND_HALF_UP, Decimal, getcontext
|
|
7
|
+
from typing import Any, Dict, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
getcontext().prec = 28 # robust money math
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ---------- Model pricing schema ----------
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ModelPricing:
|
|
17
|
+
# Prices are USD per 1,000 tokens
|
|
18
|
+
input_per_1k: Decimal
|
|
19
|
+
output_per_1k: Decimal
|
|
20
|
+
reasoning_per_1k: Optional[Decimal] = (
|
|
21
|
+
None # None --> charge 0 for reasoning tokens
|
|
22
|
+
)
|
|
23
|
+
cached_input_multiplier: Decimal = Decimal(
|
|
24
|
+
"1"
|
|
25
|
+
) # e.g., 0.25 if your provider discounts cached prompt tokens
|
|
26
|
+
|
|
27
|
+
def price_tokens(self, usage: Dict[str, Any]) -> Dict[str, Decimal]:
|
|
28
|
+
"""Compute cost components from a usage dict with keys like input_tokens, output_tokens, reasoning_tokens, cached_tokens."""
|
|
29
|
+
|
|
30
|
+
def _to_dec(x) -> Decimal:
|
|
31
|
+
if x is None:
|
|
32
|
+
return Decimal("0")
|
|
33
|
+
try:
|
|
34
|
+
return Decimal(str(x))
|
|
35
|
+
except Exception:
|
|
36
|
+
return Decimal("0")
|
|
37
|
+
|
|
38
|
+
in_t = _to_dec(usage.get("input_tokens", usage.get("prompt_tokens", 0)))
|
|
39
|
+
out_t = _to_dec(
|
|
40
|
+
usage.get("output_tokens", usage.get("completion_tokens", 0))
|
|
41
|
+
)
|
|
42
|
+
# total_t = _to_dec(usage.get("total_tokens", (in_t + out_t)))
|
|
43
|
+
cached_t = _to_dec(usage.get("cached_tokens", 0))
|
|
44
|
+
reasoning_t = _to_dec(usage.get("reasoning_tokens", 0))
|
|
45
|
+
|
|
46
|
+
eff_in = (in_t - cached_t) if in_t > cached_t else Decimal("0")
|
|
47
|
+
cached_eff_in = cached_t
|
|
48
|
+
|
|
49
|
+
input_cost = (eff_in / Decimal(1000)) * self.input_per_1k
|
|
50
|
+
cached_input_cost = (
|
|
51
|
+
(cached_eff_in / Decimal(1000))
|
|
52
|
+
* self.input_per_1k
|
|
53
|
+
* self.cached_input_multiplier
|
|
54
|
+
)
|
|
55
|
+
output_cost = (out_t / Decimal(1000)) * self.output_per_1k
|
|
56
|
+
reasoning_cost = Decimal("0")
|
|
57
|
+
if self.reasoning_per_1k is not None and reasoning_t > 0:
|
|
58
|
+
reasoning_cost = (
|
|
59
|
+
reasoning_t / Decimal(1000)
|
|
60
|
+
) * self.reasoning_per_1k
|
|
61
|
+
|
|
62
|
+
total_cost = (
|
|
63
|
+
input_cost + cached_input_cost + output_cost + reasoning_cost
|
|
64
|
+
)
|
|
65
|
+
return {
|
|
66
|
+
"input_cost": input_cost,
|
|
67
|
+
"cached_input_cost": cached_input_cost,
|
|
68
|
+
"output_cost": output_cost,
|
|
69
|
+
"reasoning_cost": reasoning_cost,
|
|
70
|
+
"total_cost": total_cost,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------- Registry & resolution ----------
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _dec(x: str | float | int) -> Decimal:
|
|
78
|
+
try:
|
|
79
|
+
return Decimal(str(x))
|
|
80
|
+
except Exception:
|
|
81
|
+
return Decimal("0")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# DEFAULTS: keep $0.00 so you don’t accidentally attribute costs.
|
|
85
|
+
# Fill in values for your org as needed (USD per 1K tokens).
|
|
86
|
+
DEFAULT_REGISTRY: Dict[str, ModelPricing] = {
|
|
87
|
+
# Examples — edit to match your negotiated prices:
|
|
88
|
+
# "openai/gpt-4o": ModelPricing(_dec("5.00"), _dec("15.00")),
|
|
89
|
+
# "openai/o3-mini": ModelPricing(_dec("2.50"), _dec("10.00"), reasoning_per_1k=_dec("5.00")),
|
|
90
|
+
"openai/o3": ModelPricing(
|
|
91
|
+
_dec("0.00"), _dec("0.00"), reasoning_per_1k=_dec("0.00")
|
|
92
|
+
),
|
|
93
|
+
"local/*": ModelPricing(_dec("0.00"), _dec("0.00")),
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def normalize_model_name(name: str) -> str:
|
|
98
|
+
return (name or "").strip().lower()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def resolve_model_name(event: Dict[str, Any]) -> str:
|
|
102
|
+
m = (
|
|
103
|
+
((event.get("metadata") or {}).get("model"))
|
|
104
|
+
or ((event.get("metadata") or {}).get("ls_model_name"))
|
|
105
|
+
or (event.get("name") or "").replace("llm:", "")
|
|
106
|
+
)
|
|
107
|
+
return normalize_model_name(m)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def find_pricing(
|
|
111
|
+
model: str, registry: Dict[str, ModelPricing]
|
|
112
|
+
) -> Optional[ModelPricing]:
|
|
113
|
+
if model in registry:
|
|
114
|
+
return registry[model]
|
|
115
|
+
# simple wildcard support like "local/*"
|
|
116
|
+
for key, mp in registry.items():
|
|
117
|
+
if key.endswith("/*") and model.startswith(key[:-2]):
|
|
118
|
+
return mp
|
|
119
|
+
# try provider/model normalization like "openai-o3" → "openai/o3"
|
|
120
|
+
model2 = model.replace("-", "/")
|
|
121
|
+
if model2 in registry:
|
|
122
|
+
return registry[model2]
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def default_registry_path() -> str:
|
|
127
|
+
"""Pricing file shipped with this module (pricing.json next to pricing.py)."""
|
|
128
|
+
return os.path.join(os.path.dirname(__file__), "pricing.json")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def load_registry(
|
|
132
|
+
path: Optional[str] = None,
|
|
133
|
+
overrides: Optional[Dict[str, Any]] = None,
|
|
134
|
+
use_default_if_missing: bool = True,
|
|
135
|
+
) -> Dict[str, ModelPricing]:
|
|
136
|
+
"""
|
|
137
|
+
Load pricing registry from:
|
|
138
|
+
1) explicit `path` (if provided), else
|
|
139
|
+
2) $URSA_PRICING_JSON (if set), else
|
|
140
|
+
3) pricing.json next to pricing.py (if present, and use_default_if_missing)
|
|
141
|
+
4) fall back to DEFAULT_REGISTRY
|
|
142
|
+
"""
|
|
143
|
+
reg: Dict[str, ModelPricing] = dict(DEFAULT_REGISTRY)
|
|
144
|
+
|
|
145
|
+
# 1) explicit path from caller wins
|
|
146
|
+
candidate = path
|
|
147
|
+
|
|
148
|
+
# 2) else env var
|
|
149
|
+
if not candidate:
|
|
150
|
+
env_path = os.environ.get("URSA_PRICING_JSON")
|
|
151
|
+
if env_path:
|
|
152
|
+
candidate = env_path
|
|
153
|
+
|
|
154
|
+
# 3) else module-local pricing.json
|
|
155
|
+
if not candidate and use_default_if_missing:
|
|
156
|
+
local_path = default_registry_path()
|
|
157
|
+
if os.path.exists(local_path):
|
|
158
|
+
candidate = local_path
|
|
159
|
+
|
|
160
|
+
# Load if we have a candidate
|
|
161
|
+
if candidate and os.path.exists(candidate):
|
|
162
|
+
with open(candidate, "r", encoding="utf-8") as f:
|
|
163
|
+
data = json.load(f)
|
|
164
|
+
for k, v in (data or {}).items():
|
|
165
|
+
# Ignore non-model notes like "_note"
|
|
166
|
+
if not isinstance(v, dict) or (
|
|
167
|
+
"input_per_1k" not in v and "output_per_1k" not in v
|
|
168
|
+
):
|
|
169
|
+
continue
|
|
170
|
+
reg[normalize_model_name(k)] = ModelPricing(
|
|
171
|
+
_dec(v.get("input_per_1k", 0)),
|
|
172
|
+
_dec(v.get("output_per_1k", 0)),
|
|
173
|
+
_dec(v["reasoning_per_1k"])
|
|
174
|
+
if v.get("reasoning_per_1k") is not None
|
|
175
|
+
else None,
|
|
176
|
+
_dec(v.get("cached_input_multiplier", 1)),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Apply programmatic overrides last
|
|
180
|
+
if overrides:
|
|
181
|
+
for k, v in overrides.items():
|
|
182
|
+
reg[normalize_model_name(k)] = ModelPricing(
|
|
183
|
+
_dec(v.get("input_per_1k", 0)),
|
|
184
|
+
_dec(v.get("output_per_1k", 0)),
|
|
185
|
+
_dec(v["reasoning_per_1k"])
|
|
186
|
+
if v.get("reasoning_per_1k") is not None
|
|
187
|
+
else None,
|
|
188
|
+
_dec(v.get("cached_input_multiplier", 1)),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return reg
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ---------- Core pricing application ----------
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _has_provider_cost(roll: Dict[str, Any]) -> bool:
|
|
198
|
+
# Treat nonzero provider totals as authoritative
|
|
199
|
+
try:
|
|
200
|
+
return any([
|
|
201
|
+
float(roll.get("total_cost", 0) or 0) > 0,
|
|
202
|
+
float(roll.get("input_cost", 0) or 0) > 0,
|
|
203
|
+
float(roll.get("output_cost", 0) or 0) > 0,
|
|
204
|
+
])
|
|
205
|
+
except Exception:
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _round_money(x: Decimal) -> float:
|
|
210
|
+
return float(x.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def price_event(
|
|
214
|
+
event: Dict[str, Any],
|
|
215
|
+
registry: Dict[str, ModelPricing],
|
|
216
|
+
overwrite: bool = False,
|
|
217
|
+
) -> Tuple[Dict[str, Any], Optional[Decimal], str]:
|
|
218
|
+
"""
|
|
219
|
+
Returns (event, total_cost_decimal_or_None, cost_source)
|
|
220
|
+
cost_source ∈ {"provider", "computed", "no_usage", "no_pricing"}
|
|
221
|
+
"""
|
|
222
|
+
metrics = event.get("metrics") or {}
|
|
223
|
+
roll = metrics.get("usage_rollup") or {}
|
|
224
|
+
if not roll:
|
|
225
|
+
return (event, None, "no_usage")
|
|
226
|
+
|
|
227
|
+
if _has_provider_cost(roll) and not overwrite:
|
|
228
|
+
# Respect provider-reported cost
|
|
229
|
+
return (event, Decimal(str(roll.get("total_cost", 0) or 0)), "provider")
|
|
230
|
+
|
|
231
|
+
model = resolve_model_name(event)
|
|
232
|
+
mp = find_pricing(model, registry)
|
|
233
|
+
if not mp:
|
|
234
|
+
return (event, None, "no_pricing")
|
|
235
|
+
|
|
236
|
+
# Compute costs from tokens
|
|
237
|
+
costs = mp.price_tokens(roll)
|
|
238
|
+
|
|
239
|
+
# Populate rollup fields (only fill or overwrite if asked)
|
|
240
|
+
roll = dict(roll) # copy to avoid mutating caller unexpectedly
|
|
241
|
+
for key in ("input_cost", "output_cost", "total_cost"):
|
|
242
|
+
if overwrite or not roll.get(key):
|
|
243
|
+
roll[key] = _round_money(costs[key])
|
|
244
|
+
# Optional: attach granular breakdown so you can inspect later
|
|
245
|
+
metrics["cost_details"] = {
|
|
246
|
+
"source": "computed",
|
|
247
|
+
"model_resolved": model,
|
|
248
|
+
"pricing_used": asdict(mp),
|
|
249
|
+
"components_usd": {k: _round_money(v) for k, v in costs.items()},
|
|
250
|
+
}
|
|
251
|
+
metrics["cost_source"] = "computed"
|
|
252
|
+
event["metrics"] = metrics
|
|
253
|
+
event["metrics"]["usage_rollup"] = roll
|
|
254
|
+
return (event, costs["total_cost"], "computed")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def price_payload(
|
|
258
|
+
payload: Dict[str, Any],
|
|
259
|
+
registry: Optional[Dict[str, ModelPricing]] = None,
|
|
260
|
+
overwrite: bool = False,
|
|
261
|
+
) -> Dict[str, Any]:
|
|
262
|
+
"""
|
|
263
|
+
Enriches payload in-place with computed costs where missing.
|
|
264
|
+
Adds a `costs` block with totals and by-model aggregation.
|
|
265
|
+
"""
|
|
266
|
+
reg = registry or load_registry()
|
|
267
|
+
llm_events = payload.get("llm_events") or []
|
|
268
|
+
total = Decimal("0")
|
|
269
|
+
by_model: Dict[str, Decimal] = {}
|
|
270
|
+
sources = {"provider": 0, "computed": 0, "no_usage": 0, "no_pricing": 0}
|
|
271
|
+
|
|
272
|
+
for ev in llm_events:
|
|
273
|
+
ev2, cost_dec, src = price_event(ev, reg, overwrite=overwrite)
|
|
274
|
+
sources[src] = sources.get(src, 0) + 1
|
|
275
|
+
model = resolve_model_name(ev2)
|
|
276
|
+
if cost_dec is not None:
|
|
277
|
+
total += cost_dec
|
|
278
|
+
by_model[model] = by_model.get(model, Decimal("0")) + cost_dec
|
|
279
|
+
|
|
280
|
+
payload.setdefault("costs", {})
|
|
281
|
+
payload["costs"]["total_usd"] = _round_money(total)
|
|
282
|
+
payload["costs"]["by_model_usd"] = {
|
|
283
|
+
k: _round_money(v) for k, v in by_model.items()
|
|
284
|
+
}
|
|
285
|
+
payload["costs"]["event_sources"] = sources
|
|
286
|
+
payload["costs"]["registry_note"] = (
|
|
287
|
+
"Edit pricing via DEFAULT_REGISTRY, pricing.json, or overrides."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return payload
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# ---------- Convenience file I/O ----------
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def price_file(
|
|
297
|
+
in_path: str,
|
|
298
|
+
out_path: Optional[str] = None,
|
|
299
|
+
registry_path: Optional[str] = None,
|
|
300
|
+
overwrite: bool = False,
|
|
301
|
+
) -> str:
|
|
302
|
+
"""
|
|
303
|
+
Reads a metrics JSON file (from timing.py), enriches with costs, writes result.
|
|
304
|
+
If out_path is None, writes alongside input as '<name>.priced.json'.
|
|
305
|
+
Returns output path.
|
|
306
|
+
"""
|
|
307
|
+
with open(in_path, "r", encoding="utf-8") as f:
|
|
308
|
+
payload = json.load(f)
|
|
309
|
+
|
|
310
|
+
reg = load_registry(path=registry_path)
|
|
311
|
+
payload = price_payload(payload, registry=reg, overwrite=overwrite)
|
|
312
|
+
|
|
313
|
+
if not out_path:
|
|
314
|
+
base, ext = os.path.splitext(in_path)
|
|
315
|
+
out_path = f"{base}.priced.json"
|
|
316
|
+
|
|
317
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
|
318
|
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
319
|
+
return out_path
|