ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import subprocess
|
|
4
|
+
from typing import Any, Mapping, Optional, TypedDict
|
|
5
|
+
|
|
6
|
+
import tiktoken
|
|
7
|
+
from langchain.chat_models import BaseChatModel
|
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
10
|
+
from langgraph.graph import END, StateGraph
|
|
11
|
+
|
|
12
|
+
from .base import BaseAgent
|
|
13
|
+
|
|
14
|
+
working = True
|
|
15
|
+
try:
|
|
16
|
+
import atomman as am
|
|
17
|
+
import trafilatura
|
|
18
|
+
except Exception:
|
|
19
|
+
working = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LammpsState(TypedDict, total=False):
|
|
23
|
+
simulation_task: str
|
|
24
|
+
elements: list[str]
|
|
25
|
+
template: Optional[str]
|
|
26
|
+
chosen_potential: Optional[Any]
|
|
27
|
+
|
|
28
|
+
matches: list[Any]
|
|
29
|
+
idx: int
|
|
30
|
+
summaries: list[str]
|
|
31
|
+
full_texts: list[str]
|
|
32
|
+
summaries_combined: str
|
|
33
|
+
|
|
34
|
+
input_script: str
|
|
35
|
+
run_returncode: Optional[int]
|
|
36
|
+
run_stdout: str
|
|
37
|
+
run_stderr: str
|
|
38
|
+
|
|
39
|
+
fix_attempts: int
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LammpsAgent(BaseAgent):
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
llm: BaseChatModel,
|
|
46
|
+
max_potentials: int = 5,
|
|
47
|
+
max_fix_attempts: int = 10,
|
|
48
|
+
find_potential_only: bool = False,
|
|
49
|
+
mpi_procs: int = 8,
|
|
50
|
+
workspace: str = "./workspace",
|
|
51
|
+
lammps_cmd: str = "lmp_mpi",
|
|
52
|
+
mpirun_cmd: str = "mpirun",
|
|
53
|
+
tiktoken_model: str = "gpt-5-mini",
|
|
54
|
+
max_tokens: int = 200000,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
if not working:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"LAMMPS agent requires the atomman and trafilatura dependencies. These can be installed using 'pip install ursa-ai[lammps]' or, if working from a local installation, 'pip install -e .[lammps]' ."
|
|
60
|
+
)
|
|
61
|
+
self.max_potentials = max_potentials
|
|
62
|
+
self.max_fix_attempts = max_fix_attempts
|
|
63
|
+
self.find_potential_only = find_potential_only
|
|
64
|
+
self.mpi_procs = mpi_procs
|
|
65
|
+
self.lammps_cmd = lammps_cmd
|
|
66
|
+
self.mpirun_cmd = mpirun_cmd
|
|
67
|
+
self.tiktoken_model = tiktoken_model
|
|
68
|
+
self.max_tokens = max_tokens
|
|
69
|
+
|
|
70
|
+
self.pair_styles = [
|
|
71
|
+
"eam",
|
|
72
|
+
"eam/alloy",
|
|
73
|
+
"eam/fs",
|
|
74
|
+
"meam",
|
|
75
|
+
"adp",
|
|
76
|
+
"kim",
|
|
77
|
+
"snap",
|
|
78
|
+
"quip",
|
|
79
|
+
"mlip",
|
|
80
|
+
"pace",
|
|
81
|
+
"nep",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
self.workspace = workspace
|
|
85
|
+
os.makedirs(self.workspace, exist_ok=True)
|
|
86
|
+
|
|
87
|
+
super().__init__(llm, **kwargs)
|
|
88
|
+
|
|
89
|
+
self.str_parser = StrOutputParser()
|
|
90
|
+
|
|
91
|
+
self.summ_chain = (
|
|
92
|
+
ChatPromptTemplate.from_template(
|
|
93
|
+
"Here is some data about an interatomic potential: {metadata}\n\n"
|
|
94
|
+
"Briefly summarize why it could be useful for this task: {simulation_task}."
|
|
95
|
+
)
|
|
96
|
+
| self.llm
|
|
97
|
+
| self.str_parser
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.choose_chain = (
|
|
101
|
+
ChatPromptTemplate.from_template(
|
|
102
|
+
"Here are the summaries of a certain number of interatomic potentials: {summaries_combined}\n\n"
|
|
103
|
+
"Pick one potential which would be most useful for this task: {simulation_task}.\n\n"
|
|
104
|
+
"Return your answer **only** as valid JSON, with no extra text or formatting.\n\n"
|
|
105
|
+
"Use this exact schema:\n"
|
|
106
|
+
"{{\n"
|
|
107
|
+
' "Chosen index": <int>,\n'
|
|
108
|
+
' "rationale": "<string>",\n'
|
|
109
|
+
' "Potential name": "<string>"\n'
|
|
110
|
+
"}}\n"
|
|
111
|
+
)
|
|
112
|
+
| self.llm
|
|
113
|
+
| self.str_parser
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.author_chain = (
|
|
117
|
+
ChatPromptTemplate.from_template(
|
|
118
|
+
"Your task is to write a LAMMPS input file for this purpose: {simulation_task}.\n"
|
|
119
|
+
"Note that all potential files are in the './' directory.\n"
|
|
120
|
+
"Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
|
|
121
|
+
"If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
|
|
122
|
+
"Template provided (if any): {template}\n"
|
|
123
|
+
"Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
|
|
124
|
+
"To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
|
|
125
|
+
"Return your answer **only** as valid JSON, with no extra text or formatting.\n"
|
|
126
|
+
"Use this exact schema:\n"
|
|
127
|
+
"{{\n"
|
|
128
|
+
' "input_script": "<string>"\n'
|
|
129
|
+
"}}\n"
|
|
130
|
+
)
|
|
131
|
+
| self.llm
|
|
132
|
+
| self.str_parser
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.fix_chain = (
|
|
136
|
+
ChatPromptTemplate.from_template(
|
|
137
|
+
"You are part of a larger scientific workflow whose purpose is to accomplish this task: {simulation_task}\n"
|
|
138
|
+
"For this purpose, this input file for LAMMPS was written: {input_script}\n"
|
|
139
|
+
"However, when running the simulation, an error was raised.\n"
|
|
140
|
+
"Here is the full stdout message that includes the error message: {err_message}\n"
|
|
141
|
+
"Your task is to write a new input file that resolves the error.\n"
|
|
142
|
+
"Note that all potential files are in the './' directory.\n"
|
|
143
|
+
"Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
|
|
144
|
+
"If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
|
|
145
|
+
"Template provided (if any): {template}\n"
|
|
146
|
+
"Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
|
|
147
|
+
"To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
|
|
148
|
+
"Return your answer **only** as valid JSON, with no extra text or formatting.\n"
|
|
149
|
+
"Use this exact schema:\n"
|
|
150
|
+
"{{\n"
|
|
151
|
+
' "input_script": "<string>"\n'
|
|
152
|
+
"}}\n"
|
|
153
|
+
)
|
|
154
|
+
| self.llm
|
|
155
|
+
| self.str_parser
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self._action = self._build_graph()
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _safe_json_loads(s: str) -> dict[str, Any]:
|
|
162
|
+
s = s.strip()
|
|
163
|
+
if s.startswith("```"):
|
|
164
|
+
s = s.strip("`")
|
|
165
|
+
i = s.find("\n")
|
|
166
|
+
if i != -1:
|
|
167
|
+
s = s[i + 1 :].strip()
|
|
168
|
+
return json.loads(s)
|
|
169
|
+
|
|
170
|
+
def _fetch_and_trim_text(self, url: str) -> str:
|
|
171
|
+
downloaded = trafilatura.fetch_url(url)
|
|
172
|
+
if not downloaded:
|
|
173
|
+
return "No metadata available"
|
|
174
|
+
text = trafilatura.extract(
|
|
175
|
+
downloaded,
|
|
176
|
+
include_comments=False,
|
|
177
|
+
include_tables=True,
|
|
178
|
+
include_links=False,
|
|
179
|
+
favor_recall=True,
|
|
180
|
+
)
|
|
181
|
+
if not text:
|
|
182
|
+
return "No metadata available"
|
|
183
|
+
text = text.strip()
|
|
184
|
+
try:
|
|
185
|
+
enc = tiktoken.encoding_for_model(self.tiktoken_model)
|
|
186
|
+
toks = enc.encode(text)
|
|
187
|
+
if len(toks) > self.max_tokens:
|
|
188
|
+
toks = toks[: self.max_tokens]
|
|
189
|
+
text = enc.decode(toks)
|
|
190
|
+
except Exception:
|
|
191
|
+
pass
|
|
192
|
+
return text
|
|
193
|
+
|
|
194
|
+
def _entry_router(self, state: LammpsState) -> dict:
|
|
195
|
+
if self.find_potential_only and state.get("chosen_potential"):
|
|
196
|
+
raise Exception(
|
|
197
|
+
"You cannot set find_potential_only=True and also specify your own potential!"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if not state.get("chosen_potential"):
|
|
201
|
+
self.potential_summaries_dir = os.path.join(
|
|
202
|
+
self.workspace, "potential_summaries"
|
|
203
|
+
)
|
|
204
|
+
os.makedirs(self.potential_summaries_dir, exist_ok=True)
|
|
205
|
+
return {}
|
|
206
|
+
|
|
207
|
+
def _find_potentials(self, state: LammpsState) -> LammpsState:
|
|
208
|
+
db = am.library.Database(remote=True)
|
|
209
|
+
matches = db.get_lammps_potentials(
|
|
210
|
+
pair_style=self.pair_styles, elements=state["elements"]
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return {
|
|
214
|
+
**state,
|
|
215
|
+
"matches": list(matches),
|
|
216
|
+
"idx": 0,
|
|
217
|
+
"summaries": [],
|
|
218
|
+
"full_texts": [],
|
|
219
|
+
"fix_attempts": 0,
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def _should_summarize(self, state: LammpsState) -> str:
|
|
223
|
+
matches = state.get("matches", [])
|
|
224
|
+
i = state.get("idx", 0)
|
|
225
|
+
if not matches:
|
|
226
|
+
print("No potentials found in NIST for this task. Exiting....")
|
|
227
|
+
return "done_no_matches"
|
|
228
|
+
if i < min(self.max_potentials, len(matches)):
|
|
229
|
+
return "summarize_one"
|
|
230
|
+
return "summarize_done"
|
|
231
|
+
|
|
232
|
+
def _summarize_one(self, state: LammpsState) -> LammpsState:
|
|
233
|
+
i = state["idx"]
|
|
234
|
+
print(f"Summarizing potential #{i}")
|
|
235
|
+
match = state["matches"][i]
|
|
236
|
+
md = match.metadata()
|
|
237
|
+
|
|
238
|
+
if md.get("comments") is None:
|
|
239
|
+
text = "No metadata available"
|
|
240
|
+
summary = "No summary available"
|
|
241
|
+
else:
|
|
242
|
+
lines = md["comments"].split("\n")
|
|
243
|
+
url = lines[1] if len(lines) > 1 else ""
|
|
244
|
+
text = (
|
|
245
|
+
self._fetch_and_trim_text(url)
|
|
246
|
+
if url
|
|
247
|
+
else "No metadata available"
|
|
248
|
+
)
|
|
249
|
+
summary = self.summ_chain.invoke({
|
|
250
|
+
"metadata": text,
|
|
251
|
+
"simulation_task": state["simulation_task"],
|
|
252
|
+
})
|
|
253
|
+
|
|
254
|
+
summary_file = os.path.join(
|
|
255
|
+
self.potential_summaries_dir, "potential_" + str(i) + ".txt"
|
|
256
|
+
)
|
|
257
|
+
with open(summary_file, "w") as f:
|
|
258
|
+
f.write(summary)
|
|
259
|
+
|
|
260
|
+
return {
|
|
261
|
+
**state,
|
|
262
|
+
"idx": i + 1,
|
|
263
|
+
"summaries": [*state["summaries"], summary],
|
|
264
|
+
"full_texts": [*state["full_texts"], text],
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
def _build_summaries(self, state: LammpsState) -> LammpsState:
|
|
268
|
+
parts = []
|
|
269
|
+
for i, s in enumerate(state["summaries"]):
|
|
270
|
+
rec = state["matches"][i]
|
|
271
|
+
parts.append(f"\nSummary of potential #{i}: {rec.id}\n{s}\n")
|
|
272
|
+
return {**state, "summaries_combined": "".join(parts)}
|
|
273
|
+
|
|
274
|
+
def _choose(self, state: LammpsState) -> LammpsState:
|
|
275
|
+
print("Choosing one potential for this task...")
|
|
276
|
+
choice = self.choose_chain.invoke({
|
|
277
|
+
"summaries_combined": state["summaries_combined"],
|
|
278
|
+
"simulation_task": state["simulation_task"],
|
|
279
|
+
})
|
|
280
|
+
choice_dict = self._safe_json_loads(choice)
|
|
281
|
+
chosen_index = int(choice_dict["Chosen index"])
|
|
282
|
+
|
|
283
|
+
print(f"Chosen potential #{chosen_index}")
|
|
284
|
+
print("Rationale for choosing this potential:")
|
|
285
|
+
print(choice_dict["rationale"])
|
|
286
|
+
|
|
287
|
+
chosen_potential = state["matches"][chosen_index]
|
|
288
|
+
|
|
289
|
+
out_file = os.path.join(self.potential_summaries_dir, "Rationale.txt")
|
|
290
|
+
with open(out_file, "w") as f:
|
|
291
|
+
f.write(f"Chosen potential #{chosen_index}")
|
|
292
|
+
f.write("\n")
|
|
293
|
+
f.write("Rationale for choosing this potential:")
|
|
294
|
+
f.write("\n")
|
|
295
|
+
f.write(choice_dict["rationale"])
|
|
296
|
+
|
|
297
|
+
return {**state, "chosen_potential": chosen_potential}
|
|
298
|
+
|
|
299
|
+
def _route_after_summarization(self, state: LammpsState) -> str:
|
|
300
|
+
if self.find_potential_only:
|
|
301
|
+
return "Exit"
|
|
302
|
+
return "continue_author"
|
|
303
|
+
|
|
304
|
+
def _author(self, state: LammpsState) -> LammpsState:
|
|
305
|
+
print("First attempt at writing LAMMPS input file....")
|
|
306
|
+
state["chosen_potential"].download_files(self.workspace)
|
|
307
|
+
pair_info = state["chosen_potential"].pair_info()
|
|
308
|
+
authored_json = self.author_chain.invoke({
|
|
309
|
+
"simulation_task": state["simulation_task"],
|
|
310
|
+
"pair_info": pair_info,
|
|
311
|
+
"template": state["template"],
|
|
312
|
+
})
|
|
313
|
+
script_dict = self._safe_json_loads(authored_json)
|
|
314
|
+
input_script = script_dict["input_script"]
|
|
315
|
+
with open(os.path.join(self.workspace, "in.lammps"), "w") as f:
|
|
316
|
+
f.write(input_script)
|
|
317
|
+
return {**state, "input_script": input_script}
|
|
318
|
+
|
|
319
|
+
def _run_lammps(self, state: LammpsState) -> LammpsState:
|
|
320
|
+
print("Running LAMMPS....")
|
|
321
|
+
result = subprocess.run(
|
|
322
|
+
[
|
|
323
|
+
self.mpirun_cmd,
|
|
324
|
+
"-np",
|
|
325
|
+
str(self.mpi_procs),
|
|
326
|
+
self.lammps_cmd,
|
|
327
|
+
"-in",
|
|
328
|
+
"in.lammps",
|
|
329
|
+
],
|
|
330
|
+
cwd=self.workspace,
|
|
331
|
+
stdout=subprocess.PIPE,
|
|
332
|
+
stderr=subprocess.PIPE,
|
|
333
|
+
text=True,
|
|
334
|
+
check=False,
|
|
335
|
+
)
|
|
336
|
+
return {
|
|
337
|
+
**state,
|
|
338
|
+
"run_returncode": result.returncode,
|
|
339
|
+
"run_stdout": result.stdout,
|
|
340
|
+
"run_stderr": result.stderr,
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
def _route_run(self, state: LammpsState) -> str:
|
|
344
|
+
rc = state.get("run_returncode", 0)
|
|
345
|
+
attempts = state.get("fix_attempts", 0)
|
|
346
|
+
if rc == 0:
|
|
347
|
+
print("LAMMPS run successful! Exiting...")
|
|
348
|
+
return "done_success"
|
|
349
|
+
if attempts < self.max_fix_attempts:
|
|
350
|
+
print("LAMMPS run Failed. Attempting to rewrite input file...")
|
|
351
|
+
return "need_fix"
|
|
352
|
+
print("LAMMPS run Failed and maximum fix attempts reached. Exiting...")
|
|
353
|
+
return "done_failed"
|
|
354
|
+
|
|
355
|
+
def _fix(self, state: LammpsState) -> LammpsState:
|
|
356
|
+
pair_info = state["chosen_potential"].pair_info()
|
|
357
|
+
err_blob = state.get("run_stdout")
|
|
358
|
+
|
|
359
|
+
fixed_json = self.fix_chain.invoke({
|
|
360
|
+
"simulation_task": state["simulation_task"],
|
|
361
|
+
"input_script": state["input_script"],
|
|
362
|
+
"err_message": err_blob,
|
|
363
|
+
"pair_info": pair_info,
|
|
364
|
+
"template": state["template"],
|
|
365
|
+
})
|
|
366
|
+
script_dict = self._safe_json_loads(fixed_json)
|
|
367
|
+
new_input = script_dict["input_script"]
|
|
368
|
+
with open(os.path.join(self.workspace, "in.lammps"), "w") as f:
|
|
369
|
+
f.write(new_input)
|
|
370
|
+
return {
|
|
371
|
+
**state,
|
|
372
|
+
"input_script": new_input,
|
|
373
|
+
"fix_attempts": state.get("fix_attempts", 0) + 1,
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
def _build_graph(self):
|
|
377
|
+
g = StateGraph(LammpsState)
|
|
378
|
+
|
|
379
|
+
self.add_node(g, self._entry_router)
|
|
380
|
+
self.add_node(g, self._find_potentials)
|
|
381
|
+
self.add_node(g, self._summarize_one)
|
|
382
|
+
self.add_node(g, self._build_summaries)
|
|
383
|
+
self.add_node(g, self._choose)
|
|
384
|
+
self.add_node(g, self._author)
|
|
385
|
+
self.add_node(g, self._run_lammps)
|
|
386
|
+
self.add_node(g, self._fix)
|
|
387
|
+
|
|
388
|
+
g.set_entry_point("_entry_router")
|
|
389
|
+
|
|
390
|
+
g.add_conditional_edges(
|
|
391
|
+
"_entry_router",
|
|
392
|
+
lambda state: "user_choice"
|
|
393
|
+
if state.get("chosen_potential")
|
|
394
|
+
else "agent_choice",
|
|
395
|
+
{
|
|
396
|
+
"user_choice": "_author",
|
|
397
|
+
"agent_choice": "_find_potentials",
|
|
398
|
+
},
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
g.add_conditional_edges(
|
|
402
|
+
"_find_potentials",
|
|
403
|
+
self._should_summarize,
|
|
404
|
+
{
|
|
405
|
+
"summarize_one": "_summarize_one",
|
|
406
|
+
"summarize_done": "_build_summaries",
|
|
407
|
+
"done_no_matches": END,
|
|
408
|
+
},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
g.add_conditional_edges(
|
|
412
|
+
"_summarize_one",
|
|
413
|
+
self._should_summarize,
|
|
414
|
+
{
|
|
415
|
+
"summarize_one": "_summarize_one",
|
|
416
|
+
"summarize_done": "_build_summaries",
|
|
417
|
+
},
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
g.add_edge("_build_summaries", "_choose")
|
|
421
|
+
|
|
422
|
+
g.add_conditional_edges(
|
|
423
|
+
"_choose",
|
|
424
|
+
self._route_after_summarization,
|
|
425
|
+
{
|
|
426
|
+
"continue_author": "_author",
|
|
427
|
+
"Exit": END,
|
|
428
|
+
},
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
g.add_edge("_author", "_run_lammps")
|
|
432
|
+
|
|
433
|
+
g.add_conditional_edges(
|
|
434
|
+
"_run_lammps",
|
|
435
|
+
self._route_run,
|
|
436
|
+
{
|
|
437
|
+
"need_fix": "_fix",
|
|
438
|
+
"done_success": END,
|
|
439
|
+
"done_failed": END,
|
|
440
|
+
},
|
|
441
|
+
)
|
|
442
|
+
g.add_edge("_fix", "_run_lammps")
|
|
443
|
+
return g.compile(checkpointer=self.checkpointer)
|
|
444
|
+
|
|
445
|
+
def _invoke(
|
|
446
|
+
self,
|
|
447
|
+
inputs: Mapping[str, Any],
|
|
448
|
+
*,
|
|
449
|
+
summarize: bool | None = None,
|
|
450
|
+
recursion_limit: int = 999_999,
|
|
451
|
+
**_,
|
|
452
|
+
) -> str:
|
|
453
|
+
config = self.build_config(
|
|
454
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if "simulation_task" not in inputs or "elements" not in inputs:
|
|
458
|
+
raise KeyError(
|
|
459
|
+
"'simulation_task' and 'elements' are required arguments"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if "template" not in inputs:
|
|
463
|
+
inputs = {**inputs, "template": "No template provided."}
|
|
464
|
+
|
|
465
|
+
return self._action.invoke(inputs, config)
|
ursa/agents/mp_agent.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
|
+
from typing import Any, Mapping, TypedDict
|
|
6
|
+
|
|
7
|
+
from langchain.chat_models import BaseChatModel
|
|
8
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
9
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
10
|
+
from langgraph.graph import StateGraph
|
|
11
|
+
from mp_api.client import MPRester
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from .base import BaseAgent
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PaperMetadata(TypedDict):
|
|
18
|
+
arxiv_id: str
|
|
19
|
+
full_text: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PaperState(TypedDict, total=False):
|
|
23
|
+
query: str
|
|
24
|
+
context: str
|
|
25
|
+
papers: list[PaperMetadata]
|
|
26
|
+
summaries: list[str]
|
|
27
|
+
final_summary: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def remove_surrogates(text: str) -> str:
|
|
31
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MaterialsProjectAgent(BaseAgent):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
llm: BaseChatModel,
|
|
38
|
+
summarize: bool = True,
|
|
39
|
+
max_results: int = 3,
|
|
40
|
+
database_path: str = "mp_database",
|
|
41
|
+
summaries_path: str = "mp_summaries",
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
super().__init__(llm, **kwargs)
|
|
45
|
+
self.summarize = summarize
|
|
46
|
+
self.max_results = max_results
|
|
47
|
+
self.database_path = database_path
|
|
48
|
+
self.summaries_path = summaries_path
|
|
49
|
+
|
|
50
|
+
os.makedirs(self.database_path, exist_ok=True)
|
|
51
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
self._action = self._build_graph()
|
|
54
|
+
|
|
55
|
+
def _fetch_node(self, state: dict) -> dict:
|
|
56
|
+
f = state["query"]
|
|
57
|
+
els = f["elements"] # e.g. ["Ga","In"]
|
|
58
|
+
bg = (f["band_gap_min"], f["band_gap_max"])
|
|
59
|
+
e_above_hull = (0, 0) # only on-hull (stable)
|
|
60
|
+
mats = []
|
|
61
|
+
with MPRester() as mpr:
|
|
62
|
+
# get ALL matching materials…
|
|
63
|
+
all_results = mpr.materials.summary.search(
|
|
64
|
+
elements=els,
|
|
65
|
+
band_gap=bg,
|
|
66
|
+
energy_above_hull=e_above_hull,
|
|
67
|
+
is_stable=True, # equivalent filter
|
|
68
|
+
)
|
|
69
|
+
# …then take only the first `max_results`
|
|
70
|
+
for doc in all_results[: self.max_results]:
|
|
71
|
+
mid = doc.material_id
|
|
72
|
+
data = doc.dict()
|
|
73
|
+
# cache to disk
|
|
74
|
+
path = os.path.join(self.database_path, f"{mid}.json")
|
|
75
|
+
if not os.path.exists(path):
|
|
76
|
+
with open(path, "w") as f:
|
|
77
|
+
json.dump(data, f, indent=2)
|
|
78
|
+
mats.append({"material_id": mid, "metadata": data})
|
|
79
|
+
|
|
80
|
+
return {**state, "materials": mats}
|
|
81
|
+
|
|
82
|
+
def _summarize_node(self, state: dict) -> dict:
|
|
83
|
+
"""Summarize each material via LLM over its metadata."""
|
|
84
|
+
# prompt template
|
|
85
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
86
|
+
You are a materials-science assistant. Given the following metadata about a material, produce a concise summary focusing on its key properties:
|
|
87
|
+
|
|
88
|
+
{metadata}
|
|
89
|
+
""")
|
|
90
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
91
|
+
|
|
92
|
+
summaries = [None] * len(state["materials"])
|
|
93
|
+
|
|
94
|
+
def process(i, mat):
|
|
95
|
+
mid = mat["material_id"]
|
|
96
|
+
meta = mat["metadata"]
|
|
97
|
+
# flatten metadata to text
|
|
98
|
+
text = "\n".join(f"{k}: {v}" for k, v in meta.items())
|
|
99
|
+
# build or load summary
|
|
100
|
+
summary_file = os.path.join(
|
|
101
|
+
self.summaries_path, f"{mid}_summary.txt"
|
|
102
|
+
)
|
|
103
|
+
if os.path.exists(summary_file):
|
|
104
|
+
with open(summary_file) as f:
|
|
105
|
+
return i, f.read()
|
|
106
|
+
# optional: vectorize & retrieve, but here we just summarize full text
|
|
107
|
+
result = chain.invoke({"metadata": text})
|
|
108
|
+
with open(summary_file, "w") as f:
|
|
109
|
+
f.write(result)
|
|
110
|
+
return i, result
|
|
111
|
+
|
|
112
|
+
with ThreadPoolExecutor(
|
|
113
|
+
max_workers=min(8, len(state["materials"]))
|
|
114
|
+
) as exe:
|
|
115
|
+
futures = [
|
|
116
|
+
exe.submit(process, i, m)
|
|
117
|
+
for i, m in enumerate(state["materials"])
|
|
118
|
+
]
|
|
119
|
+
for future in tqdm(futures, desc="Summarizing materials"):
|
|
120
|
+
i, summ = future.result()
|
|
121
|
+
summaries[i] = summ
|
|
122
|
+
|
|
123
|
+
return {**state, "summaries": summaries}
|
|
124
|
+
|
|
125
|
+
def _aggregate_node(self, state: dict) -> dict:
|
|
126
|
+
"""Combine all summaries into a single, coherent answer."""
|
|
127
|
+
combined = "\n\n----\n\n".join(
|
|
128
|
+
f"[{i + 1}] {m['material_id']}\n\n{summary}"
|
|
129
|
+
for i, (m, summary) in enumerate(
|
|
130
|
+
zip(state["materials"], state["summaries"])
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
135
|
+
You are a materials informatics assistant. Below are brief summaries of several materials:
|
|
136
|
+
|
|
137
|
+
{summaries}
|
|
138
|
+
|
|
139
|
+
Answer the user’s question in context:
|
|
140
|
+
|
|
141
|
+
{context}
|
|
142
|
+
""")
|
|
143
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
144
|
+
final = chain.invoke({
|
|
145
|
+
"summaries": combined,
|
|
146
|
+
"context": state["context"],
|
|
147
|
+
})
|
|
148
|
+
return {**state, "final_summary": final}
|
|
149
|
+
|
|
150
|
+
def _build_graph(self):
|
|
151
|
+
graph = StateGraph(dict) # using plain dict for state
|
|
152
|
+
self.add_node(graph, self._fetch_node)
|
|
153
|
+
if self.summarize:
|
|
154
|
+
self.add_node(graph, self._summarize_node)
|
|
155
|
+
self.add_node(graph, self._aggregate_node)
|
|
156
|
+
|
|
157
|
+
graph.set_entry_point("_fetch_node")
|
|
158
|
+
graph.add_edge("_fetch_node", "_summarize_node")
|
|
159
|
+
graph.add_edge("_summarize_node", "_aggregate_node")
|
|
160
|
+
graph.set_finish_point("_aggregate_node")
|
|
161
|
+
else:
|
|
162
|
+
graph.set_entry_point("_fetch_node")
|
|
163
|
+
graph.set_finish_point("_fetch_node")
|
|
164
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
165
|
+
|
|
166
|
+
def _invoke(
|
|
167
|
+
self,
|
|
168
|
+
inputs: Mapping[str, Any],
|
|
169
|
+
*,
|
|
170
|
+
summarize: bool | None = None,
|
|
171
|
+
recursion_limit: int = 1000,
|
|
172
|
+
**_,
|
|
173
|
+
) -> str:
|
|
174
|
+
config = self.build_config(
|
|
175
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if "query" not in inputs:
|
|
179
|
+
if "mp_query" in inputs:
|
|
180
|
+
# make a shallow copy and rename the key
|
|
181
|
+
inputs = dict(inputs)
|
|
182
|
+
inputs["query"] = inputs.pop("mp_query")
|
|
183
|
+
else:
|
|
184
|
+
raise KeyError(
|
|
185
|
+
"Missing 'query' in inputs (alias 'mp_query' also accepted)."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
result = self._action.invoke(inputs, config)
|
|
189
|
+
|
|
190
|
+
use_summary = self.summarize if summarize is None else summarize
|
|
191
|
+
return (
|
|
192
|
+
result.get("final_summary", "No summary generated.")
|
|
193
|
+
if use_summary
|
|
194
|
+
else "\n\nFinished Fetching Materials Database Information!"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
if __name__ == "__main__":
|
|
199
|
+
agent = MaterialsProjectAgent()
|
|
200
|
+
resp = agent.invoke(
|
|
201
|
+
mp_query="LiFePO4",
|
|
202
|
+
context="What is its band gap and stability, and any synthesis challenges?",
|
|
203
|
+
)
|
|
204
|
+
print(resp)
|