ai-data-science-team 0.0.0.9013__py3-none-any.whl → 0.0.0.9015__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ from ai_data_science_team.agents import (
2
+ DataCleaningAgent,
3
+ DataLoaderToolsAgent,
4
+ DataVisualizationAgent,
5
+ SQLDatabaseAgent,
6
+ DataWranglingAgent,
7
+ FeatureEngineeringAgent,
8
+ )
9
+
10
+ from ai_data_science_team.ds_agents import (
11
+ EDAToolsAgent,
12
+ )
13
+
14
+ from ai_data_science_team.ml_agents import (
15
+ H2OMLAgent,
16
+ MLflowToolsAgent,
17
+ )
18
+
19
+ from ai_data_science_team.multiagents import (
20
+ SQLDataAnalyst,
21
+ PandasDataAnalyst,
22
+ )
@@ -1 +1 @@
1
- __version__ = "0.0.0.9013"
1
+ __version__ = "0.0.0.9015"
@@ -12,6 +12,7 @@ from langchain_core.messages import BaseMessage
12
12
 
13
13
  from langgraph.types import Command
14
14
  from langgraph.checkpoint.memory import MemorySaver
15
+ from langgraph.types import Checkpointer
15
16
 
16
17
  import os
17
18
  import json
@@ -85,6 +86,8 @@ class DataCleaningAgent(BaseAgent):
85
86
  If True, skips the default recommended cleaning steps. Defaults to False.
86
87
  bypass_explain_code : bool, optional
87
88
  If True, skips the step that provides code explanations. Defaults to False.
89
+ checkpointer : langgraph.types.Checkpointer, optional
90
+ Checkpointer to save and load the agent's state. Defaults to None.
88
91
 
89
92
  Methods
90
93
  -------
@@ -159,7 +162,8 @@ class DataCleaningAgent(BaseAgent):
159
162
  overwrite=True,
160
163
  human_in_the_loop=False,
161
164
  bypass_recommended_steps=False,
162
- bypass_explain_code=False
165
+ bypass_explain_code=False,
166
+ checkpointer: Checkpointer = None
163
167
  ):
164
168
  self._params = {
165
169
  "model": model,
@@ -172,6 +176,7 @@ class DataCleaningAgent(BaseAgent):
172
176
  "human_in_the_loop": human_in_the_loop,
173
177
  "bypass_recommended_steps": bypass_recommended_steps,
174
178
  "bypass_explain_code": bypass_explain_code,
179
+ "checkpointer": checkpointer
175
180
  }
176
181
  self._compiled_graph = self._make_compiled_graph()
177
182
  self.response = None
@@ -320,7 +325,8 @@ def make_data_cleaning_agent(
320
325
  overwrite = True,
321
326
  human_in_the_loop=False,
322
327
  bypass_recommended_steps=False,
323
- bypass_explain_code=False
328
+ bypass_explain_code=False,
329
+ checkpointer: Checkpointer = None
324
330
  ):
325
331
  """
326
332
  Creates a data cleaning agent that can be run on a dataset. The agent can be used to clean a dataset in a variety of
@@ -369,6 +375,8 @@ def make_data_cleaning_agent(
369
375
  Bypass the recommendation step, by default False
370
376
  bypass_explain_code : bool, optional
371
377
  Bypass the code explanation step, by default False.
378
+ checkpointer : langgraph.types.Checkpointer, optional
379
+ Checkpointer to save and load the agent's state. Defaults to None.
372
380
 
373
381
  Examples
374
382
  -------
@@ -400,6 +408,11 @@ def make_data_cleaning_agent(
400
408
  """
401
409
  llm = model
402
410
 
411
+ if human_in_the_loop:
412
+ if checkpointer is None:
413
+ print("Human in the loop is enabled. A checkpointer is required. Setting to MemorySaver().")
414
+ checkpointer = MemorySaver()
415
+
403
416
  # Human in th loop requires recommended steps
404
417
  if bypass_recommended_steps and human_in_the_loop:
405
418
  bypass_recommended_steps = False
@@ -680,9 +693,10 @@ def make_data_cleaning_agent(
680
693
  error_key="data_cleaner_error",
681
694
  human_in_the_loop=human_in_the_loop,
682
695
  human_review_node_name="human_review",
683
- checkpointer=MemorySaver() if human_in_the_loop else None,
696
+ checkpointer=checkpointer,
684
697
  bypass_recommended_steps=bypass_recommended_steps,
685
698
  bypass_explain_code=bypass_explain_code,
699
+ agent_name=AGENT_NAME,
686
700
  )
687
701
 
688
702
  return app
@@ -13,6 +13,7 @@ from langchain_core.messages import BaseMessage, AIMessage
13
13
 
14
14
  from langgraph.prebuilt import create_react_agent, ToolNode
15
15
  from langgraph.prebuilt.chat_agent_executor import AgentState
16
+ from langgraph.types import Checkpointer
16
17
  from langgraph.graph import START, END, StateGraph
17
18
 
18
19
  from ai_data_science_team.templates import BaseAgent
@@ -50,6 +51,8 @@ class DataLoaderToolsAgent(BaseAgent):
50
51
  Additional keyword arguments to pass to the create_react_agent function.
51
52
  invoke_react_agent_kwargs : dict
52
53
  Additional keyword arguments to pass to the invoke method of the react agent.
54
+ checkpointer : langgraph.types.Checkpointer
55
+ A checkpointer to use for saving and loading the agent's state.
53
56
 
54
57
  Methods:
55
58
  --------
@@ -73,11 +76,13 @@ class DataLoaderToolsAgent(BaseAgent):
73
76
  model: Any,
74
77
  create_react_agent_kwargs: Optional[Dict]={},
75
78
  invoke_react_agent_kwargs: Optional[Dict]={},
79
+ checkpointer: Optional[Checkpointer]=None,
76
80
  ):
77
81
  self._params = {
78
82
  "model": model,
79
83
  "create_react_agent_kwargs": create_react_agent_kwargs,
80
84
  "invoke_react_agent_kwargs": invoke_react_agent_kwargs,
85
+ "checkpointer": checkpointer,
81
86
  }
82
87
  self._compiled_graph = self._make_compiled_graph()
83
88
  self.response = None
@@ -188,6 +193,7 @@ def make_data_loader_tools_agent(
188
193
  model: Any,
189
194
  create_react_agent_kwargs: Optional[Dict]={},
190
195
  invoke_react_agent_kwargs: Optional[Dict]={},
196
+ checkpointer: Optional[Checkpointer]=None,
191
197
  ):
192
198
  """
193
199
  Creates a Data Loader Agent that can interact with data loading tools.
@@ -200,6 +206,8 @@ def make_data_loader_tools_agent(
200
206
  Additional keyword arguments to pass to the create_react_agent function.
201
207
  invoke_react_agent_kwargs : dict
202
208
  Additional keyword arguments to pass to the invoke method of the react agent.
209
+ checkpointer : langgraph.types.Checkpointer
210
+ A checkpointer to use for saving and loading the agent's state.
203
211
 
204
212
  Returns:
205
213
  --------
@@ -228,6 +236,7 @@ def make_data_loader_tools_agent(
228
236
  model,
229
237
  tools=tool_node,
230
238
  state_schema=GraphState,
239
+ checkpointer=checkpointer,
231
240
  **create_react_agent_kwargs,
232
241
  )
233
242
 
@@ -277,7 +286,10 @@ def make_data_loader_tools_agent(
277
286
  workflow.add_edge(START, "data_loader_agent")
278
287
  workflow.add_edge("data_loader_agent", END)
279
288
 
280
- app = workflow.compile()
289
+ app = workflow.compile(
290
+ checkpointer=checkpointer,
291
+ name=AGENT_NAME,
292
+ )
281
293
 
282
294
  return app
283
295