ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ai_data_science_team/_version.py +1 -1
  2. ai_data_science_team/agents/__init__.py +0 -1
  3. ai_data_science_team/agents/data_cleaning_agent.py +50 -39
  4. ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
  5. ai_data_science_team/agents/data_visualization_agent.py +45 -50
  6. ai_data_science_team/agents/data_wrangling_agent.py +50 -49
  7. ai_data_science_team/agents/feature_engineering_agent.py +48 -67
  8. ai_data_science_team/agents/sql_database_agent.py +130 -76
  9. ai_data_science_team/ml_agents/__init__.py +2 -0
  10. ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
  11. ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
  12. ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
  13. ai_data_science_team/parsers/__init__.py +0 -0
  14. ai_data_science_team/{tools → parsers}/parsers.py +0 -1
  15. ai_data_science_team/templates/__init__.py +1 -0
  16. ai_data_science_team/templates/agent_templates.py +78 -7
  17. ai_data_science_team/tools/data_loader.py +378 -0
  18. ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
  19. ai_data_science_team/tools/h2o.py +643 -0
  20. ai_data_science_team/tools/mlflow.py +961 -0
  21. ai_data_science_team/tools/sql.py +126 -0
  22. ai_data_science_team/{tools → utils}/regex.py +59 -1
  23. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
  24. ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
  25. ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
  26. /ai_data_science_team/{tools → utils}/logging.py +0 -0
  27. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
  28. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
  29. {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+
2
+ from typing import Any, Optional, Annotated, Sequence
3
+ import operator
4
+
5
+ import pandas as pd
6
+
7
+ from IPython.display import Markdown
8
+
9
+ from langchain_core.messages import BaseMessage, AIMessage
10
+
11
+ from langgraph.prebuilt import create_react_agent, ToolNode
12
+ from langgraph.prebuilt.chat_agent_executor import AgentState
13
+ from langgraph.graph import START, END, StateGraph
14
+
15
+ from ai_data_science_team.templates import BaseAgent
16
+ from ai_data_science_team.utils.regex import format_agent_name
17
+ from ai_data_science_team.tools.mlflow import (
18
+ mlflow_search_experiments,
19
+ mlflow_search_runs,
20
+ mlflow_create_experiment,
21
+ mlflow_predict_from_run_id,
22
+ mlflow_launch_ui,
23
+ mlflow_stop_ui,
24
+ mlflow_list_artifacts,
25
+ mlflow_download_artifacts,
26
+ mlflow_list_registered_models,
27
+ mlflow_search_registered_models,
28
+ mlflow_get_model_version_details,
29
+ )
30
+
31
+ AGENT_NAME = "mlflow_tools_agent"
32
+
33
+ # TOOL SETUP
34
+ tools = [
35
+ mlflow_search_experiments,
36
+ mlflow_search_runs,
37
+ mlflow_create_experiment,
38
+ mlflow_predict_from_run_id,
39
+ mlflow_launch_ui,
40
+ mlflow_stop_ui,
41
+ mlflow_list_artifacts,
42
+ mlflow_download_artifacts,
43
+ mlflow_list_registered_models,
44
+ mlflow_search_registered_models,
45
+ mlflow_get_model_version_details,
46
+ ]
47
+
48
+ class MLflowToolsAgent(BaseAgent):
49
+ """
50
+ An agent that can interact with MLflow by calling tools.
51
+
52
+ Current tools include:
53
+ - List Experiments
54
+ - Search Runs
55
+ - Create Experiment
56
+ - Predict (from a Run ID)
57
+
58
+ Parameters:
59
+ ----------
60
+ model : langchain.llms.base.LLM
61
+ The language model used to generate the tool calling agent.
62
+ mlfow_tracking_uri : str, optional
63
+ The tracking URI for MLflow. Defaults to None.
64
+ mlflow_registry_uri : str, optional
65
+ The registry URI for MLflow. Defaults to None.
66
+ **react_agent_kwargs : dict, optional
67
+ Additional keyword arguments to pass to the agent's react agent.
68
+
69
+ Methods:
70
+ --------
71
+ update_params(**kwargs):
72
+ Updates the agent's parameters and rebuilds the compiled graph.
73
+ ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
74
+ Asynchronously runs the agent with the given user instructions.
75
+ invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
76
+ Runs the agent with the given user instructions.
77
+ get_internal_messages(markdown: bool=False):
78
+ Returns the internal messages from the agent's response.
79
+ get_mlflow_artifacts(as_dataframe: bool=False):
80
+ Returns the MLflow artifacts from the agent's response.
81
+ get_ai_message(markdown: bool=False):
82
+ Returns the AI message from the agent's response
83
+
84
+
85
+
86
+ Examples:
87
+ --------
88
+ ```python
89
+ from ai_data_science_team.ml_agents import MLflowToolsAgent
90
+
91
+ mlflow_agent = MLflowToolsAgent(llm)
92
+
93
+ mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
94
+
95
+ mlflow_agent.get_response()
96
+
97
+ mlflow_agent.get_internal_messages(markdown=True)
98
+
99
+ mlflow_agent.get_ai_message(markdown=True)
100
+
101
+ mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
102
+
103
+ ```
104
+
105
+ Returns
106
+ -------
107
+ MLflowToolsAgent : langchain.graphs.CompiledStateGraph
108
+ An instance of the MLflow Tools Agent.
109
+
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ model: Any,
115
+ mlflow_tracking_uri: Optional[str]=None,
116
+ mlflow_registry_uri: Optional[str]=None,
117
+ **react_agent_kwargs,
118
+ ):
119
+ self._params = {
120
+ "model": model,
121
+ "mlflow_tracking_uri": mlflow_tracking_uri,
122
+ "mlflow_registry_uri": mlflow_registry_uri,
123
+ **react_agent_kwargs,
124
+ }
125
+ self._compiled_graph = self._make_compiled_graph()
126
+ self.response = None
127
+
128
+ def _make_compiled_graph(self):
129
+ """
130
+ Creates the compiled graph for the agent.
131
+ """
132
+ self.response = None
133
+ return make_mlflow_tools_agent(**self._params)
134
+
135
+
136
+ def update_params(self, **kwargs):
137
+ """
138
+ Updates the agent's parameters and rebuilds the compiled graph.
139
+ """
140
+ for k, v in kwargs.items():
141
+ self._params[k] = v
142
+ self._compiled_graph = self._make_compiled_graph()
143
+
144
+ async def ainvoke_agent(
145
+ self,
146
+ user_instructions: str=None,
147
+ data_raw: pd.DataFrame=None,
148
+ **kwargs
149
+ ):
150
+ """
151
+ Runs the agent with the given user instructions.
152
+
153
+ Parameters:
154
+ ----------
155
+ user_instructions : str, optional
156
+ The user instructions to pass to the agent.
157
+ data_raw : pd.DataFrame, optional
158
+ The data to pass to the agent. Used for prediction and tool calls where data is required.
159
+ kwargs : dict, optional
160
+ Additional keyword arguments to pass to the agents ainvoke method.
161
+
162
+ """
163
+ response = await self._compiled_graph.ainvoke(
164
+ {
165
+ "user_instructions": user_instructions,
166
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
167
+ },
168
+ **kwargs
169
+ )
170
+ self.response = response
171
+ return None
172
+
173
+ def invoke_agent(
174
+ self,
175
+ user_instructions: str=None,
176
+ data_raw: pd.DataFrame=None,
177
+ **kwargs
178
+ ):
179
+ """
180
+ Runs the agent with the given user instructions.
181
+
182
+ Parameters:
183
+ ----------
184
+ user_instructions : str, optional
185
+ The user instructions to pass to the agent.
186
+ data_raw : pd.DataFrame, optional
187
+ The raw data to pass to the agent. Used for prediction and tool calls where data is required.
188
+ kwargs : dict, optional
189
+ Additional keyword arguments to pass to the agents invoke method.
190
+
191
+ """
192
+ response = self._compiled_graph.invoke(
193
+ {
194
+ "user_instructions": user_instructions,
195
+ "data_raw": data_raw.to_dict() if data_raw is not None else None,
196
+ },
197
+ **kwargs
198
+ )
199
+ self.response = response
200
+ return None
201
+
202
+ def get_internal_messages(self, markdown: bool=False):
203
+ """
204
+ Returns the internal messages from the agent's response.
205
+ """
206
+ pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])
207
+ if markdown:
208
+ return Markdown(pretty_print)
209
+ else:
210
+ return self.response["internal_messages"]
211
+
212
+ def get_mlflow_artifacts(self, as_dataframe: bool=False):
213
+ """
214
+ Returns the MLflow artifacts from the agent's response.
215
+ """
216
+ if as_dataframe:
217
+ return pd.DataFrame(self.response["mlflow_artifacts"])
218
+ else:
219
+ return self.response["mlflow_artifacts"]
220
+
221
+ def get_ai_message(self, markdown: bool=False):
222
+ """
223
+ Returns the AI message from the agent's response.
224
+ """
225
+ if markdown:
226
+ return Markdown(self.response["messages"][0].content)
227
+ else:
228
+ return self.response["messages"][0].content
229
+
230
+
231
+
232
+
233
+ def make_mlflow_tools_agent(
234
+ model: Any,
235
+ mlflow_tracking_uri: str=None,
236
+ mlflow_registry_uri: str=None,
237
+ **react_agent_kwargs,
238
+ ):
239
+ """
240
+ MLflow Tool Calling Agent
241
+ """
242
+
243
+ try:
244
+ import mlflow
245
+ except ImportError:
246
+ return "MLflow is not installed. Please install it by running: !pip install mlflow"
247
+
248
+ if mlflow_tracking_uri is not None:
249
+ mlflow.set_tracking_uri(mlflow_tracking_uri)
250
+
251
+ if mlflow_registry_uri is not None:
252
+ mlflow.set_registry_uri(mlflow_registry_uri)
253
+
254
+ class GraphState(AgentState):
255
+ internal_messages: Annotated[Sequence[BaseMessage], operator.add]
256
+ user_instructions: str
257
+ data_raw: dict
258
+ mlflow_artifacts: dict
259
+
260
+
261
+ def mflfow_tools_agent(state):
262
+ """
263
+ Postprocesses the MLflow state, keeping only the last message
264
+ and extracting the last tool artifact.
265
+ """
266
+ print(format_agent_name(AGENT_NAME))
267
+ print(" * RUN REACT TOOL-CALLING AGENT")
268
+
269
+ tool_node = ToolNode(
270
+ tools=tools
271
+ )
272
+
273
+ mlflow_agent = create_react_agent(
274
+ model,
275
+ tools=tool_node,
276
+ state_schema=GraphState,
277
+ **react_agent_kwargs,
278
+ )
279
+
280
+ response = mlflow_agent.invoke(
281
+ {
282
+ "messages": [("user", state["user_instructions"])],
283
+ "data_raw": state["data_raw"],
284
+ },
285
+ )
286
+
287
+ print(" * POST-PROCESS RESULTS")
288
+
289
+ internal_messages = response['messages']
290
+
291
+ # Ensure there is at least one AI message
292
+ if not internal_messages:
293
+ return {
294
+ "internal_messages": [],
295
+ "mlflow_artifacts": None,
296
+ }
297
+
298
+ # Get the last AI message
299
+ last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
300
+
301
+ # Get the last tool artifact safely
302
+ last_tool_artifact = None
303
+ if len(internal_messages) > 1:
304
+ last_message = internal_messages[-2] # Get second-to-last message
305
+ if hasattr(last_message, "artifact"): # Check if it has an "artifact"
306
+ last_tool_artifact = last_message.artifact
307
+ elif isinstance(last_message, dict) and "artifact" in last_message:
308
+ last_tool_artifact = last_message["artifact"]
309
+
310
+ return {
311
+ "messages": [last_ai_message],
312
+ "internal_messages": internal_messages,
313
+ "mlflow_artifacts": last_tool_artifact,
314
+ }
315
+
316
+
317
+ workflow = StateGraph(GraphState)
318
+
319
+ workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
320
+
321
+ workflow.add_edge(START, "mlflow_tools_agent")
322
+ workflow.add_edge("mlflow_tools_agent", END)
323
+
324
+ app = workflow.compile()
325
+
326
+ return app
327
+
@@ -1,24 +1,24 @@
1
1
 
2
2
  from langchain_core.messages import BaseMessage
3
- from langgraph.checkpoint.memory import MemorySaver
4
3
  from langgraph.types import Checkpointer
5
4
 
6
5
  from langgraph.graph import START, END, StateGraph
7
6
  from langgraph.graph.state import CompiledStateGraph
8
7
  from langgraph.types import Command
9
8
 
10
- from typing import TypedDict, Annotated, Sequence
9
+ from typing import TypedDict, Annotated, Sequence, Literal
11
10
  import operator
12
11
 
13
- from typing_extensions import TypedDict, Literal
12
+ from typing_extensions import TypedDict
14
13
 
15
14
  import pandas as pd
15
+ import json
16
16
  from IPython.display import Markdown
17
17
 
18
18
  from ai_data_science_team.templates import BaseAgent
19
19
  from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
20
20
  from ai_data_science_team.utils.plotly import plotly_from_dict
21
-
21
+ from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
22
22
 
23
23
 
24
24
  class SQLDataAnalyst(BaseAgent):
@@ -90,7 +90,7 @@ class SQLDataAnalyst(BaseAgent):
90
90
  self._params[k] = v
91
91
  self._compiled_graph = self._make_compiled_graph()
92
92
 
93
- def ainvoke_agent(self, user_instructions, **kwargs):
93
+ async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
94
94
  """
95
95
  Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
96
96
 
@@ -108,15 +108,53 @@ class SQLDataAnalyst(BaseAgent):
108
108
  Example:
109
109
  --------
110
110
  ``` python
111
- # TODO
111
+ from langchain_openai import ChatOpenAI
112
+ import sqlalchemy as sql
113
+ from ai_data_science_team.multiagents import SQLDataAnalyst
114
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
115
+
116
+ llm = ChatOpenAI(model = "gpt-4o-mini")
117
+
118
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
119
+
120
+ conn = sql_engine.connect()
121
+
122
+ sql_data_analyst = SQLDataAnalyst(
123
+ model = llm,
124
+ sql_database_agent = SQLDatabaseAgent(
125
+ model = llm,
126
+ connection = conn,
127
+ n_samples = 1,
128
+ ),
129
+ data_visualization_agent = DataVisualizationAgent(
130
+ model = llm,
131
+ n_samples = 10,
132
+ )
133
+ )
134
+
135
+ sql_data_analyst.ainvoke_agent(
136
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
137
+ )
138
+
139
+ sql_data_analyst.get_sql_query_code()
140
+
141
+ sql_data_analyst.get_data_sql()
142
+
143
+ sql_data_analyst.get_plotly_graph()
112
144
  ```
113
145
  """
114
- response = self._compiled_graph.ainvoke({
146
+ response = await self._compiled_graph.ainvoke({
115
147
  "user_instructions": user_instructions,
148
+ "max_retries": max_retries,
149
+ "retry_count": retry_count,
116
150
  }, **kwargs)
151
+
152
+ if response.get("messages"):
153
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
154
+
117
155
  self.response = response
118
156
 
119
- def invoke_agent(self, user_instructions, **kwargs):
157
+ def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
120
158
  """
121
159
  Invokes the SQL Data Analyst Multi-Agent.
122
160
 
@@ -124,6 +162,10 @@ class SQLDataAnalyst(BaseAgent):
124
162
  ----------
125
163
  user_instructions: str
126
164
  The user's instructions for the combined SQL and (optionally) Data Visualization agents.
165
+ max_retries (int):
166
+ Maximum retry attempts for cleaning.
167
+ retry_count (int):
168
+ Current retry attempt.
127
169
  **kwargs:
128
170
  Additional keyword arguments to pass to the compiled graph's `invoke` method.
129
171
 
@@ -134,14 +176,53 @@ class SQLDataAnalyst(BaseAgent):
134
176
  Example:
135
177
  --------
136
178
  ``` python
137
- # TODO
179
+ from langchain_openai import ChatOpenAI
180
+ import sqlalchemy as sql
181
+ from ai_data_science_team.multiagents import SQLDataAnalyst
182
+ from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
183
+
184
+ llm = ChatOpenAI(model = "gpt-4o-mini")
185
+
186
+ sql_engine = sql.create_engine("sqlite:///data/northwind.db")
187
+
188
+ conn = sql_engine.connect()
189
+
190
+ sql_data_analyst = SQLDataAnalyst(
191
+ model = llm,
192
+ sql_database_agent = SQLDatabaseAgent(
193
+ model = llm,
194
+ connection = conn,
195
+ n_samples = 1,
196
+ ),
197
+ data_visualization_agent = DataVisualizationAgent(
198
+ model = llm,
199
+ n_samples = 10,
200
+ )
201
+ )
202
+
203
+ sql_data_analyst.invoke_agent(
204
+ user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
205
+ )
206
+
207
+ sql_data_analyst.get_sql_query_code()
208
+
209
+ sql_data_analyst.get_data_sql()
210
+
211
+ sql_data_analyst.get_plotly_graph()
138
212
  ```
139
213
  """
140
214
  response = self._compiled_graph.invoke({
141
215
  "user_instructions": user_instructions,
216
+ "max_retries": max_retries,
217
+ "retry_count": retry_count,
142
218
  }, **kwargs)
219
+
220
+ if response.get("messages"):
221
+ response["messages"] = remove_consecutive_duplicates(response["messages"])
222
+
143
223
  self.response = response
144
224
 
225
+
145
226
  def get_data_sql(self):
146
227
  """
147
228
  Returns the SQL data as a Pandas DataFrame.
@@ -205,6 +286,34 @@ class SQLDataAnalyst(BaseAgent):
205
286
  if markdown:
206
287
  return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
207
288
  return self.response.get("data_visualization_function")
289
+
290
+ def get_workflow_summary(self, markdown=False):
291
+ """
292
+ Returns a summary of the SQL Data Analyst workflow.
293
+
294
+ Parameters:
295
+ ----------
296
+ markdown: bool
297
+ If True, returns the summary as a Markdown-formatted string.
298
+ """
299
+ if self.response and self.get_response()['messages']:
300
+
301
+ agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
302
+
303
+ agent_labels = []
304
+ for i in range(len(agents)):
305
+ agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
306
+
307
+ # Construct header
308
+ header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
309
+
310
+ reports = []
311
+ for msg in self.get_response()['messages']:
312
+ reports.append(get_generic_summary(json.loads(msg.content)))
313
+
314
+ if markdown:
315
+ return Markdown(header + "\n\n".join(reports))
316
+ return "\n\n".join(reports)
208
317
 
209
318
 
210
319
 
@@ -250,6 +359,8 @@ def make_sql_data_analyst(
250
359
  plot_required: bool
251
360
  data_visualization_function: str
252
361
  plotly_graph: dict
362
+ max_retries: int
363
+ retry_count: int
253
364
 
254
365
  def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]:
255
366
 
File without changes
@@ -3,7 +3,6 @@
3
3
  # ***
4
4
  # Parsers
5
5
 
6
- from langchain_core.output_parsers import JsonOutputParser
7
6
  from langchain_core.output_parsers import BaseOutputParser
8
7
 
9
8
  import re
@@ -3,6 +3,7 @@ from ai_data_science_team.templates.agent_templates import(
3
3
  node_func_human_review,
4
4
  node_func_fix_agent_code,
5
5
  node_func_explain_agent_code,
6
+ node_func_report_agent_outputs,
6
7
  node_func_execute_agent_from_sql_connection,
7
8
  create_coding_agent_graph,
8
9
  BaseAgent,
@@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode
8
8
 
9
9
  import pandas as pd
10
10
  import sqlalchemy as sql
11
+ import json
11
12
 
12
- from typing import Any, Callable, Dict, Type, Optional, Union
13
+ from typing import Any, Callable, Dict, Type, Optional, Union, List
13
14
 
14
- from ai_data_science_team.tools.parsers import PythonOutputParser
15
- from ai_data_science_team.tools.regex import relocate_imports_inside_function, add_comments_to_top
15
+ from ai_data_science_team.parsers.parsers import PythonOutputParser
16
+ from ai_data_science_team.utils.regex import (
17
+ relocate_imports_inside_function,
18
+ add_comments_to_top,
19
+ remove_consecutive_duplicates
20
+ )
16
21
 
17
22
  from IPython.display import Image, display
18
23
  import pandas as pd
@@ -82,9 +87,13 @@ class BaseAgent(CompiledStateGraph):
82
87
  Any: The agent's response.
83
88
  """
84
89
  self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
90
+
91
+ if self.response.get("messages"):
92
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
93
+
85
94
  return self.response
86
95
 
87
- def ainvoke(
96
+ async def ainvoke(
88
97
  self,
89
98
  input: Union[dict[str, Any], Any],
90
99
  config: Optional[RunnableConfig] = None,
@@ -101,7 +110,11 @@ class BaseAgent(CompiledStateGraph):
101
110
  Returns:
102
111
  Any: The agent's response.
103
112
  """
104
- self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
113
+ self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
114
+
115
+ if self.response.get("messages"):
116
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
117
+
105
118
  return self.response
106
119
 
107
120
  def stream(
@@ -129,9 +142,13 @@ class BaseAgent(CompiledStateGraph):
129
142
  Any: The agent's response.
130
143
  """
131
144
  self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
145
+
146
+ if self.response.get("messages"):
147
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
148
+
132
149
  return self.response
133
150
 
134
- def astream(
151
+ async def astream(
135
152
  self,
136
153
  input: dict[str, Any] | Any,
137
154
  config: RunnableConfig | None = None,
@@ -155,7 +172,11 @@ class BaseAgent(CompiledStateGraph):
155
172
  Returns:
156
173
  Any: The agent's response.
157
174
  """
158
- self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
175
+ self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
176
+
177
+ if self.response.get("messages"):
178
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
179
+
159
180
  return self.response
160
181
 
161
182
  def get_state_keys(self):
@@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph):
183
204
  Returns:
184
205
  Any: The agent's response.
185
206
  """
207
+ if self.response.get("messages"):
208
+ self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
209
+
186
210
  return self.response
187
211
 
188
212
  def show(self, xray: int = 0):
@@ -729,3 +753,50 @@ def node_func_explain_agent_code(
729
753
  # Return an error message if there was a problem with the code
730
754
  message = AIMessage(content=error_message)
731
755
  return {result_key: [message]}
756
+
757
+
758
+
759
+ def node_func_report_agent_outputs(
760
+ state: Dict[str, Any],
761
+ keys_to_include: List[str],
762
+ result_key: str,
763
+ role: str,
764
+ custom_title: str = "Agent Output Summary"
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Gathers relevant data directly from the state (filtered by `keys_to_include`)
768
+ and returns them as a structured message in `state[result_key]`.
769
+
770
+ No LLM is used.
771
+
772
+ Parameters
773
+ ----------
774
+ state : Dict[str, Any]
775
+ The current state dictionary holding all agent variables.
776
+ keys_to_include : List[str]
777
+ The list of keys in `state` to include in the output.
778
+ result_key : str
779
+ The key in `state` under which we'll store the final structured message.
780
+ role : str
781
+ The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
782
+ custom_title : str, optional
783
+ A title or heading for your report. Defaults to "Agent Output Summary".
784
+ """
785
+ print(" * REPORT AGENT OUTPUTS")
786
+
787
+ final_report = {"report_title": custom_title}
788
+
789
+ for key in keys_to_include:
790
+ final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
791
+
792
+ # Wrap it in a list of messages (like the current "messages" pattern).
793
+ # You can serialize this dictionary as JSON or just cast it to string.
794
+ return {
795
+ result_key: [
796
+ AIMessage(
797
+ content=json.dumps(final_report, indent=2),
798
+ role=role
799
+ )
800
+ ]
801
+ }
802
+