ursa-ai 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ursa/__init__.py +3 -0
  2. ursa/agents/__init__.py +32 -0
  3. ursa/agents/acquisition_agents.py +812 -0
  4. ursa/agents/arxiv_agent.py +429 -0
  5. ursa/agents/base.py +728 -0
  6. ursa/agents/chat_agent.py +60 -0
  7. ursa/agents/code_review_agent.py +341 -0
  8. ursa/agents/execution_agent.py +915 -0
  9. ursa/agents/hypothesizer_agent.py +614 -0
  10. ursa/agents/lammps_agent.py +465 -0
  11. ursa/agents/mp_agent.py +204 -0
  12. ursa/agents/optimization_agent.py +410 -0
  13. ursa/agents/planning_agent.py +219 -0
  14. ursa/agents/rag_agent.py +304 -0
  15. ursa/agents/recall_agent.py +54 -0
  16. ursa/agents/websearch_agent.py +196 -0
  17. ursa/cli/__init__.py +363 -0
  18. ursa/cli/hitl.py +516 -0
  19. ursa/cli/hitl_api.py +75 -0
  20. ursa/observability/metrics_charts.py +1279 -0
  21. ursa/observability/metrics_io.py +11 -0
  22. ursa/observability/metrics_session.py +750 -0
  23. ursa/observability/pricing.json +97 -0
  24. ursa/observability/pricing.py +321 -0
  25. ursa/observability/timing.py +1466 -0
  26. ursa/prompt_library/__init__.py +0 -0
  27. ursa/prompt_library/code_review_prompts.py +51 -0
  28. ursa/prompt_library/execution_prompts.py +50 -0
  29. ursa/prompt_library/hypothesizer_prompts.py +17 -0
  30. ursa/prompt_library/literature_prompts.py +11 -0
  31. ursa/prompt_library/optimization_prompts.py +131 -0
  32. ursa/prompt_library/planning_prompts.py +79 -0
  33. ursa/prompt_library/websearch_prompts.py +131 -0
  34. ursa/tools/__init__.py +0 -0
  35. ursa/tools/feasibility_checker.py +114 -0
  36. ursa/tools/feasibility_tools.py +1075 -0
  37. ursa/tools/run_command.py +27 -0
  38. ursa/tools/write_code.py +42 -0
  39. ursa/util/__init__.py +0 -0
  40. ursa/util/diff_renderer.py +128 -0
  41. ursa/util/helperFunctions.py +142 -0
  42. ursa/util/logo_generator.py +625 -0
  43. ursa/util/memory_logger.py +183 -0
  44. ursa/util/optimization_schema.py +78 -0
  45. ursa/util/parse.py +405 -0
  46. ursa_ai-0.9.1.dist-info/METADATA +304 -0
  47. ursa_ai-0.9.1.dist-info/RECORD +51 -0
  48. ursa_ai-0.9.1.dist-info/WHEEL +5 -0
  49. ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
  50. ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
  51. ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,465 @@
1
+ import json
2
+ import os
3
+ import subprocess
4
+ from typing import Any, Mapping, Optional, TypedDict
5
+
6
+ import tiktoken
7
+ from langchain.chat_models import BaseChatModel
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langgraph.graph import END, StateGraph
11
+
12
+ from .base import BaseAgent
13
+
14
+ working = True
15
+ try:
16
+ import atomman as am
17
+ import trafilatura
18
+ except Exception:
19
+ working = False
20
+
21
+
22
+ class LammpsState(TypedDict, total=False):
23
+ simulation_task: str
24
+ elements: list[str]
25
+ template: Optional[str]
26
+ chosen_potential: Optional[Any]
27
+
28
+ matches: list[Any]
29
+ idx: int
30
+ summaries: list[str]
31
+ full_texts: list[str]
32
+ summaries_combined: str
33
+
34
+ input_script: str
35
+ run_returncode: Optional[int]
36
+ run_stdout: str
37
+ run_stderr: str
38
+
39
+ fix_attempts: int
40
+
41
+
42
+ class LammpsAgent(BaseAgent):
43
+ def __init__(
44
+ self,
45
+ llm: BaseChatModel,
46
+ max_potentials: int = 5,
47
+ max_fix_attempts: int = 10,
48
+ find_potential_only: bool = False,
49
+ mpi_procs: int = 8,
50
+ workspace: str = "./workspace",
51
+ lammps_cmd: str = "lmp_mpi",
52
+ mpirun_cmd: str = "mpirun",
53
+ tiktoken_model: str = "gpt-5-mini",
54
+ max_tokens: int = 200000,
55
+ **kwargs,
56
+ ):
57
+ if not working:
58
+ raise ImportError(
59
+ "LAMMPS agent requires the atomman and trafilatura dependencies. These can be installed using 'pip install ursa-ai[lammps]' or, if working from a local installation, 'pip install -e .[lammps]' ."
60
+ )
61
+ self.max_potentials = max_potentials
62
+ self.max_fix_attempts = max_fix_attempts
63
+ self.find_potential_only = find_potential_only
64
+ self.mpi_procs = mpi_procs
65
+ self.lammps_cmd = lammps_cmd
66
+ self.mpirun_cmd = mpirun_cmd
67
+ self.tiktoken_model = tiktoken_model
68
+ self.max_tokens = max_tokens
69
+
70
+ self.pair_styles = [
71
+ "eam",
72
+ "eam/alloy",
73
+ "eam/fs",
74
+ "meam",
75
+ "adp",
76
+ "kim",
77
+ "snap",
78
+ "quip",
79
+ "mlip",
80
+ "pace",
81
+ "nep",
82
+ ]
83
+
84
+ self.workspace = workspace
85
+ os.makedirs(self.workspace, exist_ok=True)
86
+
87
+ super().__init__(llm, **kwargs)
88
+
89
+ self.str_parser = StrOutputParser()
90
+
91
+ self.summ_chain = (
92
+ ChatPromptTemplate.from_template(
93
+ "Here is some data about an interatomic potential: {metadata}\n\n"
94
+ "Briefly summarize why it could be useful for this task: {simulation_task}."
95
+ )
96
+ | self.llm
97
+ | self.str_parser
98
+ )
99
+
100
+ self.choose_chain = (
101
+ ChatPromptTemplate.from_template(
102
+ "Here are the summaries of a certain number of interatomic potentials: {summaries_combined}\n\n"
103
+ "Pick one potential which would be most useful for this task: {simulation_task}.\n\n"
104
+ "Return your answer **only** as valid JSON, with no extra text or formatting.\n\n"
105
+ "Use this exact schema:\n"
106
+ "{{\n"
107
+ ' "Chosen index": <int>,\n'
108
+ ' "rationale": "<string>",\n'
109
+ ' "Potential name": "<string>"\n'
110
+ "}}\n"
111
+ )
112
+ | self.llm
113
+ | self.str_parser
114
+ )
115
+
116
+ self.author_chain = (
117
+ ChatPromptTemplate.from_template(
118
+ "Your task is to write a LAMMPS input file for this purpose: {simulation_task}.\n"
119
+ "Note that all potential files are in the './' directory.\n"
120
+ "Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
121
+ "If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
122
+ "Template provided (if any): {template}\n"
123
+ "Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
124
+ "To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
125
+ "Return your answer **only** as valid JSON, with no extra text or formatting.\n"
126
+ "Use this exact schema:\n"
127
+ "{{\n"
128
+ ' "input_script": "<string>"\n'
129
+ "}}\n"
130
+ )
131
+ | self.llm
132
+ | self.str_parser
133
+ )
134
+
135
+ self.fix_chain = (
136
+ ChatPromptTemplate.from_template(
137
+ "You are part of a larger scientific workflow whose purpose is to accomplish this task: {simulation_task}\n"
138
+ "For this purpose, this input file for LAMMPS was written: {input_script}\n"
139
+ "However, when running the simulation, an error was raised.\n"
140
+ "Here is the full stdout message that includes the error message: {err_message}\n"
141
+ "Your task is to write a new input file that resolves the error.\n"
142
+ "Note that all potential files are in the './' directory.\n"
143
+ "Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n"
144
+ "If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n"
145
+ "Template provided (if any): {template}\n"
146
+ "Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n"
147
+ "To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n"
148
+ "Return your answer **only** as valid JSON, with no extra text or formatting.\n"
149
+ "Use this exact schema:\n"
150
+ "{{\n"
151
+ ' "input_script": "<string>"\n'
152
+ "}}\n"
153
+ )
154
+ | self.llm
155
+ | self.str_parser
156
+ )
157
+
158
+ self._action = self._build_graph()
159
+
160
+ @staticmethod
161
+ def _safe_json_loads(s: str) -> dict[str, Any]:
162
+ s = s.strip()
163
+ if s.startswith("```"):
164
+ s = s.strip("`")
165
+ i = s.find("\n")
166
+ if i != -1:
167
+ s = s[i + 1 :].strip()
168
+ return json.loads(s)
169
+
170
+ def _fetch_and_trim_text(self, url: str) -> str:
171
+ downloaded = trafilatura.fetch_url(url)
172
+ if not downloaded:
173
+ return "No metadata available"
174
+ text = trafilatura.extract(
175
+ downloaded,
176
+ include_comments=False,
177
+ include_tables=True,
178
+ include_links=False,
179
+ favor_recall=True,
180
+ )
181
+ if not text:
182
+ return "No metadata available"
183
+ text = text.strip()
184
+ try:
185
+ enc = tiktoken.encoding_for_model(self.tiktoken_model)
186
+ toks = enc.encode(text)
187
+ if len(toks) > self.max_tokens:
188
+ toks = toks[: self.max_tokens]
189
+ text = enc.decode(toks)
190
+ except Exception:
191
+ pass
192
+ return text
193
+
194
+ def _entry_router(self, state: LammpsState) -> dict:
195
+ if self.find_potential_only and state.get("chosen_potential"):
196
+ raise Exception(
197
+ "You cannot set find_potential_only=True and also specify your own potential!"
198
+ )
199
+
200
+ if not state.get("chosen_potential"):
201
+ self.potential_summaries_dir = os.path.join(
202
+ self.workspace, "potential_summaries"
203
+ )
204
+ os.makedirs(self.potential_summaries_dir, exist_ok=True)
205
+ return {}
206
+
207
+ def _find_potentials(self, state: LammpsState) -> LammpsState:
208
+ db = am.library.Database(remote=True)
209
+ matches = db.get_lammps_potentials(
210
+ pair_style=self.pair_styles, elements=state["elements"]
211
+ )
212
+
213
+ return {
214
+ **state,
215
+ "matches": list(matches),
216
+ "idx": 0,
217
+ "summaries": [],
218
+ "full_texts": [],
219
+ "fix_attempts": 0,
220
+ }
221
+
222
+ def _should_summarize(self, state: LammpsState) -> str:
223
+ matches = state.get("matches", [])
224
+ i = state.get("idx", 0)
225
+ if not matches:
226
+ print("No potentials found in NIST for this task. Exiting....")
227
+ return "done_no_matches"
228
+ if i < min(self.max_potentials, len(matches)):
229
+ return "summarize_one"
230
+ return "summarize_done"
231
+
232
+ def _summarize_one(self, state: LammpsState) -> LammpsState:
233
+ i = state["idx"]
234
+ print(f"Summarizing potential #{i}")
235
+ match = state["matches"][i]
236
+ md = match.metadata()
237
+
238
+ if md.get("comments") is None:
239
+ text = "No metadata available"
240
+ summary = "No summary available"
241
+ else:
242
+ lines = md["comments"].split("\n")
243
+ url = lines[1] if len(lines) > 1 else ""
244
+ text = (
245
+ self._fetch_and_trim_text(url)
246
+ if url
247
+ else "No metadata available"
248
+ )
249
+ summary = self.summ_chain.invoke({
250
+ "metadata": text,
251
+ "simulation_task": state["simulation_task"],
252
+ })
253
+
254
+ summary_file = os.path.join(
255
+ self.potential_summaries_dir, "potential_" + str(i) + ".txt"
256
+ )
257
+ with open(summary_file, "w") as f:
258
+ f.write(summary)
259
+
260
+ return {
261
+ **state,
262
+ "idx": i + 1,
263
+ "summaries": [*state["summaries"], summary],
264
+ "full_texts": [*state["full_texts"], text],
265
+ }
266
+
267
+ def _build_summaries(self, state: LammpsState) -> LammpsState:
268
+ parts = []
269
+ for i, s in enumerate(state["summaries"]):
270
+ rec = state["matches"][i]
271
+ parts.append(f"\nSummary of potential #{i}: {rec.id}\n{s}\n")
272
+ return {**state, "summaries_combined": "".join(parts)}
273
+
274
+ def _choose(self, state: LammpsState) -> LammpsState:
275
+ print("Choosing one potential for this task...")
276
+ choice = self.choose_chain.invoke({
277
+ "summaries_combined": state["summaries_combined"],
278
+ "simulation_task": state["simulation_task"],
279
+ })
280
+ choice_dict = self._safe_json_loads(choice)
281
+ chosen_index = int(choice_dict["Chosen index"])
282
+
283
+ print(f"Chosen potential #{chosen_index}")
284
+ print("Rationale for choosing this potential:")
285
+ print(choice_dict["rationale"])
286
+
287
+ chosen_potential = state["matches"][chosen_index]
288
+
289
+ out_file = os.path.join(self.potential_summaries_dir, "Rationale.txt")
290
+ with open(out_file, "w") as f:
291
+ f.write(f"Chosen potential #{chosen_index}")
292
+ f.write("\n")
293
+ f.write("Rationale for choosing this potential:")
294
+ f.write("\n")
295
+ f.write(choice_dict["rationale"])
296
+
297
+ return {**state, "chosen_potential": chosen_potential}
298
+
299
+ def _route_after_summarization(self, state: LammpsState) -> str:
300
+ if self.find_potential_only:
301
+ return "Exit"
302
+ return "continue_author"
303
+
304
+ def _author(self, state: LammpsState) -> LammpsState:
305
+ print("First attempt at writing LAMMPS input file....")
306
+ state["chosen_potential"].download_files(self.workspace)
307
+ pair_info = state["chosen_potential"].pair_info()
308
+ authored_json = self.author_chain.invoke({
309
+ "simulation_task": state["simulation_task"],
310
+ "pair_info": pair_info,
311
+ "template": state["template"],
312
+ })
313
+ script_dict = self._safe_json_loads(authored_json)
314
+ input_script = script_dict["input_script"]
315
+ with open(os.path.join(self.workspace, "in.lammps"), "w") as f:
316
+ f.write(input_script)
317
+ return {**state, "input_script": input_script}
318
+
319
+ def _run_lammps(self, state: LammpsState) -> LammpsState:
320
+ print("Running LAMMPS....")
321
+ result = subprocess.run(
322
+ [
323
+ self.mpirun_cmd,
324
+ "-np",
325
+ str(self.mpi_procs),
326
+ self.lammps_cmd,
327
+ "-in",
328
+ "in.lammps",
329
+ ],
330
+ cwd=self.workspace,
331
+ stdout=subprocess.PIPE,
332
+ stderr=subprocess.PIPE,
333
+ text=True,
334
+ check=False,
335
+ )
336
+ return {
337
+ **state,
338
+ "run_returncode": result.returncode,
339
+ "run_stdout": result.stdout,
340
+ "run_stderr": result.stderr,
341
+ }
342
+
343
+ def _route_run(self, state: LammpsState) -> str:
344
+ rc = state.get("run_returncode", 0)
345
+ attempts = state.get("fix_attempts", 0)
346
+ if rc == 0:
347
+ print("LAMMPS run successful! Exiting...")
348
+ return "done_success"
349
+ if attempts < self.max_fix_attempts:
350
+ print("LAMMPS run Failed. Attempting to rewrite input file...")
351
+ return "need_fix"
352
+ print("LAMMPS run Failed and maximum fix attempts reached. Exiting...")
353
+ return "done_failed"
354
+
355
+ def _fix(self, state: LammpsState) -> LammpsState:
356
+ pair_info = state["chosen_potential"].pair_info()
357
+ err_blob = state.get("run_stdout")
358
+
359
+ fixed_json = self.fix_chain.invoke({
360
+ "simulation_task": state["simulation_task"],
361
+ "input_script": state["input_script"],
362
+ "err_message": err_blob,
363
+ "pair_info": pair_info,
364
+ "template": state["template"],
365
+ })
366
+ script_dict = self._safe_json_loads(fixed_json)
367
+ new_input = script_dict["input_script"]
368
+ with open(os.path.join(self.workspace, "in.lammps"), "w") as f:
369
+ f.write(new_input)
370
+ return {
371
+ **state,
372
+ "input_script": new_input,
373
+ "fix_attempts": state.get("fix_attempts", 0) + 1,
374
+ }
375
+
376
+ def _build_graph(self):
377
+ g = StateGraph(LammpsState)
378
+
379
+ self.add_node(g, self._entry_router)
380
+ self.add_node(g, self._find_potentials)
381
+ self.add_node(g, self._summarize_one)
382
+ self.add_node(g, self._build_summaries)
383
+ self.add_node(g, self._choose)
384
+ self.add_node(g, self._author)
385
+ self.add_node(g, self._run_lammps)
386
+ self.add_node(g, self._fix)
387
+
388
+ g.set_entry_point("_entry_router")
389
+
390
+ g.add_conditional_edges(
391
+ "_entry_router",
392
+ lambda state: "user_choice"
393
+ if state.get("chosen_potential")
394
+ else "agent_choice",
395
+ {
396
+ "user_choice": "_author",
397
+ "agent_choice": "_find_potentials",
398
+ },
399
+ )
400
+
401
+ g.add_conditional_edges(
402
+ "_find_potentials",
403
+ self._should_summarize,
404
+ {
405
+ "summarize_one": "_summarize_one",
406
+ "summarize_done": "_build_summaries",
407
+ "done_no_matches": END,
408
+ },
409
+ )
410
+
411
+ g.add_conditional_edges(
412
+ "_summarize_one",
413
+ self._should_summarize,
414
+ {
415
+ "summarize_one": "_summarize_one",
416
+ "summarize_done": "_build_summaries",
417
+ },
418
+ )
419
+
420
+ g.add_edge("_build_summaries", "_choose")
421
+
422
+ g.add_conditional_edges(
423
+ "_choose",
424
+ self._route_after_summarization,
425
+ {
426
+ "continue_author": "_author",
427
+ "Exit": END,
428
+ },
429
+ )
430
+
431
+ g.add_edge("_author", "_run_lammps")
432
+
433
+ g.add_conditional_edges(
434
+ "_run_lammps",
435
+ self._route_run,
436
+ {
437
+ "need_fix": "_fix",
438
+ "done_success": END,
439
+ "done_failed": END,
440
+ },
441
+ )
442
+ g.add_edge("_fix", "_run_lammps")
443
+ return g.compile(checkpointer=self.checkpointer)
444
+
445
+ def _invoke(
446
+ self,
447
+ inputs: Mapping[str, Any],
448
+ *,
449
+ summarize: bool | None = None,
450
+ recursion_limit: int = 999_999,
451
+ **_,
452
+ ) -> str:
453
+ config = self.build_config(
454
+ recursion_limit=recursion_limit, tags=["graph"]
455
+ )
456
+
457
+ if "simulation_task" not in inputs or "elements" not in inputs:
458
+ raise KeyError(
459
+ "'simulation_task' and 'elements' are required arguments"
460
+ )
461
+
462
+ if "template" not in inputs:
463
+ inputs = {**inputs, "template": "No template provided."}
464
+
465
+ return self._action.invoke(inputs, config)
@@ -0,0 +1,204 @@
1
+ import json
2
+ import os
3
+ import re
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from typing import Any, Mapping, TypedDict
6
+
7
+ from langchain.chat_models import BaseChatModel
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langgraph.graph import StateGraph
11
+ from mp_api.client import MPRester
12
+ from tqdm import tqdm
13
+
14
+ from .base import BaseAgent
15
+
16
+
17
+ class PaperMetadata(TypedDict):
18
+ arxiv_id: str
19
+ full_text: str
20
+
21
+
22
+ class PaperState(TypedDict, total=False):
23
+ query: str
24
+ context: str
25
+ papers: list[PaperMetadata]
26
+ summaries: list[str]
27
+ final_summary: str
28
+
29
+
30
+ def remove_surrogates(text: str) -> str:
31
+ return re.sub(r"[\ud800-\udfff]", "", text)
32
+
33
+
34
+ class MaterialsProjectAgent(BaseAgent):
35
+ def __init__(
36
+ self,
37
+ llm: BaseChatModel,
38
+ summarize: bool = True,
39
+ max_results: int = 3,
40
+ database_path: str = "mp_database",
41
+ summaries_path: str = "mp_summaries",
42
+ **kwargs,
43
+ ):
44
+ super().__init__(llm, **kwargs)
45
+ self.summarize = summarize
46
+ self.max_results = max_results
47
+ self.database_path = database_path
48
+ self.summaries_path = summaries_path
49
+
50
+ os.makedirs(self.database_path, exist_ok=True)
51
+ os.makedirs(self.summaries_path, exist_ok=True)
52
+
53
+ self._action = self._build_graph()
54
+
55
+ def _fetch_node(self, state: dict) -> dict:
56
+ f = state["query"]
57
+ els = f["elements"] # e.g. ["Ga","In"]
58
+ bg = (f["band_gap_min"], f["band_gap_max"])
59
+ e_above_hull = (0, 0) # only on-hull (stable)
60
+ mats = []
61
+ with MPRester() as mpr:
62
+ # get ALL matching materials…
63
+ all_results = mpr.materials.summary.search(
64
+ elements=els,
65
+ band_gap=bg,
66
+ energy_above_hull=e_above_hull,
67
+ is_stable=True, # equivalent filter
68
+ )
69
+ # …then take only the first `max_results`
70
+ for doc in all_results[: self.max_results]:
71
+ mid = doc.material_id
72
+ data = doc.dict()
73
+ # cache to disk
74
+ path = os.path.join(self.database_path, f"{mid}.json")
75
+ if not os.path.exists(path):
76
+ with open(path, "w") as f:
77
+ json.dump(data, f, indent=2)
78
+ mats.append({"material_id": mid, "metadata": data})
79
+
80
+ return {**state, "materials": mats}
81
+
82
+ def _summarize_node(self, state: dict) -> dict:
83
+ """Summarize each material via LLM over its metadata."""
84
+ # prompt template
85
+ prompt = ChatPromptTemplate.from_template("""
86
+ You are a materials-science assistant. Given the following metadata about a material, produce a concise summary focusing on its key properties:
87
+
88
+ {metadata}
89
+ """)
90
+ chain = prompt | self.llm | StrOutputParser()
91
+
92
+ summaries = [None] * len(state["materials"])
93
+
94
+ def process(i, mat):
95
+ mid = mat["material_id"]
96
+ meta = mat["metadata"]
97
+ # flatten metadata to text
98
+ text = "\n".join(f"{k}: {v}" for k, v in meta.items())
99
+ # build or load summary
100
+ summary_file = os.path.join(
101
+ self.summaries_path, f"{mid}_summary.txt"
102
+ )
103
+ if os.path.exists(summary_file):
104
+ with open(summary_file) as f:
105
+ return i, f.read()
106
+ # optional: vectorize & retrieve, but here we just summarize full text
107
+ result = chain.invoke({"metadata": text})
108
+ with open(summary_file, "w") as f:
109
+ f.write(result)
110
+ return i, result
111
+
112
+ with ThreadPoolExecutor(
113
+ max_workers=min(8, len(state["materials"]))
114
+ ) as exe:
115
+ futures = [
116
+ exe.submit(process, i, m)
117
+ for i, m in enumerate(state["materials"])
118
+ ]
119
+ for future in tqdm(futures, desc="Summarizing materials"):
120
+ i, summ = future.result()
121
+ summaries[i] = summ
122
+
123
+ return {**state, "summaries": summaries}
124
+
125
+ def _aggregate_node(self, state: dict) -> dict:
126
+ """Combine all summaries into a single, coherent answer."""
127
+ combined = "\n\n----\n\n".join(
128
+ f"[{i + 1}] {m['material_id']}\n\n{summary}"
129
+ for i, (m, summary) in enumerate(
130
+ zip(state["materials"], state["summaries"])
131
+ )
132
+ )
133
+
134
+ prompt = ChatPromptTemplate.from_template("""
135
+ You are a materials informatics assistant. Below are brief summaries of several materials:
136
+
137
+ {summaries}
138
+
139
+ Answer the user’s question in context:
140
+
141
+ {context}
142
+ """)
143
+ chain = prompt | self.llm | StrOutputParser()
144
+ final = chain.invoke({
145
+ "summaries": combined,
146
+ "context": state["context"],
147
+ })
148
+ return {**state, "final_summary": final}
149
+
150
+ def _build_graph(self):
151
+ graph = StateGraph(dict) # using plain dict for state
152
+ self.add_node(graph, self._fetch_node)
153
+ if self.summarize:
154
+ self.add_node(graph, self._summarize_node)
155
+ self.add_node(graph, self._aggregate_node)
156
+
157
+ graph.set_entry_point("_fetch_node")
158
+ graph.add_edge("_fetch_node", "_summarize_node")
159
+ graph.add_edge("_summarize_node", "_aggregate_node")
160
+ graph.set_finish_point("_aggregate_node")
161
+ else:
162
+ graph.set_entry_point("_fetch_node")
163
+ graph.set_finish_point("_fetch_node")
164
+ return graph.compile(checkpointer=self.checkpointer)
165
+
166
+ def _invoke(
167
+ self,
168
+ inputs: Mapping[str, Any],
169
+ *,
170
+ summarize: bool | None = None,
171
+ recursion_limit: int = 1000,
172
+ **_,
173
+ ) -> str:
174
+ config = self.build_config(
175
+ recursion_limit=recursion_limit, tags=["graph"]
176
+ )
177
+
178
+ if "query" not in inputs:
179
+ if "mp_query" in inputs:
180
+ # make a shallow copy and rename the key
181
+ inputs = dict(inputs)
182
+ inputs["query"] = inputs.pop("mp_query")
183
+ else:
184
+ raise KeyError(
185
+ "Missing 'query' in inputs (alias 'mp_query' also accepted)."
186
+ )
187
+
188
+ result = self._action.invoke(inputs, config)
189
+
190
+ use_summary = self.summarize if summarize is None else summarize
191
+ return (
192
+ result.get("final_summary", "No summary generated.")
193
+ if use_summary
194
+ else "\n\nFinished Fetching Materials Database Information!"
195
+ )
196
+
197
+
198
+ if __name__ == "__main__":
199
+ agent = MaterialsProjectAgent()
200
+ resp = agent.invoke(
201
+ mp_query="LiFePO4",
202
+ context="What is its band gap and stability, and any synthesis challenges?",
203
+ )
204
+ print(resp)