ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,286 @@
1
+
2
+ from langchain_core.messages import BaseMessage
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langgraph.types import Checkpointer
5
+
6
+ from langgraph.graph import START, END, StateGraph
7
+ from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import Command
9
+
10
+ from typing import TypedDict, Annotated, Sequence
11
+ import operator
12
+
13
+ from typing_extensions import TypedDict, Literal
14
+
15
+ import pandas as pd
16
+ from IPython.display import Markdown
17
+
18
+ from ai_data_science_team.templates import BaseAgent
19
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
+ from ai_data_science_team.utils.plotly import plotly_from_dict
21
+
22
+
23
+
24
+ class SQLDataAnalyst(BaseAgent):
25
+ """
26
+ SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
27
+
28
+ Parameters:
29
+ -----------
30
+ model:
31
+ The language model to be used for the agents.
32
+ sql_database_agent: SQLDatabaseAgent
33
+ The SQL Database Agent.
34
+ data_visualization_agent: DataVisualizationAgent
35
+ The Data Visualization Agent.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, **kwargs)
40
+ Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
41
+ invoke_agent(user_instructions, **kwargs)
42
+ Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
43
+ get_data_sql()
44
+ Returns the SQL data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_sql_query_code(markdown=False)
48
+ Returns the SQL query code as a string, optionally formatted as a Markdown code block.
49
+ get_sql_database_function(markdown=False)
50
+ Returns the SQL database function as a string, optionally formatted as a Markdown code block.
51
+ get_data_visualization_function(markdown=False)
52
+ Returns the data visualization function as a string, optionally formatted as a Markdown code block.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model,
58
+ sql_database_agent: SQLDatabaseAgent,
59
+ data_visualization_agent: DataVisualizationAgent,
60
+ checkpointer: Checkpointer = None,
61
+ ):
62
+ self._params = {
63
+ "model": model,
64
+ "sql_database_agent": sql_database_agent,
65
+ "data_visualization_agent": data_visualization_agent,
66
+ "checkpointer": checkpointer,
67
+ }
68
+ self._compiled_graph = self._make_compiled_graph()
69
+ self.response = None
70
+
71
+ def _make_compiled_graph(self):
72
+ """
73
+ Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
74
+ Running this method resets the response to None.
75
+ """
76
+ self.response = None
77
+ return make_sql_data_analyst(
78
+ model=self._params["model"],
79
+ sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
80
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
81
+ checkpointer=self._params["checkpointer"],
82
+ )
83
+
84
+ def update_params(self, **kwargs):
85
+ """
86
+ Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
87
+ and rebuilds the compiled graph.
88
+ """
89
+ for k, v in kwargs.items():
90
+ self._params[k] = v
91
+ self._compiled_graph = self._make_compiled_graph()
92
+
93
+ def ainvoke_agent(self, user_instructions, **kwargs):
94
+ """
95
+ Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
96
+
97
+ Parameters:
98
+ ----------
99
+ user_instructions: str
100
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
101
+ **kwargs:
102
+ Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
103
+
104
+ Returns:
105
+ -------
106
+ None. The response is stored in the `response` attribute.
107
+
108
+ Example:
109
+ --------
110
+ ``` python
111
+ # TODO
112
+ ```
113
+ """
114
+ response = self._compiled_graph.ainvoke({
115
+ "user_instructions": user_instructions,
116
+ }, **kwargs)
117
+ self.response = response
118
+
119
+ def invoke_agent(self, user_instructions, **kwargs):
120
+ """
121
+ Invokes the SQL Data Analyst Multi-Agent.
122
+
123
+ Parameters:
124
+ ----------
125
+ user_instructions: str
126
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
127
+ **kwargs:
128
+ Additional keyword arguments to pass to the compiled graph's `invoke` method.
129
+
130
+ Returns:
131
+ -------
132
+ None. The response is stored in the `response` attribute.
133
+
134
+ Example:
135
+ --------
136
+ ``` python
137
+ # TODO
138
+ ```
139
+ """
140
+ response = self._compiled_graph.invoke({
141
+ "user_instructions": user_instructions,
142
+ }, **kwargs)
143
+ self.response = response
144
+
145
+ def get_data_sql(self):
146
+ """
147
+ Returns the SQL data as a Pandas DataFrame.
148
+ """
149
+ if self.response:
150
+ if self.response.get("data_sql"):
151
+ return pd.DataFrame(self.response.get("data_sql"))
152
+
153
+ def get_plotly_graph(self):
154
+ """
155
+ Returns the Plotly graph as a Plotly object.
156
+ """
157
+ if self.response:
158
+ if self.response.get("plotly_graph"):
159
+ return plotly_from_dict(self.response.get("plotly_graph"))
160
+
161
+ def get_sql_query_code(self, markdown=False):
162
+ """
163
+ Returns the SQL query code as a string.
164
+
165
+ Parameters:
166
+ ----------
167
+ markdown: bool
168
+ If True, returns the code as a Markdown code block for Jupyter (IPython).
169
+ For streamlit, use `st.code()` instead.
170
+ """
171
+ if self.response:
172
+ if self.response.get("sql_query_code"):
173
+ if markdown:
174
+ return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
175
+ return self.response.get("sql_query_code")
176
+
177
+ def get_sql_database_function(self, markdown=False):
178
+ """
179
+ Returns the SQL database function as a string.
180
+
181
+ Parameters:
182
+ ----------
183
+ markdown: bool
184
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
185
+ For streamlit, use `st.code()` instead.
186
+ """
187
+ if self.response:
188
+ if self.response.get("sql_database_function"):
189
+ if markdown:
190
+ return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
191
+ return self.response.get("sql_database_function")
192
+
193
+ def get_data_visualization_function(self, markdown=False):
194
+ """
195
+ Returns the data visualization function as a string.
196
+
197
+ Parameters:
198
+ ----------
199
+ markdown: bool
200
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
201
+ For streamlit, use `st.code()` instead.
202
+ """
203
+ if self.response:
204
+ if self.response.get("data_visualization_function"):
205
+ if markdown:
206
+ return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
207
+ return self.response.get("data_visualization_function")
208
+
209
+
210
+
211
+ def make_sql_data_analyst(
212
+ model,
213
+ sql_database_agent: CompiledStateGraph,
214
+ data_visualization_agent: CompiledStateGraph,
215
+ checkpointer: Checkpointer = None
216
+ ):
217
+ """
218
+ Creates a multi-agent system that takes in a SQL query and returns a plot or table.
219
+
220
+ - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
221
+ - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
222
+
223
+ Parameters:
224
+ ----------
225
+ model:
226
+ The language model to be used for the agents.
227
+ sql_database_agent: CompiledStateGraph
228
+ The SQL Database Agent made with `make_sql_database_agent()`.
229
+ data_visualization_agent: CompiledStateGraph
230
+ The Data Visualization Agent made with `make_data_visualization_agent()`.
231
+ checkpointer: Checkpointer (optional)
232
+ The checkpointer to save the state of the multi-agent system.
233
+ Default: None
234
+
235
+ Returns:
236
+ -------
237
+ CompiledStateGraph
238
+ The compiled multi-agent system.
239
+ """
240
+
241
+ llm = model
242
+
243
+ class PrimaryState(TypedDict):
244
+ messages: Annotated[Sequence[BaseMessage], operator.add]
245
+ user_instructions: str
246
+ sql_query_code: str
247
+ sql_database_function: str
248
+ data_sql: dict
249
+ data_raw: dict
250
+ plot_required: bool
251
+ data_visualization_function: str
252
+ plotly_graph: dict
253
+
254
+ def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
255
+
256
+ response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
257
+
258
+ if response.content == 'plot':
259
+ plot_required = True
260
+ goto="data_visualization_agent"
261
+ else:
262
+ plot_required = False
263
+ goto="__end__"
264
+
265
+ return Command(
266
+ update={
267
+ 'data_raw': state.get("data_sql"),
268
+ 'plot_required': plot_required,
269
+ },
270
+ goto=goto
271
+ )
272
+
273
+ workflow = StateGraph(PrimaryState)
274
+
275
+ workflow.add_node("sql_database_agent", sql_database_agent)
276
+ workflow.add_node("route_to_visualization", route_to_visualization)
277
+ workflow.add_node("data_visualization_agent", data_visualization_agent)
278
+
279
+ workflow.add_edge(START, "sql_database_agent")
280
+ workflow.add_edge("sql_database_agent", "route_to_visualization")
281
+ workflow.add_edge("data_visualization_agent", END)
282
+
283
+ app = workflow.compile(checkpointer=checkpointer)
284
+
285
+ return app
286
+
@@ -0,0 +1,2 @@
1
+ # TODO: Implement the supervised data analyst agent
2
+ # https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/#create-agent-supervisor
@@ -4,5 +4,6 @@ from ai_data_science_team.templates.agent_templates import(
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
6
  node_func_execute_agent_from_sql_connection,
7
- create_coding_agent_graph
7
+ create_coding_agent_graph,
8
+ BaseAgent,
8
9
  )
@@ -1,15 +1,202 @@
1
1
  from langchain_core.messages import AIMessage
2
2
  from langgraph.graph import StateGraph, END
3
3
  from langgraph.types import interrupt, Command
4
+ from langgraph.graph.state import CompiledStateGraph
5
+
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph.pregel.types import StreamMode
4
8
 
5
9
  import pandas as pd
6
10
  import sqlalchemy as sql
7
11
 
8
- from typing import Any, Callable, Dict, Type, Optional
12
+ from typing import Any, Callable, Dict, Type, Optional, Union
9
13
 
10
14
  from ai_data_science_team.tools.parsers import PythonOutputParser
11
15
  from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
12
16
 
17
+ from IPython.display import Image, display
18
+ import pandas as pd
19
+
20
+ class BaseAgent(CompiledStateGraph):
21
+ """
22
+ A generic base class for agents that interact with compiled state graphs.
23
+
24
+ Provides shared functionality for handling parameters, responses, and state
25
+ graph operations.
26
+ """
27
+
28
+ def __init__(self, **params):
29
+ """
30
+ Initialize the agent with provided parameters.
31
+
32
+ Parameters:
33
+ **params: Arbitrary keyword arguments representing the agent's parameters.
34
+ """
35
+ self._params = params
36
+ self._compiled_graph = self._make_compiled_graph()
37
+ self.response = None
38
+
39
+ def _make_compiled_graph(self):
40
+ """
41
+ Subclasses should override this method to create a specific compiled graph.
42
+ """
43
+ raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
44
+
45
+ def update_params(self, **kwargs):
46
+ """
47
+ Update one or more parameters and rebuild the compiled graph.
48
+
49
+ Parameters:
50
+ **kwargs: Parameters to update.
51
+ """
52
+ self._params.update(kwargs)
53
+ self._compiled_graph = self._make_compiled_graph()
54
+
55
+ def __getattr__(self, name: str):
56
+ """
57
+ Delegate attribute access to the compiled graph if the attribute is not found.
58
+
59
+ Parameters:
60
+ name (str): The attribute name.
61
+
62
+ Returns:
63
+ Any: The attribute from the compiled graph.
64
+ """
65
+ return getattr(self._compiled_graph, name)
66
+
67
+ def invoke(
68
+ self,
69
+ input: Union[dict[str, Any], Any],
70
+ config: Optional[RunnableConfig] = None,
71
+ **kwargs
72
+ ):
73
+ """
74
+ Wrapper for self._compiled_graph.invoke()
75
+
76
+ Parameters:
77
+ input: The input data for the graph. It can be a dictionary or any other type.
78
+ config: Optional. The configuration for the graph run.
79
+ **kwarg: Arguments to pass to self._compiled_graph.invoke()
80
+
81
+ Returns:
82
+ Any: The agent's response.
83
+ """
84
+ self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
85
+ return self.response
86
+
87
+ def ainvoke(
88
+ self,
89
+ input: Union[dict[str, Any], Any],
90
+ config: Optional[RunnableConfig] = None,
91
+ **kwargs
92
+ ):
93
+ """
94
+ Wrapper for self._compiled_graph.ainvoke()
95
+
96
+ Parameters:
97
+ input: The input data for the graph. It can be a dictionary or any other type.
98
+ config: Optional. The configuration for the graph run.
99
+ **kwarg: Arguments to pass to self._compiled_graph.ainvoke()
100
+
101
+ Returns:
102
+ Any: The agent's response.
103
+ """
104
+ self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
105
+ return self.response
106
+
107
+ def stream(
108
+ self,
109
+ input: dict[str, Any] | Any,
110
+ config: RunnableConfig | None = None,
111
+ stream_mode: StreamMode | list[StreamMode] | None = None,
112
+ **kwargs
113
+ ):
114
+ """
115
+ Wrapper for self._compiled_graph.stream()
116
+
117
+ Parameters:
118
+ input: The input to the graph.
119
+ config: The configuration to use for the run.
120
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
121
+ Options are 'values', 'updates', and 'debug'.
122
+ values: Emit the current values of the state for each step.
123
+ updates: Emit only the updates to the state for each step.
124
+ Output is a dict with the node name as key and the updated values as value.
125
+ debug: Emit debug events for each step.
126
+ **kwarg: Arguments to pass to self._compiled_graph.stream()
127
+
128
+ Returns:
129
+ Any: The agent's response.
130
+ """
131
+ self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
132
+ return self.response
133
+
134
+ def astream(
135
+ self,
136
+ input: dict[str, Any] | Any,
137
+ config: RunnableConfig | None = None,
138
+ stream_mode: StreamMode | list[StreamMode] | None = None,
139
+ **kwargs
140
+ ):
141
+ """
142
+ Wrapper for self._compiled_graph.astream()
143
+
144
+ Parameters:
145
+ input: The input to the graph.
146
+ config: The configuration to use for the run.
147
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
148
+ Options are 'values', 'updates', and 'debug'.
149
+ values: Emit the current values of the state for each step.
150
+ updates: Emit only the updates to the state for each step.
151
+ Output is a dict with the node name as key and the updated values as value.
152
+ debug: Emit debug events for each step.
153
+ **kwarg: Arguments to pass to self._compiled_graph.astream()
154
+
155
+ Returns:
156
+ Any: The agent's response.
157
+ """
158
+ self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
159
+ return self.response
160
+
161
+ def get_state_keys(self):
162
+ """
163
+ Returns a list of keys that the state graph response contains.
164
+
165
+ Returns:
166
+ list: A list of keys in the response.
167
+ """
168
+ return list(self.get_output_jsonschema()['properties'].keys())
169
+
170
+ def get_state_properties(self):
171
+ """
172
+ Returns detailed properties of the state graph response.
173
+
174
+ Returns:
175
+ dict: The properties of the response.
176
+ """
177
+ return self.get_output_jsonschema()['properties']
178
+
179
+ def get_response(self):
180
+ """
181
+ Returns the response generated by the agent.
182
+
183
+ Returns:
184
+ Any: The agent's response.
185
+ """
186
+ return self.response
187
+
188
+ def show(self, xray: int = 0):
189
+ """
190
+ Displays the agent's state graph as a Mermaid diagram.
191
+
192
+ Parameters:
193
+ xray (int): If set to 1, displays subgraph levels. Defaults to 0.
194
+ """
195
+ display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
196
+
197
+
198
+
199
+
13
200
  def create_coding_agent_graph(
14
201
  GraphState: Type,
15
202
  node_functions: Dict[str, Callable],
@@ -79,35 +266,37 @@ def create_coding_agent_graph(
79
266
 
80
267
  workflow = StateGraph(GraphState)
81
268
 
82
- # Conditionally add the recommended-steps node
83
- if not bypass_recommended_steps:
84
- workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
269
+ # * NODES
85
270
 
86
271
  # Always add create, execute, and fix nodes
87
272
  workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
88
273
  workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
89
274
  workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
90
275
 
276
+ # Conditionally add the recommended-steps node
277
+ if not bypass_recommended_steps:
278
+ workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
279
+
280
+ # Conditionally add the human review node
281
+ if human_in_the_loop:
282
+ workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
283
+
91
284
  # Conditionally add the explanation node
92
285
  if not bypass_explain_code:
93
286
  workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
94
287
 
288
+ # * EDGES
289
+
95
290
  # Set the entry point
96
291
  entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
292
+
97
293
  workflow.set_entry_point(entry_point)
98
294
 
99
- # Add edges for recommended steps
100
295
  if not bypass_recommended_steps:
101
- if human_in_the_loop:
102
- workflow.add_edge(recommended_steps_node_name, human_review_node_name)
103
- else:
104
- workflow.add_edge(recommended_steps_node_name, create_code_node_name)
105
- elif human_in_the_loop:
106
- # Skip recommended steps but still include human review
107
- workflow.add_edge(create_code_node_name, human_review_node_name)
296
+ workflow.add_edge(recommended_steps_node_name, create_code_node_name)
108
297
 
109
- # Create -> Execute
110
298
  workflow.add_edge(create_code_node_name, execute_code_node_name)
299
+ workflow.add_edge(fix_code_node_name, execute_code_node_name)
111
300
 
112
301
  # Define a helper to check if we have an error & can still retry
113
302
  def error_and_can_retry(state):
@@ -117,39 +306,43 @@ def create_coding_agent_graph(
117
306
  and state.get(max_retries_key) is not None
118
307
  and state[retry_count_key] < state[max_retries_key]
119
308
  )
120
-
121
- # ---- Split into two branches for bypass_explain_code ----
122
- if not bypass_explain_code:
123
- # If we are NOT bypassing explain, the next node is fix_code if error,
124
- # else explain_code. Then we wire explain_code -> END afterward.
309
+
310
+ # If human in the loop, add a branch for human review
311
+ if human_in_the_loop:
125
312
  workflow.add_conditional_edges(
126
313
  execute_code_node_name,
127
- lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
314
+ lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
128
315
  {
316
+ "human_review": human_review_node_name,
129
317
  "fix_code": fix_code_node_name,
130
- "explain_code": explain_code_node_name,
131
318
  },
132
319
  )
133
- # Fix code -> Execute again
134
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
135
- # explain_code -> END
136
- workflow.add_edge(explain_code_node_name, END)
137
320
  else:
138
- # If we ARE bypassing explain_code, the next node is fix_code if error,
139
- # else straight to END.
140
- workflow.add_conditional_edges(
141
- execute_code_node_name,
142
- lambda s: "fix_code" if error_and_can_retry(s) else "END",
143
- {
144
- "fix_code": fix_code_node_name,
145
- "END": END,
146
- },
147
- )
148
- # Fix code -> Execute again
149
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
321
+ # If no human review, the next node is fix_code if error, else explain_code.
322
+ if not bypass_explain_code:
323
+ workflow.add_conditional_edges(
324
+ execute_code_node_name,
325
+ lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
326
+ {
327
+ "fix_code": fix_code_node_name,
328
+ "explain_code": explain_code_node_name,
329
+ },
330
+ )
331
+ else:
332
+ workflow.add_conditional_edges(
333
+ execute_code_node_name,
334
+ lambda s: "fix_code" if error_and_can_retry(s) else "END",
335
+ {
336
+ "fix_code": fix_code_node_name,
337
+ "END": END,
338
+ },
339
+ )
340
+
341
+ if not bypass_explain_code:
342
+ workflow.add_edge(explain_code_node_name, END)
150
343
 
151
344
  # Finally, compile
152
- if human_in_the_loop and checkpointer is not None:
345
+ if human_in_the_loop:
153
346
  app = workflow.compile(checkpointer=checkpointer)
154
347
  else:
155
348
  app = workflow.compile()
@@ -165,6 +358,8 @@ def node_func_human_review(
165
358
  no_goto: str,
166
359
  user_instructions_key: str = "user_instructions",
167
360
  recommended_steps_key: str = "recommended_steps",
361
+ code_snippet_key: str = "code_snippet",
362
+ code_type: str = "python"
168
363
  ) -> Command[str]:
169
364
  """
170
365
  A generic function to handle human review steps.
@@ -183,6 +378,10 @@ def node_func_human_review(
183
378
  The key in the state to store user instructions.
184
379
  recommended_steps_key : str, optional
185
380
  The key in the state to store recommended steps.
381
+ code_snippet_key : str, optional
382
+ The key in the state to store the code snippet.
383
+ code_type : str, optional
384
+ The type of code snippet to display (e.g., "python").
186
385
 
187
386
  Returns
188
387
  -------
@@ -190,9 +389,11 @@ def node_func_human_review(
190
389
  A Command object directing the next state and updates to the state.
191
390
  """
192
391
  print(" * HUMAN REVIEW")
392
+
393
+ code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
193
394
 
194
395
  # Display instructions and get user response
195
- user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
396
+ user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
196
397
 
197
398
  # Decide next steps based on user input
198
399
  if user_input.strip().lower() == "yes":
@@ -200,11 +401,11 @@ def node_func_human_review(
200
401
  update = {}
201
402
  else:
202
403
  goto = no_goto
203
- modifications = "Modifications: \n" + user_input
404
+ modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
204
405
  if state.get(user_instructions_key) is None:
205
- update = {user_instructions_key: modifications}
406
+ update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
206
407
  else:
207
- update = {user_instructions_key: state.get(user_instructions_key) + modifications}
408
+ update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
208
409
 
209
410
  return Command(goto=goto, update=update)
210
411
 
@@ -394,6 +595,7 @@ def node_func_fix_agent_code(
394
595
  retry_count_key: str = "retry_count",
395
596
  log: bool = False,
396
597
  file_path: str = "logs/agent_function.py",
598
+ function_name: str = "agent_function"
397
599
  ) -> dict:
398
600
  """
399
601
  Generic function to fix a given piece of agent code using an LLM and a prompt template.
@@ -420,6 +622,8 @@ def node_func_fix_agent_code(
420
622
  Whether to log the returned code to a file.
421
623
  file_path : str, optional
422
624
  The path to the file where the code will be logged.
625
+ function_name : str, optional
626
+ The name of the function in the code snippet that will be fixed.
423
627
 
424
628
  Returns
425
629
  -------
@@ -436,7 +640,8 @@ def node_func_fix_agent_code(
436
640
  # Format the prompt with the code snippet and the error
437
641
  prompt = prompt_template.format(
438
642
  code_snippet=code_snippet,
439
- error=error_message
643
+ error=error_message,
644
+ function_name=function_name,
440
645
  )
441
646
 
442
647
  # Execute the prompt with the LLM