ai-data-science-team 0.0.0.9007__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +4 -4
- ai_data_science_team/agents/data_cleaning_agent.py +225 -84
- ai_data_science_team/agents/data_visualization_agent.py +460 -27
- ai_data_science_team/agents/data_wrangling_agent.py +455 -16
- ai_data_science_team/agents/feature_engineering_agent.py +429 -25
- ai_data_science_team/agents/sql_database_agent.py +367 -21
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +2 -1
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/regex.py +28 -1
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/METADATA +76 -28
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- ai_data_science_team-0.0.0.9007.dist-info/RECORD +0 -21
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9007.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
|
2
|
+
from langchain_core.messages import BaseMessage
|
3
|
+
from langgraph.checkpoint.memory import MemorySaver
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
|
6
|
+
from langgraph.graph import START, END, StateGraph
|
7
|
+
from langgraph.graph.state import CompiledStateGraph
|
8
|
+
from langgraph.types import Command
|
9
|
+
|
10
|
+
from typing import TypedDict, Annotated, Sequence
|
11
|
+
import operator
|
12
|
+
|
13
|
+
from typing_extensions import TypedDict, Literal
|
14
|
+
|
15
|
+
import pandas as pd
|
16
|
+
from IPython.display import Markdown
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import BaseAgent
|
19
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
20
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
|
+
|
22
|
+
|
23
|
+
|
24
|
+
class SQLDataAnalyst(BaseAgent):
|
25
|
+
"""
|
26
|
+
SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
-----------
|
30
|
+
model:
|
31
|
+
The language model to be used for the agents.
|
32
|
+
sql_database_agent: SQLDatabaseAgent
|
33
|
+
The SQL Database Agent.
|
34
|
+
data_visualization_agent: DataVisualizationAgent
|
35
|
+
The Data Visualization Agent.
|
36
|
+
|
37
|
+
Methods:
|
38
|
+
--------
|
39
|
+
ainvoke_agent(user_instructions, **kwargs)
|
40
|
+
Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
41
|
+
invoke_agent(user_instructions, **kwargs)
|
42
|
+
Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
43
|
+
get_data_sql()
|
44
|
+
Returns the SQL data as a Pandas DataFrame.
|
45
|
+
get_plotly_graph()
|
46
|
+
Returns the Plotly graph as a Plotly object.
|
47
|
+
get_sql_query_code(markdown=False)
|
48
|
+
Returns the SQL query code as a string, optionally formatted as a Markdown code block.
|
49
|
+
get_sql_database_function(markdown=False)
|
50
|
+
Returns the SQL database function as a string, optionally formatted as a Markdown code block.
|
51
|
+
get_data_visualization_function(markdown=False)
|
52
|
+
Returns the data visualization function as a string, optionally formatted as a Markdown code block.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
model,
|
58
|
+
sql_database_agent: SQLDatabaseAgent,
|
59
|
+
data_visualization_agent: DataVisualizationAgent,
|
60
|
+
checkpointer: Checkpointer = None,
|
61
|
+
):
|
62
|
+
self._params = {
|
63
|
+
"model": model,
|
64
|
+
"sql_database_agent": sql_database_agent,
|
65
|
+
"data_visualization_agent": data_visualization_agent,
|
66
|
+
"checkpointer": checkpointer,
|
67
|
+
}
|
68
|
+
self._compiled_graph = self._make_compiled_graph()
|
69
|
+
self.response = None
|
70
|
+
|
71
|
+
def _make_compiled_graph(self):
|
72
|
+
"""
|
73
|
+
Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
|
74
|
+
Running this method resets the response to None.
|
75
|
+
"""
|
76
|
+
self.response = None
|
77
|
+
return make_sql_data_analyst(
|
78
|
+
model=self._params["model"],
|
79
|
+
sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
|
80
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
81
|
+
checkpointer=self._params["checkpointer"],
|
82
|
+
)
|
83
|
+
|
84
|
+
def update_params(self, **kwargs):
|
85
|
+
"""
|
86
|
+
Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
|
87
|
+
and rebuilds the compiled graph.
|
88
|
+
"""
|
89
|
+
for k, v in kwargs.items():
|
90
|
+
self._params[k] = v
|
91
|
+
self._compiled_graph = self._make_compiled_graph()
|
92
|
+
|
93
|
+
def ainvoke_agent(self, user_instructions, **kwargs):
|
94
|
+
"""
|
95
|
+
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
96
|
+
|
97
|
+
Parameters:
|
98
|
+
----------
|
99
|
+
user_instructions: str
|
100
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
101
|
+
**kwargs:
|
102
|
+
Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
-------
|
106
|
+
None. The response is stored in the `response` attribute.
|
107
|
+
|
108
|
+
Example:
|
109
|
+
--------
|
110
|
+
``` python
|
111
|
+
# TODO
|
112
|
+
```
|
113
|
+
"""
|
114
|
+
response = self._compiled_graph.ainvoke({
|
115
|
+
"user_instructions": user_instructions,
|
116
|
+
}, **kwargs)
|
117
|
+
self.response = response
|
118
|
+
|
119
|
+
def invoke_agent(self, user_instructions, **kwargs):
|
120
|
+
"""
|
121
|
+
Invokes the SQL Data Analyst Multi-Agent.
|
122
|
+
|
123
|
+
Parameters:
|
124
|
+
----------
|
125
|
+
user_instructions: str
|
126
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
127
|
+
**kwargs:
|
128
|
+
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
-------
|
132
|
+
None. The response is stored in the `response` attribute.
|
133
|
+
|
134
|
+
Example:
|
135
|
+
--------
|
136
|
+
``` python
|
137
|
+
# TODO
|
138
|
+
```
|
139
|
+
"""
|
140
|
+
response = self._compiled_graph.invoke({
|
141
|
+
"user_instructions": user_instructions,
|
142
|
+
}, **kwargs)
|
143
|
+
self.response = response
|
144
|
+
|
145
|
+
def get_data_sql(self):
|
146
|
+
"""
|
147
|
+
Returns the SQL data as a Pandas DataFrame.
|
148
|
+
"""
|
149
|
+
if self.response:
|
150
|
+
if self.response.get("data_sql"):
|
151
|
+
return pd.DataFrame(self.response.get("data_sql"))
|
152
|
+
|
153
|
+
def get_plotly_graph(self):
|
154
|
+
"""
|
155
|
+
Returns the Plotly graph as a Plotly object.
|
156
|
+
"""
|
157
|
+
if self.response:
|
158
|
+
if self.response.get("plotly_graph"):
|
159
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
160
|
+
|
161
|
+
def get_sql_query_code(self, markdown=False):
|
162
|
+
"""
|
163
|
+
Returns the SQL query code as a string.
|
164
|
+
|
165
|
+
Parameters:
|
166
|
+
----------
|
167
|
+
markdown: bool
|
168
|
+
If True, returns the code as a Markdown code block for Jupyter (IPython).
|
169
|
+
For streamlit, use `st.code()` instead.
|
170
|
+
"""
|
171
|
+
if self.response:
|
172
|
+
if self.response.get("sql_query_code"):
|
173
|
+
if markdown:
|
174
|
+
return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
|
175
|
+
return self.response.get("sql_query_code")
|
176
|
+
|
177
|
+
def get_sql_database_function(self, markdown=False):
|
178
|
+
"""
|
179
|
+
Returns the SQL database function as a string.
|
180
|
+
|
181
|
+
Parameters:
|
182
|
+
----------
|
183
|
+
markdown: bool
|
184
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
185
|
+
For streamlit, use `st.code()` instead.
|
186
|
+
"""
|
187
|
+
if self.response:
|
188
|
+
if self.response.get("sql_database_function"):
|
189
|
+
if markdown:
|
190
|
+
return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
|
191
|
+
return self.response.get("sql_database_function")
|
192
|
+
|
193
|
+
def get_data_visualization_function(self, markdown=False):
|
194
|
+
"""
|
195
|
+
Returns the data visualization function as a string.
|
196
|
+
|
197
|
+
Parameters:
|
198
|
+
----------
|
199
|
+
markdown: bool
|
200
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
201
|
+
For streamlit, use `st.code()` instead.
|
202
|
+
"""
|
203
|
+
if self.response:
|
204
|
+
if self.response.get("data_visualization_function"):
|
205
|
+
if markdown:
|
206
|
+
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
207
|
+
return self.response.get("data_visualization_function")
|
208
|
+
|
209
|
+
|
210
|
+
|
211
|
+
def make_sql_data_analyst(
|
212
|
+
model,
|
213
|
+
sql_database_agent: CompiledStateGraph,
|
214
|
+
data_visualization_agent: CompiledStateGraph,
|
215
|
+
checkpointer: Checkpointer = None
|
216
|
+
):
|
217
|
+
"""
|
218
|
+
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
219
|
+
|
220
|
+
- Agent 1: SQL Database Agent made with `make_sql_database_agent()`
|
221
|
+
- Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
|
222
|
+
|
223
|
+
Parameters:
|
224
|
+
----------
|
225
|
+
model:
|
226
|
+
The language model to be used for the agents.
|
227
|
+
sql_database_agent: CompiledStateGraph
|
228
|
+
The SQL Database Agent made with `make_sql_database_agent()`.
|
229
|
+
data_visualization_agent: CompiledStateGraph
|
230
|
+
The Data Visualization Agent made with `make_data_visualization_agent()`.
|
231
|
+
checkpointer: Checkpointer (optional)
|
232
|
+
The checkpointer to save the state of the multi-agent system.
|
233
|
+
Default: None
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
-------
|
237
|
+
CompiledStateGraph
|
238
|
+
The compiled multi-agent system.
|
239
|
+
"""
|
240
|
+
|
241
|
+
llm = model
|
242
|
+
|
243
|
+
class PrimaryState(TypedDict):
|
244
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
245
|
+
user_instructions: str
|
246
|
+
sql_query_code: str
|
247
|
+
sql_database_function: str
|
248
|
+
data_sql: dict
|
249
|
+
data_raw: dict
|
250
|
+
plot_required: bool
|
251
|
+
data_visualization_function: str
|
252
|
+
plotly_graph: dict
|
253
|
+
|
254
|
+
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
255
|
+
|
256
|
+
response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
|
257
|
+
|
258
|
+
if response.content == 'plot':
|
259
|
+
plot_required = True
|
260
|
+
goto="data_visualization_agent"
|
261
|
+
else:
|
262
|
+
plot_required = False
|
263
|
+
goto="__end__"
|
264
|
+
|
265
|
+
return Command(
|
266
|
+
update={
|
267
|
+
'data_raw': state.get("data_sql"),
|
268
|
+
'plot_required': plot_required,
|
269
|
+
},
|
270
|
+
goto=goto
|
271
|
+
)
|
272
|
+
|
273
|
+
workflow = StateGraph(PrimaryState)
|
274
|
+
|
275
|
+
workflow.add_node("sql_database_agent", sql_database_agent)
|
276
|
+
workflow.add_node("route_to_visualization", route_to_visualization)
|
277
|
+
workflow.add_node("data_visualization_agent", data_visualization_agent)
|
278
|
+
|
279
|
+
workflow.add_edge(START, "sql_database_agent")
|
280
|
+
workflow.add_edge("sql_database_agent", "route_to_visualization")
|
281
|
+
workflow.add_edge("data_visualization_agent", END)
|
282
|
+
|
283
|
+
app = workflow.compile(checkpointer=checkpointer)
|
284
|
+
|
285
|
+
return app
|
286
|
+
|
@@ -1,15 +1,202 @@
|
|
1
1
|
from langchain_core.messages import AIMessage
|
2
2
|
from langgraph.graph import StateGraph, END
|
3
3
|
from langgraph.types import interrupt, Command
|
4
|
+
from langgraph.graph.state import CompiledStateGraph
|
5
|
+
|
6
|
+
from langchain_core.runnables import RunnableConfig
|
7
|
+
from langgraph.pregel.types import StreamMode
|
4
8
|
|
5
9
|
import pandas as pd
|
6
10
|
import sqlalchemy as sql
|
7
11
|
|
8
|
-
from typing import Any, Callable, Dict, Type, Optional
|
12
|
+
from typing import Any, Callable, Dict, Type, Optional, Union
|
9
13
|
|
10
14
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
11
15
|
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
12
16
|
|
17
|
+
from IPython.display import Image, display
|
18
|
+
import pandas as pd
|
19
|
+
|
20
|
+
class BaseAgent(CompiledStateGraph):
|
21
|
+
"""
|
22
|
+
A generic base class for agents that interact with compiled state graphs.
|
23
|
+
|
24
|
+
Provides shared functionality for handling parameters, responses, and state
|
25
|
+
graph operations.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, **params):
|
29
|
+
"""
|
30
|
+
Initialize the agent with provided parameters.
|
31
|
+
|
32
|
+
Parameters:
|
33
|
+
**params: Arbitrary keyword arguments representing the agent's parameters.
|
34
|
+
"""
|
35
|
+
self._params = params
|
36
|
+
self._compiled_graph = self._make_compiled_graph()
|
37
|
+
self.response = None
|
38
|
+
|
39
|
+
def _make_compiled_graph(self):
|
40
|
+
"""
|
41
|
+
Subclasses should override this method to create a specific compiled graph.
|
42
|
+
"""
|
43
|
+
raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
|
44
|
+
|
45
|
+
def update_params(self, **kwargs):
|
46
|
+
"""
|
47
|
+
Update one or more parameters and rebuild the compiled graph.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
**kwargs: Parameters to update.
|
51
|
+
"""
|
52
|
+
self._params.update(kwargs)
|
53
|
+
self._compiled_graph = self._make_compiled_graph()
|
54
|
+
|
55
|
+
def __getattr__(self, name: str):
|
56
|
+
"""
|
57
|
+
Delegate attribute access to the compiled graph if the attribute is not found.
|
58
|
+
|
59
|
+
Parameters:
|
60
|
+
name (str): The attribute name.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Any: The attribute from the compiled graph.
|
64
|
+
"""
|
65
|
+
return getattr(self._compiled_graph, name)
|
66
|
+
|
67
|
+
def invoke(
|
68
|
+
self,
|
69
|
+
input: Union[dict[str, Any], Any],
|
70
|
+
config: Optional[RunnableConfig] = None,
|
71
|
+
**kwargs
|
72
|
+
):
|
73
|
+
"""
|
74
|
+
Wrapper for self._compiled_graph.invoke()
|
75
|
+
|
76
|
+
Parameters:
|
77
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
78
|
+
config: Optional. The configuration for the graph run.
|
79
|
+
**kwarg: Arguments to pass to self._compiled_graph.invoke()
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Any: The agent's response.
|
83
|
+
"""
|
84
|
+
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
85
|
+
return self.response
|
86
|
+
|
87
|
+
def ainvoke(
|
88
|
+
self,
|
89
|
+
input: Union[dict[str, Any], Any],
|
90
|
+
config: Optional[RunnableConfig] = None,
|
91
|
+
**kwargs
|
92
|
+
):
|
93
|
+
"""
|
94
|
+
Wrapper for self._compiled_graph.ainvoke()
|
95
|
+
|
96
|
+
Parameters:
|
97
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
98
|
+
config: Optional. The configuration for the graph run.
|
99
|
+
**kwarg: Arguments to pass to self._compiled_graph.ainvoke()
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Any: The agent's response.
|
103
|
+
"""
|
104
|
+
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
105
|
+
return self.response
|
106
|
+
|
107
|
+
def stream(
|
108
|
+
self,
|
109
|
+
input: dict[str, Any] | Any,
|
110
|
+
config: RunnableConfig | None = None,
|
111
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
112
|
+
**kwargs
|
113
|
+
):
|
114
|
+
"""
|
115
|
+
Wrapper for self._compiled_graph.stream()
|
116
|
+
|
117
|
+
Parameters:
|
118
|
+
input: The input to the graph.
|
119
|
+
config: The configuration to use for the run.
|
120
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
121
|
+
Options are 'values', 'updates', and 'debug'.
|
122
|
+
values: Emit the current values of the state for each step.
|
123
|
+
updates: Emit only the updates to the state for each step.
|
124
|
+
Output is a dict with the node name as key and the updated values as value.
|
125
|
+
debug: Emit debug events for each step.
|
126
|
+
**kwarg: Arguments to pass to self._compiled_graph.stream()
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Any: The agent's response.
|
130
|
+
"""
|
131
|
+
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
132
|
+
return self.response
|
133
|
+
|
134
|
+
def astream(
|
135
|
+
self,
|
136
|
+
input: dict[str, Any] | Any,
|
137
|
+
config: RunnableConfig | None = None,
|
138
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
139
|
+
**kwargs
|
140
|
+
):
|
141
|
+
"""
|
142
|
+
Wrapper for self._compiled_graph.astream()
|
143
|
+
|
144
|
+
Parameters:
|
145
|
+
input: The input to the graph.
|
146
|
+
config: The configuration to use for the run.
|
147
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
148
|
+
Options are 'values', 'updates', and 'debug'.
|
149
|
+
values: Emit the current values of the state for each step.
|
150
|
+
updates: Emit only the updates to the state for each step.
|
151
|
+
Output is a dict with the node name as key and the updated values as value.
|
152
|
+
debug: Emit debug events for each step.
|
153
|
+
**kwarg: Arguments to pass to self._compiled_graph.astream()
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
Any: The agent's response.
|
157
|
+
"""
|
158
|
+
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
159
|
+
return self.response
|
160
|
+
|
161
|
+
def get_state_keys(self):
|
162
|
+
"""
|
163
|
+
Returns a list of keys that the state graph response contains.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
list: A list of keys in the response.
|
167
|
+
"""
|
168
|
+
return list(self.get_output_jsonschema()['properties'].keys())
|
169
|
+
|
170
|
+
def get_state_properties(self):
|
171
|
+
"""
|
172
|
+
Returns detailed properties of the state graph response.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
dict: The properties of the response.
|
176
|
+
"""
|
177
|
+
return self.get_output_jsonschema()['properties']
|
178
|
+
|
179
|
+
def get_response(self):
|
180
|
+
"""
|
181
|
+
Returns the response generated by the agent.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Any: The agent's response.
|
185
|
+
"""
|
186
|
+
return self.response
|
187
|
+
|
188
|
+
def show(self, xray: int = 0):
|
189
|
+
"""
|
190
|
+
Displays the agent's state graph as a Mermaid diagram.
|
191
|
+
|
192
|
+
Parameters:
|
193
|
+
xray (int): If set to 1, displays subgraph levels. Defaults to 0.
|
194
|
+
"""
|
195
|
+
display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
|
196
|
+
|
197
|
+
|
198
|
+
|
199
|
+
|
13
200
|
def create_coding_agent_graph(
|
14
201
|
GraphState: Type,
|
15
202
|
node_functions: Dict[str, Callable],
|
@@ -79,35 +266,37 @@ def create_coding_agent_graph(
|
|
79
266
|
|
80
267
|
workflow = StateGraph(GraphState)
|
81
268
|
|
82
|
-
#
|
83
|
-
if not bypass_recommended_steps:
|
84
|
-
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
269
|
+
# * NODES
|
85
270
|
|
86
271
|
# Always add create, execute, and fix nodes
|
87
272
|
workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
|
88
273
|
workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
|
89
274
|
workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
|
90
275
|
|
276
|
+
# Conditionally add the recommended-steps node
|
277
|
+
if not bypass_recommended_steps:
|
278
|
+
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
279
|
+
|
280
|
+
# Conditionally add the human review node
|
281
|
+
if human_in_the_loop:
|
282
|
+
workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
|
283
|
+
|
91
284
|
# Conditionally add the explanation node
|
92
285
|
if not bypass_explain_code:
|
93
286
|
workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
|
94
287
|
|
288
|
+
# * EDGES
|
289
|
+
|
95
290
|
# Set the entry point
|
96
291
|
entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
|
292
|
+
|
97
293
|
workflow.set_entry_point(entry_point)
|
98
294
|
|
99
|
-
# Add edges for recommended steps
|
100
295
|
if not bypass_recommended_steps:
|
101
|
-
|
102
|
-
workflow.add_edge(recommended_steps_node_name, human_review_node_name)
|
103
|
-
else:
|
104
|
-
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
105
|
-
elif human_in_the_loop:
|
106
|
-
# Skip recommended steps but still include human review
|
107
|
-
workflow.add_edge(create_code_node_name, human_review_node_name)
|
296
|
+
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
108
297
|
|
109
|
-
# Create -> Execute
|
110
298
|
workflow.add_edge(create_code_node_name, execute_code_node_name)
|
299
|
+
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
111
300
|
|
112
301
|
# Define a helper to check if we have an error & can still retry
|
113
302
|
def error_and_can_retry(state):
|
@@ -117,39 +306,43 @@ def create_coding_agent_graph(
|
|
117
306
|
and state.get(max_retries_key) is not None
|
118
307
|
and state[retry_count_key] < state[max_retries_key]
|
119
308
|
)
|
120
|
-
|
121
|
-
#
|
122
|
-
if
|
123
|
-
# If we are NOT bypassing explain, the next node is fix_code if error,
|
124
|
-
# else explain_code. Then we wire explain_code -> END afterward.
|
309
|
+
|
310
|
+
# If human in the loop, add a branch for human review
|
311
|
+
if human_in_the_loop:
|
125
312
|
workflow.add_conditional_edges(
|
126
313
|
execute_code_node_name,
|
127
|
-
lambda s: "fix_code" if error_and_can_retry(s) else "
|
314
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
|
128
315
|
{
|
316
|
+
"human_review": human_review_node_name,
|
129
317
|
"fix_code": fix_code_node_name,
|
130
|
-
"explain_code": explain_code_node_name,
|
131
318
|
},
|
132
319
|
)
|
133
|
-
# Fix code -> Execute again
|
134
|
-
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
135
|
-
# explain_code -> END
|
136
|
-
workflow.add_edge(explain_code_node_name, END)
|
137
320
|
else:
|
138
|
-
# If
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
321
|
+
# If no human review, the next node is fix_code if error, else explain_code.
|
322
|
+
if not bypass_explain_code:
|
323
|
+
workflow.add_conditional_edges(
|
324
|
+
execute_code_node_name,
|
325
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
|
326
|
+
{
|
327
|
+
"fix_code": fix_code_node_name,
|
328
|
+
"explain_code": explain_code_node_name,
|
329
|
+
},
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
workflow.add_conditional_edges(
|
333
|
+
execute_code_node_name,
|
334
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "END",
|
335
|
+
{
|
336
|
+
"fix_code": fix_code_node_name,
|
337
|
+
"END": END,
|
338
|
+
},
|
339
|
+
)
|
340
|
+
|
341
|
+
if not bypass_explain_code:
|
342
|
+
workflow.add_edge(explain_code_node_name, END)
|
150
343
|
|
151
344
|
# Finally, compile
|
152
|
-
if human_in_the_loop
|
345
|
+
if human_in_the_loop:
|
153
346
|
app = workflow.compile(checkpointer=checkpointer)
|
154
347
|
else:
|
155
348
|
app = workflow.compile()
|
@@ -165,6 +358,8 @@ def node_func_human_review(
|
|
165
358
|
no_goto: str,
|
166
359
|
user_instructions_key: str = "user_instructions",
|
167
360
|
recommended_steps_key: str = "recommended_steps",
|
361
|
+
code_snippet_key: str = "code_snippet",
|
362
|
+
code_type: str = "python"
|
168
363
|
) -> Command[str]:
|
169
364
|
"""
|
170
365
|
A generic function to handle human review steps.
|
@@ -183,6 +378,10 @@ def node_func_human_review(
|
|
183
378
|
The key in the state to store user instructions.
|
184
379
|
recommended_steps_key : str, optional
|
185
380
|
The key in the state to store recommended steps.
|
381
|
+
code_snippet_key : str, optional
|
382
|
+
The key in the state to store the code snippet.
|
383
|
+
code_type : str, optional
|
384
|
+
The type of code snippet to display (e.g., "python").
|
186
385
|
|
187
386
|
Returns
|
188
387
|
-------
|
@@ -190,9 +389,11 @@ def node_func_human_review(
|
|
190
389
|
A Command object directing the next state and updates to the state.
|
191
390
|
"""
|
192
391
|
print(" * HUMAN REVIEW")
|
392
|
+
|
393
|
+
code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
|
193
394
|
|
194
395
|
# Display instructions and get user response
|
195
|
-
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
|
396
|
+
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
|
196
397
|
|
197
398
|
# Decide next steps based on user input
|
198
399
|
if user_input.strip().lower() == "yes":
|
@@ -200,11 +401,11 @@ def node_func_human_review(
|
|
200
401
|
update = {}
|
201
402
|
else:
|
202
403
|
goto = no_goto
|
203
|
-
modifications = "Modifications: \n" + user_input
|
404
|
+
modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
|
204
405
|
if state.get(user_instructions_key) is None:
|
205
|
-
update = {user_instructions_key: modifications}
|
406
|
+
update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
|
206
407
|
else:
|
207
|
-
update = {user_instructions_key: state.get(user_instructions_key) + modifications}
|
408
|
+
update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
|
208
409
|
|
209
410
|
return Command(goto=goto, update=update)
|
210
411
|
|
@@ -394,6 +595,7 @@ def node_func_fix_agent_code(
|
|
394
595
|
retry_count_key: str = "retry_count",
|
395
596
|
log: bool = False,
|
396
597
|
file_path: str = "logs/agent_function.py",
|
598
|
+
function_name: str = "agent_function"
|
397
599
|
) -> dict:
|
398
600
|
"""
|
399
601
|
Generic function to fix a given piece of agent code using an LLM and a prompt template.
|
@@ -420,6 +622,8 @@ def node_func_fix_agent_code(
|
|
420
622
|
Whether to log the returned code to a file.
|
421
623
|
file_path : str, optional
|
422
624
|
The path to the file where the code will be logged.
|
625
|
+
function_name : str, optional
|
626
|
+
The name of the function in the code snippet that will be fixed.
|
423
627
|
|
424
628
|
Returns
|
425
629
|
-------
|
@@ -436,7 +640,8 @@ def node_func_fix_agent_code(
|
|
436
640
|
# Format the prompt with the code snippet and the error
|
437
641
|
prompt = prompt_template.format(
|
438
642
|
code_snippet=code_snippet,
|
439
|
-
error=error_message
|
643
|
+
error=error_message,
|
644
|
+
function_name=function_name,
|
440
645
|
)
|
441
646
|
|
442
647
|
# Execute the prompt with the LLM
|