ursa-ai 0.5.0__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/arxiv_agent.py +77 -47
- ursa/agents/base.py +369 -2
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +403 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +75 -44
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0rc1.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0rc1.dist-info/RECORD +39 -0
- ursa_ai-0.6.0rc1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.5.0.dist-info/RECORD +0 -28
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0rc1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0rc1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pprint
|
|
3
|
+
import subprocess
|
|
4
|
+
from typing import Annotated, Any, Dict, List, Literal, Mapping
|
|
5
|
+
|
|
6
|
+
from langchain_community.tools import DuckDuckGoSearchResults
|
|
7
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
8
|
+
from langchain_core.tools import tool
|
|
9
|
+
from langchain_openai import ChatOpenAI
|
|
10
|
+
from langgraph.graph import END, START, StateGraph
|
|
11
|
+
from langgraph.prebuilt import InjectedState
|
|
12
|
+
from typing_extensions import TypedDict
|
|
13
|
+
|
|
14
|
+
from ..prompt_library.optimization_prompts import (
|
|
15
|
+
code_generator_prompt,
|
|
16
|
+
discretizer_prompt,
|
|
17
|
+
explainer_prompt,
|
|
18
|
+
extractor_prompt,
|
|
19
|
+
feasibility_prompt,
|
|
20
|
+
math_formulator_prompt,
|
|
21
|
+
solver_selector_prompt,
|
|
22
|
+
verifier_prompt,
|
|
23
|
+
)
|
|
24
|
+
from ..tools.feasibility_tools import feasibility_check_auto as fca
|
|
25
|
+
from ..util.helperFunctions import extract_tool_calls, run_tool_calls
|
|
26
|
+
from ..util.optimization_schema import ProblemSpec, SolverSpec
|
|
27
|
+
from .base import BaseAgent
|
|
28
|
+
|
|
29
|
+
# --- ANSI color codes ---
|
|
30
|
+
GREEN = "\033[92m"
|
|
31
|
+
BLUE = "\033[94m"
|
|
32
|
+
RED = "\033[91m"
|
|
33
|
+
RESET = "\033[0m"
|
|
34
|
+
BOLD = "\033[1m"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OptimizerState(TypedDict):
|
|
38
|
+
user_input: str
|
|
39
|
+
problem: str
|
|
40
|
+
problem_spec: ProblemSpec
|
|
41
|
+
solver: SolverSpec
|
|
42
|
+
code: str
|
|
43
|
+
problem_diagnostic: List[Dict]
|
|
44
|
+
summary: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OptimizationAgent(BaseAgent):
|
|
48
|
+
def __init__(self, llm="OpenAI/gpt-4o", *args, **kwargs):
|
|
49
|
+
super().__init__(llm, *args, **kwargs)
|
|
50
|
+
self.extractor_prompt = extractor_prompt
|
|
51
|
+
self.explainer_prompt = explainer_prompt
|
|
52
|
+
self.verifier_prompt = verifier_prompt
|
|
53
|
+
self.code_generator_prompt = code_generator_prompt
|
|
54
|
+
self.solver_selector_prompt = solver_selector_prompt
|
|
55
|
+
self.math_formulator_prompt = math_formulator_prompt
|
|
56
|
+
self.discretizer_prompt = discretizer_prompt
|
|
57
|
+
self.feasibility_prompt = feasibility_prompt
|
|
58
|
+
self.tools = [fca] # [run_cmd, write_code, search_tool, fca]
|
|
59
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
60
|
+
self.tool_maps = {
|
|
61
|
+
(getattr(t, "name", None) or getattr(t, "__name__", None)): t
|
|
62
|
+
for i, t in enumerate(self.tools)
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
self._action = self._build_graph()
|
|
66
|
+
|
|
67
|
+
# Define the function that calls the model
|
|
68
|
+
def extractor(self, state: OptimizerState) -> OptimizerState:
|
|
69
|
+
new_state = state.copy()
|
|
70
|
+
new_state["problem"] = self.llm.invoke([
|
|
71
|
+
SystemMessage(content=self.extractor_prompt),
|
|
72
|
+
HumanMessage(content=new_state["user_input"]),
|
|
73
|
+
]).content
|
|
74
|
+
|
|
75
|
+
new_state["problem_diagnostic"] = []
|
|
76
|
+
|
|
77
|
+
print("Extractor:\n")
|
|
78
|
+
pprint.pprint(new_state["problem"])
|
|
79
|
+
return new_state
|
|
80
|
+
|
|
81
|
+
def formulator(self, state: OptimizerState) -> OptimizerState:
|
|
82
|
+
new_state = state.copy()
|
|
83
|
+
|
|
84
|
+
llm_out = self.llm.with_structured_output(
|
|
85
|
+
ProblemSpec, include_raw=True
|
|
86
|
+
).invoke([
|
|
87
|
+
SystemMessage(content=self.math_formulator_prompt),
|
|
88
|
+
HumanMessage(content=state["problem"]),
|
|
89
|
+
])
|
|
90
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
91
|
+
new_state["problem_diagnostic"].extend(
|
|
92
|
+
extract_tool_calls(llm_out["raw"])
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
print("Formulator:\n")
|
|
96
|
+
pprint.pprint(new_state["problem_spec"])
|
|
97
|
+
return new_state
|
|
98
|
+
|
|
99
|
+
def discretizer(self, state: OptimizerState) -> OptimizerState:
|
|
100
|
+
new_state = state.copy()
|
|
101
|
+
|
|
102
|
+
llm_out = self.llm.with_structured_output(
|
|
103
|
+
ProblemSpec, include_raw=True
|
|
104
|
+
).invoke([
|
|
105
|
+
SystemMessage(content=self.discretizer_prompt),
|
|
106
|
+
HumanMessage(content=state["problem"]),
|
|
107
|
+
])
|
|
108
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
109
|
+
new_state["problem_diagnostic"].extend(
|
|
110
|
+
extract_tool_calls(llm_out["raw"])
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
print("Discretizer:\n")
|
|
114
|
+
pprint.pprint(new_state["problem_spec"])
|
|
115
|
+
|
|
116
|
+
return new_state
|
|
117
|
+
|
|
118
|
+
def tester(self, state: OptimizerState) -> OptimizerState:
|
|
119
|
+
new_state = state.copy()
|
|
120
|
+
|
|
121
|
+
llm_out = self.llm.bind(tool_choice="required").invoke([
|
|
122
|
+
SystemMessage(content=self.feasibility_prompt),
|
|
123
|
+
HumanMessage(content=str(state["code"])),
|
|
124
|
+
])
|
|
125
|
+
|
|
126
|
+
tool_log = run_tool_calls(llm_out, self.tool_maps)
|
|
127
|
+
new_state["problem_diagnostic"].extend(tool_log)
|
|
128
|
+
|
|
129
|
+
print("Feasibility Tester:\n")
|
|
130
|
+
for msg in new_state["problem_diagnostic"]:
|
|
131
|
+
msg.pretty_print()
|
|
132
|
+
return new_state
|
|
133
|
+
|
|
134
|
+
def selector(self, state: OptimizerState) -> OptimizerState:
|
|
135
|
+
new_state = state.copy()
|
|
136
|
+
|
|
137
|
+
llm_out = self.llm.with_structured_output(
|
|
138
|
+
SolverSpec, include_raw=True
|
|
139
|
+
).invoke([
|
|
140
|
+
SystemMessage(content=self.solver_selector_prompt),
|
|
141
|
+
HumanMessage(content=str(state["problem_spec"])),
|
|
142
|
+
])
|
|
143
|
+
new_state["solver"] = llm_out["parsed"]
|
|
144
|
+
|
|
145
|
+
print("Selector:\n ")
|
|
146
|
+
pprint.pprint(new_state["solver"])
|
|
147
|
+
return new_state
|
|
148
|
+
|
|
149
|
+
def generator(self, state: OptimizerState) -> OptimizerState:
|
|
150
|
+
new_state = state.copy()
|
|
151
|
+
|
|
152
|
+
new_state["code"] = self.llm.invoke([
|
|
153
|
+
SystemMessage(content=self.code_generator_prompt),
|
|
154
|
+
HumanMessage(content=str(state["problem_spec"])),
|
|
155
|
+
]).content
|
|
156
|
+
|
|
157
|
+
print("Generator:\n")
|
|
158
|
+
pprint.pprint(new_state["code"])
|
|
159
|
+
return new_state
|
|
160
|
+
|
|
161
|
+
def verifier(self, state: OptimizerState) -> OptimizerState:
|
|
162
|
+
new_state = state.copy()
|
|
163
|
+
|
|
164
|
+
llm_out = self.llm.with_structured_output(
|
|
165
|
+
ProblemSpec, include_raw=True
|
|
166
|
+
).invoke([
|
|
167
|
+
SystemMessage(content=self.verifier_prompt),
|
|
168
|
+
HumanMessage(content=str(state["problem_spec"]) + state["code"]),
|
|
169
|
+
])
|
|
170
|
+
new_state["problem_spec"] = llm_out["parsed"]
|
|
171
|
+
if hasattr(llm_out, "tool_calls"):
|
|
172
|
+
tool_log = run_tool_calls(llm_out, self.tool_maps)
|
|
173
|
+
new_state["problem_diagnostic"].extend(tool_log)
|
|
174
|
+
|
|
175
|
+
print("Verifier:\n ")
|
|
176
|
+
pprint.pprint(new_state["problem_spec"])
|
|
177
|
+
return new_state
|
|
178
|
+
|
|
179
|
+
def explainer(self, state: OptimizerState) -> OptimizerState:
|
|
180
|
+
new_state = state.copy()
|
|
181
|
+
|
|
182
|
+
new_state["summary"] = self.llm.invoke([
|
|
183
|
+
SystemMessage(content=self.explainer_prompt),
|
|
184
|
+
HumanMessage(content=state["problem"] + str(state["problem_spec"])),
|
|
185
|
+
*state["problem_diagnostic"],
|
|
186
|
+
]).content
|
|
187
|
+
|
|
188
|
+
print("Summary:\n")
|
|
189
|
+
pprint.pprint(new_state["summary"])
|
|
190
|
+
return new_state
|
|
191
|
+
|
|
192
|
+
def _build_graph(self):
|
|
193
|
+
graph = StateGraph(OptimizerState)
|
|
194
|
+
|
|
195
|
+
self.add_node(graph, self.extractor, "Problem Extractor")
|
|
196
|
+
self.add_node(graph, self.formulator, "Math Formulator")
|
|
197
|
+
self.add_node(graph, self.selector, "Solver Selector")
|
|
198
|
+
self.add_node(graph, self.generator, "Code Generator")
|
|
199
|
+
self.add_node(graph, self.verifier, "Verifier")
|
|
200
|
+
self.add_node(graph, self.explainer, "Explainer")
|
|
201
|
+
self.add_node(graph, self.tester, "Feasibility Tester")
|
|
202
|
+
self.add_node(graph, self.discretizer, "Discretizer")
|
|
203
|
+
|
|
204
|
+
graph.add_edge(START, "Problem Extractor")
|
|
205
|
+
graph.add_edge("Problem Extractor", "Math Formulator")
|
|
206
|
+
graph.add_conditional_edges(
|
|
207
|
+
"Math Formulator",
|
|
208
|
+
should_discretize,
|
|
209
|
+
{"discretize": "Discretizer", "continue": "Solver Selector"},
|
|
210
|
+
)
|
|
211
|
+
graph.add_edge("Discretizer", "Solver Selector")
|
|
212
|
+
graph.add_edge("Solver Selector", "Code Generator")
|
|
213
|
+
graph.add_edge("Code Generator", "Feasibility Tester")
|
|
214
|
+
graph.add_edge("Feasibility Tester", "Verifier")
|
|
215
|
+
graph.add_conditional_edges(
|
|
216
|
+
"Verifier",
|
|
217
|
+
should_continue,
|
|
218
|
+
{"continue": "Explainer", "error": "Problem Extractor"},
|
|
219
|
+
)
|
|
220
|
+
graph.add_edge("Explainer", END)
|
|
221
|
+
|
|
222
|
+
return graph.compile()
|
|
223
|
+
|
|
224
|
+
def _invoke(
|
|
225
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
226
|
+
):
|
|
227
|
+
config = self.build_config(
|
|
228
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
229
|
+
)
|
|
230
|
+
if "user_input" not in inputs:
|
|
231
|
+
try:
|
|
232
|
+
inputs["user_input"] = inputs["messages"][0].content
|
|
233
|
+
except KeyError:
|
|
234
|
+
raise ("'user_input' is a required argument")
|
|
235
|
+
|
|
236
|
+
return self._action.invoke(inputs, config)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
######### try:
|
|
240
|
+
######### png_bytes = compiled_graph.get_graph().draw_mermaid_png()
|
|
241
|
+
######### img = mpimg.imread(io.BytesIO(png_bytes), format='png') # decode bytes -> array
|
|
242
|
+
|
|
243
|
+
######### plt.imshow(img)
|
|
244
|
+
######### plt.axis('off')
|
|
245
|
+
######### plt.show()
|
|
246
|
+
######### except Exception as e:
|
|
247
|
+
######### # This requires some extra dependencies and is optional
|
|
248
|
+
######### print(e)
|
|
249
|
+
######### pass
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@tool
|
|
253
|
+
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
254
|
+
"""
|
|
255
|
+
Run a commandline command from using the subprocess package in python
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
query: commandline command to be run as a string given to the subprocess.run command.
|
|
259
|
+
"""
|
|
260
|
+
workspace_dir = state["workspace"]
|
|
261
|
+
print("RUNNING: ", query)
|
|
262
|
+
try:
|
|
263
|
+
process = subprocess.Popen(
|
|
264
|
+
query.split(" "),
|
|
265
|
+
stdout=subprocess.PIPE,
|
|
266
|
+
stderr=subprocess.PIPE,
|
|
267
|
+
text=True,
|
|
268
|
+
cwd=workspace_dir,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
stdout, stderr = process.communicate(timeout=60000)
|
|
272
|
+
except KeyboardInterrupt:
|
|
273
|
+
print("Keyboard Interrupt of command: ", query)
|
|
274
|
+
stdout, stderr = "", "KeyboardInterrupt:"
|
|
275
|
+
|
|
276
|
+
print("STDOUT: ", stdout)
|
|
277
|
+
print("STDERR: ", stderr)
|
|
278
|
+
|
|
279
|
+
return f"STDOUT: {stdout} and STDERR: {stderr}"
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@tool
|
|
283
|
+
def write_code(code: str, filename: str, state: Annotated[dict, InjectedState]):
|
|
284
|
+
"""
|
|
285
|
+
Writes python or Julia code to a file in the given workspace as requested.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
code: The code to write
|
|
289
|
+
filename: the filename with an appropriate extension for programming language (.py for python, .jl for Julia, etc.)
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Execution results
|
|
293
|
+
"""
|
|
294
|
+
workspace_dir = state["workspace"]
|
|
295
|
+
print("Writing filename ", filename)
|
|
296
|
+
try:
|
|
297
|
+
# Extract code if wrapped in markdown code blocks
|
|
298
|
+
if "```" in code:
|
|
299
|
+
code_parts = code.split("```")
|
|
300
|
+
if len(code_parts) >= 3:
|
|
301
|
+
# Extract the actual code
|
|
302
|
+
if "\n" in code_parts[1]:
|
|
303
|
+
code = "\n".join(code_parts[1].strip().split("\n")[1:])
|
|
304
|
+
else:
|
|
305
|
+
code = code_parts[2].strip()
|
|
306
|
+
|
|
307
|
+
# Write code to a file
|
|
308
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
309
|
+
|
|
310
|
+
with open(code_file, "w") as f:
|
|
311
|
+
f.write(code)
|
|
312
|
+
print(f"Written code to file: {code_file}")
|
|
313
|
+
|
|
314
|
+
return f"File {filename} written successfully."
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print(f"Error generating code: {str(e)}")
|
|
318
|
+
# Return minimal code that prints the error
|
|
319
|
+
return f"Failed to write {filename} successfully."
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
323
|
+
# search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# A function to test if discretization is needed
|
|
327
|
+
def should_discretize(
|
|
328
|
+
state: OptimizerState,
|
|
329
|
+
) -> Literal["Discretize", "continue"]:
|
|
330
|
+
cons = state["problem_spec"]["constraints"]
|
|
331
|
+
decs = state["problem_spec"]["decision_variables"]
|
|
332
|
+
|
|
333
|
+
if any("infinite-dimensional" in t["tags"] for t in cons) or any(
|
|
334
|
+
"infinite-dimensional" in t["type"] for t in decs
|
|
335
|
+
):
|
|
336
|
+
# print(f"Problem has infinite-dimensional constraints/decision variables. Needs to be discretized")
|
|
337
|
+
return "discretize"
|
|
338
|
+
|
|
339
|
+
return "continue"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# Define the function that determines whether to continue or not
|
|
343
|
+
def should_continue(state: OptimizerState) -> Literal["error", "continue"]:
|
|
344
|
+
spec = state["problem_spec"]
|
|
345
|
+
try:
|
|
346
|
+
status = spec["status"].lower()
|
|
347
|
+
except KeyError:
|
|
348
|
+
status = spec["spec"]["status"].lower()
|
|
349
|
+
if "VERIFIED".lower() in status:
|
|
350
|
+
return "continue"
|
|
351
|
+
# Otherwise if there is, we continue
|
|
352
|
+
else:
|
|
353
|
+
return "error"
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def main():
|
|
357
|
+
model = ChatOpenAI(
|
|
358
|
+
model="gpt-4o", max_tokens=10000, timeout=None, max_retries=2
|
|
359
|
+
)
|
|
360
|
+
execution_agent = OptimizationAgent(llm=model)
|
|
361
|
+
# execution_agent = execution_agent.bind_tools(feasibility_checker)
|
|
362
|
+
problem_string = """
|
|
363
|
+
Solve the following optimal power flow problem
|
|
364
|
+
System topology and data:
|
|
365
|
+
- Three buses (nodes) labeled 1, 2 and 3.
|
|
366
|
+
- One generator at each bus; each can only inject power (no negative output).
|
|
367
|
+
- Loads of 1 p.u. at bus 1, 2 p.u. at bus 2, and 4 p.u. at bus 3.
|
|
368
|
+
- Transmission lines connecting every pair of buses, with susceptances (B):
|
|
369
|
+
- Line 1–2: B₁₂ = 10
|
|
370
|
+
- Line 1–3: B₁₃ = 20
|
|
371
|
+
- Line 2–3: B₂₃ = 30
|
|
372
|
+
|
|
373
|
+
Decision variables:
|
|
374
|
+
- Voltage angles θ₁, θ₂, θ₃ (in radians) at buses 1–3.
|
|
375
|
+
- Generator outputs Pᵍ₁, Pᵍ₂, Pᵍ₃ ≥ 0 (in per-unit).
|
|
376
|
+
|
|
377
|
+
Reference angle:
|
|
378
|
+
- To fix the overall angle‐shift ambiguity, we set θ₁ = 0 (“slack” or reference bus).
|
|
379
|
+
|
|
380
|
+
Objective:
|
|
381
|
+
- Minimize total generation cost with
|
|
382
|
+
- 𝑐1 = 1
|
|
383
|
+
- 𝑐2 = 10
|
|
384
|
+
- 𝑐3 = 100
|
|
385
|
+
|
|
386
|
+
Line‐flow limits
|
|
387
|
+
- Lines 1-2 and 1-3 are thermal‐limited to ±0.5 p.u., line 2-3 is unconstrained.
|
|
388
|
+
|
|
389
|
+
In words:
|
|
390
|
+
We choose how much each generator should produce (at non-negative cost) and the voltage angles at each bus (with bus 1 set to zero) so that supply exactly meets demand, flows on the critical lines don’t exceed their limits, and the total cost is as small as possible.
|
|
391
|
+
Use the tools at your disposal to check if your formulation is feasible.
|
|
392
|
+
"""
|
|
393
|
+
inputs = {"user_input": problem_string}
|
|
394
|
+
result = execution_agent.invoke(inputs)
|
|
395
|
+
print(result["messages"][-1].content)
|
|
396
|
+
return result
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
if __name__ == "__main__":
|
|
400
|
+
main()
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
# min 𝑃𝑔 𝑐1*𝑃1 + 𝑐2 * 𝑃2 + 𝑐3 * 𝑃3
|
ursa/agents/planning_agent.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
# from langgraph.checkpoint.memory import MemorySaver
|
|
2
2
|
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
-
from typing import Annotated, Any, Dict, List, Optional
|
|
3
|
+
from typing import Annotated, Any, Dict, Iterator, List, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
from langchain_core.language_models import BaseChatModel
|
|
6
6
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
-
from langgraph.graph import
|
|
7
|
+
from langgraph.graph import StateGraph
|
|
8
8
|
from langgraph.graph.message import add_messages
|
|
9
9
|
from pydantic import Field
|
|
10
10
|
from typing_extensions import TypedDict
|
|
@@ -36,7 +36,7 @@ class PlanningAgent(BaseAgent):
|
|
|
36
36
|
self.planner_prompt = planner_prompt
|
|
37
37
|
self.formalize_prompt = formalize_prompt
|
|
38
38
|
self.reflection_prompt = reflection_prompt
|
|
39
|
-
self.
|
|
39
|
+
self._action = self._build_graph()
|
|
40
40
|
|
|
41
41
|
def generation_node(self, state: PlanningState) -> PlanningState:
|
|
42
42
|
print("PlanningAgent: generating . . .")
|
|
@@ -48,7 +48,8 @@ class PlanningAgent(BaseAgent):
|
|
|
48
48
|
return {
|
|
49
49
|
"messages": [
|
|
50
50
|
self.llm.invoke(
|
|
51
|
-
messages,
|
|
51
|
+
messages,
|
|
52
|
+
self.build_config(tags=["planner", "generate"]),
|
|
52
53
|
)
|
|
53
54
|
]
|
|
54
55
|
}
|
|
@@ -64,7 +65,8 @@ class PlanningAgent(BaseAgent):
|
|
|
64
65
|
for _ in range(10):
|
|
65
66
|
try:
|
|
66
67
|
res = self.llm.invoke(
|
|
67
|
-
translated,
|
|
68
|
+
translated,
|
|
69
|
+
self.build_config(tags=["planner", "formalize"]),
|
|
68
70
|
)
|
|
69
71
|
json_out = extract_json(res.content)
|
|
70
72
|
break
|
|
@@ -88,39 +90,72 @@ class PlanningAgent(BaseAgent):
|
|
|
88
90
|
]
|
|
89
91
|
translated = [SystemMessage(content=reflection_prompt)] + translated
|
|
90
92
|
res = self.llm.invoke(
|
|
91
|
-
translated,
|
|
93
|
+
translated,
|
|
94
|
+
self.build_config(tags=["planner", "reflect"]),
|
|
92
95
|
)
|
|
93
96
|
return {"messages": [HumanMessage(content=res.content)]}
|
|
94
97
|
|
|
95
|
-
def
|
|
96
|
-
|
|
97
|
-
self.
|
|
98
|
-
self.
|
|
99
|
-
self.
|
|
98
|
+
def _build_graph(self):
|
|
99
|
+
graph = StateGraph(PlanningState)
|
|
100
|
+
self.add_node(graph, self.generation_node, "generate")
|
|
101
|
+
self.add_node(graph, self.reflection_node, "reflect")
|
|
102
|
+
self.add_node(graph, self.formalize_node, "formalize")
|
|
100
103
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
+
# Edges
|
|
105
|
+
graph.set_entry_point("generate")
|
|
106
|
+
graph.add_edge("generate", "reflect")
|
|
107
|
+
graph.set_finish_point("formalize")
|
|
104
108
|
|
|
105
|
-
|
|
109
|
+
# Time the router logic too
|
|
110
|
+
graph.add_conditional_edges(
|
|
106
111
|
"reflect",
|
|
107
|
-
should_continue,
|
|
112
|
+
self._wrap_cond(should_continue, "should_continue", "planner"),
|
|
108
113
|
{"generate": "generate", "formalize": "formalize"},
|
|
109
114
|
)
|
|
110
115
|
|
|
111
116
|
# memory = MemorySaver()
|
|
112
117
|
# self.action = self.graph.compile(checkpointer=memory)
|
|
113
|
-
|
|
118
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
114
119
|
# self.action.get_graph().draw_mermaid_png(output_file_path="planning_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
115
120
|
|
|
116
|
-
def
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
121
|
+
def _invoke(
|
|
122
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
123
|
+
):
|
|
124
|
+
config = self.build_config(
|
|
125
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
126
|
+
)
|
|
127
|
+
return self._action.invoke(inputs, config)
|
|
128
|
+
|
|
129
|
+
def _stream(
|
|
130
|
+
self,
|
|
131
|
+
inputs: Mapping[str, Any],
|
|
132
|
+
*,
|
|
133
|
+
config: dict | None = None,
|
|
134
|
+
recursion_limit: int = 1000,
|
|
135
|
+
**_,
|
|
136
|
+
) -> Iterator[dict]:
|
|
137
|
+
# If you have defaults, merge them here:
|
|
138
|
+
default = self.build_config(
|
|
139
|
+
recursion_limit=recursion_limit, tags=["planner"]
|
|
140
|
+
)
|
|
141
|
+
if config:
|
|
142
|
+
merged = {**default, **config}
|
|
143
|
+
if "configurable" in config:
|
|
144
|
+
merged["configurable"] = {
|
|
145
|
+
**default.get("configurable", {}),
|
|
146
|
+
**config["configurable"],
|
|
147
|
+
}
|
|
148
|
+
else:
|
|
149
|
+
merged = default
|
|
150
|
+
|
|
151
|
+
# Delegate to the compiled graph's stream
|
|
152
|
+
yield from self._action.stream(inputs, merged)
|
|
153
|
+
|
|
154
|
+
# prevent bypass
|
|
155
|
+
@property
|
|
156
|
+
def action(self):
|
|
157
|
+
raise AttributeError(
|
|
158
|
+
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
124
159
|
)
|
|
125
160
|
|
|
126
161
|
|
|
@@ -137,15 +172,15 @@ def should_continue(state: PlanningState):
|
|
|
137
172
|
|
|
138
173
|
def main():
|
|
139
174
|
planning_agent = PlanningAgent()
|
|
140
|
-
|
|
175
|
+
|
|
176
|
+
for event in planning_agent.stream(
|
|
141
177
|
{
|
|
142
178
|
"messages": [
|
|
143
179
|
HumanMessage(
|
|
144
|
-
content="Find a city with
|
|
180
|
+
content="Find a city with at least 10 vowels in its name."
|
|
145
181
|
)
|
|
146
182
|
],
|
|
147
183
|
},
|
|
148
|
-
config,
|
|
149
184
|
):
|
|
150
185
|
print("-" * 30)
|
|
151
186
|
print(event.keys())
|