ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+
2
+ from langchain_core.messages import BaseMessage
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langgraph.types import Checkpointer
5
+
6
+ from langgraph.graph import START, END, StateGraph
7
+ from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import Command
9
+
10
+ from typing import TypedDict, Annotated, Sequence
11
+ import operator
12
+
13
+ from typing_extensions import TypedDict, Literal
14
+
15
+ import pandas as pd
16
+ from IPython.display import Markdown
17
+
18
+ from ai_data_science_team.templates import BaseAgent
19
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
+ from ai_data_science_team.utils.plotly import plotly_from_dict
21
+
22
+
23
+
24
+ class SQLDataAnalyst(BaseAgent):
25
+ """
26
+ SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
27
+
28
+ Parameters:
29
+ -----------
30
+ model:
31
+ The language model to be used for the agents.
32
+ sql_database_agent: SQLDatabaseAgent
33
+ The SQL Database Agent.
34
+ data_visualization_agent: DataVisualizationAgent
35
+ The Data Visualization Agent.
36
+
37
+ Methods:
38
+ --------
39
+ ainvoke_agent(user_instructions, **kwargs)
40
+ Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
41
+ invoke_agent(user_instructions, **kwargs)
42
+ Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
43
+ get_data_sql()
44
+ Returns the SQL data as a Pandas DataFrame.
45
+ get_plotly_graph()
46
+ Returns the Plotly graph as a Plotly object.
47
+ get_sql_query_code(markdown=False)
48
+ Returns the SQL query code as a string, optionally formatted as a Markdown code block.
49
+ get_sql_database_function(markdown=False)
50
+ Returns the SQL database function as a string, optionally formatted as a Markdown code block.
51
+ get_data_visualization_function(markdown=False)
52
+ Returns the data visualization function as a string, optionally formatted as a Markdown code block.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model,
58
+ sql_database_agent: SQLDatabaseAgent,
59
+ data_visualization_agent: DataVisualizationAgent,
60
+ checkpointer: Checkpointer = None,
61
+ ):
62
+ self._params = {
63
+ "model": model,
64
+ "sql_database_agent": sql_database_agent,
65
+ "data_visualization_agent": data_visualization_agent,
66
+ "checkpointer": checkpointer,
67
+ }
68
+ self._compiled_graph = self._make_compiled_graph()
69
+ self.response = None
70
+
71
+ def _make_compiled_graph(self):
72
+ """
73
+ Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
74
+ Running this method resets the response to None.
75
+ """
76
+ self.response = None
77
+ return make_sql_data_analyst(
78
+ model=self._params["model"],
79
+ sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
80
+ data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
81
+ checkpointer=self._params["checkpointer"],
82
+ )
83
+
84
+ def update_params(self, **kwargs):
85
+ """
86
+ Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
87
+ and rebuilds the compiled graph.
88
+ """
89
+ for k, v in kwargs.items():
90
+ self._params[k] = v
91
+ self._compiled_graph = self._make_compiled_graph()
92
+
93
+ def ainvoke_agent(self, user_instructions, **kwargs):
94
+ """
95
+ Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
96
+
97
+ Parameters:
98
+ ----------
99
+ user_instructions: str
100
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
101
+ **kwargs:
102
+ Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
103
+
104
+ Returns:
105
+ -------
106
+ None. The response is stored in the `response` attribute.
107
+
108
+ Example:
109
+ --------
110
+ ``` python
111
+ # TODO
112
+ ```
113
+ """
114
+ response = self._compiled_graph.ainvoke({
115
+ "user_instructions": user_instructions,
116
+ }, **kwargs)
117
+ self.response = response
118
+
119
+ def invoke_agent(self, user_instructions, **kwargs):
120
+ """
121
+ Invokes the SQL Data Analyst Multi-Agent.
122
+
123
+ Parameters:
124
+ ----------
125
+ user_instructions: str
126
+ The user's instructions for the combined SQL and (optionally) Data Visualization agents.
127
+ **kwargs:
128
+ Additional keyword arguments to pass to the compiled graph's `invoke` method.
129
+
130
+ Returns:
131
+ -------
132
+ None. The response is stored in the `response` attribute.
133
+
134
+ Example:
135
+ --------
136
+ ``` python
137
+ # TODO
138
+ ```
139
+ """
140
+ response = self._compiled_graph.invoke({
141
+ "user_instructions": user_instructions,
142
+ }, **kwargs)
143
+ self.response = response
144
+
145
+ def get_data_sql(self):
146
+ """
147
+ Returns the SQL data as a Pandas DataFrame.
148
+ """
149
+ if self.response:
150
+ if self.response.get("data_sql"):
151
+ return pd.DataFrame(self.response.get("data_sql"))
152
+
153
+ def get_plotly_graph(self):
154
+ """
155
+ Returns the Plotly graph as a Plotly object.
156
+ """
157
+ if self.response:
158
+ if self.response.get("plotly_graph"):
159
+ return plotly_from_dict(self.response.get("plotly_graph"))
160
+
161
+ def get_sql_query_code(self, markdown=False):
162
+ """
163
+ Returns the SQL query code as a string.
164
+
165
+ Parameters:
166
+ ----------
167
+ markdown: bool
168
+ If True, returns the code as a Markdown code block for Jupyter (IPython).
169
+ For streamlit, use `st.code()` instead.
170
+ """
171
+ if self.response:
172
+ if self.response.get("sql_query_code"):
173
+ if markdown:
174
+ return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
175
+ return self.response.get("sql_query_code")
176
+
177
+ def get_sql_database_function(self, markdown=False):
178
+ """
179
+ Returns the SQL database function as a string.
180
+
181
+ Parameters:
182
+ ----------
183
+ markdown: bool
184
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
185
+ For streamlit, use `st.code()` instead.
186
+ """
187
+ if self.response:
188
+ if self.response.get("sql_database_function"):
189
+ if markdown:
190
+ return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
191
+ return self.response.get("sql_database_function")
192
+
193
+ def get_data_visualization_function(self, markdown=False):
194
+ """
195
+ Returns the data visualization function as a string.
196
+
197
+ Parameters:
198
+ ----------
199
+ markdown: bool
200
+ If True, returns the function as a Markdown code block for Jupyter (IPython).
201
+ For streamlit, use `st.code()` instead.
202
+ """
203
+ if self.response:
204
+ if self.response.get("data_visualization_function"):
205
+ if markdown:
206
+ return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
207
+ return self.response.get("data_visualization_function")
208
+
209
+
210
+
211
+ def make_sql_data_analyst(
212
+ model,
213
+ sql_database_agent: CompiledStateGraph,
214
+ data_visualization_agent: CompiledStateGraph,
215
+ checkpointer: Checkpointer = None
216
+ ):
217
+ """
218
+ Creates a multi-agent system that takes in a SQL query and returns a plot or table.
219
+
220
+ - Agent 1: SQL Database Agent made with `make_sql_database_agent()`
221
+ - Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
222
+
223
+ Parameters:
224
+ ----------
225
+ model:
226
+ The language model to be used for the agents.
227
+ sql_database_agent: CompiledStateGraph
228
+ The SQL Database Agent made with `make_sql_database_agent()`.
229
+ data_visualization_agent: CompiledStateGraph
230
+ The Data Visualization Agent made with `make_data_visualization_agent()`.
231
+ checkpointer: Checkpointer (optional)
232
+ The checkpointer to save the state of the multi-agent system.
233
+ Default: None
234
+
235
+ Returns:
236
+ -------
237
+ CompiledStateGraph
238
+ The compiled multi-agent system.
239
+ """
240
+
241
+ llm = model
242
+
243
+ class PrimaryState(TypedDict):
244
+ messages: Annotated[Sequence[BaseMessage], operator.add]
245
+ user_instructions: str
246
+ sql_query_code: str
247
+ sql_database_function: str
248
+ data_sql: dict
249
+ data_raw: dict
250
+ plot_required: bool
251
+ data_visualization_function: str
252
+ plotly_graph: dict
253
+
254
+ def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
255
+
256
+ response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
257
+
258
+ if response.content == 'plot':
259
+ plot_required = True
260
+ goto="data_visualization_agent"
261
+ else:
262
+ plot_required = False
263
+ goto="__end__"
264
+
265
+ return Command(
266
+ update={
267
+ 'data_raw': state.get("data_sql"),
268
+ 'plot_required': plot_required,
269
+ },
270
+ goto=goto
271
+ )
272
+
273
+ workflow = StateGraph(PrimaryState)
274
+
275
+ workflow.add_node("sql_database_agent", sql_database_agent)
276
+ workflow.add_node("route_to_visualization", route_to_visualization)
277
+ workflow.add_node("data_visualization_agent", data_visualization_agent)
278
+
279
+ workflow.add_edge(START, "sql_database_agent")
280
+ workflow.add_edge("sql_database_agent", "route_to_visualization")
281
+ workflow.add_edge("data_visualization_agent", END)
282
+
283
+ app = workflow.compile(checkpointer=checkpointer)
284
+
285
+ return app
286
+
@@ -0,0 +1,2 @@
1
+ # TODO: Implement the supervised data analyst agent
2
+ # https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/#create-agent-supervisor
@@ -0,0 +1,9 @@
1
+ from ai_data_science_team.templates.agent_templates import(
2
+ node_func_execute_agent_code_on_data,
3
+ node_func_human_review,
4
+ node_func_fix_agent_code,
5
+ node_func_explain_agent_code,
6
+ node_func_execute_agent_from_sql_connection,
7
+ create_coding_agent_graph,
8
+ BaseAgent,
9
+ )
@@ -1,15 +1,202 @@
1
1
  from langchain_core.messages import AIMessage
2
2
  from langgraph.graph import StateGraph, END
3
3
  from langgraph.types import interrupt, Command
4
+ from langgraph.graph.state import CompiledStateGraph
5
+
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph.pregel.types import StreamMode
4
8
 
5
9
  import pandas as pd
6
10
  import sqlalchemy as sql
7
11
 
8
- from typing import Any, Callable, Dict, Type, Optional
12
+ from typing import Any, Callable, Dict, Type, Optional, Union
9
13
 
10
14
  from ai_data_science_team.tools.parsers import PythonOutputParser
11
15
  from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
12
16
 
17
+ from IPython.display import Image, display
18
+ import pandas as pd
19
+
20
+ class BaseAgent(CompiledStateGraph):
21
+ """
22
+ A generic base class for agents that interact with compiled state graphs.
23
+
24
+ Provides shared functionality for handling parameters, responses, and state
25
+ graph operations.
26
+ """
27
+
28
+ def __init__(self, **params):
29
+ """
30
+ Initialize the agent with provided parameters.
31
+
32
+ Parameters:
33
+ **params: Arbitrary keyword arguments representing the agent's parameters.
34
+ """
35
+ self._params = params
36
+ self._compiled_graph = self._make_compiled_graph()
37
+ self.response = None
38
+
39
+ def _make_compiled_graph(self):
40
+ """
41
+ Subclasses should override this method to create a specific compiled graph.
42
+ """
43
+ raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
44
+
45
+ def update_params(self, **kwargs):
46
+ """
47
+ Update one or more parameters and rebuild the compiled graph.
48
+
49
+ Parameters:
50
+ **kwargs: Parameters to update.
51
+ """
52
+ self._params.update(kwargs)
53
+ self._compiled_graph = self._make_compiled_graph()
54
+
55
+ def __getattr__(self, name: str):
56
+ """
57
+ Delegate attribute access to the compiled graph if the attribute is not found.
58
+
59
+ Parameters:
60
+ name (str): The attribute name.
61
+
62
+ Returns:
63
+ Any: The attribute from the compiled graph.
64
+ """
65
+ return getattr(self._compiled_graph, name)
66
+
67
+ def invoke(
68
+ self,
69
+ input: Union[dict[str, Any], Any],
70
+ config: Optional[RunnableConfig] = None,
71
+ **kwargs
72
+ ):
73
+ """
74
+ Wrapper for self._compiled_graph.invoke()
75
+
76
+ Parameters:
77
+ input: The input data for the graph. It can be a dictionary or any other type.
78
+ config: Optional. The configuration for the graph run.
79
+ **kwarg: Arguments to pass to self._compiled_graph.invoke()
80
+
81
+ Returns:
82
+ Any: The agent's response.
83
+ """
84
+ self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
85
+ return self.response
86
+
87
+ def ainvoke(
88
+ self,
89
+ input: Union[dict[str, Any], Any],
90
+ config: Optional[RunnableConfig] = None,
91
+ **kwargs
92
+ ):
93
+ """
94
+ Wrapper for self._compiled_graph.ainvoke()
95
+
96
+ Parameters:
97
+ input: The input data for the graph. It can be a dictionary or any other type.
98
+ config: Optional. The configuration for the graph run.
99
+ **kwarg: Arguments to pass to self._compiled_graph.ainvoke()
100
+
101
+ Returns:
102
+ Any: The agent's response.
103
+ """
104
+ self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
105
+ return self.response
106
+
107
+ def stream(
108
+ self,
109
+ input: dict[str, Any] | Any,
110
+ config: RunnableConfig | None = None,
111
+ stream_mode: StreamMode | list[StreamMode] | None = None,
112
+ **kwargs
113
+ ):
114
+ """
115
+ Wrapper for self._compiled_graph.stream()
116
+
117
+ Parameters:
118
+ input: The input to the graph.
119
+ config: The configuration to use for the run.
120
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
121
+ Options are 'values', 'updates', and 'debug'.
122
+ values: Emit the current values of the state for each step.
123
+ updates: Emit only the updates to the state for each step.
124
+ Output is a dict with the node name as key and the updated values as value.
125
+ debug: Emit debug events for each step.
126
+ **kwarg: Arguments to pass to self._compiled_graph.stream()
127
+
128
+ Returns:
129
+ Any: The agent's response.
130
+ """
131
+ self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
132
+ return self.response
133
+
134
+ def astream(
135
+ self,
136
+ input: dict[str, Any] | Any,
137
+ config: RunnableConfig | None = None,
138
+ stream_mode: StreamMode | list[StreamMode] | None = None,
139
+ **kwargs
140
+ ):
141
+ """
142
+ Wrapper for self._compiled_graph.astream()
143
+
144
+ Parameters:
145
+ input: The input to the graph.
146
+ config: The configuration to use for the run.
147
+ stream_mode: The mode to stream output, defaults to self.stream_mode.
148
+ Options are 'values', 'updates', and 'debug'.
149
+ values: Emit the current values of the state for each step.
150
+ updates: Emit only the updates to the state for each step.
151
+ Output is a dict with the node name as key and the updated values as value.
152
+ debug: Emit debug events for each step.
153
+ **kwarg: Arguments to pass to self._compiled_graph.astream()
154
+
155
+ Returns:
156
+ Any: The agent's response.
157
+ """
158
+ self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
159
+ return self.response
160
+
161
+ def get_state_keys(self):
162
+ """
163
+ Returns a list of keys that the state graph response contains.
164
+
165
+ Returns:
166
+ list: A list of keys in the response.
167
+ """
168
+ return list(self.get_output_jsonschema()['properties'].keys())
169
+
170
+ def get_state_properties(self):
171
+ """
172
+ Returns detailed properties of the state graph response.
173
+
174
+ Returns:
175
+ dict: The properties of the response.
176
+ """
177
+ return self.get_output_jsonschema()['properties']
178
+
179
+ def get_response(self):
180
+ """
181
+ Returns the response generated by the agent.
182
+
183
+ Returns:
184
+ Any: The agent's response.
185
+ """
186
+ return self.response
187
+
188
+ def show(self, xray: int = 0):
189
+ """
190
+ Displays the agent's state graph as a Mermaid diagram.
191
+
192
+ Parameters:
193
+ xray (int): If set to 1, displays subgraph levels. Defaults to 0.
194
+ """
195
+ display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
196
+
197
+
198
+
199
+
13
200
  def create_coding_agent_graph(
14
201
  GraphState: Type,
15
202
  node_functions: Dict[str, Callable],
@@ -79,35 +266,37 @@ def create_coding_agent_graph(
79
266
 
80
267
  workflow = StateGraph(GraphState)
81
268
 
82
- # Conditionally add the recommended-steps node
83
- if not bypass_recommended_steps:
84
- workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
269
+ # * NODES
85
270
 
86
271
  # Always add create, execute, and fix nodes
87
272
  workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
88
273
  workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
89
274
  workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
90
275
 
276
+ # Conditionally add the recommended-steps node
277
+ if not bypass_recommended_steps:
278
+ workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
279
+
280
+ # Conditionally add the human review node
281
+ if human_in_the_loop:
282
+ workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
283
+
91
284
  # Conditionally add the explanation node
92
285
  if not bypass_explain_code:
93
286
  workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
94
287
 
288
+ # * EDGES
289
+
95
290
  # Set the entry point
96
291
  entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
292
+
97
293
  workflow.set_entry_point(entry_point)
98
294
 
99
- # Add edges for recommended steps
100
295
  if not bypass_recommended_steps:
101
- if human_in_the_loop:
102
- workflow.add_edge(recommended_steps_node_name, human_review_node_name)
103
- else:
104
- workflow.add_edge(recommended_steps_node_name, create_code_node_name)
105
- elif human_in_the_loop:
106
- # Skip recommended steps but still include human review
107
- workflow.add_edge(create_code_node_name, human_review_node_name)
296
+ workflow.add_edge(recommended_steps_node_name, create_code_node_name)
108
297
 
109
- # Create -> Execute
110
298
  workflow.add_edge(create_code_node_name, execute_code_node_name)
299
+ workflow.add_edge(fix_code_node_name, execute_code_node_name)
111
300
 
112
301
  # Define a helper to check if we have an error & can still retry
113
302
  def error_and_can_retry(state):
@@ -117,39 +306,43 @@ def create_coding_agent_graph(
117
306
  and state.get(max_retries_key) is not None
118
307
  and state[retry_count_key] < state[max_retries_key]
119
308
  )
120
-
121
- # ---- Split into two branches for bypass_explain_code ----
122
- if not bypass_explain_code:
123
- # If we are NOT bypassing explain, the next node is fix_code if error,
124
- # else explain_code. Then we wire explain_code -> END afterward.
309
+
310
+ # If human in the loop, add a branch for human review
311
+ if human_in_the_loop:
125
312
  workflow.add_conditional_edges(
126
313
  execute_code_node_name,
127
- lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
314
+ lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
128
315
  {
316
+ "human_review": human_review_node_name,
129
317
  "fix_code": fix_code_node_name,
130
- "explain_code": explain_code_node_name,
131
318
  },
132
319
  )
133
- # Fix code -> Execute again
134
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
135
- # explain_code -> END
136
- workflow.add_edge(explain_code_node_name, END)
137
320
  else:
138
- # If we ARE bypassing explain_code, the next node is fix_code if error,
139
- # else straight to END.
140
- workflow.add_conditional_edges(
141
- execute_code_node_name,
142
- lambda s: "fix_code" if error_and_can_retry(s) else "END",
143
- {
144
- "fix_code": fix_code_node_name,
145
- "END": END,
146
- },
147
- )
148
- # Fix code -> Execute again
149
- workflow.add_edge(fix_code_node_name, execute_code_node_name)
321
+ # If no human review, the next node is fix_code if error, else explain_code.
322
+ if not bypass_explain_code:
323
+ workflow.add_conditional_edges(
324
+ execute_code_node_name,
325
+ lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
326
+ {
327
+ "fix_code": fix_code_node_name,
328
+ "explain_code": explain_code_node_name,
329
+ },
330
+ )
331
+ else:
332
+ workflow.add_conditional_edges(
333
+ execute_code_node_name,
334
+ lambda s: "fix_code" if error_and_can_retry(s) else "END",
335
+ {
336
+ "fix_code": fix_code_node_name,
337
+ "END": END,
338
+ },
339
+ )
340
+
341
+ if not bypass_explain_code:
342
+ workflow.add_edge(explain_code_node_name, END)
150
343
 
151
344
  # Finally, compile
152
- if human_in_the_loop and checkpointer is not None:
345
+ if human_in_the_loop:
153
346
  app = workflow.compile(checkpointer=checkpointer)
154
347
  else:
155
348
  app = workflow.compile()
@@ -165,6 +358,8 @@ def node_func_human_review(
165
358
  no_goto: str,
166
359
  user_instructions_key: str = "user_instructions",
167
360
  recommended_steps_key: str = "recommended_steps",
361
+ code_snippet_key: str = "code_snippet",
362
+ code_type: str = "python"
168
363
  ) -> Command[str]:
169
364
  """
170
365
  A generic function to handle human review steps.
@@ -183,6 +378,10 @@ def node_func_human_review(
183
378
  The key in the state to store user instructions.
184
379
  recommended_steps_key : str, optional
185
380
  The key in the state to store recommended steps.
381
+ code_snippet_key : str, optional
382
+ The key in the state to store the code snippet.
383
+ code_type : str, optional
384
+ The type of code snippet to display (e.g., "python").
186
385
 
187
386
  Returns
188
387
  -------
@@ -190,9 +389,11 @@ def node_func_human_review(
190
389
  A Command object directing the next state and updates to the state.
191
390
  """
192
391
  print(" * HUMAN REVIEW")
392
+
393
+ code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
193
394
 
194
395
  # Display instructions and get user response
195
- user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
396
+ user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
196
397
 
197
398
  # Decide next steps based on user input
198
399
  if user_input.strip().lower() == "yes":
@@ -200,11 +401,11 @@ def node_func_human_review(
200
401
  update = {}
201
402
  else:
202
403
  goto = no_goto
203
- modifications = "Modifications: \n" + user_input
404
+ modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
204
405
  if state.get(user_instructions_key) is None:
205
- update = {user_instructions_key: modifications}
406
+ update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
206
407
  else:
207
- update = {user_instructions_key: state.get(user_instructions_key) + modifications}
408
+ update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
208
409
 
209
410
  return Command(goto=goto, update=update)
210
411
 
@@ -394,6 +595,7 @@ def node_func_fix_agent_code(
394
595
  retry_count_key: str = "retry_count",
395
596
  log: bool = False,
396
597
  file_path: str = "logs/agent_function.py",
598
+ function_name: str = "agent_function"
397
599
  ) -> dict:
398
600
  """
399
601
  Generic function to fix a given piece of agent code using an LLM and a prompt template.
@@ -420,6 +622,8 @@ def node_func_fix_agent_code(
420
622
  Whether to log the returned code to a file.
421
623
  file_path : str, optional
422
624
  The path to the file where the code will be logged.
625
+ function_name : str, optional
626
+ The name of the function in the code snippet that will be fixed.
423
627
 
424
628
  Returns
425
629
  -------
@@ -436,7 +640,8 @@ def node_func_fix_agent_code(
436
640
  # Format the prompt with the code snippet and the error
437
641
  prompt = prompt_template.format(
438
642
  code_snippet=code_snippet,
439
- error=error_message
643
+ error=error_message,
644
+ function_name=function_name,
440
645
  )
441
646
 
442
647
  # Execute the prompt with the LLM