ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9015__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/__init__.py +22 -0
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/data_cleaning_agent.py +17 -3
- ai_data_science_team/agents/data_loader_tools_agent.py +13 -1
- ai_data_science_team/agents/data_visualization_agent.py +187 -130
- ai_data_science_team/agents/data_wrangling_agent.py +31 -10
- ai_data_science_team/agents/feature_engineering_agent.py +17 -4
- ai_data_science_team/agents/sql_database_agent.py +15 -2
- ai_data_science_team/ds_agents/eda_tools_agent.py +15 -6
- ai_data_science_team/ml_agents/h2o_ml_agent.py +15 -3
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +13 -1
- ai_data_science_team/multiagents/__init__.py +2 -1
- ai_data_science_team/multiagents/pandas_data_analyst.py +305 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +126 -48
- ai_data_science_team/templates/agent_templates.py +41 -5
- ai_data_science_team/tools/eda.py +2 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/METADATA +6 -5
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/RECORD +21 -20
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/WHEEL +1 -1
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9013.dist-info → ai_data_science_team-0.0.0.9015.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from IPython.display import Markdown
|
|
13
13
|
|
14
14
|
from langchain.prompts import PromptTemplate
|
15
15
|
from langchain_core.messages import BaseMessage
|
16
|
-
from langgraph.types import Command
|
16
|
+
from langgraph.types import Command, Checkpointer
|
17
17
|
from langgraph.checkpoint.memory import MemorySaver
|
18
18
|
|
19
19
|
from ai_data_science_team.templates import(
|
@@ -83,6 +83,8 @@ class DataWranglingAgent(BaseAgent):
|
|
83
83
|
If True, skips the step that generates recommended data wrangling steps. Defaults to False.
|
84
84
|
bypass_explain_code : bool, optional
|
85
85
|
If True, skips the step that provides code explanations. Defaults to False.
|
86
|
+
checkpointer : Checkpointer, optional
|
87
|
+
A checkpointer object to save and load the agent's state. Defaults to None.
|
86
88
|
|
87
89
|
Methods
|
88
90
|
-------
|
@@ -180,7 +182,8 @@ class DataWranglingAgent(BaseAgent):
|
|
180
182
|
overwrite=True,
|
181
183
|
human_in_the_loop=False,
|
182
184
|
bypass_recommended_steps=False,
|
183
|
-
bypass_explain_code=False
|
185
|
+
bypass_explain_code=False,
|
186
|
+
checkpointer=None,
|
184
187
|
):
|
185
188
|
self._params = {
|
186
189
|
"model": model,
|
@@ -192,7 +195,8 @@ class DataWranglingAgent(BaseAgent):
|
|
192
195
|
"overwrite": overwrite,
|
193
196
|
"human_in_the_loop": human_in_the_loop,
|
194
197
|
"bypass_recommended_steps": bypass_recommended_steps,
|
195
|
-
"bypass_explain_code": bypass_explain_code
|
198
|
+
"bypass_explain_code": bypass_explain_code,
|
199
|
+
"checkpointer": checkpointer,
|
196
200
|
}
|
197
201
|
self._compiled_graph = self._make_compiled_graph()
|
198
202
|
self.response = None
|
@@ -443,7 +447,8 @@ def make_data_wrangling_agent(
|
|
443
447
|
overwrite=True,
|
444
448
|
human_in_the_loop=False,
|
445
449
|
bypass_recommended_steps=False,
|
446
|
-
bypass_explain_code=False
|
450
|
+
bypass_explain_code=False,
|
451
|
+
checkpointer=None,
|
447
452
|
):
|
448
453
|
"""
|
449
454
|
Creates a data wrangling agent that can be run on one or more datasets. The agent can be
|
@@ -488,6 +493,8 @@ def make_data_wrangling_agent(
|
|
488
493
|
Bypass the recommendation step, by default False
|
489
494
|
bypass_explain_code : bool, optional
|
490
495
|
Bypass the code explanation step, by default False.
|
496
|
+
checkpointer : Checkpointer, optional
|
497
|
+
A checkpointer object to save and load the agent's state. Defaults to None.
|
491
498
|
|
492
499
|
Example
|
493
500
|
-------
|
@@ -520,6 +527,11 @@ def make_data_wrangling_agent(
|
|
520
527
|
"""
|
521
528
|
llm = model
|
522
529
|
|
530
|
+
if human_in_the_loop:
|
531
|
+
if checkpointer is None:
|
532
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
533
|
+
checkpointer = MemorySaver()
|
534
|
+
|
523
535
|
# Human in th loop requires recommended steps
|
524
536
|
if bypass_recommended_steps and human_in_the_loop:
|
525
537
|
bypass_recommended_steps = False
|
@@ -569,7 +581,7 @@ def make_data_wrangling_agent(
|
|
569
581
|
|
570
582
|
# Create a summary for all datasets
|
571
583
|
# We'll include a short sample and info for each dataset
|
572
|
-
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
584
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
|
573
585
|
|
574
586
|
# Join all datasets summaries into one big text block
|
575
587
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -642,7 +654,7 @@ def make_data_wrangling_agent(
|
|
642
654
|
|
643
655
|
# Create a summary for all datasets
|
644
656
|
# We'll include a short sample and info for each dataset
|
645
|
-
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
|
657
|
+
all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
|
646
658
|
|
647
659
|
# Join all datasets summaries into one big text block
|
648
660
|
all_datasets_summary_str = "\n\n".join(all_datasets_summary)
|
@@ -654,9 +666,12 @@ def make_data_wrangling_agent(
|
|
654
666
|
|
655
667
|
data_wrangling_prompt = PromptTemplate(
|
656
668
|
template="""
|
657
|
-
You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
|
669
|
+
You are a Pandas Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data. You should use Pandas and NumPy for data wrangling operations.
|
670
|
+
|
671
|
+
User instructions:
|
672
|
+
{user_instructions}
|
658
673
|
|
659
|
-
Follow these recommended steps:
|
674
|
+
Follow these recommended steps (if present):
|
660
675
|
{recommended_steps}
|
661
676
|
|
662
677
|
If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
|
@@ -685,17 +700,22 @@ def make_data_wrangling_agent(
|
|
685
700
|
1. If the incoming data is not a list. Convert it to a list first.
|
686
701
|
2. Do not specify data types inside the function arguments.
|
687
702
|
|
703
|
+
Important Notes:
|
704
|
+
1. Do Not use Print statements to display the data. Return the data frame instead with the data wrangling operation performed.
|
705
|
+
2. Do not plot graphs. Only return the data frame.
|
706
|
+
|
688
707
|
Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
|
689
708
|
|
690
709
|
|
691
710
|
""",
|
692
|
-
input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
|
711
|
+
input_variables=["recommended_steps", "user_instructions", "all_datasets_summary", "function_name"]
|
693
712
|
)
|
694
713
|
|
695
714
|
data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
|
696
715
|
|
697
716
|
response = data_wrangling_agent.invoke({
|
698
717
|
"recommended_steps": state.get("recommended_steps"),
|
718
|
+
"user_instructions": state.get("user_instructions"),
|
699
719
|
"all_datasets_summary": all_datasets_summary_str,
|
700
720
|
"function_name": function_name
|
701
721
|
})
|
@@ -835,9 +855,10 @@ def make_data_wrangling_agent(
|
|
835
855
|
error_key="data_wrangler_error",
|
836
856
|
human_in_the_loop=human_in_the_loop,
|
837
857
|
human_review_node_name="human_review",
|
838
|
-
checkpointer=
|
858
|
+
checkpointer=checkpointer,
|
839
859
|
bypass_recommended_steps=bypass_recommended_steps,
|
840
860
|
bypass_explain_code=bypass_explain_code,
|
861
|
+
agent_name=AGENT_NAME,
|
841
862
|
)
|
842
863
|
|
843
864
|
return app
|
@@ -10,7 +10,7 @@ import operator
|
|
10
10
|
from langchain.prompts import PromptTemplate
|
11
11
|
from langchain_core.messages import BaseMessage
|
12
12
|
|
13
|
-
from langgraph.types import Command
|
13
|
+
from langgraph.types import Command, Checkpointer
|
14
14
|
from langgraph.checkpoint.memory import MemorySaver
|
15
15
|
|
16
16
|
import os
|
@@ -84,6 +84,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
84
84
|
If True, skips the default recommended steps. Defaults to False.
|
85
85
|
bypass_explain_code : bool, optional
|
86
86
|
If True, skips the step that provides code explanations. Defaults to False.
|
87
|
+
checkpointer : Checkpointer, optional
|
88
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
87
89
|
|
88
90
|
Methods
|
89
91
|
-------
|
@@ -170,7 +172,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
170
172
|
overwrite=True,
|
171
173
|
human_in_the_loop=False,
|
172
174
|
bypass_recommended_steps=False,
|
173
|
-
bypass_explain_code=False
|
175
|
+
bypass_explain_code=False,
|
176
|
+
checkpointer=None,
|
174
177
|
):
|
175
178
|
self._params = {
|
176
179
|
"model": model,
|
@@ -182,7 +185,8 @@ class FeatureEngineeringAgent(BaseAgent):
|
|
182
185
|
"overwrite": overwrite,
|
183
186
|
"human_in_the_loop": human_in_the_loop,
|
184
187
|
"bypass_recommended_steps": bypass_recommended_steps,
|
185
|
-
"bypass_explain_code": bypass_explain_code
|
188
|
+
"bypass_explain_code": bypass_explain_code,
|
189
|
+
"checkpointer": checkpointer,
|
186
190
|
}
|
187
191
|
self._compiled_graph = self._make_compiled_graph()
|
188
192
|
self.response = None
|
@@ -400,6 +404,7 @@ def make_feature_engineering_agent(
|
|
400
404
|
human_in_the_loop=False,
|
401
405
|
bypass_recommended_steps=False,
|
402
406
|
bypass_explain_code=False,
|
407
|
+
checkpointer=None,
|
403
408
|
):
|
404
409
|
"""
|
405
410
|
Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
|
@@ -448,6 +453,8 @@ def make_feature_engineering_agent(
|
|
448
453
|
Bypass the recommendation step, by default False
|
449
454
|
bypass_explain_code : bool, optional
|
450
455
|
Bypass the code explanation step, by default False.
|
456
|
+
checkpointer : Checkpointer, optional
|
457
|
+
Checkpointer to save and load the agent's state. Defaults to None.
|
451
458
|
|
452
459
|
Examples
|
453
460
|
-------
|
@@ -480,6 +487,11 @@ def make_feature_engineering_agent(
|
|
480
487
|
"""
|
481
488
|
llm = model
|
482
489
|
|
490
|
+
if human_in_the_loop:
|
491
|
+
if checkpointer is None:
|
492
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
493
|
+
checkpointer = MemorySaver()
|
494
|
+
|
483
495
|
# Human in th loop requires recommended steps
|
484
496
|
if bypass_recommended_steps and human_in_the_loop:
|
485
497
|
bypass_recommended_steps = False
|
@@ -782,9 +794,10 @@ def make_feature_engineering_agent(
|
|
782
794
|
retry_count_key = "retry_count",
|
783
795
|
human_in_the_loop=human_in_the_loop,
|
784
796
|
human_review_node_name="human_review",
|
785
|
-
checkpointer=
|
797
|
+
checkpointer=checkpointer,
|
786
798
|
bypass_recommended_steps=bypass_recommended_steps,
|
787
799
|
bypass_explain_code=bypass_explain_code,
|
800
|
+
agent_name=AGENT_NAME,
|
788
801
|
)
|
789
802
|
|
790
803
|
return app
|
@@ -7,7 +7,7 @@ from langchain.prompts import PromptTemplate
|
|
7
7
|
from langchain_core.messages import BaseMessage
|
8
8
|
from langchain_core.output_parsers import JsonOutputParser
|
9
9
|
|
10
|
-
from langgraph.types import Command
|
10
|
+
from langgraph.types import Command, Checkpointer
|
11
11
|
from langgraph.checkpoint.memory import MemorySaver
|
12
12
|
|
13
13
|
import os
|
@@ -75,6 +75,8 @@ class SQLDatabaseAgent(BaseAgent):
|
|
75
75
|
If True, skips the step that generates recommended SQL steps. Defaults to False.
|
76
76
|
bypass_explain_code : bool, optional
|
77
77
|
If True, skips the step that provides code explanations. Defaults to False.
|
78
|
+
checkpointer : Checkpointer, optional
|
79
|
+
A checkpointer to save and load the agent's state. Defaults to None.
|
78
80
|
smart_schema_pruning : bool, optional
|
79
81
|
If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
|
80
82
|
|
@@ -157,6 +159,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
157
159
|
human_in_the_loop=False,
|
158
160
|
bypass_recommended_steps=False,
|
159
161
|
bypass_explain_code=False,
|
162
|
+
checkpointer=None,
|
160
163
|
smart_schema_pruning=False,
|
161
164
|
):
|
162
165
|
self._params = {
|
@@ -171,6 +174,7 @@ class SQLDatabaseAgent(BaseAgent):
|
|
171
174
|
"human_in_the_loop": human_in_the_loop,
|
172
175
|
"bypass_recommended_steps": bypass_recommended_steps,
|
173
176
|
"bypass_explain_code": bypass_explain_code,
|
177
|
+
"checkpointer": checkpointer,
|
174
178
|
"smart_schema_pruning": smart_schema_pruning,
|
175
179
|
}
|
176
180
|
self._compiled_graph = self._make_compiled_graph()
|
@@ -365,6 +369,7 @@ def make_sql_database_agent(
|
|
365
369
|
human_in_the_loop=False,
|
366
370
|
bypass_recommended_steps=False,
|
367
371
|
bypass_explain_code=False,
|
372
|
+
checkpointer=None,
|
368
373
|
smart_schema_pruning=False,
|
369
374
|
):
|
370
375
|
"""
|
@@ -394,6 +399,8 @@ def make_sql_database_agent(
|
|
394
399
|
Bypass the recommendation step, by default False
|
395
400
|
bypass_explain_code : bool, optional
|
396
401
|
Bypass the code explanation step, by default False.
|
402
|
+
checkpointer : Checkpointer, optional
|
403
|
+
A checkpointer to save and load the agent's state. Defaults to None.
|
397
404
|
smart_schema_pruning : bool, optional
|
398
405
|
If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
|
399
406
|
|
@@ -432,6 +439,11 @@ def make_sql_database_agent(
|
|
432
439
|
|
433
440
|
llm = model
|
434
441
|
|
442
|
+
if human_in_the_loop:
|
443
|
+
if checkpointer is None:
|
444
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
445
|
+
checkpointer = MemorySaver()
|
446
|
+
|
435
447
|
# Human in th loop requires recommended steps
|
436
448
|
if bypass_recommended_steps and human_in_the_loop:
|
437
449
|
bypass_recommended_steps = False
|
@@ -742,9 +754,10 @@ def {function_name}(connection):
|
|
742
754
|
error_key="sql_database_error",
|
743
755
|
human_in_the_loop=human_in_the_loop,
|
744
756
|
human_review_node_name="human_review",
|
745
|
-
checkpointer=
|
757
|
+
checkpointer=checkpointer,
|
746
758
|
bypass_recommended_steps=bypass_recommended_steps,
|
747
759
|
bypass_explain_code=bypass_explain_code,
|
760
|
+
agent_name=AGENT_NAME,
|
748
761
|
)
|
749
762
|
|
750
763
|
return app
|
@@ -1,12 +1,8 @@
|
|
1
1
|
|
2
2
|
|
3
|
-
from typing import Any, Optional, Annotated, Sequence,
|
3
|
+
from typing import Any, Optional, Annotated, Sequence, Dict
|
4
4
|
import operator
|
5
5
|
import pandas as pd
|
6
|
-
import os
|
7
|
-
from io import StringIO, BytesIO
|
8
|
-
import base64
|
9
|
-
import matplotlib.pyplot as plt
|
10
6
|
|
11
7
|
from IPython.display import Markdown
|
12
8
|
|
@@ -14,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
14
10
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
15
11
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
16
12
|
from langgraph.graph import START, END, StateGraph
|
13
|
+
from langgraph.types import Checkpointer
|
17
14
|
|
18
15
|
from ai_data_science_team.templates import BaseAgent
|
19
16
|
from ai_data_science_team.utils.regex import format_agent_name
|
@@ -52,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
|
|
52
49
|
Additional kwargs for create_react_agent.
|
53
50
|
invoke_react_agent_kwargs : dict
|
54
51
|
Additional kwargs for agent invocation.
|
52
|
+
checkpointer : Checkpointer, optional
|
53
|
+
The checkpointer for the agent.
|
55
54
|
"""
|
56
55
|
|
57
56
|
def __init__(
|
@@ -59,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
|
|
59
58
|
model: Any,
|
60
59
|
create_react_agent_kwargs: Optional[Dict] = {},
|
61
60
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
61
|
+
checkpointer: Optional[Checkpointer] = None,
|
62
62
|
):
|
63
63
|
self._params = {
|
64
64
|
"model": model,
|
65
65
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
66
66
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
67
|
+
"checkpointer": checkpointer
|
67
68
|
}
|
68
69
|
self._compiled_graph = self._make_compiled_graph()
|
69
70
|
self.response = None
|
@@ -176,6 +177,7 @@ def make_eda_tools_agent(
|
|
176
177
|
model: Any,
|
177
178
|
create_react_agent_kwargs: Optional[Dict] = {},
|
178
179
|
invoke_react_agent_kwargs: Optional[Dict] = {},
|
180
|
+
checkpointer: Optional[Checkpointer] = None,
|
179
181
|
):
|
180
182
|
"""
|
181
183
|
Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
|
@@ -188,6 +190,8 @@ def make_eda_tools_agent(
|
|
188
190
|
Additional kwargs for create_react_agent.
|
189
191
|
invoke_react_agent_kwargs : dict
|
190
192
|
Additional kwargs for agent invocation.
|
193
|
+
checkpointer : Checkpointer, optional
|
194
|
+
The checkpointer for the agent.
|
191
195
|
|
192
196
|
Returns:
|
193
197
|
-------
|
@@ -215,6 +219,7 @@ def make_eda_tools_agent(
|
|
215
219
|
tools=tool_node,
|
216
220
|
state_schema=GraphState,
|
217
221
|
**create_react_agent_kwargs,
|
222
|
+
checkpointer=checkpointer,
|
218
223
|
)
|
219
224
|
|
220
225
|
response = eda_agent.invoke(
|
@@ -254,5 +259,9 @@ def make_eda_tools_agent(
|
|
254
259
|
workflow.add_edge(START, "exploratory_agent")
|
255
260
|
workflow.add_edge("exploratory_agent", END)
|
256
261
|
|
257
|
-
app = workflow.compile(
|
262
|
+
app = workflow.compile(
|
263
|
+
checkpointer=checkpointer,
|
264
|
+
name=AGENT_NAME,
|
265
|
+
)
|
266
|
+
|
258
267
|
return app
|
@@ -5,7 +5,7 @@
|
|
5
5
|
|
6
6
|
import os
|
7
7
|
import json
|
8
|
-
from typing import TypedDict, Annotated, Sequence, Literal
|
8
|
+
from typing import TypedDict, Annotated, Sequence, Literal, Optional
|
9
9
|
import operator
|
10
10
|
|
11
11
|
import pandas as pd
|
@@ -14,7 +14,7 @@ from IPython.display import Markdown
|
|
14
14
|
from langchain.prompts import PromptTemplate
|
15
15
|
from langchain_core.messages import BaseMessage
|
16
16
|
|
17
|
-
from langgraph.types import Command
|
17
|
+
from langgraph.types import Command, Checkpointer
|
18
18
|
from langgraph.checkpoint.memory import MemorySaver
|
19
19
|
|
20
20
|
from ai_data_science_team.templates import(
|
@@ -79,6 +79,8 @@ class H2OMLAgent(BaseAgent):
|
|
79
79
|
Name of the MLflow experiment (created if doesn't exist).
|
80
80
|
mlflow_run_name : str, default None
|
81
81
|
A custom name for the MLflow run.
|
82
|
+
checkpointer : langgraph.checkpoint.memory.MemorySaver, optional
|
83
|
+
A checkpointer object for saving the agent's state. Defaults to None.
|
82
84
|
|
83
85
|
|
84
86
|
Methods
|
@@ -176,6 +178,7 @@ class H2OMLAgent(BaseAgent):
|
|
176
178
|
mlflow_tracking_uri=None,
|
177
179
|
mlflow_experiment_name="H2O AutoML",
|
178
180
|
mlflow_run_name=None,
|
181
|
+
checkpointer: Optional[Checkpointer]=None,
|
179
182
|
):
|
180
183
|
self._params = {
|
181
184
|
"model": model,
|
@@ -193,6 +196,7 @@ class H2OMLAgent(BaseAgent):
|
|
193
196
|
"mlflow_tracking_uri": mlflow_tracking_uri,
|
194
197
|
"mlflow_experiment_name": mlflow_experiment_name,
|
195
198
|
"mlflow_run_name": mlflow_run_name,
|
199
|
+
"checkpointer": checkpointer,
|
196
200
|
}
|
197
201
|
self._compiled_graph = self._make_compiled_graph()
|
198
202
|
self.response = None
|
@@ -350,6 +354,7 @@ def make_h2o_ml_agent(
|
|
350
354
|
mlflow_tracking_uri=None,
|
351
355
|
mlflow_experiment_name="H2O AutoML",
|
352
356
|
mlflow_run_name=None,
|
357
|
+
checkpointer=None,
|
353
358
|
):
|
354
359
|
"""
|
355
360
|
Creates a machine learning agent that uses H2O for AutoML.
|
@@ -384,6 +389,12 @@ def make_h2o_ml_agent(
|
|
384
389
|
" pip install h2o\n\n"
|
385
390
|
"Visit https://docs.h2o.ai/h2o/latest-stable/h2o-docs/downloading.html for details."
|
386
391
|
) from e
|
392
|
+
|
393
|
+
if human_in_the_loop:
|
394
|
+
if checkpointer is None:
|
395
|
+
print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
|
396
|
+
checkpointer = MemorySaver()
|
397
|
+
|
387
398
|
|
388
399
|
# Define GraphState
|
389
400
|
class GraphState(TypedDict):
|
@@ -844,9 +855,10 @@ def make_h2o_ml_agent(
|
|
844
855
|
retry_count_key="retry_count",
|
845
856
|
human_in_the_loop=human_in_the_loop,
|
846
857
|
human_review_node_name="human_review",
|
847
|
-
checkpointer=
|
858
|
+
checkpointer=checkpointer,
|
848
859
|
bypass_recommended_steps=bypass_recommended_steps,
|
849
860
|
bypass_explain_code=bypass_explain_code,
|
861
|
+
agent_name=AGENT_NAME,
|
850
862
|
)
|
851
863
|
|
852
864
|
return app
|
@@ -10,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
|
|
10
10
|
|
11
11
|
from langgraph.prebuilt import create_react_agent, ToolNode
|
12
12
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
13
|
+
from langgraph.types import Checkpointer
|
13
14
|
from langgraph.graph import START, END, StateGraph
|
14
15
|
|
15
16
|
from ai_data_science_team.templates import BaseAgent
|
@@ -68,6 +69,8 @@ class MLflowToolsAgent(BaseAgent):
|
|
68
69
|
Additional keyword arguments to pass to the create_react_agent function.
|
69
70
|
invoke_react_agent_kwargs : dict
|
70
71
|
Additional keyword arguments to pass to the invoke method of the react agent.
|
72
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
73
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
71
74
|
|
72
75
|
Methods:
|
73
76
|
--------
|
@@ -119,6 +122,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
119
122
|
mlflow_registry_uri: Optional[str]=None,
|
120
123
|
create_react_agent_kwargs: Optional[Dict]={},
|
121
124
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
125
|
+
checkpointer: Optional[Checkpointer]=None,
|
122
126
|
):
|
123
127
|
self._params = {
|
124
128
|
"model": model,
|
@@ -126,6 +130,7 @@ class MLflowToolsAgent(BaseAgent):
|
|
126
130
|
"mlflow_registry_uri": mlflow_registry_uri,
|
127
131
|
"create_react_agent_kwargs": create_react_agent_kwargs,
|
128
132
|
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
|
133
|
+
"checkpointer": checkpointer,
|
129
134
|
}
|
130
135
|
self._compiled_graph = self._make_compiled_graph()
|
131
136
|
self.response = None
|
@@ -245,6 +250,7 @@ def make_mlflow_tools_agent(
|
|
245
250
|
mlflow_registry_uri: str=None,
|
246
251
|
create_react_agent_kwargs: Optional[Dict]={},
|
247
252
|
invoke_react_agent_kwargs: Optional[Dict]={},
|
253
|
+
checkpointer: Optional[Checkpointer]=None,
|
248
254
|
):
|
249
255
|
"""
|
250
256
|
MLflow Tool Calling Agent
|
@@ -261,6 +267,8 @@ def make_mlflow_tools_agent(
|
|
261
267
|
Additional keyword arguments to pass to the agent's create_react_agent method.
|
262
268
|
invoke_react_agent_kwargs : dict, optional
|
263
269
|
Additional keyword arguments to pass to the agent's invoke method.
|
270
|
+
checkpointer : langchain.checkpointing.Checkpointer, optional
|
271
|
+
A checkpointer to use for saving and loading the agent's state. Defaults to None.
|
264
272
|
|
265
273
|
Returns
|
266
274
|
-------
|
@@ -303,6 +311,7 @@ def make_mlflow_tools_agent(
|
|
303
311
|
model,
|
304
312
|
tools=tool_node,
|
305
313
|
state_schema=GraphState,
|
314
|
+
checkpointer=checkpointer,
|
306
315
|
**create_react_agent_kwargs,
|
307
316
|
)
|
308
317
|
|
@@ -354,7 +363,10 @@ def make_mlflow_tools_agent(
|
|
354
363
|
workflow.add_edge(START, "mlflow_tools_agent")
|
355
364
|
workflow.add_edge("mlflow_tools_agent", END)
|
356
365
|
|
357
|
-
app = workflow.compile(
|
366
|
+
app = workflow.compile(
|
367
|
+
checkpointer=checkpointer,
|
368
|
+
name=AGENT_NAME,
|
369
|
+
)
|
358
370
|
|
359
371
|
return app
|
360
372
|
|
@@ -1 +1,2 @@
|
|
1
|
-
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
1
|
+
from ai_data_science_team.multiagents.sql_data_analyst import SQLDataAnalyst, make_sql_data_analyst
|
2
|
+
from ai_data_science_team.multiagents.pandas_data_analyst import PandasDataAnalyst, make_pandas_data_analyst
|