ursa-ai 0.4.2__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,403 @@
1
+ import os
2
+ import pprint
3
+ import subprocess
4
+ from typing import Annotated, Any, Dict, List, Literal, Mapping
5
+
6
+ from langchain_community.tools import DuckDuckGoSearchResults
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+ from langchain_core.tools import tool
9
+ from langchain_openai import ChatOpenAI
10
+ from langgraph.graph import END, START, StateGraph
11
+ from langgraph.prebuilt import InjectedState
12
+ from typing_extensions import TypedDict
13
+
14
+ from ..prompt_library.optimization_prompts import (
15
+ code_generator_prompt,
16
+ discretizer_prompt,
17
+ explainer_prompt,
18
+ extractor_prompt,
19
+ feasibility_prompt,
20
+ math_formulator_prompt,
21
+ solver_selector_prompt,
22
+ verifier_prompt,
23
+ )
24
+ from ..tools.feasibility_tools import feasibility_check_auto as fca
25
+ from ..util.helperFunctions import extract_tool_calls, run_tool_calls
26
+ from ..util.optimization_schema import ProblemSpec, SolverSpec
27
+ from .base import BaseAgent
28
+
29
+ # --- ANSI color codes ---
30
+ GREEN = "\033[92m"
31
+ BLUE = "\033[94m"
32
+ RED = "\033[91m"
33
+ RESET = "\033[0m"
34
+ BOLD = "\033[1m"
35
+
36
+
37
+ class OptimizerState(TypedDict):
38
+ user_input: str
39
+ problem: str
40
+ problem_spec: ProblemSpec
41
+ solver: SolverSpec
42
+ code: str
43
+ problem_diagnostic: List[Dict]
44
+ summary: str
45
+
46
+
47
+ class OptimizationAgent(BaseAgent):
48
+ def __init__(self, llm="OpenAI/gpt-4o", *args, **kwargs):
49
+ super().__init__(llm, *args, **kwargs)
50
+ self.extractor_prompt = extractor_prompt
51
+ self.explainer_prompt = explainer_prompt
52
+ self.verifier_prompt = verifier_prompt
53
+ self.code_generator_prompt = code_generator_prompt
54
+ self.solver_selector_prompt = solver_selector_prompt
55
+ self.math_formulator_prompt = math_formulator_prompt
56
+ self.discretizer_prompt = discretizer_prompt
57
+ self.feasibility_prompt = feasibility_prompt
58
+ self.tools = [fca] # [run_cmd, write_code, search_tool, fca]
59
+ self.llm = self.llm.bind_tools(self.tools)
60
+ self.tool_maps = {
61
+ (getattr(t, "name", None) or getattr(t, "__name__", None)): t
62
+ for i, t in enumerate(self.tools)
63
+ }
64
+
65
+ self._action = self._build_graph()
66
+
67
+ # Define the function that calls the model
68
+ def extractor(self, state: OptimizerState) -> OptimizerState:
69
+ new_state = state.copy()
70
+ new_state["problem"] = self.llm.invoke([
71
+ SystemMessage(content=self.extractor_prompt),
72
+ HumanMessage(content=new_state["user_input"]),
73
+ ]).content
74
+
75
+ new_state["problem_diagnostic"] = []
76
+
77
+ print("Extractor:\n")
78
+ pprint.pprint(new_state["problem"])
79
+ return new_state
80
+
81
+ def formulator(self, state: OptimizerState) -> OptimizerState:
82
+ new_state = state.copy()
83
+
84
+ llm_out = self.llm.with_structured_output(
85
+ ProblemSpec, include_raw=True
86
+ ).invoke([
87
+ SystemMessage(content=self.math_formulator_prompt),
88
+ HumanMessage(content=state["problem"]),
89
+ ])
90
+ new_state["problem_spec"] = llm_out["parsed"]
91
+ new_state["problem_diagnostic"].extend(
92
+ extract_tool_calls(llm_out["raw"])
93
+ )
94
+
95
+ print("Formulator:\n")
96
+ pprint.pprint(new_state["problem_spec"])
97
+ return new_state
98
+
99
+ def discretizer(self, state: OptimizerState) -> OptimizerState:
100
+ new_state = state.copy()
101
+
102
+ llm_out = self.llm.with_structured_output(
103
+ ProblemSpec, include_raw=True
104
+ ).invoke([
105
+ SystemMessage(content=self.discretizer_prompt),
106
+ HumanMessage(content=state["problem"]),
107
+ ])
108
+ new_state["problem_spec"] = llm_out["parsed"]
109
+ new_state["problem_diagnostic"].extend(
110
+ extract_tool_calls(llm_out["raw"])
111
+ )
112
+
113
+ print("Discretizer:\n")
114
+ pprint.pprint(new_state["problem_spec"])
115
+
116
+ return new_state
117
+
118
+ def tester(self, state: OptimizerState) -> OptimizerState:
119
+ new_state = state.copy()
120
+
121
+ llm_out = self.llm.bind(tool_choice="required").invoke([
122
+ SystemMessage(content=self.feasibility_prompt),
123
+ HumanMessage(content=str(state["code"])),
124
+ ])
125
+
126
+ tool_log = run_tool_calls(llm_out, self.tool_maps)
127
+ new_state["problem_diagnostic"].extend(tool_log)
128
+
129
+ print("Feasibility Tester:\n")
130
+ for msg in new_state["problem_diagnostic"]:
131
+ msg.pretty_print()
132
+ return new_state
133
+
134
+ def selector(self, state: OptimizerState) -> OptimizerState:
135
+ new_state = state.copy()
136
+
137
+ llm_out = self.llm.with_structured_output(
138
+ SolverSpec, include_raw=True
139
+ ).invoke([
140
+ SystemMessage(content=self.solver_selector_prompt),
141
+ HumanMessage(content=str(state["problem_spec"])),
142
+ ])
143
+ new_state["solver"] = llm_out["parsed"]
144
+
145
+ print("Selector:\n ")
146
+ pprint.pprint(new_state["solver"])
147
+ return new_state
148
+
149
+ def generator(self, state: OptimizerState) -> OptimizerState:
150
+ new_state = state.copy()
151
+
152
+ new_state["code"] = self.llm.invoke([
153
+ SystemMessage(content=self.code_generator_prompt),
154
+ HumanMessage(content=str(state["problem_spec"])),
155
+ ]).content
156
+
157
+ print("Generator:\n")
158
+ pprint.pprint(new_state["code"])
159
+ return new_state
160
+
161
+ def verifier(self, state: OptimizerState) -> OptimizerState:
162
+ new_state = state.copy()
163
+
164
+ llm_out = self.llm.with_structured_output(
165
+ ProblemSpec, include_raw=True
166
+ ).invoke([
167
+ SystemMessage(content=self.verifier_prompt),
168
+ HumanMessage(content=str(state["problem_spec"]) + state["code"]),
169
+ ])
170
+ new_state["problem_spec"] = llm_out["parsed"]
171
+ if hasattr(llm_out, "tool_calls"):
172
+ tool_log = run_tool_calls(llm_out, self.tool_maps)
173
+ new_state["problem_diagnostic"].extend(tool_log)
174
+
175
+ print("Verifier:\n ")
176
+ pprint.pprint(new_state["problem_spec"])
177
+ return new_state
178
+
179
+ def explainer(self, state: OptimizerState) -> OptimizerState:
180
+ new_state = state.copy()
181
+
182
+ new_state["summary"] = self.llm.invoke([
183
+ SystemMessage(content=self.explainer_prompt),
184
+ HumanMessage(content=state["problem"] + str(state["problem_spec"])),
185
+ *state["problem_diagnostic"],
186
+ ]).content
187
+
188
+ print("Summary:\n")
189
+ pprint.pprint(new_state["summary"])
190
+ return new_state
191
+
192
+ def _build_graph(self):
193
+ graph = StateGraph(OptimizerState)
194
+
195
+ self.add_node(graph, self.extractor, "Problem Extractor")
196
+ self.add_node(graph, self.formulator, "Math Formulator")
197
+ self.add_node(graph, self.selector, "Solver Selector")
198
+ self.add_node(graph, self.generator, "Code Generator")
199
+ self.add_node(graph, self.verifier, "Verifier")
200
+ self.add_node(graph, self.explainer, "Explainer")
201
+ self.add_node(graph, self.tester, "Feasibility Tester")
202
+ self.add_node(graph, self.discretizer, "Discretizer")
203
+
204
+ graph.add_edge(START, "Problem Extractor")
205
+ graph.add_edge("Problem Extractor", "Math Formulator")
206
+ graph.add_conditional_edges(
207
+ "Math Formulator",
208
+ should_discretize,
209
+ {"discretize": "Discretizer", "continue": "Solver Selector"},
210
+ )
211
+ graph.add_edge("Discretizer", "Solver Selector")
212
+ graph.add_edge("Solver Selector", "Code Generator")
213
+ graph.add_edge("Code Generator", "Feasibility Tester")
214
+ graph.add_edge("Feasibility Tester", "Verifier")
215
+ graph.add_conditional_edges(
216
+ "Verifier",
217
+ should_continue,
218
+ {"continue": "Explainer", "error": "Problem Extractor"},
219
+ )
220
+ graph.add_edge("Explainer", END)
221
+
222
+ return graph.compile()
223
+
224
+ def _invoke(
225
+ self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
226
+ ):
227
+ config = self.build_config(
228
+ recursion_limit=recursion_limit, tags=["graph"]
229
+ )
230
+ if "user_input" not in inputs:
231
+ try:
232
+ inputs["user_input"] = inputs["messages"][0].content
233
+ except KeyError:
234
+ raise ("'user_input' is a required argument")
235
+
236
+ return self._action.invoke(inputs, config)
237
+
238
+
239
+ ######### try:
240
+ ######### png_bytes = compiled_graph.get_graph().draw_mermaid_png()
241
+ ######### img = mpimg.imread(io.BytesIO(png_bytes), format='png') # decode bytes -> array
242
+
243
+ ######### plt.imshow(img)
244
+ ######### plt.axis('off')
245
+ ######### plt.show()
246
+ ######### except Exception as e:
247
+ ######### # This requires some extra dependencies and is optional
248
+ ######### print(e)
249
+ ######### pass
250
+
251
+
252
+ @tool
253
+ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
254
+ """
255
+ Run a commandline command from using the subprocess package in python
256
+
257
+ Args:
258
+ query: commandline command to be run as a string given to the subprocess.run command.
259
+ """
260
+ workspace_dir = state["workspace"]
261
+ print("RUNNING: ", query)
262
+ try:
263
+ process = subprocess.Popen(
264
+ query.split(" "),
265
+ stdout=subprocess.PIPE,
266
+ stderr=subprocess.PIPE,
267
+ text=True,
268
+ cwd=workspace_dir,
269
+ )
270
+
271
+ stdout, stderr = process.communicate(timeout=60000)
272
+ except KeyboardInterrupt:
273
+ print("Keyboard Interrupt of command: ", query)
274
+ stdout, stderr = "", "KeyboardInterrupt:"
275
+
276
+ print("STDOUT: ", stdout)
277
+ print("STDERR: ", stderr)
278
+
279
+ return f"STDOUT: {stdout} and STDERR: {stderr}"
280
+
281
+
282
+ @tool
283
+ def write_code(code: str, filename: str, state: Annotated[dict, InjectedState]):
284
+ """
285
+ Writes python or Julia code to a file in the given workspace as requested.
286
+
287
+ Args:
288
+ code: The code to write
289
+ filename: the filename with an appropriate extension for programming language (.py for python, .jl for Julia, etc.)
290
+
291
+ Returns:
292
+ Execution results
293
+ """
294
+ workspace_dir = state["workspace"]
295
+ print("Writing filename ", filename)
296
+ try:
297
+ # Extract code if wrapped in markdown code blocks
298
+ if "```" in code:
299
+ code_parts = code.split("```")
300
+ if len(code_parts) >= 3:
301
+ # Extract the actual code
302
+ if "\n" in code_parts[1]:
303
+ code = "\n".join(code_parts[1].strip().split("\n")[1:])
304
+ else:
305
+ code = code_parts[2].strip()
306
+
307
+ # Write code to a file
308
+ code_file = os.path.join(workspace_dir, filename)
309
+
310
+ with open(code_file, "w") as f:
311
+ f.write(code)
312
+ print(f"Written code to file: {code_file}")
313
+
314
+ return f"File {filename} written successfully."
315
+
316
+ except Exception as e:
317
+ print(f"Error generating code: {str(e)}")
318
+ # Return minimal code that prints the error
319
+ return f"Failed to write {filename} successfully."
320
+
321
+
322
+ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
323
+ # search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
324
+
325
+
326
+ # A function to test if discretization is needed
327
+ def should_discretize(
328
+ state: OptimizerState,
329
+ ) -> Literal["Discretize", "continue"]:
330
+ cons = state["problem_spec"]["constraints"]
331
+ decs = state["problem_spec"]["decision_variables"]
332
+
333
+ if any("infinite-dimensional" in t["tags"] for t in cons) or any(
334
+ "infinite-dimensional" in t["type"] for t in decs
335
+ ):
336
+ # print(f"Problem has infinite-dimensional constraints/decision variables. Needs to be discretized")
337
+ return "discretize"
338
+
339
+ return "continue"
340
+
341
+
342
+ # Define the function that determines whether to continue or not
343
+ def should_continue(state: OptimizerState) -> Literal["error", "continue"]:
344
+ spec = state["problem_spec"]
345
+ try:
346
+ status = spec["status"].lower()
347
+ except KeyError:
348
+ status = spec["spec"]["status"].lower()
349
+ if "VERIFIED".lower() in status:
350
+ return "continue"
351
+ # Otherwise if there is, we continue
352
+ else:
353
+ return "error"
354
+
355
+
356
+ def main():
357
+ model = ChatOpenAI(
358
+ model="gpt-4o", max_tokens=10000, timeout=None, max_retries=2
359
+ )
360
+ execution_agent = OptimizationAgent(llm=model)
361
+ # execution_agent = execution_agent.bind_tools(feasibility_checker)
362
+ problem_string = """
363
+ Solve the following optimal power flow problem
364
+ System topology and data:
365
+ - Three buses (nodes) labeled 1, 2 and 3.
366
+ - One generator at each bus; each can only inject power (no negative output).
367
+ - Loads of 1 p.u. at bus 1, 2 p.u. at bus 2, and 4 p.u. at bus 3.
368
+ - Transmission lines connecting every pair of buses, with susceptances (B):
369
+ - Line 1–2: B₁₂ = 10
370
+ - Line 1–3: B₁₃ = 20
371
+ - Line 2–3: B₂₃ = 30
372
+
373
+ Decision variables:
374
+ - Voltage angles θ₁, θ₂, θ₃ (in radians) at buses 1–3.
375
+ - Generator outputs Pᵍ₁, Pᵍ₂, Pᵍ₃ ≥ 0 (in per-unit).
376
+
377
+ Reference angle:
378
+ - To fix the overall angle‐shift ambiguity, we set θ₁ = 0 (“slack” or reference bus).
379
+
380
+ Objective:
381
+ - Minimize total generation cost with
382
+ - 𝑐1 = 1
383
+ - 𝑐2 = 10
384
+ - 𝑐3 = 100
385
+
386
+ Line‐flow limits
387
+ - Lines 1-2 and 1-3 are thermal‐limited to ±0.5 p.u., line 2-3 is unconstrained.
388
+
389
+ In words:
390
+ We choose how much each generator should produce (at non-negative cost) and the voltage angles at each bus (with bus 1 set to zero) so that supply exactly meets demand, flows on the critical lines don’t exceed their limits, and the total cost is as small as possible.
391
+ Use the tools at your disposal to check if your formulation is feasible.
392
+ """
393
+ inputs = {"user_input": problem_string}
394
+ result = execution_agent.invoke(inputs)
395
+ print(result["messages"][-1].content)
396
+ return result
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
401
+
402
+
403
+ # min⁡ 𝑃𝑔  𝑐1*𝑃1 + 𝑐2 * 𝑃2 + 𝑐3 * 𝑃3
@@ -1,10 +1,10 @@
1
1
  # from langgraph.checkpoint.memory import MemorySaver
2
2
  # from langchain_core.runnables.graph import MermaidDrawMethod
3
- from typing import Annotated, Any, Dict, List, Optional
3
+ from typing import Annotated, Any, Dict, Iterator, List, Mapping, Optional
4
4
 
5
5
  from langchain_core.language_models import BaseChatModel
6
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
- from langgraph.graph import END, START, StateGraph
7
+ from langgraph.graph import StateGraph
8
8
  from langgraph.graph.message import add_messages
9
9
  from pydantic import Field
10
10
  from typing_extensions import TypedDict
@@ -36,7 +36,7 @@ class PlanningAgent(BaseAgent):
36
36
  self.planner_prompt = planner_prompt
37
37
  self.formalize_prompt = formalize_prompt
38
38
  self.reflection_prompt = reflection_prompt
39
- self._initialize_agent()
39
+ self._action = self._build_graph()
40
40
 
41
41
  def generation_node(self, state: PlanningState) -> PlanningState:
42
42
  print("PlanningAgent: generating . . .")
@@ -48,7 +48,8 @@ class PlanningAgent(BaseAgent):
48
48
  return {
49
49
  "messages": [
50
50
  self.llm.invoke(
51
- messages, {"configurable": {"thread_id": self.thread_id}}
51
+ messages,
52
+ self.build_config(tags=["planner", "generate"]),
52
53
  )
53
54
  ]
54
55
  }
@@ -64,7 +65,8 @@ class PlanningAgent(BaseAgent):
64
65
  for _ in range(10):
65
66
  try:
66
67
  res = self.llm.invoke(
67
- translated, {"configurable": {"thread_id": self.thread_id}}
68
+ translated,
69
+ self.build_config(tags=["planner", "formalize"]),
68
70
  )
69
71
  json_out = extract_json(res.content)
70
72
  break
@@ -88,39 +90,72 @@ class PlanningAgent(BaseAgent):
88
90
  ]
89
91
  translated = [SystemMessage(content=reflection_prompt)] + translated
90
92
  res = self.llm.invoke(
91
- translated, {"configurable": {"thread_id": self.thread_id}}
93
+ translated,
94
+ self.build_config(tags=["planner", "reflect"]),
92
95
  )
93
96
  return {"messages": [HumanMessage(content=res.content)]}
94
97
 
95
- def _initialize_agent(self):
96
- self.graph = StateGraph(PlanningState)
97
- self.graph.add_node("generate", self.generation_node)
98
- self.graph.add_node("reflect", self.reflection_node)
99
- self.graph.add_node("formalize", self.formalize_node)
98
+ def _build_graph(self):
99
+ graph = StateGraph(PlanningState)
100
+ self.add_node(graph, self.generation_node, "generate")
101
+ self.add_node(graph, self.reflection_node, "reflect")
102
+ self.add_node(graph, self.formalize_node, "formalize")
100
103
 
101
- self.graph.add_edge(START, "generate")
102
- self.graph.add_edge("generate", "reflect")
103
- self.graph.add_edge("formalize", END)
104
+ # Edges
105
+ graph.set_entry_point("generate")
106
+ graph.add_edge("generate", "reflect")
107
+ graph.set_finish_point("formalize")
104
108
 
105
- self.graph.add_conditional_edges(
109
+ # Time the router logic too
110
+ graph.add_conditional_edges(
106
111
  "reflect",
107
- should_continue,
112
+ self._wrap_cond(should_continue, "should_continue", "planner"),
108
113
  {"generate": "generate", "formalize": "formalize"},
109
114
  )
110
115
 
111
116
  # memory = MemorySaver()
112
117
  # self.action = self.graph.compile(checkpointer=memory)
113
- self.action = self.graph.compile(checkpointer=self.checkpointer)
118
+ return graph.compile(checkpointer=self.checkpointer)
114
119
  # self.action.get_graph().draw_mermaid_png(output_file_path="planning_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
115
120
 
116
- def run(self, prompt, recursion_limit=100):
117
- initial_state = {"messages": [HumanMessage(content=prompt)]}
118
- return self.action.invoke(
119
- initial_state,
120
- {
121
- "recursion_limit": recursion_limit,
122
- "configurable": {"thread_id": self.thread_id},
123
- },
121
+ def _invoke(
122
+ self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
123
+ ):
124
+ config = self.build_config(
125
+ recursion_limit=recursion_limit, tags=["graph"]
126
+ )
127
+ return self._action.invoke(inputs, config)
128
+
129
+ def _stream(
130
+ self,
131
+ inputs: Mapping[str, Any],
132
+ *,
133
+ config: dict | None = None,
134
+ recursion_limit: int = 1000,
135
+ **_,
136
+ ) -> Iterator[dict]:
137
+ # If you have defaults, merge them here:
138
+ default = self.build_config(
139
+ recursion_limit=recursion_limit, tags=["planner"]
140
+ )
141
+ if config:
142
+ merged = {**default, **config}
143
+ if "configurable" in config:
144
+ merged["configurable"] = {
145
+ **default.get("configurable", {}),
146
+ **config["configurable"],
147
+ }
148
+ else:
149
+ merged = default
150
+
151
+ # Delegate to the compiled graph's stream
152
+ yield from self._action.stream(inputs, merged)
153
+
154
+ # prevent bypass
155
+ @property
156
+ def action(self):
157
+ raise AttributeError(
158
+ "Use .stream(...) or .invoke(...); direct .action access is unsupported."
124
159
  )
125
160
 
126
161
 
@@ -137,15 +172,15 @@ def should_continue(state: PlanningState):
137
172
 
138
173
  def main():
139
174
  planning_agent = PlanningAgent()
140
- for event in planning_agent.action.stream(
175
+
176
+ for event in planning_agent.stream(
141
177
  {
142
178
  "messages": [
143
179
  HumanMessage(
144
- content="Find a city with as least 10 vowels in its name." # "Write an essay on ideal high-entropy alloys for spacecraft."
180
+ content="Find a city with at least 10 vowels in its name."
145
181
  )
146
182
  ],
147
183
  },
148
- config,
149
184
  ):
150
185
  print("-" * 30)
151
186
  print(event.keys())