ai-data-science-team 0.0.0.9013__tar.gz → 0.0.0.9014__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {ai_data_science_team-0.0.0.9013/ai_data_science_team.egg-info → ai_data_science_team-0.0.0.9014}/PKG-INFO +2 -2
- ai_data_science_team-0.0.0.9014/ai_data_science_team/__init__.py +22 -0
- ai_data_science_team-0.0.0.9014/ai_data_science_team/_version.py +1 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/data_loader_tools_agent.py +13 -1
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/data_visualization_agent.py +17 -3
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/data_wrangling_agent.py +30 -10
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/sql_database_agent.py +15 -2
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ds_agents/eda_tools_agent.py +15 -6
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ml_agents/mlflow_tools_agent.py +13 -1
- ai_data_science_team-0.0.0.9014/ai_data_science_team/multiagents/__init__.py +2 -0
- ai_data_science_team-0.0.0.9014/ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/multiagents/sql_data_analyst.py +119 -30
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/templates/agent_templates.py +41 -5
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014/ai_data_science_team.egg-info}/PKG-INFO +2 -2
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team.egg-info/SOURCES.txt +1 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team.egg-info/requires.txt +1 -1
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/requirements.txt +1 -1
- ai_data_science_team-0.0.0.9013/ai_data_science_team/_version.py +0 -1
- ai_data_science_team-0.0.0.9013/ai_data_science_team/multiagents/__init__.py +0 -1
- ai_data_science_team-0.0.0.9013/ai_data_science_team/utils/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/MANIFEST.in +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/README.md +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/agents/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ds_agents/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ds_agents/modeling_tools_agent.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ml_agents/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/ml_agents/h2o_ml_tools_agent.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/multiagents/supervised_data_analyst.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/orchestration.py +0 -0
- {ai_data_science_team-0.0.0.9013/ai_data_science_team → ai_data_science_team-0.0.0.9014/ai_data_science_team/parsers}/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/parsers/parsers.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/templates/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013/ai_data_science_team/parsers → ai_data_science_team-0.0.0.9014/ai_data_science_team/tools}/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/data_loader.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/dataframe.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/eda.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/h2o.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/mlflow.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/tools/sql.py +0 -0
- {ai_data_science_team-0.0.0.9013/ai_data_science_team/tools → ai_data_science_team-0.0.0.9014/ai_data_science_team/utils}/__init__.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/html.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/logging.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/matplotlib.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/messages.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/plotly.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team/utils/regex.py +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team.egg-info/dependency_links.txt +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/ai_data_science_team.egg-info/top_level.txt +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/setup.cfg +0 -0
- {ai_data_science_team-0.0.0.9013 → ai_data_science_team-0.0.0.9014}/setup.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: ai-data-science-team
|
3
|
-
Version: 0.0.0.
|
3
|
+
Version: 0.0.0.9014
|
4
4
|
Summary: Build and run an AI-powered data science team.
|
5
5
|
Home-page: https://github.com/business-science/ai-data-science-team
|
6
6
|
Author: Matt Dancho
|
@@ -18,7 +18,7 @@ Requires-Dist: langchain
|
|
18
18
|
Requires-Dist: langchain_community
|
19
19
|
Requires-Dist: langchain_openai
|
20
20
|
Requires-Dist: langchain_experimental
|
21
|
-
Requires-Dist: langgraph>=0.2.
|
21
|
+
Requires-Dist: langgraph>=0.2.74
|
22
22
|
Requires-Dist: openai
|
23
23
|
Requires-Dist: pandas
|
24
24
|
Requires-Dist: sqlalchemy
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from ai_data_science_team.agents import (
|
2
|
+
DataCleaningAgent,
|
3
|
+
DataLoaderToolsAgent,
|
4
|
+
DataVisualizationAgent,
|
5
|
+
SQLDatabaseAgent,
|
6
|
+
DataWranglingAgent,
|
7
|
+
FeatureEngineeringAgent,
|
8
|
+
)
|
9
|
+
|
10
|
+
from ai_data_science_team.ds_agents import (
|
11
|
+
EDAToolsAgent,
|
12
|
+
)
|
13
|
+
|
14
|
+
from ai_data_science_team.ml_agents import (
|
15
|
+
H2OMLAgent,
|
16
|
+
MLflowToolsAgent,
|
17
|
+
)
|
18
|
+
|
19
|
+
from ai_data_science_team.multiagents import (
|
20
|
+
SQLDataAnalyst,
|
21
|
+
PandasDataAnalyst,
|
22
|
+
)
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.0.0.9014"
|
@@ -12,6 +12,7 @@ from langchain_core.messages import BaseMessage
|
|
12
12
|
|
13
13
|
from langgraph.types import Command
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
|
+
from langgraph.types import Checkpointer
|
15
16
|
|
16
17
|
import os
|
17
18
|
import json
|
@@ -85,6 +86,8 @@ class DataCleaningAgent(BaseAgent):
|
|
85
86
|
If True, skips the default recommended cleaning steps. Defaults to False.
|
86
87
|
bypass_explain_code : bool, optional
|
87
88
|
If True, skips the step that provides code explanations. Defaults to False.
|
89
|
+
checkpointer : langgraph.types.Checkpointer, optional
|
90
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
88
91
|
|
89
92
|
Methods
|
90
93
|
-------
|
@@ -159,7 +162,8 @@ class DataCleaningAgent(BaseAgent):
|
|
159
162
|
overwrite=True,
|
160
163
|
human_in_the_loop=False,
|
161
164
|
bypass_recommended_steps=False,
|
162
|
-
bypass_explain_code=False
|
165
|
+
bypass_explain_code=False,
|
166
|
+
checkpointer: Checkpointer = None
|
163
167
|
):
|
164
168
|
self._params = {
|
165
169
|
"model": model,
|
@@ -172,6 +176,7 @@ class DataCleaningAgent(BaseAgent):
|
|
172
176
|
"human_in_the_loop": human_in_the_loop,
|
173
177
|
"bypass_recommended_steps": bypass_recommended_steps,
|
174
178
|
"bypass_explain_code": bypass_explain_code,
|
179
|
+
"checkpointer": checkpointer
|
175
180
|
}
|
176
181
|
self._compiled_graph = self._make_compiled_graph()
|
177
182
|
self.response = None
|
@@ -320,7 +325,8 @@ def make_data_cleaning_agent(
|
|
320
325
|
overwrite = True,
|
321
326
|
human_in_the_loop=False,
|
322
327
|
bypass_recommended_steps=False,
|
323
|
-
bypass_explain_code=False
|
328
|
+
bypass_explain_code=False,
|
329
|
+
checkpointer: Checkpointer = None
|
324
330
|
):
|
325
331
|
"""
|
326
332
|
Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
|
@@ -369,6 +375,8 @@ def make_data_cleaning_agent(
|
|
369
375
|
Bypass the recommendation step, by default False
|
370
376
|
bypass_explain_code : bool, optional
|
371
377
|
Bypass the code explanation step, by default False.
|
378
|
+
checkpointer : langgraph.types.Checkpointer, optional
|
379
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
372
380
|
|
373
381
|
Examples
|
374
382
|
-------
|
@@ -400,6 +408,11 @@ def make_data_cleaning_agent(
|
|
400
408
|
"""
|
401
409
|
llm = model
|
402
410
|
|
411
|
+
if human_in_the_loop:
|
412
|
+
if checkpointer is None:
|
413
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
414
|
+
checkpointer = MemorySaver()
|
415
|
+
|
403
416
|
# Human in th loop requires recommended steps
|
404
417
|
if bypass_recommended_steps and human_in_the_loop:
|
405
418
|
bypass_recommended_steps = False
|
@@ -680,9 +693,10 @@ def make_data_cleaning_agent(
|
|
680
693
|
error_key="data_cleaner_error",
|
681
694
|
human_in_the_loop=human_in_the_loop,
|
682
695
|
human_review_node_name="human_review",
|
683
|
-
checkpointer=
|
696
|
+
checkpointer=checkpointer,
|
684
697
|
bypass_recommended_steps=bypass_recommended_steps,
|
685
698
|
bypass_explain_code=bypass_explain_code,
|
699
|
+
agent_name=AGENT_NAME,
|
686
700
|
)
|
687
701
|
|
688
702
|
return app
|
@@ -13,6 +13,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
13
13
|
|
14
14
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
15
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
|
+
from langgraph.types import Checkpointer
|
16
17
|
from langgraph.graph import START, END, StateGraph
|
17
18
|
|
18
19
|
from ai_data_science_team.templates import BaseAgent
|
@@ -50,6 +51,8 @@ class DataLoaderToolsAgent(BaseAgent):
|
|
50
51
|
Additional keyword arguments to pass to the create_react_agent function.
|
51
52
|
invoke_react_agent_kwargs : dict
|
52
53
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
54
|
+
checkpointer : langgraph.types.Checkpointer
|
55
|
+
A checkpointer to use for saving and loading the agent's state.
|
53
56
|
|
54
57
|
Methods:
|
55
58
|
--------
|
@@ -73,11 +76,13 @@ class DataLoaderToolsAgent(BaseAgent):
|
|
73
76
|
model: Any,
|
74
77
|
create_react_agent_kwargs: Optional[Dict]={},
|
75
78
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
79
|
+
checkpointer: Optional[Checkpointer]=None,
|
76
80
|
):
|
77
81
|
self._params = {
|
78
82
|
"model": model,
|
79
83
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
80
84
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
85
|
+
"checkpointer": checkpointer,
|
81
86
|
}
|
82
87
|
self._compiled_graph = self._make_compiled_graph()
|
83
88
|
self.response = None
|
@@ -188,6 +193,7 @@ def make_data_loader_tools_agent(
|
|
188
193
|
model: Any,
|
189
194
|
create_react_agent_kwargs: Optional[Dict]={},
|
190
195
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
196
|
+
checkpointer: Optional[Checkpointer]=None,
|
191
197
|
):
|
192
198
|
"""
|
193
199
|
Creates a Data Loader Agent that can interact with data loading tools.
|
@@ -200,6 +206,8 @@ def make_data_loader_tools_agent(
|
|
200
206
|
Additional keyword arguments to pass to the create_react_agent function.
|
201
207
|
invoke_react_agent_kwargs : dict
|
202
208
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
209
|
+
checkpointer : langgraph.types.Checkpointer
|
210
|
+
A checkpointer to use for saving and loading the agent's state.
|
203
211
|
|
204
212
|
Returns:
|
205
213
|
--------
|
@@ -228,6 +236,7 @@ def make_data_loader_tools_agent(
|
|
228
236
|
model,
|
229
237
|
tools=tool_node,
|
230
238
|
state_schema=GraphState,
|
239
|
+
checkpointer=checkpointer,
|
231
240
|
**create_react_agent_kwargs,
|
232
241
|
)
|
233
242
|
|
@@ -277,7 +286,10 @@ def make_data_loader_tools_agent(
|
|
277
286
|
workflow.add_edge(START, "data_loader_agent")
|
278
287
|
workflow.add_edge("data_loader_agent", END)
|
279
288
|
|
280
|
-
app = workflow.compile(
|
289
|
+
app = workflow.compile(
|
290
|
+
checkpointer=checkpointer,
|
291
|
+
name=AGENT_NAME,
|
292
|
+
)
|
281
293
|
|
282
294
|
return app
|
283
295
|
|
@@ -14,6 +14,7 @@ from langchain_core.messages import BaseMessage
|
|
14
14
|
|
15
15
|
from langgraph.types import Command
|
16
16
|
from langgraph.checkpoint.memory import MemorySaver
|
17
|
+
from langgraph.types import Checkpointer
|
17
18
|
|
18
19
|
import os
|
19
20
|
import json
|
@@ -85,6 +86,8 @@ class DataVisualizationAgent(BaseAgent):
|
|
85
86
|
If True, skips the default recommended visualization steps. Defaults to False.
|
86
87
|
bypass_explain_code : bool, optional
|
87
88
|
If True, skips the step that provides code explanations. Defaults to False.
|
89
|
+
checkpointer : langgraph.types.Checkpointer
|
90
|
+
A checkpointer to use for saving and loading the agent
|
88
91
|
|
89
92
|
Methods
|
90
93
|
-------
|
@@ -161,7 +164,8 @@ class DataVisualizationAgent(BaseAgent):
|
|
161
164
|
overwrite=True,
|
162
165
|
human_in_the_loop=False,
|
163
166
|
bypass_recommended_steps=False,
|
164
|
-
bypass_explain_code=False
|
167
|
+
bypass_explain_code=False,
|
168
|
+
checkpointer=None,
|
165
169
|
):
|
166
170
|
self._params = {
|
167
171
|
"model": model,
|
@@ -174,6 +178,7 @@ class DataVisualizationAgent(BaseAgent):
|
|
174
178
|
"human_in_the_loop": human_in_the_loop,
|
175
179
|
"bypass_recommended_steps": bypass_recommended_steps,
|
176
180
|
"bypass_explain_code": bypass_explain_code,
|
181
|
+
"checkpointer": checkpointer,
|
177
182
|
}
|
178
183
|
self._compiled_graph = self._make_compiled_graph()
|
179
184
|
self.response = None
|
@@ -385,7 +390,8 @@ def make_data_visualization_agent(
|
|
385
390
|
overwrite=True,
|
386
391
|
human_in_the_loop=False,
|
387
392
|
bypass_recommended_steps=False,
|
388
|
-
bypass_explain_code=False
|
393
|
+
bypass_explain_code=False,
|
394
|
+
checkpointer=None,
|
389
395
|
):
|
390
396
|
"""
|
391
397
|
Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
|
@@ -423,6 +429,8 @@ def make_data_visualization_agent(
|
|
423
429
|
If True, skips the default recommended visualization steps. Defaults to False.
|
424
430
|
bypass_explain_code : bool, optional
|
425
431
|
If True, skips the step that provides code explanations. Defaults to False.
|
432
|
+
checkpointer : langgraph.types.Checkpointer
|
433
|
+
A checkpointer to use for saving and loading the agent
|
426
434
|
|
427
435
|
Examples
|
428
436
|
--------
|
@@ -455,6 +463,11 @@ def make_data_visualization_agent(
|
|
455
463
|
|
456
464
|
llm = model
|
457
465
|
|
466
|
+
if human_in_the_loop:
|
467
|
+
if checkpointer is None:
|
468
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
469
|
+
checkpointer = MemorySaver()
|
470
|
+
|
458
471
|
# Human in th loop requires recommended steps
|
459
472
|
if bypass_recommended_steps and human_in_the_loop:
|
460
473
|
bypass_recommended_steps = False
|
@@ -751,9 +764,10 @@ def make_data_visualization_agent(
|
|
751
764
|
error_key="data_visualization_error",
|
752
765
|
human_in_the_loop=human_in_the_loop, # or False
|
753
766
|
human_review_node_name="human_review",
|
754
|
-
checkpointer=
|
767
|
+
checkpointer=checkpointer,
|
755
768
|
bypass_recommended_steps=bypass_recommended_steps,
|
756
769
|
bypass_explain_code=bypass_explain_code,
|
770
|
+
agent_name=AGENT_NAME,
|
757
771
|
)
|
758
772
|
|
759
773
|
return app
|
@@ -13,7 +13,7 @@ from IPython.display import Markdown
|
|
13
13
|
|
14
14
|
from langchain.prompts import PromptTemplate
|
15
15
|
from langchain_core.messages import BaseMessage
|
16
|
-
from langgraph.types import Command
|
16
|
+
from langgraph.types import Command, Checkpointer
|
17
17
|
from langgraph.checkpoint.memory import MemorySaver
|
18
18
|
|
19
19
|
from ai_data_science_team.templates import(
|
@@ -83,6 +83,8 @@ class DataWranglingAgent(BaseAgent):
|
|
83
83
|
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
84
84
|
bypass_explain_code : bool, optional
|
85
85
|
If True, skips the step that provides code explanations. Defaults to False.
|
86
|
+
checkpointer : Checkpointer, optional
|
87
|
+
A checkpointer object to save and load the agent's state. Defaults to None.
|
86
88
|
|
87
89
|
Methods
|
88
90
|
-------
|
@@ -180,7 +182,8 @@ class DataWranglingAgent(BaseAgent):
|
|
180
182
|
overwrite=True,
|
181
183
|
human_in_the_loop=False,
|
182
184
|
bypass_recommended_steps=False,
|
183
|
-
bypass_explain_code=False
|
185
|
+
bypass_explain_code=False,
|
186
|
+
checkpointer=None,
|
184
187
|
):
|
185
188
|
self._params = {
|
186
189
|
"model": model,
|
@@ -192,7 +195,8 @@ class DataWranglingAgent(BaseAgent):
|
|
192
195
|
"overwrite": overwrite,
|
193
196
|
"human_in_the_loop": human_in_the_loop,
|
194
197
|
"bypass_recommended_steps": bypass_recommended_steps,
|
195
|
-
"bypass_explain_code": bypass_explain_code
|
198
|
+
"bypass_explain_code": bypass_explain_code,
|
199
|
+
"checkpointer": checkpointer,
|
196
200
|
}
|
197
201
|
self._compiled_graph = self._make_compiled_graph()
|
198
202
|
self.response = None
|
@@ -443,7 +447,8 @@ def make_data_wrangling_agent(
|
|
443
447
|
overwrite=True,
|
444
448
|
human_in_the_loop=False,
|
445
449
|
bypass_recommended_steps=False,
|
446
|
-
bypass_explain_code=False
|
450
|
+
bypass_explain_code=False,
|
451
|
+
checkpointer=None,
|
447
452
|
):
|
448
453
|
"""
|
449
454
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
@@ -488,6 +493,8 @@ def make_data_wrangling_agent(
|
|
488
493
|
Bypass the recommendation step, by default False
|
489
494
|
bypass_explain_code : bool, optional
|
490
495
|
Bypass the code explanation step, by default False.
|
496
|
+
checkpointer : Checkpointer, optional
|
497
|
+
A checkpointer object to save and load the agent's state. Defaults to None.
|
491
498
|
|
492
499
|
Example
|
493
500
|
-------
|
@@ -520,6 +527,11 @@ def make_data_wrangling_agent(
|
|
520
527
|
"""
|
521
528
|
llm = model
|
522
529
|
|
530
|
+
if human_in_the_loop:
|
531
|
+
if checkpointer is None:
|
532
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
533
|
+
checkpointer = MemorySaver()
|
534
|
+
|
523
535
|
# Human in th loop requires recommended steps
|
524
536
|
if bypass_recommended_steps and human_in_the_loop:
|
525
537
|
bypass_recommended_steps = False
|
@@ -569,7 +581,7 @@ def make_data_wrangling_agent(
|
|
569
581
|
|
570
582
|
# Create a summary for all datasets
|
571
583
|
# We'll include a short sample and info for each dataset
|
572
|
-
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
584
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
|
573
585
|
|
574
586
|
# Join all datasets summaries into one big text block
|
575
587
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -642,7 +654,7 @@ def make_data_wrangling_agent(
|
|
642
654
|
|
643
655
|
# Create a summary for all datasets
|
644
656
|
# We'll include a short sample and info for each dataset
|
645
|
-
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
657
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
|
646
658
|
|
647
659
|
# Join all datasets summaries into one big text block
|
648
660
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -654,9 +666,12 @@ def make_data_wrangling_agent(
|
|
654
666
|
|
655
667
|
data_wrangling_prompt = PromptTemplate(
|
656
668
|
template="""
|
657
|
-
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
669
|
+
You are a Pandas Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data. You should use Pandas and NumPy for data wrangling operations.
|
670
|
+
|
671
|
+
User instructions:
|
672
|
+
{user_instructions}
|
658
673
|
|
659
|
-
Follow these recommended steps:
|
674
|
+
Follow these recommended steps (if present):
|
660
675
|
{recommended_steps}
|
661
676
|
|
662
677
|
If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
|
@@ -685,17 +700,21 @@ def make_data_wrangling_agent(
|
|
685
700
|
1. If the incoming data is not a list. Convert it to a list first.
|
686
701
|
2. Do not specify data types inside the function arguments.
|
687
702
|
|
703
|
+
Important Notes:
|
704
|
+
1. Do Not use Print statements to display the data. Return the data frame instead with the data wrangling operation performed.
|
705
|
+
|
688
706
|
Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
|
689
707
|
|
690
708
|
|
691
709
|
""",
|
692
|
-
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
710
|
+
input_variables=["recommended_steps", "user_instructions", "all_datasets_summary", "function_name"]
|
693
711
|
)
|
694
712
|
|
695
713
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
696
714
|
|
697
715
|
response = data_wrangling_agent.invoke({
|
698
716
|
"recommended_steps": state.get("recommended_steps"),
|
717
|
+
"user_instructions": state.get("user_instructions"),
|
699
718
|
"all_datasets_summary": all_datasets_summary_str,
|
700
719
|
"function_name": function_name
|
701
720
|
})
|
@@ -835,9 +854,10 @@ def make_data_wrangling_agent(
|
|
835
854
|
error_key="data_wrangler_error",
|
836
855
|
human_in_the_loop=human_in_the_loop,
|
837
856
|
human_review_node_name="human_review",
|
838
|
-
checkpointer=
|
857
|
+
checkpointer=checkpointer,
|
839
858
|
bypass_recommended_steps=bypass_recommended_steps,
|
840
859
|
bypass_explain_code=bypass_explain_code,
|
860
|
+
agent_name=AGENT_NAME,
|
841
861
|
)
|
842
862
|
|
843
863
|
return app
|
@@ -10,7 +10,7 @@ import operator
|
|
10
10
|
from langchain.prompts import PromptTemplate
|
11
11
|
from langchain_core.messages import BaseMessage
|
12
12
|
|
13
|
-
from langgraph.types import Command
|
13
|
+
from langgraph.types import Command, Checkpointer
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
16
|
import os
|
@@ -84,6 +84,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
84
84
|
If True, skips the default recommended steps. Defaults to False.
|
85
85
|
bypass_explain_code : bool, optional
|
86
86
|
If True, skips the step that provides code explanations. Defaults to False.
|
87
|
+
checkpointer : Checkpointer, optional
|
88
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
87
89
|
|
88
90
|
Methods
|
89
91
|
-------
|
@@ -170,7 +172,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
170
172
|
overwrite=True,
|
171
173
|
human_in_the_loop=False,
|
172
174
|
bypass_recommended_steps=False,
|
173
|
-
bypass_explain_code=False
|
175
|
+
bypass_explain_code=False,
|
176
|
+
checkpointer=None,
|
174
177
|
):
|
175
178
|
self._params = {
|
176
179
|
"model": model,
|
@@ -182,7 +185,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
182
185
|
"overwrite": overwrite,
|
183
186
|
"human_in_the_loop": human_in_the_loop,
|
184
187
|
"bypass_recommended_steps": bypass_recommended_steps,
|
185
|
-
"bypass_explain_code": bypass_explain_code
|
188
|
+
"bypass_explain_code": bypass_explain_code,
|
189
|
+
"checkpointer": checkpointer,
|
186
190
|
}
|
187
191
|
self._compiled_graph = self._make_compiled_graph()
|
188
192
|
self.response = None
|
@@ -400,6 +404,7 @@ def make_feature_engineering_agent(
|
|
400
404
|
human_in_the_loop=False,
|
401
405
|
bypass_recommended_steps=False,
|
402
406
|
bypass_explain_code=False,
|
407
|
+
checkpointer=None,
|
403
408
|
):
|
404
409
|
"""
|
405
410
|
Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
|
@@ -448,6 +453,8 @@ def make_feature_engineering_agent(
|
|
448
453
|
Bypass the recommendation step, by default False
|
449
454
|
bypass_explain_code : bool, optional
|
450
455
|
Bypass the code explanation step, by default False.
|
456
|
+
checkpointer : Checkpointer, optional
|
457
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
451
458
|
|
452
459
|
Examples
|
453
460
|
-------
|
@@ -480,6 +487,11 @@ def make_feature_engineering_agent(
|
|
480
487
|
"""
|
481
488
|
llm = model
|
482
489
|
|
490
|
+
if human_in_the_loop:
|
491
|
+
if checkpointer is None:
|
492
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
493
|
+
checkpointer = MemorySaver()
|
494
|
+
|
483
495
|
# Human in th loop requires recommended steps
|
484
496
|
if bypass_recommended_steps and human_in_the_loop:
|
485
497
|
bypass_recommended_steps = False
|
@@ -782,9 +794,10 @@ def make_feature_engineering_agent(
|
|
782
794
|
retry_count_key = "retry_count",
|
783
795
|
human_in_the_loop=human_in_the_loop,
|
784
796
|
human_review_node_name="human_review",
|
785
|
-
checkpointer=
|
797
|
+
checkpointer=checkpointer,
|
786
798
|
bypass_recommended_steps=bypass_recommended_steps,
|
787
799
|
bypass_explain_code=bypass_explain_code,
|
800
|
+
agent_name=AGENT_NAME,
|
788
801
|
)
|
789
802
|
|
790
803
|
return app
|
@@ -7,7 +7,7 @@ from langchain.prompts import PromptTemplate
|
|
7
7
|
from langchain_core.messages import BaseMessage
|
8
8
|
from langchain_core.output_parsers import JsonOutputParser
|
9
9
|
|
10
|
-
from langgraph.types import Command
|
10
|
+
from langgraph.types import Command, Checkpointer
|
11
11
|
from langgraph.checkpoint.memory import MemorySaver
|
12
12
|
|
13
13
|
import os
|
@@ -75,6 +75,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
75
75
|
If True, skips the step that generates recommended SQL steps. Defaults to False.
|
76
76
|
bypass_explain_code : bool, optional
|
77
77
|
If True, skips the step that provides code explanations. Defaults to False.
|
78
|
+
checkpointer : Checkpointer, optional
|
79
|
+
A checkpointer to save and load the agent's state. Defaults to None.
|
78
80
|
smart_schema_pruning : bool, optional
|
79
81
|
If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
|
80
82
|
|
@@ -157,6 +159,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
157
159
|
human_in_the_loop=False,
|
158
160
|
bypass_recommended_steps=False,
|
159
161
|
bypass_explain_code=False,
|
162
|
+
checkpointer=None,
|
160
163
|
smart_schema_pruning=False,
|
161
164
|
):
|
162
165
|
self._params = {
|
@@ -171,6 +174,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
171
174
|
"human_in_the_loop": human_in_the_loop,
|
172
175
|
"bypass_recommended_steps": bypass_recommended_steps,
|
173
176
|
"bypass_explain_code": bypass_explain_code,
|
177
|
+
"checkpointer": checkpointer,
|
174
178
|
"smart_schema_pruning": smart_schema_pruning,
|
175
179
|
}
|
176
180
|
self._compiled_graph = self._make_compiled_graph()
|
@@ -365,6 +369,7 @@ def make_sql_database_agent(
|
|
365
369
|
human_in_the_loop=False,
|
366
370
|
bypass_recommended_steps=False,
|
367
371
|
bypass_explain_code=False,
|
372
|
+
checkpointer=None,
|
368
373
|
smart_schema_pruning=False,
|
369
374
|
):
|
370
375
|
"""
|
@@ -394,6 +399,8 @@ def make_sql_database_agent(
|
|
394
399
|
Bypass the recommendation step, by default False
|
395
400
|
bypass_explain_code : bool, optional
|
396
401
|
Bypass the code explanation step, by default False.
|
402
|
+
checkpointer : Checkpointer, optional
|
403
|
+
A checkpointer to save and load the agent's state. Defaults to None.
|
397
404
|
smart_schema_pruning : bool, optional
|
398
405
|
If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
|
399
406
|
|
@@ -432,6 +439,11 @@ def make_sql_database_agent(
|
|
432
439
|
|
433
440
|
llm = model
|
434
441
|
|
442
|
+
if human_in_the_loop:
|
443
|
+
if checkpointer is None:
|
444
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
445
|
+
checkpointer = MemorySaver()
|
446
|
+
|
435
447
|
# Human in th loop requires recommended steps
|
436
448
|
if bypass_recommended_steps and human_in_the_loop:
|
437
449
|
bypass_recommended_steps = False
|
@@ -742,9 +754,10 @@ def {function_name}(connection):
|
|
742
754
|
error_key="sql_database_error",
|
743
755
|
human_in_the_loop=human_in_the_loop,
|
744
756
|
human_review_node_name="human_review",
|
745
|
-
checkpointer=
|
757
|
+
checkpointer=checkpointer,
|
746
758
|
bypass_recommended_steps=bypass_recommended_steps,
|
747
759
|
bypass_explain_code=bypass_explain_code,
|
760
|
+
agent_name=AGENT_NAME,
|
748
761
|
)
|
749
762
|
|
750
763
|
return app
|
@@ -1,12 +1,8 @@
|
|
1
1
|
|
2
2
|
|
3
|
-
from typing import Any, Optional, Annotated, Sequence,
|
3
|
+
from typing import Any, Optional, Annotated, Sequence, Dict
|
4
4
|
import operator
|
5
5
|
import pandas as pd
|
6
|
-
import os
|
7
|
-
from io import StringIO, BytesIO
|
8
|
-
import base64
|
9
|
-
import matplotlib.pyplot as plt
|
10
6
|
|
11
7
|
from IPython.display import Markdown
|
12
8
|
|
@@ -14,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
14
10
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
11
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
12
|
from langgraph.graph import START, END, StateGraph
|
13
|
+
from langgraph.types import Checkpointer
|
17
14
|
|
18
15
|
from ai_data_science_team.templates import BaseAgent
|
19
16
|
from ai_data_science_team.utils.regex import format_agent_name
|
@@ -52,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
|
|
52
49
|
Additional kwargs for create_react_agent.
|
53
50
|
invoke_react_agent_kwargs : dict
|
54
51
|
Additional kwargs for agent invocation.
|
52
|
+
checkpointer : Checkpointer, optional
|
53
|
+
The checkpointer for the agent.
|
55
54
|
"""
|
56
55
|
|
57
56
|
def __init__(
|
@@ -59,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
|
|
59
58
|
model: Any,
|
60
59
|
create_react_agent_kwargs: Optional[Dict] = {},
|
61
60
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
61
|
+
checkpointer: Optional[Checkpointer] = None,
|
62
62
|
):
|
63
63
|
self._params = {
|
64
64
|
"model": model,
|
65
65
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
66
66
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
67
|
+
"checkpointer": checkpointer
|
67
68
|
}
|
68
69
|
self._compiled_graph = self._make_compiled_graph()
|
69
70
|
self.response = None
|
@@ -176,6 +177,7 @@ def make_eda_tools_agent(
|
|
176
177
|
model: Any,
|
177
178
|
create_react_agent_kwargs: Optional[Dict] = {},
|
178
179
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
180
|
+
checkpointer: Optional[Checkpointer] = None,
|
179
181
|
):
|
180
182
|
"""
|
181
183
|
Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
|
@@ -188,6 +190,8 @@ def make_eda_tools_agent(
|
|
188
190
|
Additional kwargs for create_react_agent.
|
189
191
|
invoke_react_agent_kwargs : dict
|
190
192
|
Additional kwargs for agent invocation.
|
193
|
+
checkpointer : Checkpointer, optional
|
194
|
+
The checkpointer for the agent.
|
191
195
|
|
192
196
|
Returns:
|
193
197
|
-------
|
@@ -215,6 +219,7 @@ def make_eda_tools_agent(
|
|
215
219
|
tools=tool_node,
|
216
220
|
state_schema=GraphState,
|
217
221
|
**create_react_agent_kwargs,
|
222
|
+
checkpointer=checkpointer,
|
218
223
|
)
|
219
224
|
|
220
225
|
response = eda_agent.invoke(
|
@@ -254,5 +259,9 @@ def make_eda_tools_agent(
|
|
254
259
|
workflow.add_edge(START, "exploratory_agent")
|
255
260
|
workflow.add_edge("exploratory_agent", END)
|
256
261
|
|
257
|
-
app = workflow.compile(
|
262
|
+
app = workflow.compile(
|
263
|
+
checkpointer=checkpointer,
|
264
|
+
name=AGENT_NAME,
|
265
|
+
)
|
266
|
+
|
258
267
|
return app
|