ursa-ai 0.4.2__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +2 -0
- ursa/agents/arxiv_agent.py +88 -99
- ursa/agents/base.py +369 -2
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +403 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +303 -0
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0rc1.dist-info/RECORD +39 -0
- ursa_ai-0.6.0rc1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.4.2.dist-info/RECORD +0 -27
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0rc1.dist-info}/top_level.txt +0 -0
ursa/agents/execution_agent.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
4
4
|
import subprocess
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Annotated, Any, Literal, Optional
|
|
6
|
+
from typing import Annotated, Any, Literal, Mapping, Optional
|
|
7
7
|
|
|
8
8
|
import randomname
|
|
9
9
|
from langchain_community.tools import (
|
|
@@ -17,11 +17,11 @@ from langchain_core.messages import (
|
|
|
17
17
|
ToolMessage,
|
|
18
18
|
)
|
|
19
19
|
from langchain_core.tools import InjectedToolCallId, tool
|
|
20
|
-
from langgraph.graph import
|
|
20
|
+
from langgraph.graph import StateGraph
|
|
21
21
|
from langgraph.graph.message import add_messages
|
|
22
22
|
from langgraph.prebuilt import InjectedState, ToolNode
|
|
23
23
|
from langgraph.types import Command
|
|
24
|
-
from litellm import ContentPolicyViolationError
|
|
24
|
+
from litellm.exceptions import ContentPolicyViolationError
|
|
25
25
|
|
|
26
26
|
# Rich
|
|
27
27
|
from rich import get_console
|
|
@@ -29,7 +29,11 @@ from rich.panel import Panel
|
|
|
29
29
|
from rich.syntax import Syntax
|
|
30
30
|
from typing_extensions import TypedDict
|
|
31
31
|
|
|
32
|
-
from ..prompt_library.execution_prompts import
|
|
32
|
+
from ..prompt_library.execution_prompts import (
|
|
33
|
+
executor_prompt,
|
|
34
|
+
safety_prompt,
|
|
35
|
+
summarize_prompt,
|
|
36
|
+
)
|
|
33
37
|
from ..util.diff_renderer import DiffRenderer
|
|
34
38
|
from ..util.memory_logger import AgentMemory
|
|
35
39
|
from .base import BaseAgent
|
|
@@ -62,6 +66,7 @@ class ExecutionAgent(BaseAgent):
|
|
|
62
66
|
):
|
|
63
67
|
super().__init__(llm, **kwargs)
|
|
64
68
|
self.agent_memory = agent_memory
|
|
69
|
+
self.safety_prompt = safety_prompt
|
|
65
70
|
self.executor_prompt = executor_prompt
|
|
66
71
|
self.summarize_prompt = summarize_prompt
|
|
67
72
|
self.tools = [run_cmd, write_code, edit_code, search_tool]
|
|
@@ -69,7 +74,7 @@ class ExecutionAgent(BaseAgent):
|
|
|
69
74
|
self.llm = self.llm.bind_tools(self.tools)
|
|
70
75
|
self.log_state = log_state
|
|
71
76
|
|
|
72
|
-
self.
|
|
77
|
+
self._action = self._build_graph()
|
|
73
78
|
|
|
74
79
|
# Define the function that calls the model
|
|
75
80
|
def query_executor(self, state: ExecutionState) -> ExecutionState:
|
|
@@ -115,8 +120,7 @@ class ExecutionAgent(BaseAgent):
|
|
|
115
120
|
] + state["messages"]
|
|
116
121
|
try:
|
|
117
122
|
response = self.llm.invoke(
|
|
118
|
-
new_state["messages"],
|
|
119
|
-
{"configurable": {"thread_id": self.thread_id}},
|
|
123
|
+
new_state["messages"], self.build_config(tags=["agent"])
|
|
120
124
|
)
|
|
121
125
|
except ContentPolicyViolationError as e:
|
|
122
126
|
print("Error: ", e, " ", new_state["messages"][-1].content)
|
|
@@ -129,7 +133,7 @@ class ExecutionAgent(BaseAgent):
|
|
|
129
133
|
messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
|
|
130
134
|
try:
|
|
131
135
|
response = self.llm.invoke(
|
|
132
|
-
messages,
|
|
136
|
+
messages, self.build_config(tags=["summarize"])
|
|
133
137
|
)
|
|
134
138
|
except ContentPolicyViolationError as e:
|
|
135
139
|
print("Error: ", e, " ", messages[-1].content)
|
|
@@ -181,11 +185,8 @@ class ExecutionAgent(BaseAgent):
|
|
|
181
185
|
if call_name == "run_cmd":
|
|
182
186
|
query = tool_call["args"]["query"]
|
|
183
187
|
safety_check = self.llm.invoke(
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
"the files are from a trusted source. "
|
|
187
|
-
f"Explain why, followed by an answer [YES] or [NO]. Is this command safe to run: {query}"
|
|
188
|
-
)
|
|
188
|
+
self.safety_prompt + query,
|
|
189
|
+
self.build_config(tags=["safety_check"]),
|
|
189
190
|
)
|
|
190
191
|
|
|
191
192
|
if "[NO]" in safety_check.content:
|
|
@@ -222,53 +223,90 @@ class ExecutionAgent(BaseAgent):
|
|
|
222
223
|
|
|
223
224
|
return new_state
|
|
224
225
|
|
|
225
|
-
def
|
|
226
|
-
|
|
226
|
+
def _build_graph(self):
|
|
227
|
+
graph = StateGraph(ExecutionState)
|
|
227
228
|
|
|
228
|
-
self.
|
|
229
|
-
self.
|
|
230
|
-
self.
|
|
231
|
-
self.
|
|
229
|
+
self.add_node(graph, self.query_executor, "agent")
|
|
230
|
+
self.add_node(graph, self.tool_node, "action")
|
|
231
|
+
self.add_node(graph, self.summarize, "summarize")
|
|
232
|
+
self.add_node(graph, self.safety_check, "safety_check")
|
|
232
233
|
|
|
233
234
|
# Set the entrypoint as `agent`
|
|
234
235
|
# This means that this node is the first one called
|
|
235
|
-
|
|
236
|
+
graph.set_entry_point("agent")
|
|
236
237
|
|
|
237
|
-
|
|
238
|
+
graph.add_conditional_edges(
|
|
238
239
|
"agent",
|
|
239
|
-
should_continue,
|
|
240
|
-
{
|
|
241
|
-
"continue": "safety_check",
|
|
242
|
-
"summarize": "summarize",
|
|
243
|
-
},
|
|
240
|
+
self._wrap_cond(should_continue, "should_continue", "execution"),
|
|
241
|
+
{"continue": "safety_check", "summarize": "summarize"},
|
|
244
242
|
)
|
|
245
243
|
|
|
246
|
-
|
|
244
|
+
graph.add_conditional_edges(
|
|
247
245
|
"safety_check",
|
|
248
|
-
command_safe,
|
|
249
|
-
{
|
|
250
|
-
"safe": "action",
|
|
251
|
-
"unsafe": "agent",
|
|
252
|
-
},
|
|
246
|
+
self._wrap_cond(command_safe, "command_safe", "execution"),
|
|
247
|
+
{"safe": "action", "unsafe": "agent"},
|
|
253
248
|
)
|
|
254
249
|
|
|
255
|
-
|
|
256
|
-
|
|
250
|
+
graph.add_edge("action", "agent")
|
|
251
|
+
graph.set_finish_point("summarize")
|
|
257
252
|
|
|
258
|
-
|
|
253
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
259
254
|
# self.action.get_graph().draw_mermaid_png(output_file_path="execution_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
260
255
|
|
|
261
|
-
def
|
|
262
|
-
inputs
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
256
|
+
def _invoke(
|
|
257
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
|
|
258
|
+
):
|
|
259
|
+
config = self.build_config(
|
|
260
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
261
|
+
)
|
|
262
|
+
return self._action.invoke(inputs, config)
|
|
263
|
+
|
|
264
|
+
# this is trying to stop people bypassing invoke
|
|
265
|
+
@property
|
|
266
|
+
def action(self):
|
|
267
|
+
raise AttributeError(
|
|
268
|
+
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
269
269
|
)
|
|
270
270
|
|
|
271
271
|
|
|
272
|
+
def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
|
|
273
|
+
if text is None:
|
|
274
|
+
return "", False
|
|
275
|
+
if max_chars <= 0:
|
|
276
|
+
return "", len(text) > 0
|
|
277
|
+
if len(text) <= max_chars:
|
|
278
|
+
return text, False
|
|
279
|
+
head = max_chars // 2
|
|
280
|
+
tail = max_chars - head
|
|
281
|
+
return (
|
|
282
|
+
text[:head]
|
|
283
|
+
+ f"\n... [snipped {len(text) - max_chars} chars] ...\n"
|
|
284
|
+
+ text[-tail:],
|
|
285
|
+
True,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
|
|
290
|
+
label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
|
|
291
|
+
budget = max(0, total_budget - label_overhead)
|
|
292
|
+
|
|
293
|
+
if len(stdout) + len(stderr) <= budget:
|
|
294
|
+
return stdout, stderr
|
|
295
|
+
|
|
296
|
+
total_len = max(1, len(stdout) + len(stderr))
|
|
297
|
+
stdout_budget = int(budget * (len(stdout) / total_len))
|
|
298
|
+
stderr_budget = budget - stdout_budget
|
|
299
|
+
|
|
300
|
+
stdout_snip, _ = _snip_text(stdout, stdout_budget)
|
|
301
|
+
stderr_snip, _ = _snip_text(stderr, stderr_budget)
|
|
302
|
+
return stdout_snip, stderr_snip
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# the idea here is that we just set a limit - the user could overload
|
|
306
|
+
# that in their env, or maybe we could pull this out of the LLM parameters
|
|
307
|
+
MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
|
|
308
|
+
|
|
309
|
+
|
|
272
310
|
@tool
|
|
273
311
|
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
274
312
|
"""
|
|
@@ -293,10 +331,15 @@ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
|
293
331
|
print("Keyboard Interrupt of command: ", query)
|
|
294
332
|
stdout, stderr = "", "KeyboardInterrupt:"
|
|
295
333
|
|
|
296
|
-
|
|
297
|
-
|
|
334
|
+
# Fit BOTH streams under a single overall cap
|
|
335
|
+
stdout_fit, stderr_fit = _fit_streams_to_budget(
|
|
336
|
+
stdout or "", stderr or "", MAX_TOOL_MSG_CHARS
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
print("STDOUT: ", stdout_fit)
|
|
340
|
+
print("STDERR: ", stderr_fit)
|
|
298
341
|
|
|
299
|
-
return f"STDOUT
|
|
342
|
+
return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
|
|
300
343
|
|
|
301
344
|
|
|
302
345
|
def _strip_fences(snippet: str) -> str:
|
|
@@ -499,8 +542,9 @@ def main():
|
|
|
499
542
|
inputs = {
|
|
500
543
|
"messages": [HumanMessage(content=problem_string)]
|
|
501
544
|
} # , "workspace":"dummy_test"}
|
|
502
|
-
result = execution_agent.
|
|
503
|
-
inputs,
|
|
545
|
+
result = execution_agent.invoke(
|
|
546
|
+
inputs,
|
|
547
|
+
config={"configurable": {"thread_id": execution_agent.thread_id}},
|
|
504
548
|
)
|
|
505
549
|
print(result["messages"][-1].content)
|
|
506
550
|
return result
|
|
@@ -3,12 +3,12 @@ import ast
|
|
|
3
3
|
# from langchain_community.tools import TavilySearchResults
|
|
4
4
|
# from textwrap import dedent
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from typing import List, Literal, TypedDict
|
|
6
|
+
from typing import Any, List, Literal, Mapping, TypedDict
|
|
7
7
|
|
|
8
8
|
from langchain_community.tools import DuckDuckGoSearchResults
|
|
9
9
|
from langchain_core.language_models import BaseChatModel
|
|
10
10
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
11
|
-
from langgraph.graph import
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
12
|
|
|
13
13
|
from ..prompt_library.hypothesizer_prompts import (
|
|
14
14
|
competitor_prompt,
|
|
@@ -53,7 +53,7 @@ class HypothesizerAgent(BaseAgent):
|
|
|
53
53
|
# max_results=10, search_depth="advanced", include_answer=False
|
|
54
54
|
# )
|
|
55
55
|
|
|
56
|
-
self.
|
|
56
|
+
self._action = self._build_graph()
|
|
57
57
|
|
|
58
58
|
def agent1_generate_solution(
|
|
59
59
|
self, state: HypothesizerState
|
|
@@ -444,68 +444,65 @@ class HypothesizerAgent(BaseAgent):
|
|
|
444
444
|
)
|
|
445
445
|
return new_state
|
|
446
446
|
|
|
447
|
-
def
|
|
447
|
+
def _build_graph(self):
|
|
448
448
|
# Initialize the graph
|
|
449
|
-
|
|
449
|
+
graph = StateGraph(HypothesizerState)
|
|
450
450
|
|
|
451
451
|
# Add nodes
|
|
452
|
-
self.
|
|
453
|
-
self.
|
|
454
|
-
self.
|
|
455
|
-
self.
|
|
456
|
-
self.
|
|
457
|
-
self.
|
|
458
|
-
self.
|
|
459
|
-
|
|
452
|
+
self.add_node(graph, self.agent1_generate_solution, "agent1")
|
|
453
|
+
self.add_node(graph, self.agent2_critique, "agent2")
|
|
454
|
+
self.add_node(graph, self.agent3_competitor_perspective, "agent3")
|
|
455
|
+
self.add_node(graph, self.increment_iteration, "increment_iteration")
|
|
456
|
+
self.add_node(graph, self.generate_solution, "finalize")
|
|
457
|
+
self.add_node(graph, self.print_visited_sites, "print_sites")
|
|
458
|
+
self.add_node(
|
|
459
|
+
graph, self.summarize_process_as_latex, "summarize_as_latex"
|
|
460
460
|
)
|
|
461
461
|
# self.graph.add_node("compile_pdf", compile_summary_to_pdf)
|
|
462
462
|
|
|
463
463
|
# Add simple edges for the known flow
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
464
|
+
graph.add_edge("agent1", "agent2")
|
|
465
|
+
graph.add_edge("agent2", "agent3")
|
|
466
|
+
graph.add_edge("agent3", "increment_iteration")
|
|
467
467
|
|
|
468
468
|
# Then from increment_iteration, we have a conditional:
|
|
469
469
|
# If we 'continue', we go back to agent1
|
|
470
470
|
# If we 'finish', we jump to the finalize node
|
|
471
|
-
|
|
471
|
+
graph.add_conditional_edges(
|
|
472
472
|
"increment_iteration",
|
|
473
473
|
should_continue,
|
|
474
474
|
{"continue": "agent1", "finish": "finalize"},
|
|
475
475
|
)
|
|
476
476
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
self.graph.add_edge("print_sites", END)
|
|
477
|
+
graph.add_edge("finalize", "summarize_as_latex")
|
|
478
|
+
graph.add_edge("summarize_as_latex", "print_sites")
|
|
480
479
|
# self.graph.add_edge("summarize_as_latex", "compile_pdf")
|
|
481
480
|
# self.graph.add_edge("compile_pdf", "print_sites")
|
|
482
481
|
|
|
483
482
|
# Set the entry point
|
|
484
|
-
|
|
483
|
+
graph.set_entry_point("agent1")
|
|
484
|
+
graph.set_finish_point("print_sites")
|
|
485
485
|
|
|
486
|
-
|
|
486
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
487
487
|
# self.action.get_graph().draw_mermaid_png(output_file_path="hypothesizer_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
488
488
|
|
|
489
|
-
def
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
max_iterations=max_iter,
|
|
495
|
-
agent1_solution=[],
|
|
496
|
-
agent2_critiques=[],
|
|
497
|
-
agent3_perspectives=[],
|
|
498
|
-
solution="",
|
|
489
|
+
def _invoke(
|
|
490
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
491
|
+
):
|
|
492
|
+
config = self.build_config(
|
|
493
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
499
494
|
)
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
495
|
+
if "prompt" not in inputs:
|
|
496
|
+
raise KeyError("'prompt' is a required arguments")
|
|
497
|
+
|
|
498
|
+
inputs["max_iterations"] = inputs.get("max_iterations", 3)
|
|
499
|
+
inputs["current_iteration"] = 0
|
|
500
|
+
inputs["agent1_solution"] = []
|
|
501
|
+
inputs["agent2_critiques"] = []
|
|
502
|
+
inputs["agent3_perspectives"] = []
|
|
503
|
+
inputs["solution"] = ""
|
|
504
|
+
|
|
505
|
+
return self._action.invoke(inputs, config)
|
|
509
506
|
|
|
510
507
|
|
|
511
508
|
def should_continue(state: HypothesizerState) -> Literal["continue", "finish"]:
|
|
@@ -583,7 +580,7 @@ if __name__ == "__main__":
|
|
|
583
580
|
|
|
584
581
|
print("[DEBUG] Invoking the graph...")
|
|
585
582
|
# Run the graph
|
|
586
|
-
result = hypothesizer_agent.
|
|
583
|
+
result = hypothesizer_agent.invoke(
|
|
587
584
|
initial_state,
|
|
588
585
|
{
|
|
589
586
|
"recursion_limit": 999999,
|
ursa/agents/lammps_agent.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import subprocess
|
|
4
|
-
from typing import Any, Dict, List, Optional, TypedDict
|
|
4
|
+
from typing import Any, Dict, List, Mapping, Optional, TypedDict
|
|
5
5
|
|
|
6
|
-
import atomman as am
|
|
7
6
|
import tiktoken
|
|
8
|
-
import trafilatura
|
|
9
7
|
from langchain_core.output_parsers import StrOutputParser
|
|
10
8
|
from langchain_core.prompts import ChatPromptTemplate
|
|
11
9
|
from langgraph.graph import END, StateGraph
|
|
12
10
|
|
|
13
11
|
from .base import BaseAgent
|
|
14
12
|
|
|
13
|
+
working = True
|
|
14
|
+
try:
|
|
15
|
+
import atomman as am
|
|
16
|
+
import trafilatura
|
|
17
|
+
except Exception:
|
|
18
|
+
working = False
|
|
19
|
+
|
|
15
20
|
|
|
16
21
|
class LammpsState(TypedDict, total=False):
|
|
17
22
|
simulation_task: str
|
|
@@ -50,6 +55,10 @@ class LammpsAgent(BaseAgent):
|
|
|
50
55
|
max_tokens: int = 200000,
|
|
51
56
|
**kwargs,
|
|
52
57
|
):
|
|
58
|
+
if not working:
|
|
59
|
+
raise ImportError(
|
|
60
|
+
"LAMMPS agent requires the atomman and trafilatura dependencies. These can be installed using 'pip install ursa-ai[lammps]' or, if working from a local installation, 'pip install -e .[lammps]' ."
|
|
61
|
+
)
|
|
53
62
|
self.max_potentials = max_potentials
|
|
54
63
|
self.max_fix_attempts = max_fix_attempts
|
|
55
64
|
self.mpi_procs = mpi_procs
|
|
@@ -144,7 +153,7 @@ class LammpsAgent(BaseAgent):
|
|
|
144
153
|
| self.str_parser
|
|
145
154
|
)
|
|
146
155
|
|
|
147
|
-
self.
|
|
156
|
+
self._action = self._build_graph()
|
|
148
157
|
|
|
149
158
|
@staticmethod
|
|
150
159
|
def _safe_json_loads(s: str) -> Dict[str, Any]:
|
|
@@ -340,53 +349,66 @@ class LammpsAgent(BaseAgent):
|
|
|
340
349
|
def _build_graph(self):
|
|
341
350
|
g = StateGraph(LammpsState)
|
|
342
351
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
352
|
+
self.add_node(g, self._find_potentials)
|
|
353
|
+
self.add_node(g, self._summarize_one)
|
|
354
|
+
self.add_node(g, self._build_summaries)
|
|
355
|
+
self.add_node(g, self._choose)
|
|
356
|
+
self.add_node(g, self._author)
|
|
357
|
+
self.add_node(g, self._run_lammps)
|
|
358
|
+
self.add_node(g, self._fix)
|
|
350
359
|
|
|
351
|
-
g.set_entry_point("
|
|
360
|
+
g.set_entry_point("_find_potentials")
|
|
352
361
|
|
|
353
362
|
g.add_conditional_edges(
|
|
354
|
-
"
|
|
363
|
+
"_find_potentials",
|
|
355
364
|
self._should_summarize,
|
|
356
365
|
{
|
|
357
|
-
"summarize_one": "
|
|
358
|
-
"summarize_done": "
|
|
366
|
+
"summarize_one": "_summarize_one",
|
|
367
|
+
"summarize_done": "_build_summaries",
|
|
359
368
|
"done_no_matches": END,
|
|
360
369
|
},
|
|
361
370
|
)
|
|
362
371
|
|
|
363
372
|
g.add_conditional_edges(
|
|
364
|
-
"
|
|
373
|
+
"_summarize_one",
|
|
365
374
|
self._should_summarize,
|
|
366
375
|
{
|
|
367
|
-
"summarize_one": "
|
|
368
|
-
"summarize_done": "
|
|
376
|
+
"summarize_one": "_summarize_one",
|
|
377
|
+
"summarize_done": "_build_summaries",
|
|
369
378
|
},
|
|
370
379
|
)
|
|
371
380
|
|
|
372
|
-
g.add_edge("
|
|
373
|
-
g.add_edge("
|
|
374
|
-
g.add_edge("
|
|
381
|
+
g.add_edge("_build_summaries", "_choose")
|
|
382
|
+
g.add_edge("_choose", "_author")
|
|
383
|
+
g.add_edge("_author", "_run_lammps")
|
|
375
384
|
|
|
376
385
|
g.add_conditional_edges(
|
|
377
|
-
"
|
|
386
|
+
"_run_lammps",
|
|
378
387
|
self._route_run,
|
|
379
388
|
{
|
|
380
|
-
"need_fix": "
|
|
389
|
+
"need_fix": "_fix",
|
|
381
390
|
"done_success": END,
|
|
382
391
|
"done_failed": END,
|
|
383
392
|
},
|
|
384
393
|
)
|
|
385
|
-
g.add_edge("
|
|
386
|
-
return g
|
|
394
|
+
g.add_edge("_fix", "_run_lammps")
|
|
395
|
+
return g.compile(checkpointer=self.checkpointer)
|
|
387
396
|
|
|
388
|
-
def
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
397
|
+
def _invoke(
|
|
398
|
+
self,
|
|
399
|
+
inputs: Mapping[str, Any],
|
|
400
|
+
*,
|
|
401
|
+
summarize: bool | None = None,
|
|
402
|
+
recursion_limit: int = 1000,
|
|
403
|
+
**_,
|
|
404
|
+
) -> str:
|
|
405
|
+
config = self.build_config(
|
|
406
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
392
407
|
)
|
|
408
|
+
|
|
409
|
+
if "simulation_task" not in inputs or "elements" not in inputs:
|
|
410
|
+
raise KeyError(
|
|
411
|
+
"'simulation_task' and 'elements' are required arguments"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
return self._action.invoke(inputs, config)
|
ursa/agents/mp_agent.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import os
|
|
3
3
|
import re
|
|
4
4
|
from concurrent.futures import ThreadPoolExecutor
|
|
5
|
-
from typing import Dict
|
|
5
|
+
from typing import Any, Dict, Mapping
|
|
6
6
|
|
|
7
7
|
from langchain_core.output_parsers import StrOutputParser
|
|
8
8
|
from langchain_core.prompts import ChatPromptTemplate
|
|
@@ -50,7 +50,7 @@ class MaterialsProjectAgent(BaseAgent):
|
|
|
50
50
|
os.makedirs(self.database_path, exist_ok=True)
|
|
51
51
|
os.makedirs(self.summaries_path, exist_ok=True)
|
|
52
52
|
|
|
53
|
-
self.
|
|
53
|
+
self._action = self._build_graph()
|
|
54
54
|
|
|
55
55
|
def _fetch_node(self, state: Dict) -> Dict:
|
|
56
56
|
f = state["query"]
|
|
@@ -148,31 +148,56 @@ You are a materials-science assistant. Given the following metadata about a mate
|
|
|
148
148
|
return {**state, "final_summary": final}
|
|
149
149
|
|
|
150
150
|
def _build_graph(self):
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
graph = StateGraph(dict) # using plain dict for state
|
|
152
|
+
self.add_node(graph, self._fetch_node)
|
|
153
153
|
if self.summarize:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
154
|
+
self.add_node(graph, self._summarize_node)
|
|
155
|
+
self.add_node(graph, self._aggregate_node)
|
|
156
|
+
|
|
157
|
+
graph.set_entry_point("_fetch_node")
|
|
158
|
+
graph.add_edge("_fetch_node", "_summarize_node")
|
|
159
|
+
graph.add_edge("_summarize_node", "_aggregate_node")
|
|
160
|
+
graph.set_finish_point("_aggregate_node")
|
|
160
161
|
else:
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
return
|
|
162
|
+
graph.set_entry_point("_fetch_node")
|
|
163
|
+
graph.set_finish_point("_fetch_node")
|
|
164
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
164
165
|
|
|
165
|
-
def
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
166
|
+
def _invoke(
|
|
167
|
+
self,
|
|
168
|
+
inputs: Mapping[str, Any],
|
|
169
|
+
*,
|
|
170
|
+
summarize: bool | None = None,
|
|
171
|
+
recursion_limit: int = 1000,
|
|
172
|
+
**_,
|
|
173
|
+
) -> str:
|
|
174
|
+
config = self.build_config(
|
|
175
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if "query" not in inputs:
|
|
179
|
+
if "mp_query" in inputs:
|
|
180
|
+
# make a shallow copy and rename the key
|
|
181
|
+
inputs = dict(inputs)
|
|
182
|
+
inputs["query"] = inputs.pop("mp_query")
|
|
183
|
+
else:
|
|
184
|
+
raise KeyError(
|
|
185
|
+
"Missing 'query' in inputs (alias 'mp_query' also accepted)."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
result = self._action.invoke(inputs, config)
|
|
189
|
+
|
|
190
|
+
use_summary = self.summarize if summarize is None else summarize
|
|
191
|
+
return (
|
|
192
|
+
result.get("final_summary", "No summary generated.")
|
|
193
|
+
if use_summary
|
|
194
|
+
else "\n\nFinished Fetching Materials Database Information!"
|
|
195
|
+
)
|
|
171
196
|
|
|
172
197
|
|
|
173
198
|
if __name__ == "__main__":
|
|
174
199
|
agent = MaterialsProjectAgent()
|
|
175
|
-
resp = agent.
|
|
200
|
+
resp = agent.invoke(
|
|
176
201
|
mp_query="LiFePO4",
|
|
177
202
|
context="What is its band gap and stability, and any synthesis challenges?",
|
|
178
203
|
)
|