ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9014__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,22 @@
1
+ from ai_data_science_team.agents import (
2
+ DataCleaningAgent,
3
+ DataLoaderToolsAgent,
4
+ DataVisualizationAgent,
5
+ SQLDatabaseAgent,
6
+ DataWranglingAgent,
7
+ FeatureEngineeringAgent,
8
+ )
9
+
10
+ from ai_data_science_team.ds_agents import (
11
+ EDAToolsAgent,
12
+ )
13
+
14
+ from ai_data_science_team.ml_agents import (
15
+ H2OMLAgent,
16
+ MLflowToolsAgent,
17
+ )
18
+
19
+ from ai_data_science_team.multiagents import (
20
+ SQLDataAnalyst,
21
+ PandasDataAnalyst,
22
+ )
@@ -1 +1 @@
1
- __version__ = "0.0.0.9013"
1
+ __version__ = "0.0.0.9014"
@@ -12,6 +12,7 @@ from langchain_core.messages import BaseMessage
12
12
 
13
13
  from langgraph.types import Command
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
+ from langgraph.types import Checkpointer
15
16
 
16
17
  import os
17
18
  import json
@@ -85,6 +86,8 @@ class DataCleaningAgent(BaseAgent):
85
86
  If True, skips the default recommended cleaning steps. Defaults to False.
86
87
  bypass_explain_code : bool, optional
87
88
  If True, skips the step that provides code explanations. Defaults to False.
89
+ checkpointer : langgraph.types.Checkpointer, optional
90
+ Checkpointer to save and load the agent's state. Defaults to None.
88
91
 
89
92
  Methods
90
93
  -------
@@ -159,7 +162,8 @@ class DataCleaningAgent(BaseAgent):
159
162
  overwrite=True,
160
163
  human_in_the_loop=False,
161
164
  bypass_recommended_steps=False,
162
- bypass_explain_code=False
165
+ bypass_explain_code=False,
166
+ checkpointer: Checkpointer = None
163
167
  ):
164
168
  self._params = {
165
169
  "model": model,
@@ -172,6 +176,7 @@ class DataCleaningAgent(BaseAgent):
172
176
  "human_in_the_loop": human_in_the_loop,
173
177
  "bypass_recommended_steps": bypass_recommended_steps,
174
178
  "bypass_explain_code": bypass_explain_code,
179
+ "checkpointer": checkpointer
175
180
  }
176
181
  self._compiled_graph = self._make_compiled_graph()
177
182
  self.response = None
@@ -320,7 +325,8 @@ def make_data_cleaning_agent(
320
325
  overwrite = True,
321
326
  human_in_the_loop=False,
322
327
  bypass_recommended_steps=False,
323
- bypass_explain_code=False
328
+ bypass_explain_code=False,
329
+ checkpointer: Checkpointer = None
324
330
  ):
325
331
  """
326
332
  Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
@@ -369,6 +375,8 @@ def make_data_cleaning_agent(
369
375
  Bypass the recommendation step, by default False
370
376
  bypass_explain_code : bool, optional
371
377
  Bypass the code explanation step, by default False.
378
+ checkpointer : langgraph.types.Checkpointer, optional
379
+ Checkpointer to save and load the agent's state. Defaults to None.
372
380
 
373
381
  Examples
374
382
  -------
@@ -400,6 +408,11 @@ def make_data_cleaning_agent(
400
408
  """
401
409
  llm = model
402
410
 
411
+ if human_in_the_loop:
412
+ if checkpointer is None:
413
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
414
+ checkpointer = MemorySaver()
415
+
403
416
  # Human in th loop requires recommended steps
404
417
  if bypass_recommended_steps and human_in_the_loop:
405
418
  bypass_recommended_steps = False
@@ -680,9 +693,10 @@ def make_data_cleaning_agent(
680
693
  error_key="data_cleaner_error",
681
694
  human_in_the_loop=human_in_the_loop,
682
695
  human_review_node_name="human_review",
683
- checkpointer=MemorySaver() if human_in_the_loop else None,
696
+ checkpointer=checkpointer,
684
697
  bypass_recommended_steps=bypass_recommended_steps,
685
698
  bypass_explain_code=bypass_explain_code,
699
+ agent_name=AGENT_NAME,
686
700
  )
687
701
 
688
702
  return app
@@ -13,6 +13,7 @@ from langchain_core.messages import BaseMessage, AIMessage
13
13
 
14
14
  from langgraph.prebuilt import create_react_agent, ToolNode
15
15
  from langgraph.prebuilt.chat_agent_executor import AgentState
16
+ from langgraph.types import Checkpointer
16
17
  from langgraph.graph import START, END, StateGraph
17
18
 
18
19
  from ai_data_science_team.templates import BaseAgent
@@ -50,6 +51,8 @@ class DataLoaderToolsAgent(BaseAgent):
50
51
  Additional keyword arguments to pass to the create_react_agent function.
51
52
  invoke_react_agent_kwargs : dict
52
53
  Additional keyword arguments to pass to the invoke method of the react agent.
54
+ checkpointer : langgraph.types.Checkpointer
55
+ A checkpointer to use for saving and loading the agent's state.
53
56
 
54
57
  Methods:
55
58
  --------
@@ -73,11 +76,13 @@ class DataLoaderToolsAgent(BaseAgent):
73
76
  model: Any,
74
77
  create_react_agent_kwargs: Optional[Dict]={},
75
78
  invoke_react_agent_kwargs: Optional[Dict]={},
79
+ checkpointer: Optional[Checkpointer]=None,
76
80
  ):
77
81
  self._params = {
78
82
  "model": model,
79
83
  "create_react_agent_kwargs": create_react_agent_kwargs,
80
84
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
85
+ "checkpointer": checkpointer,
81
86
  }
82
87
  self._compiled_graph = self._make_compiled_graph()
83
88
  self.response = None
@@ -188,6 +193,7 @@ def make_data_loader_tools_agent(
188
193
  model: Any,
189
194
  create_react_agent_kwargs: Optional[Dict]={},
190
195
  invoke_react_agent_kwargs: Optional[Dict]={},
196
+ checkpointer: Optional[Checkpointer]=None,
191
197
  ):
192
198
  """
193
199
  Creates a Data Loader Agent that can interact with data loading tools.
@@ -200,6 +206,8 @@ def make_data_loader_tools_agent(
200
206
  Additional keyword arguments to pass to the create_react_agent function.
201
207
  invoke_react_agent_kwargs : dict
202
208
  Additional keyword arguments to pass to the invoke method of the react agent.
209
+ checkpointer : langgraph.types.Checkpointer
210
+ A checkpointer to use for saving and loading the agent's state.
203
211
 
204
212
  Returns:
205
213
  --------
@@ -228,6 +236,7 @@ def make_data_loader_tools_agent(
228
236
  model,
229
237
  tools=tool_node,
230
238
  state_schema=GraphState,
239
+ checkpointer=checkpointer,
231
240
  **create_react_agent_kwargs,
232
241
  )
233
242
 
@@ -277,7 +286,10 @@ def make_data_loader_tools_agent(
277
286
  workflow.add_edge(START, "data_loader_agent")
278
287
  workflow.add_edge("data_loader_agent", END)
279
288
 
280
- app = workflow.compile()
289
+ app = workflow.compile(
290
+ checkpointer=checkpointer,
291
+ name=AGENT_NAME,
292
+ )
281
293
 
282
294
  return app
283
295
 
@@ -14,6 +14,7 @@ from langchain_core.messages import BaseMessage
14
14
 
15
15
  from langgraph.types import Command
16
16
  from langgraph.checkpoint.memory import MemorySaver
17
+ from langgraph.types import Checkpointer
17
18
 
18
19
  import os
19
20
  import json
@@ -85,6 +86,8 @@ class DataVisualizationAgent(BaseAgent):
85
86
  If True, skips the default recommended visualization steps. Defaults to False.
86
87
  bypass_explain_code : bool, optional
87
88
  If True, skips the step that provides code explanations. Defaults to False.
89
+ checkpointer : langgraph.types.Checkpointer
90
+ A checkpointer to use for saving and loading the agent
88
91
 
89
92
  Methods
90
93
  -------
@@ -161,7 +164,8 @@ class DataVisualizationAgent(BaseAgent):
161
164
  overwrite=True,
162
165
  human_in_the_loop=False,
163
166
  bypass_recommended_steps=False,
164
- bypass_explain_code=False
167
+ bypass_explain_code=False,
168
+ checkpointer=None,
165
169
  ):
166
170
  self._params = {
167
171
  "model": model,
@@ -174,6 +178,7 @@ class DataVisualizationAgent(BaseAgent):
174
178
  "human_in_the_loop": human_in_the_loop,
175
179
  "bypass_recommended_steps": bypass_recommended_steps,
176
180
  "bypass_explain_code": bypass_explain_code,
181
+ "checkpointer": checkpointer,
177
182
  }
178
183
  self._compiled_graph = self._make_compiled_graph()
179
184
  self.response = None
@@ -385,7 +390,8 @@ def make_data_visualization_agent(
385
390
  overwrite=True,
386
391
  human_in_the_loop=False,
387
392
  bypass_recommended_steps=False,
388
- bypass_explain_code=False
393
+ bypass_explain_code=False,
394
+ checkpointer=None,
389
395
  ):
390
396
  """
391
397
  Creates a data visualization agent that can generate Plotly charts based on user-defined instructions or
@@ -423,6 +429,8 @@ def make_data_visualization_agent(
423
429
  If True, skips the default recommended visualization steps. Defaults to False.
424
430
  bypass_explain_code : bool, optional
425
431
  If True, skips the step that provides code explanations. Defaults to False.
432
+ checkpointer : langgraph.types.Checkpointer
433
+ A checkpointer to use for saving and loading the agent
426
434
 
427
435
  Examples
428
436
  --------
@@ -455,6 +463,11 @@ def make_data_visualization_agent(
455
463
 
456
464
  llm = model
457
465
 
466
+ if human_in_the_loop:
467
+ if checkpointer is None:
468
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
469
+ checkpointer = MemorySaver()
470
+
458
471
  # Human in th loop requires recommended steps
459
472
  if bypass_recommended_steps and human_in_the_loop:
460
473
  bypass_recommended_steps = False
@@ -751,9 +764,10 @@ def make_data_visualization_agent(
751
764
  error_key="data_visualization_error",
752
765
  human_in_the_loop=human_in_the_loop, # or False
753
766
  human_review_node_name="human_review",
754
- checkpointer=MemorySaver() if human_in_the_loop else None,
767
+ checkpointer=checkpointer,
755
768
  bypass_recommended_steps=bypass_recommended_steps,
756
769
  bypass_explain_code=bypass_explain_code,
770
+ agent_name=AGENT_NAME,
757
771
  )
758
772
 
759
773
  return app
@@ -13,7 +13,7 @@ from IPython.display import Markdown
13
13
 
14
14
  from langchain.prompts import PromptTemplate
15
15
  from langchain_core.messages import BaseMessage
16
- from langgraph.types import Command
16
+ from langgraph.types import Command, Checkpointer
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
 
19
19
  from ai_data_science_team.templates import(
@@ -83,6 +83,8 @@ class DataWranglingAgent(BaseAgent):
83
83
  If True, skips the step that generates recommended data wrangling steps. Defaults to False.
84
84
  bypass_explain_code : bool, optional
85
85
  If True, skips the step that provides code explanations. Defaults to False.
86
+ checkpointer : Checkpointer, optional
87
+ A checkpointer object to save and load the agent's state. Defaults to None.
86
88
 
87
89
  Methods
88
90
  -------
@@ -180,7 +182,8 @@ class DataWranglingAgent(BaseAgent):
180
182
  overwrite=True,
181
183
  human_in_the_loop=False,
182
184
  bypass_recommended_steps=False,
183
- bypass_explain_code=False
185
+ bypass_explain_code=False,
186
+ checkpointer=None,
184
187
  ):
185
188
  self._params = {
186
189
  "model": model,
@@ -192,7 +195,8 @@ class DataWranglingAgent(BaseAgent):
192
195
  "overwrite": overwrite,
193
196
  "human_in_the_loop": human_in_the_loop,
194
197
  "bypass_recommended_steps": bypass_recommended_steps,
195
- "bypass_explain_code": bypass_explain_code
198
+ "bypass_explain_code": bypass_explain_code,
199
+ "checkpointer": checkpointer,
196
200
  }
197
201
  self._compiled_graph = self._make_compiled_graph()
198
202
  self.response = None
@@ -443,7 +447,8 @@ def make_data_wrangling_agent(
443
447
  overwrite=True,
444
448
  human_in_the_loop=False,
445
449
  bypass_recommended_steps=False,
446
- bypass_explain_code=False
450
+ bypass_explain_code=False,
451
+ checkpointer=None,
447
452
  ):
448
453
  """
449
454
  Creates a data wrangling agent that can be run on one or more datasets. The agent can be
@@ -488,6 +493,8 @@ def make_data_wrangling_agent(
488
493
  Bypass the recommendation step, by default False
489
494
  bypass_explain_code : bool, optional
490
495
  Bypass the code explanation step, by default False.
496
+ checkpointer : Checkpointer, optional
497
+ A checkpointer object to save and load the agent's state. Defaults to None.
491
498
 
492
499
  Example
493
500
  -------
@@ -520,6 +527,11 @@ def make_data_wrangling_agent(
520
527
  """
521
528
  llm = model
522
529
 
530
+ if human_in_the_loop:
531
+ if checkpointer is None:
532
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
533
+ checkpointer = MemorySaver()
534
+
523
535
  # Human in th loop requires recommended steps
524
536
  if bypass_recommended_steps and human_in_the_loop:
525
537
  bypass_recommended_steps = False
@@ -569,7 +581,7 @@ def make_data_wrangling_agent(
569
581
 
570
582
  # Create a summary for all datasets
571
583
  # We'll include a short sample and info for each dataset
572
- all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
584
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
573
585
 
574
586
  # Join all datasets summaries into one big text block
575
587
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -642,7 +654,7 @@ def make_data_wrangling_agent(
642
654
 
643
655
  # Create a summary for all datasets
644
656
  # We'll include a short sample and info for each dataset
645
- all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples)
657
+ all_datasets_summary = get_dataframe_summary(dataframes, n_sample=n_samples, skip_stats=True)
646
658
 
647
659
  # Join all datasets summaries into one big text block
648
660
  all_datasets_summary_str = "\n\n".join(all_datasets_summary)
@@ -654,9 +666,12 @@ def make_data_wrangling_agent(
654
666
 
655
667
  data_wrangling_prompt = PromptTemplate(
656
668
  template="""
657
- You are a Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data.
669
+ You are a Pandas Data Wrangling Coding Agent. Your job is to create a {function_name}() function that can be run on the provided data. You should use Pandas and NumPy for data wrangling operations.
670
+
671
+ User instructions:
672
+ {user_instructions}
658
673
 
659
- Follow these recommended steps:
674
+ Follow these recommended steps (if present):
660
675
  {recommended_steps}
661
676
 
662
677
  If multiple datasets are provided, you may need to merge or join them. Make sure to handle that scenario based on the recommended steps and user instructions.
@@ -685,17 +700,21 @@ def make_data_wrangling_agent(
685
700
  1. If the incoming data is not a list. Convert it to a list first.
686
701
  2. Do not specify data types inside the function arguments.
687
702
 
703
+ Important Notes:
704
+ 1. Do Not use Print statements to display the data. Return the data frame instead with the data wrangling operation performed.
705
+
688
706
  Make sure to explain any non-trivial steps with inline comments. Follow user instructions. Comment code thoroughly.
689
707
 
690
708
 
691
709
  """,
692
- input_variables=["recommended_steps", "all_datasets_summary", "function_name"]
710
+ input_variables=["recommended_steps", "user_instructions", "all_datasets_summary", "function_name"]
693
711
  )
694
712
 
695
713
  data_wrangling_agent = data_wrangling_prompt | llm | PythonOutputParser()
696
714
 
697
715
  response = data_wrangling_agent.invoke({
698
716
  "recommended_steps": state.get("recommended_steps"),
717
+ "user_instructions": state.get("user_instructions"),
699
718
  "all_datasets_summary": all_datasets_summary_str,
700
719
  "function_name": function_name
701
720
  })
@@ -835,9 +854,10 @@ def make_data_wrangling_agent(
835
854
  error_key="data_wrangler_error",
836
855
  human_in_the_loop=human_in_the_loop,
837
856
  human_review_node_name="human_review",
838
- checkpointer=MemorySaver() if human_in_the_loop else None,
857
+ checkpointer=checkpointer,
839
858
  bypass_recommended_steps=bypass_recommended_steps,
840
859
  bypass_explain_code=bypass_explain_code,
860
+ agent_name=AGENT_NAME,
841
861
  )
842
862
 
843
863
  return app
@@ -10,7 +10,7 @@ import operator
10
10
  from langchain.prompts import PromptTemplate
11
11
  from langchain_core.messages import BaseMessage
12
12
 
13
- from langgraph.types import Command
13
+ from langgraph.types import Command, Checkpointer
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
15
 
16
16
  import os
@@ -84,6 +84,8 @@ class FeatureEngineeringAgent(BaseAgent):
84
84
  If True, skips the default recommended steps. Defaults to False.
85
85
  bypass_explain_code : bool, optional
86
86
  If True, skips the step that provides code explanations. Defaults to False.
87
+ checkpointer : Checkpointer, optional
88
+ Checkpointer to save and load the agent's state. Defaults to None.
87
89
 
88
90
  Methods
89
91
  -------
@@ -170,7 +172,8 @@ class FeatureEngineeringAgent(BaseAgent):
170
172
  overwrite=True,
171
173
  human_in_the_loop=False,
172
174
  bypass_recommended_steps=False,
173
- bypass_explain_code=False
175
+ bypass_explain_code=False,
176
+ checkpointer=None,
174
177
  ):
175
178
  self._params = {
176
179
  "model": model,
@@ -182,7 +185,8 @@ class FeatureEngineeringAgent(BaseAgent):
182
185
  "overwrite": overwrite,
183
186
  "human_in_the_loop": human_in_the_loop,
184
187
  "bypass_recommended_steps": bypass_recommended_steps,
185
- "bypass_explain_code": bypass_explain_code
188
+ "bypass_explain_code": bypass_explain_code,
189
+ "checkpointer": checkpointer,
186
190
  }
187
191
  self._compiled_graph = self._make_compiled_graph()
188
192
  self.response = None
@@ -400,6 +404,7 @@ def make_feature_engineering_agent(
400
404
  human_in_the_loop=False,
401
405
  bypass_recommended_steps=False,
402
406
  bypass_explain_code=False,
407
+ checkpointer=None,
403
408
  ):
404
409
  """
405
410
  Creates a feature engineering agent that can be run on a dataset. The agent applies various feature engineering
@@ -448,6 +453,8 @@ def make_feature_engineering_agent(
448
453
  Bypass the recommendation step, by default False
449
454
  bypass_explain_code : bool, optional
450
455
  Bypass the code explanation step, by default False.
456
+ checkpointer : Checkpointer, optional
457
+ Checkpointer to save and load the agent's state. Defaults to None.
451
458
 
452
459
  Examples
453
460
  -------
@@ -480,6 +487,11 @@ def make_feature_engineering_agent(
480
487
  """
481
488
  llm = model
482
489
 
490
+ if human_in_the_loop:
491
+ if checkpointer is None:
492
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
493
+ checkpointer = MemorySaver()
494
+
483
495
  # Human in th loop requires recommended steps
484
496
  if bypass_recommended_steps and human_in_the_loop:
485
497
  bypass_recommended_steps = False
@@ -782,9 +794,10 @@ def make_feature_engineering_agent(
782
794
  retry_count_key = "retry_count",
783
795
  human_in_the_loop=human_in_the_loop,
784
796
  human_review_node_name="human_review",
785
- checkpointer=MemorySaver(),
797
+ checkpointer=checkpointer,
786
798
  bypass_recommended_steps=bypass_recommended_steps,
787
799
  bypass_explain_code=bypass_explain_code,
800
+ agent_name=AGENT_NAME,
788
801
  )
789
802
 
790
803
  return app
@@ -7,7 +7,7 @@ from langchain.prompts import PromptTemplate
7
7
  from langchain_core.messages import BaseMessage
8
8
  from langchain_core.output_parsers import JsonOutputParser
9
9
 
10
- from langgraph.types import Command
10
+ from langgraph.types import Command, Checkpointer
11
11
  from langgraph.checkpoint.memory import MemorySaver
12
12
 
13
13
  import os
@@ -75,6 +75,8 @@ class SQLDatabaseAgent(BaseAgent):
75
75
  If True, skips the step that generates recommended SQL steps. Defaults to False.
76
76
  bypass_explain_code : bool, optional
77
77
  If True, skips the step that provides code explanations. Defaults to False.
78
+ checkpointer : Checkpointer, optional
79
+ A checkpointer to save and load the agent's state. Defaults to None.
78
80
  smart_schema_pruning : bool, optional
79
81
  If True, filters the tables and columns based on the user instructions and recommended steps. Defaults to False.
80
82
 
@@ -157,6 +159,7 @@ class SQLDatabaseAgent(BaseAgent):
157
159
  human_in_the_loop=False,
158
160
  bypass_recommended_steps=False,
159
161
  bypass_explain_code=False,
162
+ checkpointer=None,
160
163
  smart_schema_pruning=False,
161
164
  ):
162
165
  self._params = {
@@ -171,6 +174,7 @@ class SQLDatabaseAgent(BaseAgent):
171
174
  "human_in_the_loop": human_in_the_loop,
172
175
  "bypass_recommended_steps": bypass_recommended_steps,
173
176
  "bypass_explain_code": bypass_explain_code,
177
+ "checkpointer": checkpointer,
174
178
  "smart_schema_pruning": smart_schema_pruning,
175
179
  }
176
180
  self._compiled_graph = self._make_compiled_graph()
@@ -365,6 +369,7 @@ def make_sql_database_agent(
365
369
  human_in_the_loop=False,
366
370
  bypass_recommended_steps=False,
367
371
  bypass_explain_code=False,
372
+ checkpointer=None,
368
373
  smart_schema_pruning=False,
369
374
  ):
370
375
  """
@@ -394,6 +399,8 @@ def make_sql_database_agent(
394
399
  Bypass the recommendation step, by default False
395
400
  bypass_explain_code : bool, optional
396
401
  Bypass the code explanation step, by default False.
402
+ checkpointer : Checkpointer, optional
403
+ A checkpointer to save and load the agent's state. Defaults to None.
397
404
  smart_schema_pruning : bool, optional
398
405
  If True, filters the tables and columns with an extra LLM step to reduce tokens for large databases. Increases processing time but can avoid errors due to hitting max token limits with large databases. Defaults to False.
399
406
 
@@ -432,6 +439,11 @@ def make_sql_database_agent(
432
439
 
433
440
  llm = model
434
441
 
442
+ if human_in_the_loop:
443
+ if checkpointer is None:
444
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
445
+ checkpointer = MemorySaver()
446
+
435
447
  # Human in th loop requires recommended steps
436
448
  if bypass_recommended_steps and human_in_the_loop:
437
449
  bypass_recommended_steps = False
@@ -742,9 +754,10 @@ def {function_name}(connection):
742
754
  error_key="sql_database_error",
743
755
  human_in_the_loop=human_in_the_loop,
744
756
  human_review_node_name="human_review",
745
- checkpointer=MemorySaver() if human_in_the_loop else None,
757
+ checkpointer=checkpointer,
746
758
  bypass_recommended_steps=bypass_recommended_steps,
747
759
  bypass_explain_code=bypass_explain_code,
760
+ agent_name=AGENT_NAME,
748
761
  )
749
762
 
750
763
  return app
@@ -1,12 +1,8 @@
1
1
 
2
2
 
3
- from typing import Any, Optional, Annotated, Sequence, List, Dict, Tuple
3
+ from typing import Any, Optional, Annotated, Sequence, Dict
4
4
  import operator
5
5
  import pandas as pd
6
- import os
7
- from io import StringIO, BytesIO
8
- import base64
9
- import matplotlib.pyplot as plt
10
6
 
11
7
  from IPython.display import Markdown
12
8
 
@@ -14,6 +10,7 @@ from langchain_core.messages import BaseMessage, AIMessage
14
10
  from langgraph.prebuilt import create_react_agent, ToolNode
15
11
  from langgraph.prebuilt.chat_agent_executor import AgentState
16
12
  from langgraph.graph import START, END, StateGraph
13
+ from langgraph.types import Checkpointer
17
14
 
18
15
  from ai_data_science_team.templates import BaseAgent
19
16
  from ai_data_science_team.utils.regex import format_agent_name
@@ -52,6 +49,8 @@ class EDAToolsAgent(BaseAgent):
52
49
  Additional kwargs for create_react_agent.
53
50
  invoke_react_agent_kwargs : dict
54
51
  Additional kwargs for agent invocation.
52
+ checkpointer : Checkpointer, optional
53
+ The checkpointer for the agent.
55
54
  """
56
55
 
57
56
  def __init__(
@@ -59,11 +58,13 @@ class EDAToolsAgent(BaseAgent):
59
58
  model: Any,
60
59
  create_react_agent_kwargs: Optional[Dict] = {},
61
60
  invoke_react_agent_kwargs: Optional[Dict] = {},
61
+ checkpointer: Optional[Checkpointer] = None,
62
62
  ):
63
63
  self._params = {
64
64
  "model": model,
65
65
  "create_react_agent_kwargs": create_react_agent_kwargs,
66
66
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
67
+ "checkpointer": checkpointer
67
68
  }
68
69
  self._compiled_graph = self._make_compiled_graph()
69
70
  self.response = None
@@ -176,6 +177,7 @@ def make_eda_tools_agent(
176
177
  model: Any,
177
178
  create_react_agent_kwargs: Optional[Dict] = {},
178
179
  invoke_react_agent_kwargs: Optional[Dict] = {},
180
+ checkpointer: Optional[Checkpointer] = None,
179
181
  ):
180
182
  """
181
183
  Creates an Exploratory Data Analyst Agent that can interact with EDA tools.
@@ -188,6 +190,8 @@ def make_eda_tools_agent(
188
190
  Additional kwargs for create_react_agent.
189
191
  invoke_react_agent_kwargs : dict
190
192
  Additional kwargs for agent invocation.
193
+ checkpointer : Checkpointer, optional
194
+ The checkpointer for the agent.
191
195
 
192
196
  Returns:
193
197
  -------
@@ -215,6 +219,7 @@ def make_eda_tools_agent(
215
219
  tools=tool_node,
216
220
  state_schema=GraphState,
217
221
  **create_react_agent_kwargs,
222
+ checkpointer=checkpointer,
218
223
  )
219
224
 
220
225
  response = eda_agent.invoke(
@@ -254,5 +259,9 @@ def make_eda_tools_agent(
254
259
  workflow.add_edge(START, "exploratory_agent")
255
260
  workflow.add_edge("exploratory_agent", END)
256
261
 
257
- app = workflow.compile()
262
+ app = workflow.compile(
263
+ checkpointer=checkpointer,
264
+ name=AGENT_NAME,
265
+ )
266
+
258
267
  return app
@@ -5,7 +5,7 @@
5
5
 
6
6
  import os
7
7
  import json
8
- from typing import TypedDict, Annotated, Sequence, Literal
8
+ from typing import TypedDict, Annotated, Sequence, Literal, Optional
9
9
  import operator
10
10
 
11
11
  import pandas as pd
@@ -14,7 +14,7 @@ from IPython.display import Markdown
14
14
  from langchain.prompts import PromptTemplate
15
15
  from langchain_core.messages import BaseMessage
16
16
 
17
- from langgraph.types import Command
17
+ from langgraph.types import Command, Checkpointer
18
18
  from langgraph.checkpoint.memory import MemorySaver
19
19
 
20
20
  from ai_data_science_team.templates import(
@@ -79,6 +79,8 @@ class H2OMLAgent(BaseAgent):
79
79
  Name of the MLflow experiment (created if doesn't exist).
80
80
  mlflow_run_name : str, default None
81
81
  A custom name for the MLflow run.
82
+ checkpointer : langgraph.checkpoint.memory.MemorySaver, optional
83
+ A checkpointer object for saving the agent's state. Defaults to None.
82
84
 
83
85
 
84
86
  Methods
@@ -176,6 +178,7 @@ class H2OMLAgent(BaseAgent):
176
178
  mlflow_tracking_uri=None,
177
179
  mlflow_experiment_name="H2O AutoML",
178
180
  mlflow_run_name=None,
181
+ checkpointer: Optional[Checkpointer]=None,
179
182
  ):
180
183
  self._params = {
181
184
  "model": model,
@@ -193,6 +196,7 @@ class H2OMLAgent(BaseAgent):
193
196
  "mlflow_tracking_uri": mlflow_tracking_uri,
194
197
  "mlflow_experiment_name": mlflow_experiment_name,
195
198
  "mlflow_run_name": mlflow_run_name,
199
+ "checkpointer": checkpointer,
196
200
  }
197
201
  self._compiled_graph = self._make_compiled_graph()
198
202
  self.response = None
@@ -350,6 +354,7 @@ def make_h2o_ml_agent(
350
354
  mlflow_tracking_uri=None,
351
355
  mlflow_experiment_name="H2O AutoML",
352
356
  mlflow_run_name=None,
357
+ checkpointer=None,
353
358
  ):
354
359
  """
355
360
  Creates a machine learning agent that uses H2O for AutoML.
@@ -384,6 +389,12 @@ def make_h2o_ml_agent(
384
389
  " pip install h2o\n\n"
385
390
  "Visit https://docs.h2o.ai/h2o/latest-stable/h2o-docs/downloading.html for details."
386
391
  ) from e
392
+
393
+ if human_in_the_loop:
394
+ if checkpointer is None:
395
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
396
+ checkpointer = MemorySaver()
397
+
387
398
 
388
399
  # Define GraphState
389
400
  class GraphState(TypedDict):
@@ -844,9 +855,10 @@ def make_h2o_ml_agent(
844
855
  retry_count_key="retry_count",
845
856
  human_in_the_loop=human_in_the_loop,
846
857
  human_review_node_name="human_review",
847
- checkpointer=MemorySaver(),
858
+ checkpointer=checkpointer,
848
859
  bypass_recommended_steps=bypass_recommended_steps,
849
860
  bypass_explain_code=bypass_explain_code,
861
+ agent_name=AGENT_NAME,
850
862
  )
851
863
 
852
864
  return app