ursa-ai 0.7.0rc1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +13 -2
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +1 -1
- ursa/agents/base.py +352 -91
- ursa/agents/chat_agent.py +58 -0
- ursa/agents/execution_agent.py +506 -260
- ursa/agents/lammps_agent.py +81 -31
- ursa/agents/planning_agent.py +27 -2
- ursa/agents/websearch_agent.py +2 -2
- ursa/cli/__init__.py +5 -1
- ursa/cli/hitl.py +46 -34
- ursa/observability/pricing.json +85 -0
- ursa/observability/pricing.py +20 -18
- ursa/util/parse.py +316 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/METADATA +5 -1
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/RECORD +20 -17
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/entry_points.txt +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/top_level.txt +0 -0
ursa/agents/lammps_agent.py
CHANGED
|
@@ -21,17 +21,14 @@ except Exception:
|
|
|
21
21
|
class LammpsState(TypedDict, total=False):
|
|
22
22
|
simulation_task: str
|
|
23
23
|
elements: List[str]
|
|
24
|
+
template: Optional[str]
|
|
25
|
+
chosen_potential: Optional[Any]
|
|
24
26
|
|
|
25
27
|
matches: List[Any]
|
|
26
|
-
db_message: str
|
|
27
|
-
|
|
28
28
|
idx: int
|
|
29
29
|
summaries: List[str]
|
|
30
30
|
full_texts: List[str]
|
|
31
|
-
|
|
32
31
|
summaries_combined: str
|
|
33
|
-
choice_json: str
|
|
34
|
-
chosen_index: int
|
|
35
32
|
|
|
36
33
|
input_script: str
|
|
37
34
|
run_returncode: Optional[int]
|
|
@@ -47,6 +44,7 @@ class LammpsAgent(BaseAgent):
|
|
|
47
44
|
llm,
|
|
48
45
|
max_potentials: int = 5,
|
|
49
46
|
max_fix_attempts: int = 10,
|
|
47
|
+
find_potential_only: bool = False,
|
|
50
48
|
mpi_procs: int = 8,
|
|
51
49
|
workspace: str = "./workspace",
|
|
52
50
|
lammps_cmd: str = "lmp_mpi",
|
|
@@ -61,6 +59,7 @@ class LammpsAgent(BaseAgent):
|
|
|
61
59
|
)
|
|
62
60
|
self.max_potentials = max_potentials
|
|
63
61
|
self.max_fix_attempts = max_fix_attempts
|
|
62
|
+
self.find_potential_only = find_potential_only
|
|
64
63
|
self.mpi_procs = mpi_procs
|
|
65
64
|
self.lammps_cmd = lammps_cmd
|
|
66
65
|
self.mpirun_cmd = mpirun_cmd
|
|
@@ -72,13 +71,13 @@ class LammpsAgent(BaseAgent):
|
|
|
72
71
|
"eam/alloy",
|
|
73
72
|
"eam/fs",
|
|
74
73
|
"meam",
|
|
75
|
-
"adp",
|
|
76
|
-
"kim",
|
|
74
|
+
"adp",
|
|
75
|
+
"kim",
|
|
77
76
|
"snap",
|
|
78
77
|
"quip",
|
|
79
78
|
"mlip",
|
|
80
79
|
"pace",
|
|
81
|
-
"nep",
|
|
80
|
+
"nep",
|
|
82
81
|
]
|
|
83
82
|
|
|
84
83
|
self.workspace = workspace
|
|
@@ -116,9 +115,10 @@ class LammpsAgent(BaseAgent):
|
|
|
116
115
|
self.author_chain = (
|
|
117
116
|
ChatPromptTemplate.from_template(
|
|
118
117
|
"Your task is to write a LAMMPS input file for this purpose: {simulation_task}.\n"
|
|
119
|
-
"Here is metadata about the interatomic potential that will be used: {metadata}.\n"
|
|
120
118
|
"Note that all potential files are in the './' directory.\n"
|
|
121
119
|
"Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
|
|
120
|
+
"If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
|
|
121
|
+
"Template provided (if any): {template}\n"
|
|
122
122
|
"Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
|
|
123
123
|
"To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
|
|
124
124
|
"Return your answer **only** as valid JSON, with no extra text or formatting.\n"
|
|
@@ -138,9 +138,10 @@ class LammpsAgent(BaseAgent):
|
|
|
138
138
|
"However, when running the simulation, an error was raised.\n"
|
|
139
139
|
"Here is the full stdout message that includes the error message: {err_message}\n"
|
|
140
140
|
"Your task is to write a new input file that resolves the error.\n"
|
|
141
|
-
"Here is metadata about the interatomic potential that will be used: {metadata}.\n"
|
|
142
141
|
"Note that all potential files are in the './' directory.\n"
|
|
143
142
|
"Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
|
|
143
|
+
"If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
|
|
144
|
+
"Template provided (if any): {template}\n"
|
|
144
145
|
"Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
|
|
145
146
|
"To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
|
|
146
147
|
"Return your answer **only** as valid JSON, with no extra text or formatting.\n"
|
|
@@ -189,22 +190,28 @@ class LammpsAgent(BaseAgent):
|
|
|
189
190
|
pass
|
|
190
191
|
return text
|
|
191
192
|
|
|
193
|
+
def _entry_router(self, state: LammpsState) -> dict:
|
|
194
|
+
if self.find_potential_only and state.get("chosen_potential"):
|
|
195
|
+
raise Exception(
|
|
196
|
+
"You cannot set find_potential_only=True and also specify your own potential!"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if not state.get("chosen_potential"):
|
|
200
|
+
self.potential_summaries_dir = os.path.join(
|
|
201
|
+
self.workspace, "potential_summaries"
|
|
202
|
+
)
|
|
203
|
+
os.makedirs(self.potential_summaries_dir, exist_ok=True)
|
|
204
|
+
return {}
|
|
205
|
+
|
|
192
206
|
def _find_potentials(self, state: LammpsState) -> LammpsState:
|
|
193
207
|
db = am.library.Database(remote=True)
|
|
194
208
|
matches = db.get_lammps_potentials(
|
|
195
209
|
pair_style=self.pair_styles, elements=state["elements"]
|
|
196
210
|
)
|
|
197
|
-
|
|
198
|
-
if not list(matches):
|
|
199
|
-
msg_lines.append("No potentials found for this task in NIST.")
|
|
200
|
-
else:
|
|
201
|
-
msg_lines.append("Found these potentials in NIST:")
|
|
202
|
-
for rec in matches:
|
|
203
|
-
msg_lines.append(f"{rec.id} {rec.pair_style} {rec.symbols}")
|
|
211
|
+
|
|
204
212
|
return {
|
|
205
213
|
**state,
|
|
206
214
|
"matches": list(matches),
|
|
207
|
-
"db_message": "\n".join(msg_lines),
|
|
208
215
|
"idx": 0,
|
|
209
216
|
"summaries": [],
|
|
210
217
|
"full_texts": [],
|
|
@@ -243,6 +250,12 @@ class LammpsAgent(BaseAgent):
|
|
|
243
250
|
"simulation_task": state["simulation_task"],
|
|
244
251
|
})
|
|
245
252
|
|
|
253
|
+
summary_file = os.path.join(
|
|
254
|
+
self.potential_summaries_dir, "potential_" + str(i) + ".txt"
|
|
255
|
+
)
|
|
256
|
+
with open(summary_file, "w") as f:
|
|
257
|
+
f.write(summary)
|
|
258
|
+
|
|
246
259
|
return {
|
|
247
260
|
**state,
|
|
248
261
|
"idx": i + 1,
|
|
@@ -265,21 +278,36 @@ class LammpsAgent(BaseAgent):
|
|
|
265
278
|
})
|
|
266
279
|
choice_dict = self._safe_json_loads(choice)
|
|
267
280
|
chosen_index = int(choice_dict["Chosen index"])
|
|
281
|
+
|
|
268
282
|
print(f"Chosen potential #{chosen_index}")
|
|
269
283
|
print("Rationale for choosing this potential:")
|
|
270
284
|
print(choice_dict["rationale"])
|
|
271
|
-
|
|
285
|
+
|
|
286
|
+
chosen_potential = state["matches"][chosen_index]
|
|
287
|
+
|
|
288
|
+
out_file = os.path.join(self.potential_summaries_dir, "Rationale.txt")
|
|
289
|
+
with open(out_file, "w") as f:
|
|
290
|
+
f.write(f"Chosen potential #{chosen_index}")
|
|
291
|
+
f.write("\n")
|
|
292
|
+
f.write("Rationale for choosing this potential:")
|
|
293
|
+
f.write("\n")
|
|
294
|
+
f.write(choice_dict["rationale"])
|
|
295
|
+
|
|
296
|
+
return {**state, "chosen_potential": chosen_potential}
|
|
297
|
+
|
|
298
|
+
def _route_after_summarization(self, state: LammpsState) -> str:
|
|
299
|
+
if self.find_potential_only:
|
|
300
|
+
return "Exit"
|
|
301
|
+
return "continue_author"
|
|
272
302
|
|
|
273
303
|
def _author(self, state: LammpsState) -> LammpsState:
|
|
274
304
|
print("First attempt at writing LAMMPS input file....")
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
text = state["full_texts"][state["chosen_index"]]
|
|
278
|
-
pair_info = match.pair_info()
|
|
305
|
+
state["chosen_potential"].download_files(self.workspace)
|
|
306
|
+
pair_info = state["chosen_potential"].pair_info()
|
|
279
307
|
authored_json = self.author_chain.invoke({
|
|
280
308
|
"simulation_task": state["simulation_task"],
|
|
281
|
-
"metadata": text,
|
|
282
309
|
"pair_info": pair_info,
|
|
310
|
+
"template": state["template"],
|
|
283
311
|
})
|
|
284
312
|
script_dict = self._safe_json_loads(authored_json)
|
|
285
313
|
input_script = script_dict["input_script"]
|
|
@@ -324,17 +352,15 @@ class LammpsAgent(BaseAgent):
|
|
|
324
352
|
return "done_failed"
|
|
325
353
|
|
|
326
354
|
def _fix(self, state: LammpsState) -> LammpsState:
|
|
327
|
-
|
|
328
|
-
text = state["full_texts"][state["chosen_index"]]
|
|
329
|
-
pair_info = match.pair_info()
|
|
355
|
+
pair_info = state["chosen_potential"].pair_info()
|
|
330
356
|
err_blob = state.get("run_stdout")
|
|
331
357
|
|
|
332
358
|
fixed_json = self.fix_chain.invoke({
|
|
333
359
|
"simulation_task": state["simulation_task"],
|
|
334
360
|
"input_script": state["input_script"],
|
|
335
361
|
"err_message": err_blob,
|
|
336
|
-
"metadata": text,
|
|
337
362
|
"pair_info": pair_info,
|
|
363
|
+
"template": state["template"],
|
|
338
364
|
})
|
|
339
365
|
script_dict = self._safe_json_loads(fixed_json)
|
|
340
366
|
new_input = script_dict["input_script"]
|
|
@@ -349,6 +375,7 @@ class LammpsAgent(BaseAgent):
|
|
|
349
375
|
def _build_graph(self):
|
|
350
376
|
g = StateGraph(LammpsState)
|
|
351
377
|
|
|
378
|
+
self.add_node(g, self._entry_router)
|
|
352
379
|
self.add_node(g, self._find_potentials)
|
|
353
380
|
self.add_node(g, self._summarize_one)
|
|
354
381
|
self.add_node(g, self._build_summaries)
|
|
@@ -357,7 +384,18 @@ class LammpsAgent(BaseAgent):
|
|
|
357
384
|
self.add_node(g, self._run_lammps)
|
|
358
385
|
self.add_node(g, self._fix)
|
|
359
386
|
|
|
360
|
-
g.set_entry_point("
|
|
387
|
+
g.set_entry_point("_entry_router")
|
|
388
|
+
|
|
389
|
+
g.add_conditional_edges(
|
|
390
|
+
"_entry_router",
|
|
391
|
+
lambda state: "user_choice"
|
|
392
|
+
if state.get("chosen_potential")
|
|
393
|
+
else "agent_choice",
|
|
394
|
+
{
|
|
395
|
+
"user_choice": "_author",
|
|
396
|
+
"agent_choice": "_find_potentials",
|
|
397
|
+
},
|
|
398
|
+
)
|
|
361
399
|
|
|
362
400
|
g.add_conditional_edges(
|
|
363
401
|
"_find_potentials",
|
|
@@ -379,7 +417,16 @@ class LammpsAgent(BaseAgent):
|
|
|
379
417
|
)
|
|
380
418
|
|
|
381
419
|
g.add_edge("_build_summaries", "_choose")
|
|
382
|
-
|
|
420
|
+
|
|
421
|
+
g.add_conditional_edges(
|
|
422
|
+
"_choose",
|
|
423
|
+
self._route_after_summarization,
|
|
424
|
+
{
|
|
425
|
+
"continue_author": "_author",
|
|
426
|
+
"Exit": END,
|
|
427
|
+
},
|
|
428
|
+
)
|
|
429
|
+
|
|
383
430
|
g.add_edge("_author", "_run_lammps")
|
|
384
431
|
|
|
385
432
|
g.add_conditional_edges(
|
|
@@ -399,7 +446,7 @@ class LammpsAgent(BaseAgent):
|
|
|
399
446
|
inputs: Mapping[str, Any],
|
|
400
447
|
*,
|
|
401
448
|
summarize: bool | None = None,
|
|
402
|
-
recursion_limit: int =
|
|
449
|
+
recursion_limit: int = 999_999,
|
|
403
450
|
**_,
|
|
404
451
|
) -> str:
|
|
405
452
|
config = self.build_config(
|
|
@@ -411,4 +458,7 @@ class LammpsAgent(BaseAgent):
|
|
|
411
458
|
"'simulation_task' and 'elements' are required arguments"
|
|
412
459
|
)
|
|
413
460
|
|
|
461
|
+
if "template" not in inputs:
|
|
462
|
+
inputs = {**inputs, "template": "No template provided."}
|
|
463
|
+
|
|
414
464
|
return self._action.invoke(inputs, config)
|
ursa/agents/planning_agent.py
CHANGED
|
@@ -163,10 +163,35 @@ config = {"configurable": {"thread_id": "1"}}
|
|
|
163
163
|
|
|
164
164
|
|
|
165
165
|
def should_continue(state: PlanningState):
|
|
166
|
-
|
|
166
|
+
reviewMaxLength = 0 # 0 = no limit, else some character limit like 300
|
|
167
|
+
|
|
168
|
+
# Latest reviewer output (if present)
|
|
169
|
+
last_content = (
|
|
170
|
+
state["messages"][-1].content if state.get("messages") else ""
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
max_reflections = state.get("reflection_steps", 3)
|
|
174
|
+
|
|
175
|
+
# Hit the reflection cap?
|
|
176
|
+
if len(state["messages"]) > (max_reflections + 3):
|
|
177
|
+
print(
|
|
178
|
+
f"PlanningAgent: reached reflection limit ({max_reflections}); formalizing . . ."
|
|
179
|
+
)
|
|
167
180
|
return "formalize"
|
|
168
|
-
|
|
181
|
+
|
|
182
|
+
# Approved?
|
|
183
|
+
if "[APPROVED]" in last_content:
|
|
184
|
+
print("PlanningAgent: [APPROVED] — formalizing . . .")
|
|
169
185
|
return "formalize"
|
|
186
|
+
|
|
187
|
+
# Not approved — print a concise reason before another cycle
|
|
188
|
+
reason = " ".join(last_content.strip().split()) # collapse whitespace
|
|
189
|
+
if reviewMaxLength > 0 and len(reason) > reviewMaxLength:
|
|
190
|
+
reason = reason[:reviewMaxLength] + ". . ."
|
|
191
|
+
print(
|
|
192
|
+
f"PlanningAgent: not approved — iterating again. Reviewer notes: {reason}"
|
|
193
|
+
)
|
|
194
|
+
|
|
170
195
|
return "generate"
|
|
171
196
|
|
|
172
197
|
|
ursa/agents/websearch_agent.py
CHANGED
|
@@ -46,7 +46,7 @@ class WebSearchState(TypedDict):
|
|
|
46
46
|
# all the tokens of all the sources.
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class
|
|
49
|
+
class WebSearchAgentLegacy(BaseAgent):
|
|
50
50
|
def __init__(
|
|
51
51
|
self, llm: str | BaseChatModel = "openai/gpt-4o-mini", **kwargs
|
|
52
52
|
):
|
|
@@ -204,7 +204,7 @@ def main():
|
|
|
204
204
|
model = ChatOpenAI(
|
|
205
205
|
model="gpt-4o", max_tokens=10000, timeout=None, max_retries=2
|
|
206
206
|
)
|
|
207
|
-
websearcher =
|
|
207
|
+
websearcher = WebSearchAgentLegacy(llm=model)
|
|
208
208
|
problem_string = "Who are the 2025 Detroit Tigers top 10 prospects and what year were they born?"
|
|
209
209
|
inputs = {
|
|
210
210
|
"messages": [HumanMessage(content=problem_string)],
|
ursa/cli/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@ app = Typer()
|
|
|
12
12
|
def run(
|
|
13
13
|
workspace: Annotated[
|
|
14
14
|
Path, Option(help="Directory to store ursa ouput")
|
|
15
|
-
] = Path("
|
|
15
|
+
] = Path("ursa_workspace"),
|
|
16
16
|
llm_model_name: Annotated[
|
|
17
17
|
str,
|
|
18
18
|
Option(
|
|
@@ -48,6 +48,9 @@ def run(
|
|
|
48
48
|
)
|
|
49
49
|
),
|
|
50
50
|
] = False,
|
|
51
|
+
thread_id: Annotated[
|
|
52
|
+
str, Option(help="Thread ID for persistance", envvar="URSA_THREAD_ID")
|
|
53
|
+
] = "ursa_cli",
|
|
51
54
|
arxiv_summarize: Annotated[
|
|
52
55
|
bool,
|
|
53
56
|
Option(
|
|
@@ -104,6 +107,7 @@ def run(
|
|
|
104
107
|
emb_base_url=emb_base_url,
|
|
105
108
|
emb_api_key=emb_api_key,
|
|
106
109
|
share_key=share_key,
|
|
110
|
+
thread_id=thread_id,
|
|
107
111
|
arxiv_summarize=arxiv_summarize,
|
|
108
112
|
arxiv_process_images=arxiv_process_images,
|
|
109
113
|
arxiv_max_results=arxiv_max_results,
|
ursa/cli/hitl.py
CHANGED
|
@@ -19,6 +19,7 @@ from typer import Typer
|
|
|
19
19
|
|
|
20
20
|
from ursa.agents import (
|
|
21
21
|
ArxivAgent,
|
|
22
|
+
ChatAgent,
|
|
22
23
|
ExecutionAgent,
|
|
23
24
|
PlanningAgent,
|
|
24
25
|
RecallAgent,
|
|
@@ -63,6 +64,7 @@ class HITL:
|
|
|
63
64
|
emb_base_url: str
|
|
64
65
|
emb_api_key: Optional[str]
|
|
65
66
|
share_key: bool
|
|
67
|
+
thread_id: str
|
|
66
68
|
arxiv_summarize: bool
|
|
67
69
|
arxiv_process_images: bool
|
|
68
70
|
arxiv_max_results: int
|
|
@@ -122,9 +124,10 @@ class HITL:
|
|
|
122
124
|
|
|
123
125
|
self.last_agent_result = ""
|
|
124
126
|
self.arxiv_state = []
|
|
127
|
+
self.chatter_state = {"messages": []}
|
|
125
128
|
self.executor_state = {}
|
|
126
129
|
self.planner_state = {}
|
|
127
|
-
self.websearcher_state =
|
|
130
|
+
self.websearcher_state = []
|
|
128
131
|
|
|
129
132
|
def update_last_agent_result(self, result: str):
|
|
130
133
|
self.last_agent_result = result
|
|
@@ -149,9 +152,21 @@ class HITL:
|
|
|
149
152
|
download_papers=self.arxiv_download_papers,
|
|
150
153
|
)
|
|
151
154
|
|
|
155
|
+
@cached_property
|
|
156
|
+
def chatter(self) -> ChatAgent:
|
|
157
|
+
edb_path = self.workspace / "checkpoint.db"
|
|
158
|
+
edb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
econn = sqlite3.connect(str(edb_path), check_same_thread=False)
|
|
160
|
+
self.chatter_checkpointer = SqliteSaver(econn)
|
|
161
|
+
return ChatAgent(
|
|
162
|
+
llm=self.model,
|
|
163
|
+
checkpointer=self.chatter_checkpointer,
|
|
164
|
+
thread_id=self.thread_id + "_chatter",
|
|
165
|
+
)
|
|
166
|
+
|
|
152
167
|
@cached_property
|
|
153
168
|
def executor(self) -> ExecutionAgent:
|
|
154
|
-
edb_path = self.workspace / "
|
|
169
|
+
edb_path = self.workspace / "checkpoint.db"
|
|
155
170
|
edb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
156
171
|
econn = sqlite3.connect(str(edb_path), check_same_thread=False)
|
|
157
172
|
self.executor_checkpointer = SqliteSaver(econn)
|
|
@@ -159,29 +174,35 @@ class HITL:
|
|
|
159
174
|
llm=self.model,
|
|
160
175
|
checkpointer=self.executor_checkpointer,
|
|
161
176
|
agent_memory=self.memory,
|
|
177
|
+
thread_id=self.thread_id + "_executor",
|
|
162
178
|
)
|
|
163
179
|
|
|
164
180
|
@cached_property
|
|
165
181
|
def planner(self) -> PlanningAgent:
|
|
166
|
-
pdb_path = Path(self.workspace) / "
|
|
182
|
+
pdb_path = Path(self.workspace) / "checkpoint.db"
|
|
167
183
|
pdb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
168
184
|
pconn = sqlite3.connect(str(pdb_path), check_same_thread=False)
|
|
169
185
|
self.planner_checkpointer = SqliteSaver(pconn)
|
|
170
186
|
return PlanningAgent(
|
|
171
187
|
llm=self.model,
|
|
172
188
|
checkpointer=self.planner_checkpointer,
|
|
189
|
+
thread_id=self.thread_id + "_planner",
|
|
173
190
|
)
|
|
174
191
|
|
|
175
192
|
@cached_property
|
|
176
193
|
def websearcher(self) -> WebSearchAgent:
|
|
177
|
-
rdb_path = Path(self.workspace) / "
|
|
194
|
+
rdb_path = Path(self.workspace) / "checkpoint.db"
|
|
178
195
|
rdb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
179
196
|
rconn = sqlite3.connect(str(rdb_path), check_same_thread=False)
|
|
180
197
|
self.websearcher_checkpointer = SqliteSaver(rconn)
|
|
181
198
|
|
|
182
199
|
return WebSearchAgent(
|
|
183
200
|
llm=self.model,
|
|
201
|
+
max_results=10,
|
|
202
|
+
database_path="web_db",
|
|
203
|
+
summaries_path="web_summaries",
|
|
184
204
|
checkpointer=self.websearcher_checkpointer,
|
|
205
|
+
thread_id=self.thread_id + "_websearch",
|
|
185
206
|
)
|
|
186
207
|
|
|
187
208
|
@cached_property
|
|
@@ -249,13 +270,19 @@ class HITL:
|
|
|
249
270
|
return f"[Rememberer Output]:\n {memory_output}"
|
|
250
271
|
|
|
251
272
|
def run_chatter(self, prompt: str) -> str:
|
|
252
|
-
|
|
253
|
-
|
|
273
|
+
self.chatter_state["messages"].append(
|
|
274
|
+
HumanMessage(
|
|
275
|
+
content=f"The last agent output was: {self.last_agent_result}\n The user stated: {prompt}"
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
self.chatter_state = self.chatter.invoke(
|
|
279
|
+
self.chatter_state,
|
|
254
280
|
)
|
|
281
|
+
chat_output = self.chatter_state["messages"][-1]
|
|
255
282
|
|
|
256
283
|
if not isinstance(chat_output.content, str):
|
|
257
284
|
raise TypeError(
|
|
258
|
-
f"chat_output is not a str! Instead, it is: {chat_output}."
|
|
285
|
+
f"chat_output is not a str! Instead, it is: {type(chat_output.content)}."
|
|
259
286
|
)
|
|
260
287
|
|
|
261
288
|
self.update_last_agent_result(chat_output.content)
|
|
@@ -285,35 +312,20 @@ class HITL:
|
|
|
285
312
|
return f"[Planner Agent Output]:\n {self.last_agent_result}"
|
|
286
313
|
|
|
287
314
|
def run_websearcher(self, prompt: str) -> str:
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
self.websearcher_state,
|
|
297
|
-
)
|
|
298
|
-
self.update_last_agent_result(
|
|
299
|
-
self.websearcher_state["messages"][-1].content
|
|
315
|
+
llm_search_query = self.model.invoke(
|
|
316
|
+
f"The user stated {prompt}. Generate between 1 and 8 words for a search query to address the users need. Return only the words to search."
|
|
317
|
+
).content
|
|
318
|
+
print("Searching Web for ", llm_search_query)
|
|
319
|
+
if isinstance(llm_search_query, str):
|
|
320
|
+
web_result = self.websearcher.invoke(
|
|
321
|
+
query=llm_search_query,
|
|
322
|
+
context=prompt,
|
|
300
323
|
)
|
|
324
|
+
self.websearcher_state.append(web_result)
|
|
325
|
+
self.update_last_agent_result(web_result)
|
|
326
|
+
return f"[WebSearch Agent Output]:\n {self.last_agent_result}"
|
|
301
327
|
else:
|
|
302
|
-
|
|
303
|
-
"messages": [
|
|
304
|
-
HumanMessage(
|
|
305
|
-
f"The last agent output was: {self.last_agent_result}\n"
|
|
306
|
-
f"The user stated: {prompt}"
|
|
307
|
-
)
|
|
308
|
-
]
|
|
309
|
-
}
|
|
310
|
-
self.websearcher_state = self.websearcher.invoke(
|
|
311
|
-
self.websearcher_state,
|
|
312
|
-
)
|
|
313
|
-
self.update_last_agent_result(
|
|
314
|
-
self.websearcher_state["messages"][-1].content
|
|
315
|
-
)
|
|
316
|
-
return f"[Planner Agent Output]:\n {self.last_agent_result}"
|
|
328
|
+
raise RuntimeError("Unexpected error while running WebSearchAgent!")
|
|
317
329
|
|
|
318
330
|
|
|
319
331
|
class UrsaRepl(Cmd):
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
{
|
|
2
|
+
"_note": "Prices are per 1K tokens; derived from OpenAI's table (originally per 1M). Source: https://platform.openai.com/docs/pricing",
|
|
3
|
+
"local/*": { "input_per_1k": 0.0, "output_per_1k": 0.0, "cached_input_multiplier": 1.0 },
|
|
4
|
+
|
|
5
|
+
"gpt-5": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
6
|
+
"openai/gpt-5": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
7
|
+
|
|
8
|
+
"gpt-5-mini": { "input_per_1k": 0.00025, "output_per_1k": 0.002, "cached_input_multiplier": 0.1 },
|
|
9
|
+
"openai/gpt-5-mini": { "input_per_1k": 0.00025, "output_per_1k": 0.002, "cached_input_multiplier": 0.1 },
|
|
10
|
+
|
|
11
|
+
"gpt-5-nano": { "input_per_1k": 0.00005, "output_per_1k": 0.0004, "cached_input_multiplier": 0.1 },
|
|
12
|
+
"openai/gpt-5-nano": { "input_per_1k": 0.00005, "output_per_1k": 0.0004, "cached_input_multiplier": 0.1 },
|
|
13
|
+
|
|
14
|
+
"gpt-5-chat-latest": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
15
|
+
"openai/gpt-5-chat-latest": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
16
|
+
|
|
17
|
+
"gpt-5-codex": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
18
|
+
"openai/gpt-5-codex": { "input_per_1k": 0.00125, "output_per_1k": 0.01, "cached_input_multiplier": 0.1 },
|
|
19
|
+
|
|
20
|
+
"gpt-4.1": { "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
21
|
+
"openai/gpt-4.1": { "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
22
|
+
|
|
23
|
+
"gpt-4.1-mini": { "input_per_1k": 0.0004, "output_per_1k": 0.0016, "cached_input_multiplier": 0.25 },
|
|
24
|
+
"openai/gpt-4.1-mini": { "input_per_1k": 0.0004, "output_per_1k": 0.0016, "cached_input_multiplier": 0.25 },
|
|
25
|
+
|
|
26
|
+
"gpt-4.1-nano": { "input_per_1k": 0.0001, "output_per_1k": 0.0004, "cached_input_multiplier": 0.25 },
|
|
27
|
+
"openai/gpt-4.1-nano": { "input_per_1k": 0.0001, "output_per_1k": 0.0004, "cached_input_multiplier": 0.25 },
|
|
28
|
+
|
|
29
|
+
"gpt-4o": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 0.5 },
|
|
30
|
+
"openai/gpt-4o": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 0.5 },
|
|
31
|
+
|
|
32
|
+
"gpt-4o-2024-05-13": { "input_per_1k": 0.005, "output_per_1k": 0.015, "cached_input_multiplier": 1.0 },
|
|
33
|
+
"openai/gpt-4o-2024-05-13": { "input_per_1k": 0.005, "output_per_1k": 0.015, "cached_input_multiplier": 1.0 },
|
|
34
|
+
|
|
35
|
+
"gpt-4o-mini": { "input_per_1k": 0.00015, "output_per_1k": 0.0006, "cached_input_multiplier": 0.5 },
|
|
36
|
+
"openai/gpt-4o-mini": { "input_per_1k": 0.00015, "output_per_1k": 0.0006, "cached_input_multiplier": 0.5 },
|
|
37
|
+
|
|
38
|
+
"gpt-realtime": { "input_per_1k": 0.004, "output_per_1k": 0.016, "cached_input_multiplier": 0.1 },
|
|
39
|
+
"openai/gpt-realtime": { "input_per_1k": 0.004, "output_per_1k": 0.016, "cached_input_multiplier": 0.1 },
|
|
40
|
+
|
|
41
|
+
"gpt-4o-realtime-preview": { "input_per_1k": 0.005, "output_per_1k": 0.02, "cached_input_multiplier": 0.5 },
|
|
42
|
+
"openai/gpt-4o-realtime-preview": { "input_per_1k": 0.005, "output_per_1k": 0.02, "cached_input_multiplier": 0.5 },
|
|
43
|
+
|
|
44
|
+
"gpt-4o-mini-realtime-preview":{ "input_per_1k": 0.0006, "output_per_1k": 0.0024, "cached_input_multiplier": 0.5 },
|
|
45
|
+
"openai/gpt-4o-mini-realtime-preview": { "input_per_1k": 0.0006, "output_per_1k": 0.0024, "cached_input_multiplier": 0.5 },
|
|
46
|
+
|
|
47
|
+
"gpt-audio": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 1.0 },
|
|
48
|
+
"openai/gpt-audio": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 1.0 },
|
|
49
|
+
|
|
50
|
+
"gpt-4o-audio-preview": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 1.0 },
|
|
51
|
+
"openai/gpt-4o-audio-preview": { "input_per_1k": 0.0025, "output_per_1k": 0.01, "cached_input_multiplier": 1.0 },
|
|
52
|
+
|
|
53
|
+
"gpt-4o-mini-audio-preview": { "input_per_1k": 0.00015, "output_per_1k": 0.0006, "cached_input_multiplier": 1.0 },
|
|
54
|
+
"openai/gpt-4o-mini-audio-preview": { "input_per_1k": 0.00015, "output_per_1k": 0.0006, "cached_input_multiplier": 1.0 },
|
|
55
|
+
|
|
56
|
+
"o1": { "input_per_1k": 0.015, "output_per_1k": 0.06, "cached_input_multiplier": 0.5 },
|
|
57
|
+
"openai/o1": { "input_per_1k": 0.015, "output_per_1k": 0.06, "cached_input_multiplier": 0.5 },
|
|
58
|
+
|
|
59
|
+
"o1-pro": { "input_per_1k": 0.15, "output_per_1k": 0.6, "cached_input_multiplier": 1.0 },
|
|
60
|
+
"openai/o1-pro": { "input_per_1k": 0.15, "output_per_1k": 0.6, "cached_input_multiplier": 1.0 },
|
|
61
|
+
|
|
62
|
+
"o3-pro": { "input_per_1k": 0.02, "output_per_1k": 0.08, "cached_input_multiplier": 1.0 },
|
|
63
|
+
"openai/o3-pro": { "input_per_1k": 0.02, "output_per_1k": 0.08, "cached_input_multiplier": 1.0 },
|
|
64
|
+
|
|
65
|
+
"o3": { "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
66
|
+
"openai/o3": { "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
67
|
+
|
|
68
|
+
"o3-deep-research": { "input_per_1k": 0.01, "output_per_1k": 0.04, "cached_input_multiplier": 0.25 },
|
|
69
|
+
"openai/o3-deep-research": { "input_per_1k": 0.01, "output_per_1k": 0.04, "cached_input_multiplier": 0.25 },
|
|
70
|
+
|
|
71
|
+
"o4-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.25 },
|
|
72
|
+
"openai/o4-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.25 },
|
|
73
|
+
|
|
74
|
+
"o4-mini-deep-research": { "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
75
|
+
"openai/o4-mini-deep-research":{ "input_per_1k": 0.002, "output_per_1k": 0.008, "cached_input_multiplier": 0.25 },
|
|
76
|
+
|
|
77
|
+
"o3-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.5 },
|
|
78
|
+
"openai/o3-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.5 },
|
|
79
|
+
|
|
80
|
+
"o1-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.5 },
|
|
81
|
+
"openai/o1-mini": { "input_per_1k": 0.0011, "output_per_1k": 0.0044, "cached_input_multiplier": 0.5 },
|
|
82
|
+
|
|
83
|
+
"codex-mini-latest": { "input_per_1k": 0.0015, "output_per_1k": 0.006, "cached_input_multiplier": 0.25 },
|
|
84
|
+
"openai/codex-mini-latest": { "input_per_1k": 0.0015, "output_per_1k": 0.006, "cached_input_multiplier": 0.25 }
|
|
85
|
+
}
|