ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pprint
|
|
3
|
+
import subprocess
|
|
4
|
+
from typing import Annotated, Any, Literal, Mapping, TypedDict
|
|
5
|
+
|
|
6
|
+
from langchain.chat_models import BaseChatModel
|
|
7
|
+
from langchain_community.tools import DuckDuckGoSearchResults
|
|
8
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
9
|
+
from langchain_core.tools import tool
|
|
10
|
+
from langchain_openai import ChatOpenAI
|
|
11
|
+
from langgraph.graph import END, START, StateGraph
|
|
12
|
+
from langgraph.prebuilt import InjectedState
|
|
13
|
+
|
|
14
|
+
from ..prompt_library.optimization_prompts import (
|
|
15
|
+
code_generator_prompt,
|
|
16
|
+
discretizer_prompt,
|
|
17
|
+
explainer_prompt,
|
|
18
|
+
extractor_prompt,
|
|
19
|
+
feasibility_prompt,
|
|
20
|
+
math_formulator_prompt,
|
|
21
|
+
solver_selector_prompt,
|
|
22
|
+
verifier_prompt,
|
|
23
|
+
)
|
|
24
|
+
from ..tools.feasibility_tools import feasibility_check_auto as fca
|
|
25
|
+
from ..util.helperFunctions import extract_tool_calls, run_tool_calls
|
|
26
|
+
from ..util.optimization_schema import ProblemSpec, SolverSpec
|
|
27
|
+
from .base import BaseAgent
|
|
28
|
+
|
|
29
|
+
# --- ANSI color codes ---
|
|
30
|
+
GREEN = "\033[92m"
|
|
31
|
+
BLUE = "\033[94m"
|
|
32
|
+
RED = "\033[91m"
|
|
33
|
+
RESET = "\033[0m"
|
|
34
|
+
BOLD = "\033[1m"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OptimizerState(TypedDict):
|
|
38
|
+
user_input: str
|
|
39
|
+
problem: str
|
|
40
|
+
problem_spec: ProblemSpec
|
|
41
|
+
solver: SolverSpec
|
|
42
|
+
code: str
|
|
43
|
+
problem_diagnostic: list[dict]
|
|
44
|
+
summary: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OptimizationAgent(BaseAgent):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
llm: BaseChatModel,
|
|
51
|
+
*args,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
super().__init__(llm, *args, **kwargs)
|
|
55
|
+
self.extractor_prompt = extractor_prompt
|
|
56
|
+
self.explainer_prompt = explainer_prompt
|
|
57
|
+
self.verifier_prompt = verifier_prompt
|
|
58
|
+
self.code_generator_prompt = code_generator_prompt
|
|
59
|
+
self.solver_selector_prompt = solver_selector_prompt
|
|
60
|
+
self.math_formulator_prompt = math_formulator_prompt
|
|
61
|
+
self.discretizer_prompt = discretizer_prompt
|
|
62
|
+
self.feasibility_prompt = feasibility_prompt
|
|
63
|
+
self.tools = [fca] # [run_cmd, write_code, search_tool, fca]
|
|
64
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
65
|
+
self.tool_maps = {
|
|
66
|
+
(getattr(t, "name", None) or getattr(t, "__name__", None)): t
|
|
67
|
+
for i, t in enumerate(self.tools)
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
self._action = self._build_graph()
|
|
71
|
+
|
|
72
|
+
# Define the function that calls the model
|
|
73
|
+
def extractor(self, state: OptimizerState) -> OptimizerState:
|
|
74
|
+
new_state = state.copy()
|
|
75
|
+
new_state["problem"] = self.llm.invoke([
|
|
76
|
+
SystemMessage(content=self.extractor_prompt),
|
|
77
|
+
HumanMessage(content=new_state["user_input"]),
|
|
78
|
+
]).content
|
|
79
|
+
|
|
80
|
+
new_state["problem_diagnostic"] = []
|
|
81
|
+
|
|
82
|
+
print("Extractor:\n")
|
|
83
|
+
pprint.pprint(new_state["problem"])
|
|
84
|
+
return new_state
|
|
85
|
+
|
|
86
|
+
def formulator(self, state: OptimizerState) -> OptimizerState:
|
|
87
|
+
new_state = state.copy()
|
|
88
|
+
|
|
89
|
+
llm_out = self.llm.with_structured_output(
|
|
90
|
+
ProblemSpec, include_raw=True
|
|
91
|
+
).invoke([
|
|
92
|
+
SystemMessage(content=self.math_formulator_prompt),
|
|
93
|
+
HumanMessage(content=state["problem"]),
|
|
94
|
+
])
|
|
95
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
96
|
+
new_state["problem_diagnostic"].extend(
|
|
97
|
+
extract_tool_calls(llm_out["raw"])
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
print("Formulator:\n")
|
|
101
|
+
pprint.pprint(new_state["problem_spec"])
|
|
102
|
+
return new_state
|
|
103
|
+
|
|
104
|
+
def discretizer(self, state: OptimizerState) -> OptimizerState:
|
|
105
|
+
new_state = state.copy()
|
|
106
|
+
|
|
107
|
+
llm_out = self.llm.with_structured_output(
|
|
108
|
+
ProblemSpec, include_raw=True
|
|
109
|
+
).invoke([
|
|
110
|
+
SystemMessage(content=self.discretizer_prompt),
|
|
111
|
+
HumanMessage(content=state["problem"]),
|
|
112
|
+
])
|
|
113
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
114
|
+
new_state["problem_diagnostic"].extend(
|
|
115
|
+
extract_tool_calls(llm_out["raw"])
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
print("Discretizer:\n")
|
|
119
|
+
pprint.pprint(new_state["problem_spec"])
|
|
120
|
+
|
|
121
|
+
return new_state
|
|
122
|
+
|
|
123
|
+
def tester(self, state: OptimizerState) -> OptimizerState:
|
|
124
|
+
new_state = state.copy()
|
|
125
|
+
|
|
126
|
+
llm_out = self.llm.bind(tool_choice="required").invoke([
|
|
127
|
+
SystemMessage(content=self.feasibility_prompt),
|
|
128
|
+
HumanMessage(content=str(state["code"])),
|
|
129
|
+
])
|
|
130
|
+
|
|
131
|
+
tool_log = run_tool_calls(llm_out, self.tool_maps)
|
|
132
|
+
new_state["problem_diagnostic"].extend(tool_log)
|
|
133
|
+
|
|
134
|
+
print("Feasibility Tester:\n")
|
|
135
|
+
for msg in new_state["problem_diagnostic"]:
|
|
136
|
+
msg.pretty_print()
|
|
137
|
+
return new_state
|
|
138
|
+
|
|
139
|
+
def selector(self, state: OptimizerState) -> OptimizerState:
|
|
140
|
+
new_state = state.copy()
|
|
141
|
+
|
|
142
|
+
llm_out = self.llm.with_structured_output(
|
|
143
|
+
SolverSpec, include_raw=True
|
|
144
|
+
).invoke([
|
|
145
|
+
SystemMessage(content=self.solver_selector_prompt),
|
|
146
|
+
HumanMessage(content=str(state["problem_spec"])),
|
|
147
|
+
])
|
|
148
|
+
new_state["solver"] = llm_out["parsed"]
|
|
149
|
+
|
|
150
|
+
print("Selector:\n ")
|
|
151
|
+
pprint.pprint(new_state["solver"])
|
|
152
|
+
return new_state
|
|
153
|
+
|
|
154
|
+
def generator(self, state: OptimizerState) -> OptimizerState:
|
|
155
|
+
new_state = state.copy()
|
|
156
|
+
|
|
157
|
+
new_state["code"] = self.llm.invoke([
|
|
158
|
+
SystemMessage(content=self.code_generator_prompt),
|
|
159
|
+
HumanMessage(content=str(state["problem_spec"])),
|
|
160
|
+
]).content
|
|
161
|
+
|
|
162
|
+
print("Generator:\n")
|
|
163
|
+
pprint.pprint(new_state["code"])
|
|
164
|
+
return new_state
|
|
165
|
+
|
|
166
|
+
def verifier(self, state: OptimizerState) -> OptimizerState:
|
|
167
|
+
new_state = state.copy()
|
|
168
|
+
|
|
169
|
+
llm_out = self.llm.with_structured_output(
|
|
170
|
+
ProblemSpec, include_raw=True
|
|
171
|
+
).invoke([
|
|
172
|
+
SystemMessage(content=self.verifier_prompt),
|
|
173
|
+
HumanMessage(content=str(state["problem_spec"]) + state["code"]),
|
|
174
|
+
])
|
|
175
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
176
|
+
if hasattr(llm_out, "tool_calls"):
|
|
177
|
+
tool_log = run_tool_calls(llm_out, self.tool_maps)
|
|
178
|
+
new_state["problem_diagnostic"].extend(tool_log)
|
|
179
|
+
|
|
180
|
+
print("Verifier:\n ")
|
|
181
|
+
pprint.pprint(new_state["problem_spec"])
|
|
182
|
+
return new_state
|
|
183
|
+
|
|
184
|
+
def explainer(self, state: OptimizerState) -> OptimizerState:
|
|
185
|
+
new_state = state.copy()
|
|
186
|
+
|
|
187
|
+
new_state["summary"] = self.llm.invoke([
|
|
188
|
+
SystemMessage(content=self.explainer_prompt),
|
|
189
|
+
HumanMessage(content=state["problem"] + str(state["problem_spec"])),
|
|
190
|
+
*state["problem_diagnostic"],
|
|
191
|
+
]).content
|
|
192
|
+
|
|
193
|
+
print("Summary:\n")
|
|
194
|
+
pprint.pprint(new_state["summary"])
|
|
195
|
+
return new_state
|
|
196
|
+
|
|
197
|
+
def _build_graph(self):
|
|
198
|
+
graph = StateGraph(OptimizerState)
|
|
199
|
+
|
|
200
|
+
self.add_node(graph, self.extractor, "Problem Extractor")
|
|
201
|
+
self.add_node(graph, self.formulator, "Math Formulator")
|
|
202
|
+
self.add_node(graph, self.selector, "Solver Selector")
|
|
203
|
+
self.add_node(graph, self.generator, "Code Generator")
|
|
204
|
+
self.add_node(graph, self.verifier, "Verifier")
|
|
205
|
+
self.add_node(graph, self.explainer, "Explainer")
|
|
206
|
+
self.add_node(graph, self.tester, "Feasibility Tester")
|
|
207
|
+
self.add_node(graph, self.discretizer, "Discretizer")
|
|
208
|
+
|
|
209
|
+
graph.add_edge(START, "Problem Extractor")
|
|
210
|
+
graph.add_edge("Problem Extractor", "Math Formulator")
|
|
211
|
+
graph.add_conditional_edges(
|
|
212
|
+
"Math Formulator",
|
|
213
|
+
should_discretize,
|
|
214
|
+
{"discretize": "Discretizer", "continue": "Solver Selector"},
|
|
215
|
+
)
|
|
216
|
+
graph.add_edge("Discretizer", "Solver Selector")
|
|
217
|
+
graph.add_edge("Solver Selector", "Code Generator")
|
|
218
|
+
graph.add_edge("Code Generator", "Feasibility Tester")
|
|
219
|
+
graph.add_edge("Feasibility Tester", "Verifier")
|
|
220
|
+
graph.add_conditional_edges(
|
|
221
|
+
"Verifier",
|
|
222
|
+
should_continue,
|
|
223
|
+
{"continue": "Explainer", "error": "Problem Extractor"},
|
|
224
|
+
)
|
|
225
|
+
graph.add_edge("Explainer", END)
|
|
226
|
+
|
|
227
|
+
return graph.compile()
|
|
228
|
+
|
|
229
|
+
def _invoke(
|
|
230
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
231
|
+
):
|
|
232
|
+
config = self.build_config(
|
|
233
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
234
|
+
)
|
|
235
|
+
if "user_input" not in inputs:
|
|
236
|
+
try:
|
|
237
|
+
inputs["user_input"] = inputs["messages"][0].content
|
|
238
|
+
except KeyError:
|
|
239
|
+
raise ("'user_input' is a required argument")
|
|
240
|
+
|
|
241
|
+
return self._action.invoke(inputs, config)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
######### try:
|
|
245
|
+
######### png_bytes = compiled_graph.get_graph().draw_mermaid_png()
|
|
246
|
+
######### img = mpimg.imread(io.BytesIO(png_bytes), format='png') # decode bytes -> array
|
|
247
|
+
|
|
248
|
+
######### plt.imshow(img)
|
|
249
|
+
######### plt.axis('off')
|
|
250
|
+
######### plt.show()
|
|
251
|
+
######### except Exception as e:
|
|
252
|
+
######### # This requires some extra dependencies and is optional
|
|
253
|
+
######### print(e)
|
|
254
|
+
######### pass
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@tool
|
|
258
|
+
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
259
|
+
"""
|
|
260
|
+
Run a commandline command from using the subprocess package in python
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
query: commandline command to be run as a string given to the subprocess.run command.
|
|
264
|
+
"""
|
|
265
|
+
workspace_dir = state["workspace"]
|
|
266
|
+
print("RUNNING: ", query)
|
|
267
|
+
try:
|
|
268
|
+
process = subprocess.Popen(
|
|
269
|
+
query.split(" "),
|
|
270
|
+
stdout=subprocess.PIPE,
|
|
271
|
+
stderr=subprocess.PIPE,
|
|
272
|
+
text=True,
|
|
273
|
+
cwd=workspace_dir,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
stdout, stderr = process.communicate(timeout=60000)
|
|
277
|
+
except KeyboardInterrupt:
|
|
278
|
+
print("Keyboard Interrupt of command: ", query)
|
|
279
|
+
stdout, stderr = "", "KeyboardInterrupt:"
|
|
280
|
+
|
|
281
|
+
print("STDOUT: ", stdout)
|
|
282
|
+
print("STDERR: ", stderr)
|
|
283
|
+
|
|
284
|
+
return f"STDOUT: {stdout} and STDERR: {stderr}"
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@tool
|
|
288
|
+
def write_code(
|
|
289
|
+
code: str, filename: str, state: Annotated[dict, InjectedState]
|
|
290
|
+
) -> str:
|
|
291
|
+
"""
|
|
292
|
+
Writes python or Julia code to a file in the given workspace as requested.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
code: The code to write
|
|
296
|
+
filename: the filename with an appropriate extension for programming language (.py for python, .jl for Julia, etc.)
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Execution results
|
|
300
|
+
"""
|
|
301
|
+
workspace_dir = state["workspace"]
|
|
302
|
+
print("Writing filename ", filename)
|
|
303
|
+
try:
|
|
304
|
+
# Extract code if wrapped in markdown code blocks
|
|
305
|
+
if "```" in code:
|
|
306
|
+
code_parts = code.split("```")
|
|
307
|
+
if len(code_parts) >= 3:
|
|
308
|
+
# Extract the actual code
|
|
309
|
+
if "\n" in code_parts[1]:
|
|
310
|
+
code = "\n".join(code_parts[1].strip().split("\n")[1:])
|
|
311
|
+
else:
|
|
312
|
+
code = code_parts[2].strip()
|
|
313
|
+
|
|
314
|
+
# Write code to a file
|
|
315
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
316
|
+
|
|
317
|
+
with open(code_file, "w") as f:
|
|
318
|
+
f.write(code)
|
|
319
|
+
print(f"Written code to file: {code_file}")
|
|
320
|
+
|
|
321
|
+
return f"File {filename} written successfully."
|
|
322
|
+
|
|
323
|
+
except Exception as e:
|
|
324
|
+
print(f"Error generating code: {str(e)}")
|
|
325
|
+
# Return minimal code that prints the error
|
|
326
|
+
return f"Failed to write {filename} successfully."
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
330
|
+
# search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# A function to test if discretization is needed
|
|
334
|
+
def should_discretize(
|
|
335
|
+
state: OptimizerState,
|
|
336
|
+
) -> Literal["Discretize", "continue"]:
|
|
337
|
+
cons = state["problem_spec"]["constraints"]
|
|
338
|
+
decs = state["problem_spec"]["decision_variables"]
|
|
339
|
+
|
|
340
|
+
if any("infinite-dimensional" in t["tags"] for t in cons) or any(
|
|
341
|
+
"infinite-dimensional" in t["type"] for t in decs
|
|
342
|
+
):
|
|
343
|
+
# print(f"Problem has infinite-dimensional constraints/decision variables. Needs to be discretized")
|
|
344
|
+
return "discretize"
|
|
345
|
+
|
|
346
|
+
return "continue"
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
# Define the function that determines whether to continue or not
|
|
350
|
+
def should_continue(state: OptimizerState) -> Literal["error", "continue"]:
|
|
351
|
+
spec = state["problem_spec"]
|
|
352
|
+
try:
|
|
353
|
+
status = spec["status"].lower()
|
|
354
|
+
except KeyError:
|
|
355
|
+
status = spec["spec"]["status"].lower()
|
|
356
|
+
if "VERIFIED".lower() in status:
|
|
357
|
+
return "continue"
|
|
358
|
+
# Otherwise if there is, we continue
|
|
359
|
+
else:
|
|
360
|
+
return "error"
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def main():
|
|
364
|
+
model = ChatOpenAI(
|
|
365
|
+
model="gpt-5-mini", max_tokens=10000, timeout=None, max_retries=2
|
|
366
|
+
)
|
|
367
|
+
execution_agent = OptimizationAgent(llm=model)
|
|
368
|
+
# execution_agent = execution_agent.bind_tools(feasibility_checker)
|
|
369
|
+
problem_string = """
|
|
370
|
+
Solve the following optimal power flow problem
|
|
371
|
+
System topology and data:
|
|
372
|
+
- Three buses (nodes) labeled 1, 2 and 3.
|
|
373
|
+
- One generator at each bus; each can only inject power (no negative output).
|
|
374
|
+
- Loads of 1 p.u. at bus 1, 2 p.u. at bus 2, and 4 p.u. at bus 3.
|
|
375
|
+
- Transmission lines connecting every pair of buses, with susceptances (B):
|
|
376
|
+
- Line 1–2: B₁₂ = 10
|
|
377
|
+
- Line 1–3: B₁₃ = 20
|
|
378
|
+
- Line 2–3: B₂₃ = 30
|
|
379
|
+
|
|
380
|
+
Decision variables:
|
|
381
|
+
- Voltage angles θ₁, θ₂, θ₃ (in radians) at buses 1–3.
|
|
382
|
+
- Generator outputs Pᵍ₁, Pᵍ₂, Pᵍ₃ ≥ 0 (in per-unit).
|
|
383
|
+
|
|
384
|
+
Reference angle:
|
|
385
|
+
- To fix the overall angle‐shift ambiguity, we set θ₁ = 0 (“slack” or reference bus).
|
|
386
|
+
|
|
387
|
+
Objective:
|
|
388
|
+
- Minimize total generation cost with
|
|
389
|
+
- 𝑐1 = 1
|
|
390
|
+
- 𝑐2 = 10
|
|
391
|
+
- 𝑐3 = 100
|
|
392
|
+
|
|
393
|
+
Line‐flow limits
|
|
394
|
+
- Lines 1-2 and 1-3 are thermal‐limited to ±0.5 p.u., line 2-3 is unconstrained.
|
|
395
|
+
|
|
396
|
+
In words:
|
|
397
|
+
We choose how much each generator should produce (at non-negative cost) and the voltage angles at each bus (with bus 1 set to zero) so that supply exactly meets demand, flows on the critical lines don’t exceed their limits, and the total cost is as small as possible.
|
|
398
|
+
Use the tools at your disposal to check if your formulation is feasible.
|
|
399
|
+
"""
|
|
400
|
+
inputs = {"user_input": problem_string}
|
|
401
|
+
result = execution_agent.invoke(inputs)
|
|
402
|
+
print(result["messages"][-1].content)
|
|
403
|
+
return result
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
if __name__ == "__main__":
|
|
407
|
+
main()
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
# min 𝑃𝑔 𝑐1*𝑃1 + 𝑐2 * 𝑃2 + 𝑐3 * 𝑃3
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# from langgraph.checkpoint.memory import MemorySaver
|
|
2
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
+
from typing import Annotated, Any, Dict, Iterator, List, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
from langchain.chat_models import BaseChatModel
|
|
6
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
|
+
from langgraph.graph.message import add_messages
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
from typing_extensions import TypedDict
|
|
11
|
+
|
|
12
|
+
from ..prompt_library.planning_prompts import (
|
|
13
|
+
formalize_prompt,
|
|
14
|
+
planner_prompt,
|
|
15
|
+
reflection_prompt,
|
|
16
|
+
)
|
|
17
|
+
from ..util.parse import extract_json
|
|
18
|
+
from .base import BaseAgent
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PlanningState(TypedDict):
|
|
22
|
+
messages: Annotated[list, add_messages]
|
|
23
|
+
plan_steps: List[Dict[str, Any]] = Field(
|
|
24
|
+
default_factory=list, description="Ordered steps in the solution plan"
|
|
25
|
+
)
|
|
26
|
+
reflection_steps: Optional[int] = Field(
|
|
27
|
+
default=3, description="Number of reflection steps"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlanningAgent(BaseAgent):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
llm: BaseChatModel,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(llm, **kwargs)
|
|
38
|
+
self.planner_prompt = planner_prompt
|
|
39
|
+
self.formalize_prompt = formalize_prompt
|
|
40
|
+
self.reflection_prompt = reflection_prompt
|
|
41
|
+
self._action = self._build_graph()
|
|
42
|
+
|
|
43
|
+
def generation_node(self, state: PlanningState) -> PlanningState:
|
|
44
|
+
print("PlanningAgent: generating . . .")
|
|
45
|
+
messages = state["messages"]
|
|
46
|
+
if isinstance(messages[0], SystemMessage):
|
|
47
|
+
messages[0] = SystemMessage(content=self.planner_prompt)
|
|
48
|
+
else:
|
|
49
|
+
messages = [SystemMessage(content=self.planner_prompt)] + messages
|
|
50
|
+
return {
|
|
51
|
+
"messages": [
|
|
52
|
+
self.llm.invoke(
|
|
53
|
+
messages,
|
|
54
|
+
self.build_config(tags=["planner", "generate"]),
|
|
55
|
+
)
|
|
56
|
+
]
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def formalize_node(self, state: PlanningState) -> PlanningState:
|
|
60
|
+
print("PlanningAgent: formalizing . . .")
|
|
61
|
+
cls_map = {"ai": HumanMessage, "human": AIMessage}
|
|
62
|
+
translated = [state["messages"][0]] + [
|
|
63
|
+
cls_map[msg.type](content=msg.content)
|
|
64
|
+
for msg in state["messages"][1:]
|
|
65
|
+
]
|
|
66
|
+
translated = [SystemMessage(content=self.formalize_prompt)] + translated
|
|
67
|
+
for _ in range(10):
|
|
68
|
+
try:
|
|
69
|
+
res = self.llm.invoke(
|
|
70
|
+
translated,
|
|
71
|
+
self.build_config(tags=["planner", "formalize"]),
|
|
72
|
+
)
|
|
73
|
+
json_out = extract_json(res.content)
|
|
74
|
+
break
|
|
75
|
+
except ValueError:
|
|
76
|
+
translated.append(
|
|
77
|
+
HumanMessage(
|
|
78
|
+
content="Your response was not valid JSON. Try again."
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
return {
|
|
82
|
+
"messages": [HumanMessage(content=res.content)],
|
|
83
|
+
"plan_steps": json_out,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
def reflection_node(self, state: PlanningState) -> PlanningState:
|
|
87
|
+
print("PlanningAgent: reflecting . . .")
|
|
88
|
+
cls_map = {"ai": HumanMessage, "human": AIMessage}
|
|
89
|
+
translated = [state["messages"][0]] + [
|
|
90
|
+
cls_map[msg.type](content=msg.content)
|
|
91
|
+
for msg in state["messages"][1:]
|
|
92
|
+
]
|
|
93
|
+
translated = [SystemMessage(content=reflection_prompt)] + translated
|
|
94
|
+
res = self.llm.invoke(
|
|
95
|
+
translated,
|
|
96
|
+
self.build_config(tags=["planner", "reflect"]),
|
|
97
|
+
)
|
|
98
|
+
return {"messages": [HumanMessage(content=res.content)]}
|
|
99
|
+
|
|
100
|
+
def _build_graph(self):
|
|
101
|
+
graph = StateGraph(PlanningState)
|
|
102
|
+
self.add_node(graph, self.generation_node, "generate")
|
|
103
|
+
self.add_node(graph, self.reflection_node, "reflect")
|
|
104
|
+
self.add_node(graph, self.formalize_node, "formalize")
|
|
105
|
+
|
|
106
|
+
# Edges
|
|
107
|
+
graph.set_entry_point("generate")
|
|
108
|
+
graph.add_edge("generate", "reflect")
|
|
109
|
+
graph.set_finish_point("formalize")
|
|
110
|
+
|
|
111
|
+
# Time the router logic too
|
|
112
|
+
graph.add_conditional_edges(
|
|
113
|
+
"reflect",
|
|
114
|
+
self._wrap_cond(should_continue, "should_continue", "planner"),
|
|
115
|
+
{"generate": "generate", "formalize": "formalize"},
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# memory = MemorySaver()
|
|
119
|
+
# self.action = self.graph.compile(checkpointer=memory)
|
|
120
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
121
|
+
# self.action.get_graph().draw_mermaid_png(output_file_path="planning_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
122
|
+
|
|
123
|
+
def _invoke(
|
|
124
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
125
|
+
):
|
|
126
|
+
config = self.build_config(
|
|
127
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
128
|
+
)
|
|
129
|
+
return self._action.invoke(inputs, config)
|
|
130
|
+
|
|
131
|
+
def _stream(
|
|
132
|
+
self,
|
|
133
|
+
inputs: Mapping[str, Any],
|
|
134
|
+
*,
|
|
135
|
+
config: dict | None = None,
|
|
136
|
+
recursion_limit: int = 1000,
|
|
137
|
+
**_,
|
|
138
|
+
) -> Iterator[dict]:
|
|
139
|
+
# If you have defaults, merge them here:
|
|
140
|
+
default = self.build_config(
|
|
141
|
+
recursion_limit=recursion_limit, tags=["planner"]
|
|
142
|
+
)
|
|
143
|
+
if config:
|
|
144
|
+
merged = {**default, **config}
|
|
145
|
+
if "configurable" in config:
|
|
146
|
+
merged["configurable"] = {
|
|
147
|
+
**default.get("configurable", {}),
|
|
148
|
+
**config["configurable"],
|
|
149
|
+
}
|
|
150
|
+
else:
|
|
151
|
+
merged = default
|
|
152
|
+
|
|
153
|
+
# Delegate to the compiled graph's stream
|
|
154
|
+
yield from self._action.stream(inputs, merged)
|
|
155
|
+
|
|
156
|
+
# prevent bypass
|
|
157
|
+
@property
|
|
158
|
+
def action(self):
|
|
159
|
+
raise AttributeError(
|
|
160
|
+
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
config = {"configurable": {"thread_id": "1"}}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def should_continue(state: PlanningState):
|
|
168
|
+
reviewMaxLength = 0 # 0 = no limit, else some character limit like 300
|
|
169
|
+
|
|
170
|
+
# Latest reviewer output (if present)
|
|
171
|
+
last_content = (
|
|
172
|
+
state["messages"][-1].content if state.get("messages") else ""
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
max_reflections = state.get("reflection_steps", 3)
|
|
176
|
+
|
|
177
|
+
# Hit the reflection cap?
|
|
178
|
+
if len(state["messages"]) > (max_reflections + 3):
|
|
179
|
+
print(
|
|
180
|
+
f"PlanningAgent: reached reflection limit ({max_reflections}); formalizing . . ."
|
|
181
|
+
)
|
|
182
|
+
return "formalize"
|
|
183
|
+
|
|
184
|
+
# Approved?
|
|
185
|
+
if "[APPROVED]" in last_content:
|
|
186
|
+
print("PlanningAgent: [APPROVED] — formalizing . . .")
|
|
187
|
+
return "formalize"
|
|
188
|
+
|
|
189
|
+
# Not approved — print a concise reason before another cycle
|
|
190
|
+
reason = " ".join(last_content.strip().split()) # collapse whitespace
|
|
191
|
+
if reviewMaxLength > 0 and len(reason) > reviewMaxLength:
|
|
192
|
+
reason = reason[:reviewMaxLength] + ". . ."
|
|
193
|
+
print(
|
|
194
|
+
f"PlanningAgent: not approved — iterating again. Reviewer notes: {reason}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return "generate"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def main():
|
|
201
|
+
planning_agent = PlanningAgent()
|
|
202
|
+
|
|
203
|
+
for event in planning_agent.stream(
|
|
204
|
+
{
|
|
205
|
+
"messages": [
|
|
206
|
+
HumanMessage(
|
|
207
|
+
content="Find a city with at least 10 vowels in its name."
|
|
208
|
+
)
|
|
209
|
+
],
|
|
210
|
+
},
|
|
211
|
+
):
|
|
212
|
+
print("-" * 30)
|
|
213
|
+
print(event.keys())
|
|
214
|
+
print(event[list(event.keys())[0]]["messages"][-1].content)
|
|
215
|
+
print("-" * 30)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
if __name__ == "__main__":
|
|
219
|
+
main()
|