ai-data-science-team 0.0.0.9008__py3-none-any.whl → 0.0.0.9010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_data_science_team/_version.py +1 -1
- ai_data_science_team/agents/__init__.py +0 -1
- ai_data_science_team/agents/data_cleaning_agent.py +50 -39
- ai_data_science_team/agents/data_loader_tools_agent.py +69 -0
- ai_data_science_team/agents/data_visualization_agent.py +45 -50
- ai_data_science_team/agents/data_wrangling_agent.py +50 -49
- ai_data_science_team/agents/feature_engineering_agent.py +48 -67
- ai_data_science_team/agents/sql_database_agent.py +130 -76
- ai_data_science_team/ml_agents/__init__.py +2 -0
- ai_data_science_team/ml_agents/h2o_ml_agent.py +852 -0
- ai_data_science_team/ml_agents/mlflow_tools_agent.py +327 -0
- ai_data_science_team/multiagents/sql_data_analyst.py +120 -9
- ai_data_science_team/parsers/__init__.py +0 -0
- ai_data_science_team/{tools → parsers}/parsers.py +0 -1
- ai_data_science_team/templates/__init__.py +1 -0
- ai_data_science_team/templates/agent_templates.py +78 -7
- ai_data_science_team/tools/data_loader.py +378 -0
- ai_data_science_team/tools/{metadata.py → dataframe.py} +0 -91
- ai_data_science_team/tools/h2o.py +643 -0
- ai_data_science_team/tools/mlflow.py +961 -0
- ai_data_science_team/tools/sql.py +126 -0
- ai_data_science_team/{tools → utils}/regex.py +59 -1
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/METADATA +56 -24
- ai_data_science_team-0.0.0.9010.dist-info/RECORD +35 -0
- ai_data_science_team-0.0.0.9008.dist-info/RECORD +0 -26
- /ai_data_science_team/{tools → utils}/logging.py +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/LICENSE +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/WHEEL +0 -0
- {ai_data_science_team-0.0.0.9008.dist-info → ai_data_science_team-0.0.0.9010.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1,327 @@ | |
| 1 | 
            +
             | 
| 2 | 
            +
            from typing import Any, Optional, Annotated, Sequence
         | 
| 3 | 
            +
            import operator
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from IPython.display import Markdown
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from langchain_core.messages import BaseMessage, AIMessage
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from langgraph.prebuilt import create_react_agent, ToolNode
         | 
| 12 | 
            +
            from langgraph.prebuilt.chat_agent_executor import AgentState
         | 
| 13 | 
            +
            from langgraph.graph import START, END, StateGraph
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ai_data_science_team.templates import BaseAgent
         | 
| 16 | 
            +
            from ai_data_science_team.utils.regex import format_agent_name
         | 
| 17 | 
            +
            from ai_data_science_team.tools.mlflow import (
         | 
| 18 | 
            +
                mlflow_search_experiments, 
         | 
| 19 | 
            +
                mlflow_search_runs,
         | 
| 20 | 
            +
                mlflow_create_experiment, 
         | 
| 21 | 
            +
                mlflow_predict_from_run_id,
         | 
| 22 | 
            +
                mlflow_launch_ui,
         | 
| 23 | 
            +
                mlflow_stop_ui,
         | 
| 24 | 
            +
                mlflow_list_artifacts,
         | 
| 25 | 
            +
                mlflow_download_artifacts,
         | 
| 26 | 
            +
                mlflow_list_registered_models,
         | 
| 27 | 
            +
                mlflow_search_registered_models,
         | 
| 28 | 
            +
                mlflow_get_model_version_details,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            AGENT_NAME = "mlflow_tools_agent"
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # TOOL SETUP
         | 
| 34 | 
            +
            tools = [
         | 
| 35 | 
            +
                mlflow_search_experiments, 
         | 
| 36 | 
            +
                mlflow_search_runs, 
         | 
| 37 | 
            +
                mlflow_create_experiment, 
         | 
| 38 | 
            +
                mlflow_predict_from_run_id,
         | 
| 39 | 
            +
                mlflow_launch_ui,
         | 
| 40 | 
            +
                mlflow_stop_ui,
         | 
| 41 | 
            +
                mlflow_list_artifacts,
         | 
| 42 | 
            +
                mlflow_download_artifacts,
         | 
| 43 | 
            +
                mlflow_list_registered_models,
         | 
| 44 | 
            +
                mlflow_search_registered_models,
         | 
| 45 | 
            +
                mlflow_get_model_version_details,
         | 
| 46 | 
            +
            ]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            class MLflowToolsAgent(BaseAgent):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                An agent that can interact with MLflow by calling tools.
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                Current tools include:
         | 
| 53 | 
            +
                - List Experiments
         | 
| 54 | 
            +
                - Search Runs
         | 
| 55 | 
            +
                - Create Experiment
         | 
| 56 | 
            +
                - Predict (from a Run ID)
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                Parameters:
         | 
| 59 | 
            +
                ----------
         | 
| 60 | 
            +
                model : langchain.llms.base.LLM
         | 
| 61 | 
            +
                    The language model used to generate the tool calling agent.
         | 
| 62 | 
            +
                mlfow_tracking_uri : str, optional
         | 
| 63 | 
            +
                    The tracking URI for MLflow. Defaults to None.
         | 
| 64 | 
            +
                mlflow_registry_uri : str, optional
         | 
| 65 | 
            +
                    The registry URI for MLflow. Defaults to None.
         | 
| 66 | 
            +
                **react_agent_kwargs : dict, optional
         | 
| 67 | 
            +
                    Additional keyword arguments to pass to the agent's react agent.
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                Methods:
         | 
| 70 | 
            +
                --------
         | 
| 71 | 
            +
                update_params(**kwargs):
         | 
| 72 | 
            +
                    Updates the agent's parameters and rebuilds the compiled graph.
         | 
| 73 | 
            +
                ainvoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
         | 
| 74 | 
            +
                    Asynchronously runs the agent with the given user instructions.
         | 
| 75 | 
            +
                invoke_agent(user_instructions: str=None, data_raw: pd.DataFrame=None, **kwargs):
         | 
| 76 | 
            +
                    Runs the agent with the given user instructions.
         | 
| 77 | 
            +
                get_internal_messages(markdown: bool=False):
         | 
| 78 | 
            +
                    Returns the internal messages from the agent's response.
         | 
| 79 | 
            +
                get_mlflow_artifacts(as_dataframe: bool=False):
         | 
| 80 | 
            +
                    Returns the MLflow artifacts from the agent's response.
         | 
| 81 | 
            +
                get_ai_message(markdown: bool=False):
         | 
| 82 | 
            +
                    Returns the AI message from the agent's response
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                Examples:
         | 
| 87 | 
            +
                --------
         | 
| 88 | 
            +
                ```python
         | 
| 89 | 
            +
                from ai_data_science_team.ml_agents import MLflowToolsAgent
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                mlflow_agent = MLflowToolsAgent(llm)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                mlflow_agent.invoke_agent(user_instructions="List the MLflow experiments")
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                mlflow_agent.get_response()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                mlflow_agent.get_internal_messages(markdown=True)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                mlflow_agent.get_ai_message(markdown=True)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                mlflow_agent.get_mlflow_artifacts(as_dataframe=True)
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                ```
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                Returns
         | 
| 106 | 
            +
                -------
         | 
| 107 | 
            +
                MLflowToolsAgent : langchain.graphs.CompiledStateGraph 
         | 
| 108 | 
            +
                    An instance of the MLflow Tools Agent.
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                
         | 
| 112 | 
            +
                def __init__(
         | 
| 113 | 
            +
                    self, 
         | 
| 114 | 
            +
                    model: Any,
         | 
| 115 | 
            +
                    mlflow_tracking_uri: Optional[str]=None,
         | 
| 116 | 
            +
                    mlflow_registry_uri: Optional[str]=None,
         | 
| 117 | 
            +
                    **react_agent_kwargs,
         | 
| 118 | 
            +
                ):
         | 
| 119 | 
            +
                    self._params = {
         | 
| 120 | 
            +
                        "model": model,
         | 
| 121 | 
            +
                        "mlflow_tracking_uri": mlflow_tracking_uri,
         | 
| 122 | 
            +
                        "mlflow_registry_uri": mlflow_registry_uri,
         | 
| 123 | 
            +
                        **react_agent_kwargs,
         | 
| 124 | 
            +
                    }
         | 
| 125 | 
            +
                    self._compiled_graph = self._make_compiled_graph()
         | 
| 126 | 
            +
                    self.response = None
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                def _make_compiled_graph(self):
         | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    Creates the compiled graph for the agent.
         | 
| 131 | 
            +
                    """
         | 
| 132 | 
            +
                    self.response = None
         | 
| 133 | 
            +
                    return make_mlflow_tools_agent(**self._params)
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                def update_params(self, **kwargs):
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    Updates the agent's parameters and rebuilds the compiled graph.
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    for k, v in kwargs.items():
         | 
| 141 | 
            +
                        self._params[k] = v
         | 
| 142 | 
            +
                    self._compiled_graph = self._make_compiled_graph()
         | 
| 143 | 
            +
                    
         | 
| 144 | 
            +
                async def ainvoke_agent(
         | 
| 145 | 
            +
                    self, 
         | 
| 146 | 
            +
                    user_instructions: str=None, 
         | 
| 147 | 
            +
                    data_raw: pd.DataFrame=None, 
         | 
| 148 | 
            +
                    **kwargs
         | 
| 149 | 
            +
                ):
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    Runs the agent with the given user instructions.
         | 
| 152 | 
            +
                    
         | 
| 153 | 
            +
                    Parameters:
         | 
| 154 | 
            +
                    ----------
         | 
| 155 | 
            +
                    user_instructions : str, optional
         | 
| 156 | 
            +
                        The user instructions to pass to the agent.
         | 
| 157 | 
            +
                    data_raw : pd.DataFrame, optional
         | 
| 158 | 
            +
                        The data to pass to the agent. Used for prediction and tool calls where data is required.
         | 
| 159 | 
            +
                    kwargs : dict, optional
         | 
| 160 | 
            +
                        Additional keyword arguments to pass to the agents ainvoke method.
         | 
| 161 | 
            +
                    
         | 
| 162 | 
            +
                    """
         | 
| 163 | 
            +
                    response = await self._compiled_graph.ainvoke(
         | 
| 164 | 
            +
                        {
         | 
| 165 | 
            +
                            "user_instructions": user_instructions,
         | 
| 166 | 
            +
                            "data_raw": data_raw.to_dict() if data_raw is not None else None,
         | 
| 167 | 
            +
                        }, 
         | 
| 168 | 
            +
                        **kwargs
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
            +
                    self.response = response
         | 
| 171 | 
            +
                    return None
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                def invoke_agent(
         | 
| 174 | 
            +
                    self, 
         | 
| 175 | 
            +
                    user_instructions: str=None, 
         | 
| 176 | 
            +
                    data_raw: pd.DataFrame=None, 
         | 
| 177 | 
            +
                    **kwargs
         | 
| 178 | 
            +
                ):
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    Runs the agent with the given user instructions.
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    Parameters:
         | 
| 183 | 
            +
                    ----------
         | 
| 184 | 
            +
                    user_instructions : str, optional
         | 
| 185 | 
            +
                        The user instructions to pass to the agent.
         | 
| 186 | 
            +
                    data_raw : pd.DataFrame, optional
         | 
| 187 | 
            +
                        The raw data to pass to the agent. Used for prediction and tool calls where data is required.
         | 
| 188 | 
            +
                    kwargs : dict, optional
         | 
| 189 | 
            +
                        Additional keyword arguments to pass to the agents invoke method.
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                    """
         | 
| 192 | 
            +
                    response = self._compiled_graph.invoke(
         | 
| 193 | 
            +
                        {
         | 
| 194 | 
            +
                            "user_instructions": user_instructions,
         | 
| 195 | 
            +
                            "data_raw": data_raw.to_dict() if data_raw is not None else None,
         | 
| 196 | 
            +
                        },
         | 
| 197 | 
            +
                        **kwargs
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    self.response = response
         | 
| 200 | 
            +
                    return None
         | 
| 201 | 
            +
                
         | 
| 202 | 
            +
                def get_internal_messages(self, markdown: bool=False):
         | 
| 203 | 
            +
                    """
         | 
| 204 | 
            +
                    Returns the internal messages from the agent's response.
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    pretty_print = "\n\n".join([f"### {msg.type.upper()}\n\nID: {msg.id}\n\nContent:\n\n{msg.content}" for msg in self.response["internal_messages"]])       
         | 
| 207 | 
            +
                    if markdown:
         | 
| 208 | 
            +
                        return Markdown(pretty_print)
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        return self.response["internal_messages"]
         | 
| 211 | 
            +
                
         | 
| 212 | 
            +
                def get_mlflow_artifacts(self, as_dataframe: bool=False):
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    Returns the MLflow artifacts from the agent's response.
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    if as_dataframe:
         | 
| 217 | 
            +
                        return pd.DataFrame(self.response["mlflow_artifacts"])
         | 
| 218 | 
            +
                    else:
         | 
| 219 | 
            +
                        return self.response["mlflow_artifacts"]
         | 
| 220 | 
            +
                
         | 
| 221 | 
            +
                def get_ai_message(self, markdown: bool=False):
         | 
| 222 | 
            +
                    """
         | 
| 223 | 
            +
                    Returns the AI message from the agent's response.
         | 
| 224 | 
            +
                    """
         | 
| 225 | 
            +
                    if markdown:
         | 
| 226 | 
            +
                        return Markdown(self.response["messages"][0].content)
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        return self.response["messages"][0].content
         | 
| 229 | 
            +
                        
         | 
| 230 | 
            +
                
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            def make_mlflow_tools_agent(
         | 
| 234 | 
            +
                model: Any,
         | 
| 235 | 
            +
                mlflow_tracking_uri: str=None,
         | 
| 236 | 
            +
                mlflow_registry_uri: str=None,
         | 
| 237 | 
            +
                **react_agent_kwargs,
         | 
| 238 | 
            +
            ):
         | 
| 239 | 
            +
                """
         | 
| 240 | 
            +
                MLflow Tool Calling Agent
         | 
| 241 | 
            +
                """
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
                try:
         | 
| 244 | 
            +
                    import mlflow
         | 
| 245 | 
            +
                except ImportError:
         | 
| 246 | 
            +
                    return "MLflow is not installed. Please install it by running: !pip install mlflow"
         | 
| 247 | 
            +
                
         | 
| 248 | 
            +
                if mlflow_tracking_uri is not None:
         | 
| 249 | 
            +
                    mlflow.set_tracking_uri(mlflow_tracking_uri)
         | 
| 250 | 
            +
                
         | 
| 251 | 
            +
                if mlflow_registry_uri is not None:
         | 
| 252 | 
            +
                    mlflow.set_registry_uri(mlflow_registry_uri)
         | 
| 253 | 
            +
                
         | 
| 254 | 
            +
                class GraphState(AgentState):
         | 
| 255 | 
            +
                    internal_messages: Annotated[Sequence[BaseMessage], operator.add]
         | 
| 256 | 
            +
                    user_instructions: str
         | 
| 257 | 
            +
                    data_raw: dict
         | 
| 258 | 
            +
                    mlflow_artifacts: dict
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                
         | 
| 261 | 
            +
                def mflfow_tools_agent(state):
         | 
| 262 | 
            +
                    """
         | 
| 263 | 
            +
                    Postprocesses the MLflow state, keeping only the last message
         | 
| 264 | 
            +
                    and extracting the last tool artifact.
         | 
| 265 | 
            +
                    """
         | 
| 266 | 
            +
                    print(format_agent_name(AGENT_NAME))
         | 
| 267 | 
            +
                    print("    * RUN REACT TOOL-CALLING AGENT")
         | 
| 268 | 
            +
                    
         | 
| 269 | 
            +
                    tool_node = ToolNode(
         | 
| 270 | 
            +
                        tools=tools
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    
         | 
| 273 | 
            +
                    mlflow_agent = create_react_agent(
         | 
| 274 | 
            +
                        model, 
         | 
| 275 | 
            +
                        tools=tool_node, 
         | 
| 276 | 
            +
                        state_schema=GraphState,
         | 
| 277 | 
            +
                        **react_agent_kwargs,
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    
         | 
| 280 | 
            +
                    response = mlflow_agent.invoke(
         | 
| 281 | 
            +
                        {
         | 
| 282 | 
            +
                            "messages": [("user", state["user_instructions"])],
         | 
| 283 | 
            +
                            "data_raw": state["data_raw"],
         | 
| 284 | 
            +
                        },
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
                    
         | 
| 287 | 
            +
                    print("    * POST-PROCESS RESULTS")
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    internal_messages = response['messages']
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # Ensure there is at least one AI message
         | 
| 292 | 
            +
                    if not internal_messages:
         | 
| 293 | 
            +
                        return {
         | 
| 294 | 
            +
                            "internal_messages": [],
         | 
| 295 | 
            +
                            "mlflow_artifacts": None,
         | 
| 296 | 
            +
                        }
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Get the last AI message
         | 
| 299 | 
            +
                    last_ai_message = AIMessage(internal_messages[-1].content, role = AGENT_NAME)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    # Get the last tool artifact safely
         | 
| 302 | 
            +
                    last_tool_artifact = None
         | 
| 303 | 
            +
                    if len(internal_messages) > 1:
         | 
| 304 | 
            +
                        last_message = internal_messages[-2]  # Get second-to-last message
         | 
| 305 | 
            +
                        if hasattr(last_message, "artifact"):  # Check if it has an "artifact"
         | 
| 306 | 
            +
                            last_tool_artifact = last_message.artifact
         | 
| 307 | 
            +
                        elif isinstance(last_message, dict) and "artifact" in last_message:
         | 
| 308 | 
            +
                            last_tool_artifact = last_message["artifact"]
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    return {
         | 
| 311 | 
            +
                        "messages": [last_ai_message], 
         | 
| 312 | 
            +
                        "internal_messages": internal_messages,
         | 
| 313 | 
            +
                        "mlflow_artifacts": last_tool_artifact,
         | 
| 314 | 
            +
                    }
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                
         | 
| 317 | 
            +
                workflow = StateGraph(GraphState)
         | 
| 318 | 
            +
                
         | 
| 319 | 
            +
                workflow.add_node("mlflow_tools_agent", mflfow_tools_agent)
         | 
| 320 | 
            +
                
         | 
| 321 | 
            +
                workflow.add_edge(START, "mlflow_tools_agent")
         | 
| 322 | 
            +
                workflow.add_edge("mlflow_tools_agent", END)
         | 
| 323 | 
            +
                
         | 
| 324 | 
            +
                app = workflow.compile()
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                return app
         | 
| 327 | 
            +
                
         | 
| @@ -1,24 +1,24 @@ | |
| 1 1 |  | 
| 2 2 | 
             
            from langchain_core.messages import BaseMessage
         | 
| 3 | 
            -
            from langgraph.checkpoint.memory import MemorySaver
         | 
| 4 3 | 
             
            from langgraph.types import Checkpointer
         | 
| 5 4 |  | 
| 6 5 | 
             
            from langgraph.graph import START, END, StateGraph
         | 
| 7 6 | 
             
            from langgraph.graph.state import CompiledStateGraph
         | 
| 8 7 | 
             
            from langgraph.types import Command
         | 
| 9 8 |  | 
| 10 | 
            -
            from typing import TypedDict, Annotated, Sequence
         | 
| 9 | 
            +
            from typing import TypedDict, Annotated, Sequence, Literal
         | 
| 11 10 | 
             
            import operator
         | 
| 12 11 |  | 
| 13 | 
            -
            from typing_extensions import TypedDict | 
| 12 | 
            +
            from typing_extensions import TypedDict
         | 
| 14 13 |  | 
| 15 14 | 
             
            import pandas as pd
         | 
| 15 | 
            +
            import json
         | 
| 16 16 | 
             
            from IPython.display import Markdown
         | 
| 17 17 |  | 
| 18 18 | 
             
            from ai_data_science_team.templates import BaseAgent
         | 
| 19 19 | 
             
            from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
         | 
| 20 20 | 
             
            from ai_data_science_team.utils.plotly import plotly_from_dict
         | 
| 21 | 
            -
             | 
| 21 | 
            +
            from ai_data_science_team.utils.regex import remove_consecutive_duplicates, get_generic_summary
         | 
| 22 22 |  | 
| 23 23 |  | 
| 24 24 | 
             
            class SQLDataAnalyst(BaseAgent):
         | 
| @@ -90,7 +90,7 @@ class SQLDataAnalyst(BaseAgent): | |
| 90 90 | 
             
                        self._params[k] = v
         | 
| 91 91 | 
             
                    self._compiled_graph = self._make_compiled_graph()
         | 
| 92 92 |  | 
| 93 | 
            -
                def ainvoke_agent(self, user_instructions, **kwargs):
         | 
| 93 | 
            +
                async def ainvoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
         | 
| 94 94 | 
             
                    """
         | 
| 95 95 | 
             
                    Asynchronosly nvokes the SQL Data Analyst Multi-Agent.
         | 
| 96 96 |  | 
| @@ -108,15 +108,53 @@ class SQLDataAnalyst(BaseAgent): | |
| 108 108 | 
             
                    Example:
         | 
| 109 109 | 
             
                    --------
         | 
| 110 110 | 
             
                    ``` python
         | 
| 111 | 
            -
                     | 
| 111 | 
            +
                    from langchain_openai import ChatOpenAI
         | 
| 112 | 
            +
                    import sqlalchemy as sql
         | 
| 113 | 
            +
                    from ai_data_science_team.multiagents import SQLDataAnalyst
         | 
| 114 | 
            +
                    from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    llm = ChatOpenAI(model = "gpt-4o-mini")
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                    sql_engine = sql.create_engine("sqlite:///data/northwind.db")
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    conn = sql_engine.connect()
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    sql_data_analyst = SQLDataAnalyst(
         | 
| 123 | 
            +
                        model = llm,
         | 
| 124 | 
            +
                        sql_database_agent = SQLDatabaseAgent(
         | 
| 125 | 
            +
                            model = llm,
         | 
| 126 | 
            +
                            connection = conn,
         | 
| 127 | 
            +
                            n_samples = 1,
         | 
| 128 | 
            +
                        ),
         | 
| 129 | 
            +
                        data_visualization_agent = DataVisualizationAgent(
         | 
| 130 | 
            +
                            model = llm,
         | 
| 131 | 
            +
                            n_samples = 10,
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    sql_data_analyst.ainvoke_agent(
         | 
| 136 | 
            +
                        user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    sql_data_analyst.get_sql_query_code()
         | 
| 140 | 
            +
                    
         | 
| 141 | 
            +
                    sql_data_analyst.get_data_sql()
         | 
| 142 | 
            +
                    
         | 
| 143 | 
            +
                    sql_data_analyst.get_plotly_graph()
         | 
| 112 144 | 
             
                    ```
         | 
| 113 145 | 
             
                    """
         | 
| 114 | 
            -
                    response = self._compiled_graph.ainvoke({
         | 
| 146 | 
            +
                    response = await self._compiled_graph.ainvoke({
         | 
| 115 147 | 
             
                        "user_instructions": user_instructions,
         | 
| 148 | 
            +
                        "max_retries": max_retries,
         | 
| 149 | 
            +
                        "retry_count": retry_count,
         | 
| 116 150 | 
             
                    }, **kwargs)
         | 
| 151 | 
            +
                    
         | 
| 152 | 
            +
                    if response.get("messages"):
         | 
| 153 | 
            +
                        response["messages"] = remove_consecutive_duplicates(response["messages"])
         | 
| 154 | 
            +
                    
         | 
| 117 155 | 
             
                    self.response = response
         | 
| 118 156 |  | 
| 119 | 
            -
                def invoke_agent(self, user_instructions, **kwargs):
         | 
| 157 | 
            +
                def invoke_agent(self, user_instructions, max_retries:int=3, retry_count:int=0, **kwargs):
         | 
| 120 158 | 
             
                    """
         | 
| 121 159 | 
             
                    Invokes the SQL Data Analyst Multi-Agent.
         | 
| 122 160 |  | 
| @@ -124,6 +162,10 @@ class SQLDataAnalyst(BaseAgent): | |
| 124 162 | 
             
                    ----------
         | 
| 125 163 | 
             
                    user_instructions: str
         | 
| 126 164 | 
             
                        The user's instructions for the combined SQL and (optionally) Data Visualization agents.
         | 
| 165 | 
            +
                    max_retries (int): 
         | 
| 166 | 
            +
                            Maximum retry attempts for cleaning.
         | 
| 167 | 
            +
                    retry_count (int): 
         | 
| 168 | 
            +
                        Current retry attempt.
         | 
| 127 169 | 
             
                    **kwargs:
         | 
| 128 170 | 
             
                        Additional keyword arguments to pass to the compiled graph's `invoke` method.
         | 
| 129 171 |  | 
| @@ -134,14 +176,53 @@ class SQLDataAnalyst(BaseAgent): | |
| 134 176 | 
             
                    Example:
         | 
| 135 177 | 
             
                    --------
         | 
| 136 178 | 
             
                    ``` python
         | 
| 137 | 
            -
                     | 
| 179 | 
            +
                    from langchain_openai import ChatOpenAI
         | 
| 180 | 
            +
                    import sqlalchemy as sql
         | 
| 181 | 
            +
                    from ai_data_science_team.multiagents import SQLDataAnalyst
         | 
| 182 | 
            +
                    from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
         | 
| 183 | 
            +
                    
         | 
| 184 | 
            +
                    llm = ChatOpenAI(model = "gpt-4o-mini")
         | 
| 185 | 
            +
                    
         | 
| 186 | 
            +
                    sql_engine = sql.create_engine("sqlite:///data/northwind.db")
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    conn = sql_engine.connect()
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                    sql_data_analyst = SQLDataAnalyst(
         | 
| 191 | 
            +
                        model = llm,
         | 
| 192 | 
            +
                        sql_database_agent = SQLDatabaseAgent(
         | 
| 193 | 
            +
                            model = llm,
         | 
| 194 | 
            +
                            connection = conn,
         | 
| 195 | 
            +
                            n_samples = 1,
         | 
| 196 | 
            +
                        ),
         | 
| 197 | 
            +
                        data_visualization_agent = DataVisualizationAgent(
         | 
| 198 | 
            +
                            model = llm,
         | 
| 199 | 
            +
                            n_samples = 10,
         | 
| 200 | 
            +
                        )
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
                    
         | 
| 203 | 
            +
                    sql_data_analyst.invoke_agent(
         | 
| 204 | 
            +
                        user_instructions = "Make a plot of sales revenue by month by territory. Make a dropdown for the user to select the territory.",
         | 
| 205 | 
            +
                    )
         | 
| 206 | 
            +
                    
         | 
| 207 | 
            +
                    sql_data_analyst.get_sql_query_code()
         | 
| 208 | 
            +
                    
         | 
| 209 | 
            +
                    sql_data_analyst.get_data_sql()
         | 
| 210 | 
            +
                    
         | 
| 211 | 
            +
                    sql_data_analyst.get_plotly_graph()
         | 
| 138 212 | 
             
                    ```
         | 
| 139 213 | 
             
                    """
         | 
| 140 214 | 
             
                    response = self._compiled_graph.invoke({
         | 
| 141 215 | 
             
                        "user_instructions": user_instructions,
         | 
| 216 | 
            +
                        "max_retries": max_retries,
         | 
| 217 | 
            +
                        "retry_count": retry_count,
         | 
| 142 218 | 
             
                    }, **kwargs)
         | 
| 219 | 
            +
                    
         | 
| 220 | 
            +
                    if response.get("messages"):
         | 
| 221 | 
            +
                        response["messages"] = remove_consecutive_duplicates(response["messages"])
         | 
| 222 | 
            +
                    
         | 
| 143 223 | 
             
                    self.response = response
         | 
| 144 224 |  | 
| 225 | 
            +
                    
         | 
| 145 226 | 
             
                def get_data_sql(self):
         | 
| 146 227 | 
             
                    """
         | 
| 147 228 | 
             
                    Returns the SQL data as a Pandas DataFrame.
         | 
| @@ -205,6 +286,34 @@ class SQLDataAnalyst(BaseAgent): | |
| 205 286 | 
             
                            if markdown:
         | 
| 206 287 | 
             
                                return Markdown(f"```python\n{self.response.get('data_visualization_function')}\n```")
         | 
| 207 288 | 
             
                            return self.response.get("data_visualization_function")
         | 
| 289 | 
            +
                        
         | 
| 290 | 
            +
                def get_workflow_summary(self, markdown=False):
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    Returns a summary of the SQL Data Analyst workflow.
         | 
| 293 | 
            +
                    
         | 
| 294 | 
            +
                    Parameters:
         | 
| 295 | 
            +
                    ----------
         | 
| 296 | 
            +
                    markdown: bool
         | 
| 297 | 
            +
                        If True, returns the summary as a Markdown-formatted string.
         | 
| 298 | 
            +
                    """
         | 
| 299 | 
            +
                    if self.response and self.get_response()['messages']:
         | 
| 300 | 
            +
                        
         | 
| 301 | 
            +
                        agents = [self.get_response()['messages'][i].role for i in range(len(self.get_response()['messages']))]
         | 
| 302 | 
            +
                        
         | 
| 303 | 
            +
                        agent_labels = []
         | 
| 304 | 
            +
                        for i in range(len(agents)):
         | 
| 305 | 
            +
                            agent_labels.append(f"- **Agent {i+1}:** {agents[i]}")
         | 
| 306 | 
            +
                        
         | 
| 307 | 
            +
                        # Construct header
         | 
| 308 | 
            +
                        header = f"# SQL Data Analyst Workflow Summary Report\n\nThis agentic workflow contains {len(agents)} agents:\n\n" + "\n".join(agent_labels)
         | 
| 309 | 
            +
                        
         | 
| 310 | 
            +
                        reports = []
         | 
| 311 | 
            +
                        for msg in self.get_response()['messages']:
         | 
| 312 | 
            +
                            reports.append(get_generic_summary(json.loads(msg.content)))
         | 
| 313 | 
            +
                            
         | 
| 314 | 
            +
                        if markdown:
         | 
| 315 | 
            +
                            return Markdown(header + "\n\n".join(reports))
         | 
| 316 | 
            +
                        return "\n\n".join(reports)
         | 
| 208 317 |  | 
| 209 318 |  | 
| 210 319 |  | 
| @@ -250,6 +359,8 @@ def make_sql_data_analyst( | |
| 250 359 | 
             
                    plot_required: bool
         | 
| 251 360 | 
             
                    data_visualization_function: str
         | 
| 252 361 | 
             
                    plotly_graph: dict
         | 
| 362 | 
            +
                    max_retries: int
         | 
| 363 | 
            +
                    retry_count: int
         | 
| 253 364 |  | 
| 254 365 | 
             
                def route_to_visualization(state) -> Command[Literal["data_visualization_agent", "__end__"]]: 
         | 
| 255 366 |  | 
| 
            File without changes
         | 
| @@ -8,11 +8,16 @@ from langgraph.pregel.types import StreamMode | |
| 8 8 |  | 
| 9 9 | 
             
            import pandas as pd
         | 
| 10 10 | 
             
            import sqlalchemy as sql
         | 
| 11 | 
            +
            import json
         | 
| 11 12 |  | 
| 12 | 
            -
            from typing import Any, Callable, Dict, Type, Optional, Union
         | 
| 13 | 
            +
            from typing import Any, Callable, Dict, Type, Optional, Union, List
         | 
| 13 14 |  | 
| 14 | 
            -
            from ai_data_science_team. | 
| 15 | 
            -
            from ai_data_science_team. | 
| 15 | 
            +
            from ai_data_science_team.parsers.parsers import PythonOutputParser
         | 
| 16 | 
            +
            from ai_data_science_team.utils.regex import (
         | 
| 17 | 
            +
                relocate_imports_inside_function, 
         | 
| 18 | 
            +
                add_comments_to_top,
         | 
| 19 | 
            +
                remove_consecutive_duplicates
         | 
| 20 | 
            +
            )
         | 
| 16 21 |  | 
| 17 22 | 
             
            from IPython.display import Image, display
         | 
| 18 23 | 
             
            import pandas as pd
         | 
| @@ -82,9 +87,13 @@ class BaseAgent(CompiledStateGraph): | |
| 82 87 | 
             
                        Any: The agent's response.
         | 
| 83 88 | 
             
                    """
         | 
| 84 89 | 
             
                    self.response = self._compiled_graph.invoke(input=input, config=config,**kwargs)
         | 
| 90 | 
            +
                    
         | 
| 91 | 
            +
                    if self.response.get("messages"):
         | 
| 92 | 
            +
                        self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
         | 
| 93 | 
            +
                    
         | 
| 85 94 | 
             
                    return self.response
         | 
| 86 95 |  | 
| 87 | 
            -
                def ainvoke(
         | 
| 96 | 
            +
                async def ainvoke(
         | 
| 88 97 | 
             
                    self, 
         | 
| 89 98 | 
             
                    input: Union[dict[str, Any], Any], 
         | 
| 90 99 | 
             
                    config: Optional[RunnableConfig] = None, 
         | 
| @@ -101,7 +110,11 @@ class BaseAgent(CompiledStateGraph): | |
| 101 110 | 
             
                    Returns:
         | 
| 102 111 | 
             
                        Any: The agent's response.
         | 
| 103 112 | 
             
                    """
         | 
| 104 | 
            -
                    self.response = self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
         | 
| 113 | 
            +
                    self.response = await self._compiled_graph.ainvoke(input=input, config=config,**kwargs)
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    if self.response.get("messages"):
         | 
| 116 | 
            +
                        self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
         | 
| 117 | 
            +
                    
         | 
| 105 118 | 
             
                    return self.response
         | 
| 106 119 |  | 
| 107 120 | 
             
                def stream(
         | 
| @@ -129,9 +142,13 @@ class BaseAgent(CompiledStateGraph): | |
| 129 142 | 
             
                        Any: The agent's response.
         | 
| 130 143 | 
             
                    """
         | 
| 131 144 | 
             
                    self.response = self._compiled_graph.stream(input=input, config=config, stream_mode=stream_mode, **kwargs)
         | 
| 145 | 
            +
                    
         | 
| 146 | 
            +
                    if self.response.get("messages"):
         | 
| 147 | 
            +
                        self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])        
         | 
| 148 | 
            +
                    
         | 
| 132 149 | 
             
                    return self.response
         | 
| 133 150 |  | 
| 134 | 
            -
                def astream(
         | 
| 151 | 
            +
                async def astream(
         | 
| 135 152 | 
             
                    self,
         | 
| 136 153 | 
             
                    input: dict[str, Any] | Any,
         | 
| 137 154 | 
             
                    config: RunnableConfig | None = None,
         | 
| @@ -155,7 +172,11 @@ class BaseAgent(CompiledStateGraph): | |
| 155 172 | 
             
                    Returns:
         | 
| 156 173 | 
             
                        Any: The agent's response.
         | 
| 157 174 | 
             
                    """
         | 
| 158 | 
            -
                    self.response = self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
         | 
| 175 | 
            +
                    self.response = await self._compiled_graph.astream(input=input, config=config, stream_mode=stream_mode, **kwargs)
         | 
| 176 | 
            +
                    
         | 
| 177 | 
            +
                    if self.response.get("messages"):
         | 
| 178 | 
            +
                        self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])
         | 
| 179 | 
            +
                    
         | 
| 159 180 | 
             
                    return self.response
         | 
| 160 181 |  | 
| 161 182 | 
             
                def get_state_keys(self):
         | 
| @@ -183,6 +204,9 @@ class BaseAgent(CompiledStateGraph): | |
| 183 204 | 
             
                    Returns:
         | 
| 184 205 | 
             
                        Any: The agent's response.
         | 
| 185 206 | 
             
                    """
         | 
| 207 | 
            +
                    if self.response.get("messages"):
         | 
| 208 | 
            +
                        self.response["messages"] = remove_consecutive_duplicates(self.response["messages"])  
         | 
| 209 | 
            +
                    
         | 
| 186 210 | 
             
                    return self.response
         | 
| 187 211 |  | 
| 188 212 | 
             
                def show(self, xray: int = 0):
         | 
| @@ -729,3 +753,50 @@ def node_func_explain_agent_code( | |
| 729 753 | 
             
                    # Return an error message if there was a problem with the code
         | 
| 730 754 | 
             
                    message = AIMessage(content=error_message)
         | 
| 731 755 | 
             
                    return {result_key: [message]}
         | 
| 756 | 
            +
             | 
| 757 | 
            +
             | 
| 758 | 
            +
             | 
| 759 | 
            +
            def node_func_report_agent_outputs(
         | 
| 760 | 
            +
                state: Dict[str, Any],
         | 
| 761 | 
            +
                keys_to_include: List[str],
         | 
| 762 | 
            +
                result_key: str,
         | 
| 763 | 
            +
                role: str,
         | 
| 764 | 
            +
                custom_title: str = "Agent Output Summary"
         | 
| 765 | 
            +
            ) -> Dict[str, Any]:
         | 
| 766 | 
            +
                """
         | 
| 767 | 
            +
                Gathers relevant data directly from the state (filtered by `keys_to_include`) 
         | 
| 768 | 
            +
                and returns them as a structured message in `state[result_key]`.
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                No LLM is used.
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                Parameters
         | 
| 773 | 
            +
                ----------
         | 
| 774 | 
            +
                state : Dict[str, Any]
         | 
| 775 | 
            +
                    The current state dictionary holding all agent variables.
         | 
| 776 | 
            +
                keys_to_include : List[str]
         | 
| 777 | 
            +
                    The list of keys in `state` to include in the output.
         | 
| 778 | 
            +
                result_key : str
         | 
| 779 | 
            +
                    The key in `state` under which we'll store the final structured message.
         | 
| 780 | 
            +
                role : str
         | 
| 781 | 
            +
                    The role that will be used in the final AIMessage (e.g., "DataCleaningAgent").
         | 
| 782 | 
            +
                custom_title : str, optional
         | 
| 783 | 
            +
                    A title or heading for your report. Defaults to "Agent Output Summary".
         | 
| 784 | 
            +
                """
         | 
| 785 | 
            +
                print("    * REPORT AGENT OUTPUTS")
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                final_report = {"report_title": custom_title}
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                for key in keys_to_include:
         | 
| 790 | 
            +
                    final_report[key] = state.get(key, f"<{key}_not_found_in_state>")
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                # Wrap it in a list of messages (like the current "messages" pattern).
         | 
| 793 | 
            +
                # You can serialize this dictionary as JSON or just cast it to string.
         | 
| 794 | 
            +
                return {
         | 
| 795 | 
            +
                    result_key: [
         | 
| 796 | 
            +
                        AIMessage(
         | 
| 797 | 
            +
                            content=json.dumps(final_report, indent=2), 
         | 
| 798 | 
            +
                            role=role
         | 
| 799 | 
            +
                        )
         | 
| 800 | 
            +
                    ]
         | 
| 801 | 
            +
                }
         | 
| 802 | 
            +
             |