ursa-ai 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -3,7 +3,7 @@ import os
3
3
  # from langchain_core.runnables.graph import MermaidDrawMethod
4
4
  import subprocess
5
5
  from pathlib import Path
6
- from typing import Annotated, Any, Literal, Optional
6
+ from typing import Annotated, Any, Literal, Mapping, Optional
7
7
 
8
8
  import randomname
9
9
  from langchain_community.tools import (
@@ -17,11 +17,11 @@ from langchain_core.messages import (
17
17
  ToolMessage,
18
18
  )
19
19
  from langchain_core.tools import InjectedToolCallId, tool
20
- from langgraph.graph import END, START, StateGraph
20
+ from langgraph.graph import StateGraph
21
21
  from langgraph.graph.message import add_messages
22
22
  from langgraph.prebuilt import InjectedState, ToolNode
23
23
  from langgraph.types import Command
24
- from litellm import ContentPolicyViolationError
24
+ from litellm.exceptions import ContentPolicyViolationError
25
25
 
26
26
  # Rich
27
27
  from rich import get_console
@@ -29,7 +29,11 @@ from rich.panel import Panel
29
29
  from rich.syntax import Syntax
30
30
  from typing_extensions import TypedDict
31
31
 
32
- from ..prompt_library.execution_prompts import executor_prompt, summarize_prompt
32
+ from ..prompt_library.execution_prompts import (
33
+ executor_prompt,
34
+ safety_prompt,
35
+ summarize_prompt,
36
+ )
33
37
  from ..util.diff_renderer import DiffRenderer
34
38
  from ..util.memory_logger import AgentMemory
35
39
  from .base import BaseAgent
@@ -62,6 +66,7 @@ class ExecutionAgent(BaseAgent):
62
66
  ):
63
67
  super().__init__(llm, **kwargs)
64
68
  self.agent_memory = agent_memory
69
+ self.safety_prompt = safety_prompt
65
70
  self.executor_prompt = executor_prompt
66
71
  self.summarize_prompt = summarize_prompt
67
72
  self.tools = [run_cmd, write_code, edit_code, search_tool]
@@ -69,7 +74,7 @@ class ExecutionAgent(BaseAgent):
69
74
  self.llm = self.llm.bind_tools(self.tools)
70
75
  self.log_state = log_state
71
76
 
72
- self._initialize_agent()
77
+ self._action = self._build_graph()
73
78
 
74
79
  # Define the function that calls the model
75
80
  def query_executor(self, state: ExecutionState) -> ExecutionState:
@@ -115,8 +120,7 @@ class ExecutionAgent(BaseAgent):
115
120
  ] + state["messages"]
116
121
  try:
117
122
  response = self.llm.invoke(
118
- new_state["messages"],
119
- {"configurable": {"thread_id": self.thread_id}},
123
+ new_state["messages"], self.build_config(tags=["agent"])
120
124
  )
121
125
  except ContentPolicyViolationError as e:
122
126
  print("Error: ", e, " ", new_state["messages"][-1].content)
@@ -129,7 +133,7 @@ class ExecutionAgent(BaseAgent):
129
133
  messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
130
134
  try:
131
135
  response = self.llm.invoke(
132
- messages, {"configurable": {"thread_id": self.thread_id}}
136
+ messages, self.build_config(tags=["summarize"])
133
137
  )
134
138
  except ContentPolicyViolationError as e:
135
139
  print("Error: ", e, " ", messages[-1].content)
@@ -181,11 +185,8 @@ class ExecutionAgent(BaseAgent):
181
185
  if call_name == "run_cmd":
182
186
  query = tool_call["args"]["query"]
183
187
  safety_check = self.llm.invoke(
184
- (
185
- "Assume commands to run/install python and Julia files are safe because "
186
- "the files are from a trusted source. "
187
- f"Explain why, followed by an answer [YES] or [NO]. Is this command safe to run: {query}"
188
- )
188
+ self.safety_prompt + query,
189
+ self.build_config(tags=["safety_check"]),
189
190
  )
190
191
 
191
192
  if "[NO]" in safety_check.content:
@@ -222,53 +223,90 @@ class ExecutionAgent(BaseAgent):
222
223
 
223
224
  return new_state
224
225
 
225
- def _initialize_agent(self):
226
- self.graph = StateGraph(ExecutionState)
226
+ def _build_graph(self):
227
+ graph = StateGraph(ExecutionState)
227
228
 
228
- self.graph.add_node("agent", self.query_executor)
229
- self.graph.add_node("action", self.tool_node)
230
- self.graph.add_node("summarize", self.summarize)
231
- self.graph.add_node("safety_check", self.safety_check)
229
+ self.add_node(graph, self.query_executor, "agent")
230
+ self.add_node(graph, self.tool_node, "action")
231
+ self.add_node(graph, self.summarize, "summarize")
232
+ self.add_node(graph, self.safety_check, "safety_check")
232
233
 
233
234
  # Set the entrypoint as `agent`
234
235
  # This means that this node is the first one called
235
- self.graph.add_edge(START, "agent")
236
+ graph.set_entry_point("agent")
236
237
 
237
- self.graph.add_conditional_edges(
238
+ graph.add_conditional_edges(
238
239
  "agent",
239
- should_continue,
240
- {
241
- "continue": "safety_check",
242
- "summarize": "summarize",
243
- },
240
+ self._wrap_cond(should_continue, "should_continue", "execution"),
241
+ {"continue": "safety_check", "summarize": "summarize"},
244
242
  )
245
243
 
246
- self.graph.add_conditional_edges(
244
+ graph.add_conditional_edges(
247
245
  "safety_check",
248
- command_safe,
249
- {
250
- "safe": "action",
251
- "unsafe": "agent",
252
- },
246
+ self._wrap_cond(command_safe, "command_safe", "execution"),
247
+ {"safe": "action", "unsafe": "agent"},
253
248
  )
254
249
 
255
- self.graph.add_edge("action", "agent")
256
- self.graph.add_edge("summarize", END)
250
+ graph.add_edge("action", "agent")
251
+ graph.set_finish_point("summarize")
257
252
 
258
- self.action = self.graph.compile(checkpointer=self.checkpointer)
253
+ return graph.compile(checkpointer=self.checkpointer)
259
254
  # self.action.get_graph().draw_mermaid_png(output_file_path="execution_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
260
255
 
261
- def run(self, prompt, recursion_limit=1000):
262
- inputs = {"messages": [HumanMessage(content=prompt)]}
263
- return self.action.invoke(
264
- inputs,
265
- {
266
- "recursion_limit": recursion_limit,
267
- "configurable": {"thread_id": self.thread_id},
268
- },
256
+ def _invoke(
257
+ self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
258
+ ):
259
+ config = self.build_config(
260
+ recursion_limit=recursion_limit, tags=["graph"]
261
+ )
262
+ return self._action.invoke(inputs, config)
263
+
264
+ # this is trying to stop people bypassing invoke
265
+ @property
266
+ def action(self):
267
+ raise AttributeError(
268
+ "Use .stream(...) or .invoke(...); direct .action access is unsupported."
269
269
  )
270
270
 
271
271
 
272
+ def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
273
+ if text is None:
274
+ return "", False
275
+ if max_chars <= 0:
276
+ return "", len(text) > 0
277
+ if len(text) <= max_chars:
278
+ return text, False
279
+ head = max_chars // 2
280
+ tail = max_chars - head
281
+ return (
282
+ text[:head]
283
+ + f"\n... [snipped {len(text) - max_chars} chars] ...\n"
284
+ + text[-tail:],
285
+ True,
286
+ )
287
+
288
+
289
+ def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
290
+ label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
291
+ budget = max(0, total_budget - label_overhead)
292
+
293
+ if len(stdout) + len(stderr) <= budget:
294
+ return stdout, stderr
295
+
296
+ total_len = max(1, len(stdout) + len(stderr))
297
+ stdout_budget = int(budget * (len(stdout) / total_len))
298
+ stderr_budget = budget - stdout_budget
299
+
300
+ stdout_snip, _ = _snip_text(stdout, stdout_budget)
301
+ stderr_snip, _ = _snip_text(stderr, stderr_budget)
302
+ return stdout_snip, stderr_snip
303
+
304
+
305
+ # the idea here is that we just set a limit - the user could overload
306
+ # that in their env, or maybe we could pull this out of the LLM parameters
307
+ MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
308
+
309
+
272
310
  @tool
273
311
  def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
274
312
  """
@@ -293,10 +331,15 @@ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
293
331
  print("Keyboard Interrupt of command: ", query)
294
332
  stdout, stderr = "", "KeyboardInterrupt:"
295
333
 
296
- print("STDOUT: ", stdout)
297
- print("STDERR: ", stderr)
334
+ # Fit BOTH streams under a single overall cap
335
+ stdout_fit, stderr_fit = _fit_streams_to_budget(
336
+ stdout or "", stderr or "", MAX_TOOL_MSG_CHARS
337
+ )
338
+
339
+ print("STDOUT: ", stdout_fit)
340
+ print("STDERR: ", stderr_fit)
298
341
 
299
- return f"STDOUT: {stdout} and STDERR: {stderr}"
342
+ return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
300
343
 
301
344
 
302
345
  def _strip_fences(snippet: str) -> str:
@@ -499,8 +542,9 @@ def main():
499
542
  inputs = {
500
543
  "messages": [HumanMessage(content=problem_string)]
501
544
  } # , "workspace":"dummy_test"}
502
- result = execution_agent.action.invoke(
503
- inputs, {"configurable": {"thread_id": execution_agent.thread_id}}
545
+ result = execution_agent.invoke(
546
+ inputs,
547
+ config={"configurable": {"thread_id": execution_agent.thread_id}},
504
548
  )
505
549
  print(result["messages"][-1].content)
506
550
  return result
@@ -3,12 +3,12 @@ import ast
3
3
  # from langchain_community.tools import TavilySearchResults
4
4
  # from textwrap import dedent
5
5
  from datetime import datetime
6
- from typing import List, Literal, TypedDict
6
+ from typing import Any, List, Literal, Mapping, TypedDict
7
7
 
8
8
  from langchain_community.tools import DuckDuckGoSearchResults
9
9
  from langchain_core.language_models import BaseChatModel
10
10
  from langchain_core.messages import HumanMessage, SystemMessage
11
- from langgraph.graph import END, StateGraph
11
+ from langgraph.graph import StateGraph
12
12
 
13
13
  from ..prompt_library.hypothesizer_prompts import (
14
14
  competitor_prompt,
@@ -53,7 +53,7 @@ class HypothesizerAgent(BaseAgent):
53
53
  # max_results=10, search_depth="advanced", include_answer=False
54
54
  # )
55
55
 
56
- self._initialize_agent()
56
+ self._action = self._build_graph()
57
57
 
58
58
  def agent1_generate_solution(
59
59
  self, state: HypothesizerState
@@ -444,68 +444,65 @@ class HypothesizerAgent(BaseAgent):
444
444
  )
445
445
  return new_state
446
446
 
447
- def _initialize_agent(self):
447
+ def _build_graph(self):
448
448
  # Initialize the graph
449
- self.graph = StateGraph(HypothesizerState)
449
+ graph = StateGraph(HypothesizerState)
450
450
 
451
451
  # Add nodes
452
- self.graph.add_node("agent1", self.agent1_generate_solution)
453
- self.graph.add_node("agent2", self.agent2_critique)
454
- self.graph.add_node("agent3", self.agent3_competitor_perspective)
455
- self.graph.add_node("increment_iteration", self.increment_iteration)
456
- self.graph.add_node("finalize", self.generate_solution)
457
- self.graph.add_node("print_sites", self.print_visited_sites)
458
- self.graph.add_node(
459
- "summarize_as_latex", self.summarize_process_as_latex
452
+ self.add_node(graph, self.agent1_generate_solution, "agent1")
453
+ self.add_node(graph, self.agent2_critique, "agent2")
454
+ self.add_node(graph, self.agent3_competitor_perspective, "agent3")
455
+ self.add_node(graph, self.increment_iteration, "increment_iteration")
456
+ self.add_node(graph, self.generate_solution, "finalize")
457
+ self.add_node(graph, self.print_visited_sites, "print_sites")
458
+ self.add_node(
459
+ graph, self.summarize_process_as_latex, "summarize_as_latex"
460
460
  )
461
461
  # self.graph.add_node("compile_pdf", compile_summary_to_pdf)
462
462
 
463
463
  # Add simple edges for the known flow
464
- self.graph.add_edge("agent1", "agent2")
465
- self.graph.add_edge("agent2", "agent3")
466
- self.graph.add_edge("agent3", "increment_iteration")
464
+ graph.add_edge("agent1", "agent2")
465
+ graph.add_edge("agent2", "agent3")
466
+ graph.add_edge("agent3", "increment_iteration")
467
467
 
468
468
  # Then from increment_iteration, we have a conditional:
469
469
  # If we 'continue', we go back to agent1
470
470
  # If we 'finish', we jump to the finalize node
471
- self.graph.add_conditional_edges(
471
+ graph.add_conditional_edges(
472
472
  "increment_iteration",
473
473
  should_continue,
474
474
  {"continue": "agent1", "finish": "finalize"},
475
475
  )
476
476
 
477
- self.graph.add_edge("finalize", "summarize_as_latex")
478
- self.graph.add_edge("summarize_as_latex", "print_sites")
479
- self.graph.add_edge("print_sites", END)
477
+ graph.add_edge("finalize", "summarize_as_latex")
478
+ graph.add_edge("summarize_as_latex", "print_sites")
480
479
  # self.graph.add_edge("summarize_as_latex", "compile_pdf")
481
480
  # self.graph.add_edge("compile_pdf", "print_sites")
482
481
 
483
482
  # Set the entry point
484
- self.graph.set_entry_point("agent1")
483
+ graph.set_entry_point("agent1")
484
+ graph.set_finish_point("print_sites")
485
485
 
486
- self.action = self.graph.compile(checkpointer=self.checkpointer)
486
+ return graph.compile(checkpointer=self.checkpointer)
487
487
  # self.action.get_graph().draw_mermaid_png(output_file_path="hypothesizer_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
488
488
 
489
- def run(self, prompt, max_iter=3, recursion_limit=99999):
490
- # Initialize the state
491
- initial_state = HypothesizerState(
492
- question=prompt,
493
- current_iteration=0,
494
- max_iterations=max_iter,
495
- agent1_solution=[],
496
- agent2_critiques=[],
497
- agent3_perspectives=[],
498
- solution="",
489
+ def _invoke(
490
+ self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
491
+ ):
492
+ config = self.build_config(
493
+ recursion_limit=recursion_limit, tags=["graph"]
499
494
  )
500
- # Run the graph
501
- result = self.action.invoke(
502
- initial_state,
503
- {
504
- "recursion_limit": recursion_limit,
505
- "configurable": {"thread_id": self.thread_id},
506
- },
507
- )
508
- return result["solution"]
495
+ if "prompt" not in inputs:
496
+ raise KeyError("'prompt' is a required arguments")
497
+
498
+ inputs["max_iterations"] = inputs.get("max_iterations", 3)
499
+ inputs["current_iteration"] = 0
500
+ inputs["agent1_solution"] = []
501
+ inputs["agent2_critiques"] = []
502
+ inputs["agent3_perspectives"] = []
503
+ inputs["solution"] = ""
504
+
505
+ return self._action.invoke(inputs, config)
509
506
 
510
507
 
511
508
  def should_continue(state: HypothesizerState) -> Literal["continue", "finish"]:
@@ -583,7 +580,7 @@ if __name__ == "__main__":
583
580
 
584
581
  print("[DEBUG] Invoking the graph...")
585
582
  # Run the graph
586
- result = hypothesizer_agent.action.invoke(
583
+ result = hypothesizer_agent.invoke(
587
584
  initial_state,
588
585
  {
589
586
  "recursion_limit": 999999,
@@ -1,17 +1,22 @@
1
1
  import json
2
2
  import os
3
3
  import subprocess
4
- from typing import Any, Dict, List, Optional, TypedDict
4
+ from typing import Any, Dict, List, Mapping, Optional, TypedDict
5
5
 
6
- import atomman as am
7
6
  import tiktoken
8
- import trafilatura
9
7
  from langchain_core.output_parsers import StrOutputParser
10
8
  from langchain_core.prompts import ChatPromptTemplate
11
9
  from langgraph.graph import END, StateGraph
12
10
 
13
11
  from .base import BaseAgent
14
12
 
13
+ working = True
14
+ try:
15
+ import atomman as am
16
+ import trafilatura
17
+ except Exception:
18
+ working = False
19
+
15
20
 
16
21
  class LammpsState(TypedDict, total=False):
17
22
  simulation_task: str
@@ -50,6 +55,10 @@ class LammpsAgent(BaseAgent):
50
55
  max_tokens: int = 200000,
51
56
  **kwargs,
52
57
  ):
58
+ if not working:
59
+ raise ImportError(
60
+ "LAMMPS agent requires the atomman and trafilatura dependencies. These can be installed using 'pip install ursa-ai[lammps]' or, if working from a local installation, 'pip install -e .[lammps]' ."
61
+ )
53
62
  self.max_potentials = max_potentials
54
63
  self.max_fix_attempts = max_fix_attempts
55
64
  self.mpi_procs = mpi_procs
@@ -144,7 +153,7 @@ class LammpsAgent(BaseAgent):
144
153
  | self.str_parser
145
154
  )
146
155
 
147
- self.graph = self._build_graph().compile()
156
+ self._action = self._build_graph()
148
157
 
149
158
  @staticmethod
150
159
  def _safe_json_loads(s: str) -> Dict[str, Any]:
@@ -340,53 +349,66 @@ class LammpsAgent(BaseAgent):
340
349
  def _build_graph(self):
341
350
  g = StateGraph(LammpsState)
342
351
 
343
- g.add_node("find_potentials", self._find_potentials)
344
- g.add_node("summarize_one", self._summarize_one)
345
- g.add_node("build_summaries", self._build_summaries)
346
- g.add_node("choose", self._choose)
347
- g.add_node("author", self._author)
348
- g.add_node("run_lammps", self._run_lammps)
349
- g.add_node("fix", self._fix)
352
+ self.add_node(g, self._find_potentials)
353
+ self.add_node(g, self._summarize_one)
354
+ self.add_node(g, self._build_summaries)
355
+ self.add_node(g, self._choose)
356
+ self.add_node(g, self._author)
357
+ self.add_node(g, self._run_lammps)
358
+ self.add_node(g, self._fix)
350
359
 
351
- g.set_entry_point("find_potentials")
360
+ g.set_entry_point("_find_potentials")
352
361
 
353
362
  g.add_conditional_edges(
354
- "find_potentials",
363
+ "_find_potentials",
355
364
  self._should_summarize,
356
365
  {
357
- "summarize_one": "summarize_one",
358
- "summarize_done": "build_summaries",
366
+ "summarize_one": "_summarize_one",
367
+ "summarize_done": "_build_summaries",
359
368
  "done_no_matches": END,
360
369
  },
361
370
  )
362
371
 
363
372
  g.add_conditional_edges(
364
- "summarize_one",
373
+ "_summarize_one",
365
374
  self._should_summarize,
366
375
  {
367
- "summarize_one": "summarize_one",
368
- "summarize_done": "build_summaries",
376
+ "summarize_one": "_summarize_one",
377
+ "summarize_done": "_build_summaries",
369
378
  },
370
379
  )
371
380
 
372
- g.add_edge("build_summaries", "choose")
373
- g.add_edge("choose", "author")
374
- g.add_edge("author", "run_lammps")
381
+ g.add_edge("_build_summaries", "_choose")
382
+ g.add_edge("_choose", "_author")
383
+ g.add_edge("_author", "_run_lammps")
375
384
 
376
385
  g.add_conditional_edges(
377
- "run_lammps",
386
+ "_run_lammps",
378
387
  self._route_run,
379
388
  {
380
- "need_fix": "fix",
389
+ "need_fix": "_fix",
381
390
  "done_success": END,
382
391
  "done_failed": END,
383
392
  },
384
393
  )
385
- g.add_edge("fix", "run_lammps")
386
- return g
394
+ g.add_edge("_fix", "_run_lammps")
395
+ return g.compile(checkpointer=self.checkpointer)
387
396
 
388
- def run(self, simulation_task, elements):
389
- return self.graph.invoke(
390
- {"simulation_task": simulation_task, "elements": elements},
391
- {"recursion_limit": 999_999},
397
+ def _invoke(
398
+ self,
399
+ inputs: Mapping[str, Any],
400
+ *,
401
+ summarize: bool | None = None,
402
+ recursion_limit: int = 1000,
403
+ **_,
404
+ ) -> str:
405
+ config = self.build_config(
406
+ recursion_limit=recursion_limit, tags=["graph"]
392
407
  )
408
+
409
+ if "simulation_task" not in inputs or "elements" not in inputs:
410
+ raise KeyError(
411
+ "'simulation_task' and 'elements' are required arguments"
412
+ )
413
+
414
+ return self._action.invoke(inputs, config)
ursa/agents/mp_agent.py CHANGED
@@ -2,7 +2,7 @@ import json
2
2
  import os
3
3
  import re
4
4
  from concurrent.futures import ThreadPoolExecutor
5
- from typing import Dict
5
+ from typing import Any, Dict, Mapping
6
6
 
7
7
  from langchain_core.output_parsers import StrOutputParser
8
8
  from langchain_core.prompts import ChatPromptTemplate
@@ -50,7 +50,7 @@ class MaterialsProjectAgent(BaseAgent):
50
50
  os.makedirs(self.database_path, exist_ok=True)
51
51
  os.makedirs(self.summaries_path, exist_ok=True)
52
52
 
53
- self.graph = self._build_graph()
53
+ self._action = self._build_graph()
54
54
 
55
55
  def _fetch_node(self, state: Dict) -> Dict:
56
56
  f = state["query"]
@@ -148,31 +148,56 @@ You are a materials-science assistant. Given the following metadata about a mate
148
148
  return {**state, "final_summary": final}
149
149
 
150
150
  def _build_graph(self):
151
- g = StateGraph(dict) # using plain dict for state
152
- g.add_node("fetch", self._fetch_node)
151
+ graph = StateGraph(dict) # using plain dict for state
152
+ self.add_node(graph, self._fetch_node)
153
153
  if self.summarize:
154
- g.add_node("summarize", self._summarize_node)
155
- g.add_node("aggregate", self._aggregate_node)
156
- g.set_entry_point("fetch")
157
- g.add_edge("fetch", "summarize")
158
- g.add_edge("summarize", "aggregate")
159
- g.set_finish_point("aggregate")
154
+ self.add_node(graph, self._summarize_node)
155
+ self.add_node(graph, self._aggregate_node)
156
+
157
+ graph.set_entry_point("_fetch_node")
158
+ graph.add_edge("_fetch_node", "_summarize_node")
159
+ graph.add_edge("_summarize_node", "_aggregate_node")
160
+ graph.set_finish_point("_aggregate_node")
160
161
  else:
161
- g.set_entry_point("fetch")
162
- g.set_finish_point("fetch")
163
- return g.compile()
162
+ graph.set_entry_point("_fetch_node")
163
+ graph.set_finish_point("_fetch_node")
164
+ return graph.compile(checkpointer=self.checkpointer)
164
165
 
165
- def run(self, mp_query: str, context: str) -> str:
166
- state = {"query": mp_query, "context": context}
167
- out = self.graph.invoke(state)
168
- if self.summarize:
169
- return out.get("final_summary", "")
170
- return json.dumps(out.get("materials", []), indent=2)
166
+ def _invoke(
167
+ self,
168
+ inputs: Mapping[str, Any],
169
+ *,
170
+ summarize: bool | None = None,
171
+ recursion_limit: int = 1000,
172
+ **_,
173
+ ) -> str:
174
+ config = self.build_config(
175
+ recursion_limit=recursion_limit, tags=["graph"]
176
+ )
177
+
178
+ if "query" not in inputs:
179
+ if "mp_query" in inputs:
180
+ # make a shallow copy and rename the key
181
+ inputs = dict(inputs)
182
+ inputs["query"] = inputs.pop("mp_query")
183
+ else:
184
+ raise KeyError(
185
+ "Missing 'query' in inputs (alias 'mp_query' also accepted)."
186
+ )
187
+
188
+ result = self._action.invoke(inputs, config)
189
+
190
+ use_summary = self.summarize if summarize is None else summarize
191
+ return (
192
+ result.get("final_summary", "No summary generated.")
193
+ if use_summary
194
+ else "\n\nFinished Fetching Materials Database Information!"
195
+ )
171
196
 
172
197
 
173
198
  if __name__ == "__main__":
174
199
  agent = MaterialsProjectAgent()
175
- resp = agent.run(
200
+ resp = agent.invoke(
176
201
  mp_query="LiFePO4",
177
202
  context="What is its band gap and stability, and any synthesis challenges?",
178
203
  )