ai-data-science-team 0.0.0.9006__py3-none-any.whl → 0.0.0.9008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +5 -4
- ai_data_science_team/agents/data_cleaning_agent.py +371 -45
- ai_data_science_team/agents/data_visualization_agent.py +764 -0
- ai_data_science_team/agents/data_wrangling_agent.py +507 -23
- ai_data_science_team/agents/feature_engineering_agent.py +467 -34
- ai_data_science_team/agents/sql_database_agent.py +394 -30
- ai_data_science_team/multiagents/__init__.py +1 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +286 -0
- ai_data_science_team/multiagents/supervised_data_analyst.py +2 -0
- ai_data_science_team/templates/__init__.py +9 -0
- ai_data_science_team/templates/agent_templates.py +247 -42
- ai_data_science_team/tools/metadata.py +110 -47
- ai_data_science_team/tools/regex.py +33 -0
- ai_data_science_team/utils/__init__.py +0 -0
- ai_data_science_team/utils/plotly.py +24 -0
- ai_data_science_team-0.0.0.9008.dist-info/METADATA +231 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +26 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/WHEEL +1 -1
- ai_data_science_team-0.0.0.9006.dist-info/METADATA +0 -165
- ai_data_science_team-0.0.0.9006.dist-info/RECORD +0 -20
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9006.dist-info → ai_data_science_team-0.0.0.9008.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
|
2
|
+
from langchain_core.messages import BaseMessage
|
3
|
+
from langgraph.checkpoint.memory import MemorySaver
|
4
|
+
from langgraph.types import Checkpointer
|
5
|
+
|
6
|
+
from langgraph.graph import START, END, StateGraph
|
7
|
+
from langgraph.graph.state import CompiledStateGraph
|
8
|
+
from langgraph.types import Command
|
9
|
+
|
10
|
+
from typing import TypedDict, Annotated, Sequence
|
11
|
+
import operator
|
12
|
+
|
13
|
+
from typing_extensions import TypedDict, Literal
|
14
|
+
|
15
|
+
import pandas as pd
|
16
|
+
from IPython.display import Markdown
|
17
|
+
|
18
|
+
from ai_data_science_team.templates import BaseAgent
|
19
|
+
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
|
20
|
+
from ai_data_science_team.utils.plotly import plotly_from_dict
|
21
|
+
|
22
|
+
|
23
|
+
|
24
|
+
class SQLDataAnalyst(BaseAgent):
|
25
|
+
"""
|
26
|
+
SQLDataAnalyst is a multi-agent class that combines SQL database querying and data visualization capabilities.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
-----------
|
30
|
+
model:
|
31
|
+
The language model to be used for the agents.
|
32
|
+
sql_database_agent: SQLDatabaseAgent
|
33
|
+
The SQL Database Agent.
|
34
|
+
data_visualization_agent: DataVisualizationAgent
|
35
|
+
The Data Visualization Agent.
|
36
|
+
|
37
|
+
Methods:
|
38
|
+
--------
|
39
|
+
ainvoke_agent(user_instructions, **kwargs)
|
40
|
+
Asynchronously invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
41
|
+
invoke_agent(user_instructions, **kwargs)
|
42
|
+
Invokes the SQL Data Analyst Multi-Agent with the given user instructions.
|
43
|
+
get_data_sql()
|
44
|
+
Returns the SQL data as a Pandas DataFrame.
|
45
|
+
get_plotly_graph()
|
46
|
+
Returns the Plotly graph as a Plotly object.
|
47
|
+
get_sql_query_code(markdown=False)
|
48
|
+
Returns the SQL query code as a string, optionally formatted as a Markdown code block.
|
49
|
+
get_sql_database_function(markdown=False)
|
50
|
+
Returns the SQL database function as a string, optionally formatted as a Markdown code block.
|
51
|
+
get_data_visualization_function(markdown=False)
|
52
|
+
Returns the data visualization function as a string, optionally formatted as a Markdown code block.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
model,
|
58
|
+
sql_database_agent: SQLDatabaseAgent,
|
59
|
+
data_visualization_agent: DataVisualizationAgent,
|
60
|
+
checkpointer: Checkpointer = None,
|
61
|
+
):
|
62
|
+
self._params = {
|
63
|
+
"model": model,
|
64
|
+
"sql_database_agent": sql_database_agent,
|
65
|
+
"data_visualization_agent": data_visualization_agent,
|
66
|
+
"checkpointer": checkpointer,
|
67
|
+
}
|
68
|
+
self._compiled_graph = self._make_compiled_graph()
|
69
|
+
self.response = None
|
70
|
+
|
71
|
+
def _make_compiled_graph(self):
|
72
|
+
"""
|
73
|
+
Create or rebuild the compiled graph for the SQL Data Analyst Multi-Agent.
|
74
|
+
Running this method resets the response to None.
|
75
|
+
"""
|
76
|
+
self.response = None
|
77
|
+
return make_sql_data_analyst(
|
78
|
+
model=self._params["model"],
|
79
|
+
sql_database_agent=self._params["sql_database_agent"]._compiled_graph,
|
80
|
+
data_visualization_agent=self._params["data_visualization_agent"]._compiled_graph,
|
81
|
+
checkpointer=self._params["checkpointer"],
|
82
|
+
)
|
83
|
+
|
84
|
+
def update_params(self, **kwargs):
|
85
|
+
"""
|
86
|
+
Updates the agent's parameters (e.g. model, sql_database_agent, etc.)
|
87
|
+
and rebuilds the compiled graph.
|
88
|
+
"""
|
89
|
+
for k, v in kwargs.items():
|
90
|
+
self._params[k] = v
|
91
|
+
self._compiled_graph = self._make_compiled_graph()
|
92
|
+
|
93
|
+
def ainvoke_agent(self, user_instructions, **kwargs):
|
94
|
+
"""
|
95
|
+
Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
|
96
|
+
|
97
|
+
Parameters:
|
98
|
+
----------
|
99
|
+
user_instructions: str
|
100
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
101
|
+
**kwargs:
|
102
|
+
Additional keyword arguments to pass to the compiled graph's `ainvoke` method.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
-------
|
106
|
+
None. The response is stored in the `response` attribute.
|
107
|
+
|
108
|
+
Example:
|
109
|
+
--------
|
110
|
+
``` python
|
111
|
+
# TODO
|
112
|
+
```
|
113
|
+
"""
|
114
|
+
response = self._compiled_graph.ainvoke({
|
115
|
+
"user_instructions": user_instructions,
|
116
|
+
}, **kwargs)
|
117
|
+
self.response = response
|
118
|
+
|
119
|
+
def invoke_agent(self, user_instructions, **kwargs):
|
120
|
+
"""
|
121
|
+
Invokes the SQL Data Analyst Multi-Agent.
|
122
|
+
|
123
|
+
Parameters:
|
124
|
+
----------
|
125
|
+
user_instructions: str
|
126
|
+
The user's instructions for the combined SQL and (optionally) Data Visualization agents.
|
127
|
+
**kwargs:
|
128
|
+
Additional keyword arguments to pass to the compiled graph's `invoke` method.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
-------
|
132
|
+
None. The response is stored in the `response` attribute.
|
133
|
+
|
134
|
+
Example:
|
135
|
+
--------
|
136
|
+
``` python
|
137
|
+
# TODO
|
138
|
+
```
|
139
|
+
"""
|
140
|
+
response = self._compiled_graph.invoke({
|
141
|
+
"user_instructions": user_instructions,
|
142
|
+
}, **kwargs)
|
143
|
+
self.response = response
|
144
|
+
|
145
|
+
def get_data_sql(self):
|
146
|
+
"""
|
147
|
+
Returns the SQL data as a Pandas DataFrame.
|
148
|
+
"""
|
149
|
+
if self.response:
|
150
|
+
if self.response.get("data_sql"):
|
151
|
+
return pd.DataFrame(self.response.get("data_sql"))
|
152
|
+
|
153
|
+
def get_plotly_graph(self):
|
154
|
+
"""
|
155
|
+
Returns the Plotly graph as a Plotly object.
|
156
|
+
"""
|
157
|
+
if self.response:
|
158
|
+
if self.response.get("plotly_graph"):
|
159
|
+
return plotly_from_dict(self.response.get("plotly_graph"))
|
160
|
+
|
161
|
+
def get_sql_query_code(self, markdown=False):
|
162
|
+
"""
|
163
|
+
Returns the SQL query code as a string.
|
164
|
+
|
165
|
+
Parameters:
|
166
|
+
----------
|
167
|
+
markdown: bool
|
168
|
+
If True, returns the code as a Markdown code block for Jupyter (IPython).
|
169
|
+
For streamlit, use `st.code()` instead.
|
170
|
+
"""
|
171
|
+
if self.response:
|
172
|
+
if self.response.get("sql_query_code"):
|
173
|
+
if markdown:
|
174
|
+
return Markdown(f"```sql\n{self.response.get('sql_query_code')}\n```")
|
175
|
+
return self.response.get("sql_query_code")
|
176
|
+
|
177
|
+
def get_sql_database_function(self, markdown=False):
|
178
|
+
"""
|
179
|
+
Returns the SQL database function as a string.
|
180
|
+
|
181
|
+
Parameters:
|
182
|
+
----------
|
183
|
+
markdown: bool
|
184
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
185
|
+
For streamlit, use `st.code()` instead.
|
186
|
+
"""
|
187
|
+
if self.response:
|
188
|
+
if self.response.get("sql_database_function"):
|
189
|
+
if markdown:
|
190
|
+
return Markdown(f"```python\n{self.response.get('sql_database_function')}\n```")
|
191
|
+
return self.response.get("sql_database_function")
|
192
|
+
|
193
|
+
def get_data_visualization_function(self, markdown=False):
|
194
|
+
"""
|
195
|
+
Returns the data visualization function as a string.
|
196
|
+
|
197
|
+
Parameters:
|
198
|
+
----------
|
199
|
+
markdown: bool
|
200
|
+
If True, returns the function as a Markdown code block for Jupyter (IPython).
|
201
|
+
For streamlit, use `st.code()` instead.
|
202
|
+
"""
|
203
|
+
if self.response:
|
204
|
+
if self.response.get("data_visualization_function"):
|
205
|
+
if markdown:
|
206
|
+
return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
|
207
|
+
return self.response.get("data_visualization_function")
|
208
|
+
|
209
|
+
|
210
|
+
|
211
|
+
def make_sql_data_analyst(
|
212
|
+
model,
|
213
|
+
sql_database_agent: CompiledStateGraph,
|
214
|
+
data_visualization_agent: CompiledStateGraph,
|
215
|
+
checkpointer: Checkpointer = None
|
216
|
+
):
|
217
|
+
"""
|
218
|
+
Creates a multi-agent system that takes in a SQL query and returns a plot or table.
|
219
|
+
|
220
|
+
- Agent 1: SQL Database Agent made with `make_sql_database_agent()`
|
221
|
+
- Agent 2: Data Visualization Agent made with `make_data_visualization_agent()`
|
222
|
+
|
223
|
+
Parameters:
|
224
|
+
----------
|
225
|
+
model:
|
226
|
+
The language model to be used for the agents.
|
227
|
+
sql_database_agent: CompiledStateGraph
|
228
|
+
The SQL Database Agent made with `make_sql_database_agent()`.
|
229
|
+
data_visualization_agent: CompiledStateGraph
|
230
|
+
The Data Visualization Agent made with `make_data_visualization_agent()`.
|
231
|
+
checkpointer: Checkpointer (optional)
|
232
|
+
The checkpointer to save the state of the multi-agent system.
|
233
|
+
Default: None
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
-------
|
237
|
+
CompiledStateGraph
|
238
|
+
The compiled multi-agent system.
|
239
|
+
"""
|
240
|
+
|
241
|
+
llm = model
|
242
|
+
|
243
|
+
class PrimaryState(TypedDict):
|
244
|
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
245
|
+
user_instructions: str
|
246
|
+
sql_query_code: str
|
247
|
+
sql_database_function: str
|
248
|
+
data_sql: dict
|
249
|
+
data_raw: dict
|
250
|
+
plot_required: bool
|
251
|
+
data_visualization_function: str
|
252
|
+
plotly_graph: dict
|
253
|
+
|
254
|
+
def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
|
255
|
+
|
256
|
+
response = llm.invoke(f"Respond in 1 word ('plot' or 'table'). Is the user requesting a plot? If unknown, select 'table'. \n\n User Instructions:\n{state.get('user_instructions')}")
|
257
|
+
|
258
|
+
if response.content == 'plot':
|
259
|
+
plot_required = True
|
260
|
+
goto="data_visualization_agent"
|
261
|
+
else:
|
262
|
+
plot_required = False
|
263
|
+
goto="__end__"
|
264
|
+
|
265
|
+
return Command(
|
266
|
+
update={
|
267
|
+
'data_raw': state.get("data_sql"),
|
268
|
+
'plot_required': plot_required,
|
269
|
+
},
|
270
|
+
goto=goto
|
271
|
+
)
|
272
|
+
|
273
|
+
workflow = StateGraph(PrimaryState)
|
274
|
+
|
275
|
+
workflow.add_node("sql_database_agent", sql_database_agent)
|
276
|
+
workflow.add_node("route_to_visualization", route_to_visualization)
|
277
|
+
workflow.add_node("data_visualization_agent", data_visualization_agent)
|
278
|
+
|
279
|
+
workflow.add_edge(START, "sql_database_agent")
|
280
|
+
workflow.add_edge("sql_database_agent", "route_to_visualization")
|
281
|
+
workflow.add_edge("data_visualization_agent", END)
|
282
|
+
|
283
|
+
app = workflow.compile(checkpointer=checkpointer)
|
284
|
+
|
285
|
+
return app
|
286
|
+
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from ai_data_science_team.templates.agent_templates import(
|
2
|
+
node_func_execute_agent_code_on_data,
|
3
|
+
node_func_human_review,
|
4
|
+
node_func_fix_agent_code,
|
5
|
+
node_func_explain_agent_code,
|
6
|
+
node_func_execute_agent_from_sql_connection,
|
7
|
+
create_coding_agent_graph,
|
8
|
+
BaseAgent,
|
9
|
+
)
|
@@ -1,15 +1,202 @@
|
|
1
1
|
from langchain_core.messages import AIMessage
|
2
2
|
from langgraph.graph import StateGraph, END
|
3
3
|
from langgraph.types import interrupt, Command
|
4
|
+
from langgraph.graph.state import CompiledStateGraph
|
5
|
+
|
6
|
+
from langchain_core.runnables import RunnableConfig
|
7
|
+
from langgraph.pregel.types import StreamMode
|
4
8
|
|
5
9
|
import pandas as pd
|
6
10
|
import sqlalchemy as sql
|
7
11
|
|
8
|
-
from typing import Any, Callable, Dict, Type, Optional
|
12
|
+
from typing import Any, Callable, Dict, Type, Optional, Union
|
9
13
|
|
10
14
|
from ai_data_science_team.tools.parsers import PythonOutputParser
|
11
15
|
from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
|
12
16
|
|
17
|
+
from IPython.display import Image, display
|
18
|
+
import pandas as pd
|
19
|
+
|
20
|
+
class BaseAgent(CompiledStateGraph):
|
21
|
+
"""
|
22
|
+
A generic base class for agents that interact with compiled state graphs.
|
23
|
+
|
24
|
+
Provides shared functionality for handling parameters, responses, and state
|
25
|
+
graph operations.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, **params):
|
29
|
+
"""
|
30
|
+
Initialize the agent with provided parameters.
|
31
|
+
|
32
|
+
Parameters:
|
33
|
+
**params: Arbitrary keyword arguments representing the agent's parameters.
|
34
|
+
"""
|
35
|
+
self._params = params
|
36
|
+
self._compiled_graph = self._make_compiled_graph()
|
37
|
+
self.response = None
|
38
|
+
|
39
|
+
def _make_compiled_graph(self):
|
40
|
+
"""
|
41
|
+
Subclasses should override this method to create a specific compiled graph.
|
42
|
+
"""
|
43
|
+
raise NotImplementedError("Subclasses must implement the `_make_compiled_graph` method.")
|
44
|
+
|
45
|
+
def update_params(self, **kwargs):
|
46
|
+
"""
|
47
|
+
Update one or more parameters and rebuild the compiled graph.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
**kwargs: Parameters to update.
|
51
|
+
"""
|
52
|
+
self._params.update(kwargs)
|
53
|
+
self._compiled_graph = self._make_compiled_graph()
|
54
|
+
|
55
|
+
def __getattr__(self, name: str):
|
56
|
+
"""
|
57
|
+
Delegate attribute access to the compiled graph if the attribute is not found.
|
58
|
+
|
59
|
+
Parameters:
|
60
|
+
name (str): The attribute name.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Any: The attribute from the compiled graph.
|
64
|
+
"""
|
65
|
+
return getattr(self._compiled_graph, name)
|
66
|
+
|
67
|
+
def invoke(
|
68
|
+
self,
|
69
|
+
input: Union[dict[str, Any], Any],
|
70
|
+
config: Optional[RunnableConfig] = None,
|
71
|
+
**kwargs
|
72
|
+
):
|
73
|
+
"""
|
74
|
+
Wrapper for self._compiled_graph.invoke()
|
75
|
+
|
76
|
+
Parameters:
|
77
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
78
|
+
config: Optional. The configuration for the graph run.
|
79
|
+
**kwarg: Arguments to pass to self._compiled_graph.invoke()
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
Any: The agent's response.
|
83
|
+
"""
|
84
|
+
self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
|
85
|
+
return self.response
|
86
|
+
|
87
|
+
def ainvoke(
|
88
|
+
self,
|
89
|
+
input: Union[dict[str, Any], Any],
|
90
|
+
config: Optional[RunnableConfig] = None,
|
91
|
+
**kwargs
|
92
|
+
):
|
93
|
+
"""
|
94
|
+
Wrapper for self._compiled_graph.ainvoke()
|
95
|
+
|
96
|
+
Parameters:
|
97
|
+
input: The input data for the graph. It can be a dictionary or any other type.
|
98
|
+
config: Optional. The configuration for the graph run.
|
99
|
+
**kwarg: Arguments to pass to self._compiled_graph.ainvoke()
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Any: The agent's response.
|
103
|
+
"""
|
104
|
+
self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
|
105
|
+
return self.response
|
106
|
+
|
107
|
+
def stream(
|
108
|
+
self,
|
109
|
+
input: dict[str, Any] | Any,
|
110
|
+
config: RunnableConfig | None = None,
|
111
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
112
|
+
**kwargs
|
113
|
+
):
|
114
|
+
"""
|
115
|
+
Wrapper for self._compiled_graph.stream()
|
116
|
+
|
117
|
+
Parameters:
|
118
|
+
input: The input to the graph.
|
119
|
+
config: The configuration to use for the run.
|
120
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
121
|
+
Options are 'values', 'updates', and 'debug'.
|
122
|
+
values: Emit the current values of the state for each step.
|
123
|
+
updates: Emit only the updates to the state for each step.
|
124
|
+
Output is a dict with the node name as key and the updated values as value.
|
125
|
+
debug: Emit debug events for each step.
|
126
|
+
**kwarg: Arguments to pass to self._compiled_graph.stream()
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Any: The agent's response.
|
130
|
+
"""
|
131
|
+
self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
132
|
+
return self.response
|
133
|
+
|
134
|
+
def astream(
|
135
|
+
self,
|
136
|
+
input: dict[str, Any] | Any,
|
137
|
+
config: RunnableConfig | None = None,
|
138
|
+
stream_mode: StreamMode | list[StreamMode] | None = None,
|
139
|
+
**kwargs
|
140
|
+
):
|
141
|
+
"""
|
142
|
+
Wrapper for self._compiled_graph.astream()
|
143
|
+
|
144
|
+
Parameters:
|
145
|
+
input: The input to the graph.
|
146
|
+
config: The configuration to use for the run.
|
147
|
+
stream_mode: The mode to stream output, defaults to self.stream_mode.
|
148
|
+
Options are 'values', 'updates', and 'debug'.
|
149
|
+
values: Emit the current values of the state for each step.
|
150
|
+
updates: Emit only the updates to the state for each step.
|
151
|
+
Output is a dict with the node name as key and the updated values as value.
|
152
|
+
debug: Emit debug events for each step.
|
153
|
+
**kwarg: Arguments to pass to self._compiled_graph.astream()
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
Any: The agent's response.
|
157
|
+
"""
|
158
|
+
self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
|
159
|
+
return self.response
|
160
|
+
|
161
|
+
def get_state_keys(self):
|
162
|
+
"""
|
163
|
+
Returns a list of keys that the state graph response contains.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
list: A list of keys in the response.
|
167
|
+
"""
|
168
|
+
return list(self.get_output_jsonschema()['properties'].keys())
|
169
|
+
|
170
|
+
def get_state_properties(self):
|
171
|
+
"""
|
172
|
+
Returns detailed properties of the state graph response.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
dict: The properties of the response.
|
176
|
+
"""
|
177
|
+
return self.get_output_jsonschema()['properties']
|
178
|
+
|
179
|
+
def get_response(self):
|
180
|
+
"""
|
181
|
+
Returns the response generated by the agent.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Any: The agent's response.
|
185
|
+
"""
|
186
|
+
return self.response
|
187
|
+
|
188
|
+
def show(self, xray: int = 0):
|
189
|
+
"""
|
190
|
+
Displays the agent's state graph as a Mermaid diagram.
|
191
|
+
|
192
|
+
Parameters:
|
193
|
+
xray (int): If set to 1, displays subgraph levels. Defaults to 0.
|
194
|
+
"""
|
195
|
+
display(Image(self.get_graph(xray=xray).draw_mermaid_png()))
|
196
|
+
|
197
|
+
|
198
|
+
|
199
|
+
|
13
200
|
def create_coding_agent_graph(
|
14
201
|
GraphState: Type,
|
15
202
|
node_functions: Dict[str, Callable],
|
@@ -79,35 +266,37 @@ def create_coding_agent_graph(
|
|
79
266
|
|
80
267
|
workflow = StateGraph(GraphState)
|
81
268
|
|
82
|
-
#
|
83
|
-
if not bypass_recommended_steps:
|
84
|
-
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
269
|
+
# * NODES
|
85
270
|
|
86
271
|
# Always add create, execute, and fix nodes
|
87
272
|
workflow.add_node(create_code_node_name, node_functions[create_code_node_name])
|
88
273
|
workflow.add_node(execute_code_node_name, node_functions[execute_code_node_name])
|
89
274
|
workflow.add_node(fix_code_node_name, node_functions[fix_code_node_name])
|
90
275
|
|
276
|
+
# Conditionally add the recommended-steps node
|
277
|
+
if not bypass_recommended_steps:
|
278
|
+
workflow.add_node(recommended_steps_node_name, node_functions[recommended_steps_node_name])
|
279
|
+
|
280
|
+
# Conditionally add the human review node
|
281
|
+
if human_in_the_loop:
|
282
|
+
workflow.add_node(human_review_node_name, node_functions[human_review_node_name])
|
283
|
+
|
91
284
|
# Conditionally add the explanation node
|
92
285
|
if not bypass_explain_code:
|
93
286
|
workflow.add_node(explain_code_node_name, node_functions[explain_code_node_name])
|
94
287
|
|
288
|
+
# * EDGES
|
289
|
+
|
95
290
|
# Set the entry point
|
96
291
|
entry_point = create_code_node_name if bypass_recommended_steps else recommended_steps_node_name
|
292
|
+
|
97
293
|
workflow.set_entry_point(entry_point)
|
98
294
|
|
99
|
-
# Add edges for recommended steps
|
100
295
|
if not bypass_recommended_steps:
|
101
|
-
|
102
|
-
workflow.add_edge(recommended_steps_node_name, human_review_node_name)
|
103
|
-
else:
|
104
|
-
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
105
|
-
elif human_in_the_loop:
|
106
|
-
# Skip recommended steps but still include human review
|
107
|
-
workflow.add_edge(create_code_node_name, human_review_node_name)
|
296
|
+
workflow.add_edge(recommended_steps_node_name, create_code_node_name)
|
108
297
|
|
109
|
-
# Create -> Execute
|
110
298
|
workflow.add_edge(create_code_node_name, execute_code_node_name)
|
299
|
+
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
111
300
|
|
112
301
|
# Define a helper to check if we have an error & can still retry
|
113
302
|
def error_and_can_retry(state):
|
@@ -117,39 +306,43 @@ def create_coding_agent_graph(
|
|
117
306
|
and state.get(max_retries_key) is not None
|
118
307
|
and state[retry_count_key] < state[max_retries_key]
|
119
308
|
)
|
120
|
-
|
121
|
-
#
|
122
|
-
if
|
123
|
-
# If we are NOT bypassing explain, the next node is fix_code if error,
|
124
|
-
# else explain_code. Then we wire explain_code -> END afterward.
|
309
|
+
|
310
|
+
# If human in the loop, add a branch for human review
|
311
|
+
if human_in_the_loop:
|
125
312
|
workflow.add_conditional_edges(
|
126
313
|
execute_code_node_name,
|
127
|
-
lambda s: "fix_code" if error_and_can_retry(s) else "
|
314
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "human_review",
|
128
315
|
{
|
316
|
+
"human_review": human_review_node_name,
|
129
317
|
"fix_code": fix_code_node_name,
|
130
|
-
"explain_code": explain_code_node_name,
|
131
318
|
},
|
132
319
|
)
|
133
|
-
# Fix code -> Execute again
|
134
|
-
workflow.add_edge(fix_code_node_name, execute_code_node_name)
|
135
|
-
# explain_code -> END
|
136
|
-
workflow.add_edge(explain_code_node_name, END)
|
137
320
|
else:
|
138
|
-
# If
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
321
|
+
# If no human review, the next node is fix_code if error, else explain_code.
|
322
|
+
if not bypass_explain_code:
|
323
|
+
workflow.add_conditional_edges(
|
324
|
+
execute_code_node_name,
|
325
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "explain_code",
|
326
|
+
{
|
327
|
+
"fix_code": fix_code_node_name,
|
328
|
+
"explain_code": explain_code_node_name,
|
329
|
+
},
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
workflow.add_conditional_edges(
|
333
|
+
execute_code_node_name,
|
334
|
+
lambda s: "fix_code" if error_and_can_retry(s) else "END",
|
335
|
+
{
|
336
|
+
"fix_code": fix_code_node_name,
|
337
|
+
"END": END,
|
338
|
+
},
|
339
|
+
)
|
340
|
+
|
341
|
+
if not bypass_explain_code:
|
342
|
+
workflow.add_edge(explain_code_node_name, END)
|
150
343
|
|
151
344
|
# Finally, compile
|
152
|
-
if human_in_the_loop
|
345
|
+
if human_in_the_loop:
|
153
346
|
app = workflow.compile(checkpointer=checkpointer)
|
154
347
|
else:
|
155
348
|
app = workflow.compile()
|
@@ -165,6 +358,8 @@ def node_func_human_review(
|
|
165
358
|
no_goto: str,
|
166
359
|
user_instructions_key: str = "user_instructions",
|
167
360
|
recommended_steps_key: str = "recommended_steps",
|
361
|
+
code_snippet_key: str = "code_snippet",
|
362
|
+
code_type: str = "python"
|
168
363
|
) -> Command[str]:
|
169
364
|
"""
|
170
365
|
A generic function to handle human review steps.
|
@@ -183,6 +378,10 @@ def node_func_human_review(
|
|
183
378
|
The key in the state to store user instructions.
|
184
379
|
recommended_steps_key : str, optional
|
185
380
|
The key in the state to store recommended steps.
|
381
|
+
code_snippet_key : str, optional
|
382
|
+
The key in the state to store the code snippet.
|
383
|
+
code_type : str, optional
|
384
|
+
The type of code snippet to display (e.g., "python").
|
186
385
|
|
187
386
|
Returns
|
188
387
|
-------
|
@@ -190,9 +389,11 @@ def node_func_human_review(
|
|
190
389
|
A Command object directing the next state and updates to the state.
|
191
390
|
"""
|
192
391
|
print(" * HUMAN REVIEW")
|
392
|
+
|
393
|
+
code_markdown=f"```{code_type}\n" + state.get(code_snippet_key)+"\n```"
|
193
394
|
|
194
395
|
# Display instructions and get user response
|
195
|
-
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '')))
|
396
|
+
user_input = interrupt(value=prompt_text.format(steps=state.get(recommended_steps_key, '') + "\n\n" + code_markdown))
|
196
397
|
|
197
398
|
# Decide next steps based on user input
|
198
399
|
if user_input.strip().lower() == "yes":
|
@@ -200,11 +401,11 @@ def node_func_human_review(
|
|
200
401
|
update = {}
|
201
402
|
else:
|
202
403
|
goto = no_goto
|
203
|
-
modifications = "Modifications: \n" + user_input
|
404
|
+
modifications = "User Has Requested Modifications To Previous Code: \n" + user_input
|
204
405
|
if state.get(user_instructions_key) is None:
|
205
|
-
update = {user_instructions_key: modifications}
|
406
|
+
update = {user_instructions_key: modifications + "\n\nPrevious Code:\n" + code_markdown}
|
206
407
|
else:
|
207
|
-
update = {user_instructions_key: state.get(user_instructions_key) + modifications}
|
408
|
+
update = {user_instructions_key: state.get(user_instructions_key) + modifications + "\n\nPrevious Code:\n" + code_markdown}
|
208
409
|
|
209
410
|
return Command(goto=goto, update=update)
|
210
411
|
|
@@ -394,6 +595,7 @@ def node_func_fix_agent_code(
|
|
394
595
|
retry_count_key: str = "retry_count",
|
395
596
|
log: bool = False,
|
396
597
|
file_path: str = "logs/agent_function.py",
|
598
|
+
function_name: str = "agent_function"
|
397
599
|
) -> dict:
|
398
600
|
"""
|
399
601
|
Generic function to fix a given piece of agent code using an LLM and a prompt template.
|
@@ -420,6 +622,8 @@ def node_func_fix_agent_code(
|
|
420
622
|
Whether to log the returned code to a file.
|
421
623
|
file_path : str, optional
|
422
624
|
The path to the file where the code will be logged.
|
625
|
+
function_name : str, optional
|
626
|
+
The name of the function in the code snippet that will be fixed.
|
423
627
|
|
424
628
|
Returns
|
425
629
|
-------
|
@@ -436,7 +640,8 @@ def node_func_fix_agent_code(
|
|
436
640
|
# Format the prompt with the code snippet and the error
|
437
641
|
prompt = prompt_template.format(
|
438
642
|
code_snippet=code_snippet,
|
439
|
-
error=error_message
|
643
|
+
error=error_message,
|
644
|
+
function_name=function_name,
|
440
645
|
)
|
441
646
|
|
442
647
|
# Execute the prompt with the LLM
|