lionagi 0.0.206__py3-none-any.whl → 0.0.208__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/_services/ollama.py +2 -2
- lionagi/core/branch/branch.py +517 -265
- lionagi/core/branch/branch_manager.py +0 -1
- lionagi/core/branch/conversation.py +640 -337
- lionagi/core/core_util.py +0 -59
- lionagi/core/sessions/session.py +137 -64
- lionagi/tools/tool_manager.py +39 -62
- lionagi/utils/__init__.py +3 -2
- lionagi/utils/call_util.py +9 -7
- lionagi/utils/sys_util.py +287 -255
- lionagi/version.py +1 -1
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/METADATA +1 -1
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/RECORD +16 -17
- lionagi/utils/pd_util.py +0 -57
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/LICENSE +0 -0
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/WHEEL +0 -0
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/top_level.txt +0 -0
    
        lionagi/core/core_util.py
    CHANGED
    
    | @@ -1,59 +0,0 @@ | |
| 1 | 
            -
            import json
         | 
| 2 | 
            -
            from ..utils.sys_util import strip_lower
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            def sign_message(messages, sender: str):
         | 
| 6 | 
            -
                """
         | 
| 7 | 
            -
                Sign messages with a sender identifier.
         | 
| 8 | 
            -
             | 
| 9 | 
            -
                Args:
         | 
| 10 | 
            -
                    messages (pd.DataFrame): A DataFrame containing messages with columns 'node_id', 'role', 'sender', 'timestamp', and 'content'.
         | 
| 11 | 
            -
                    sender (str): The sender identifier to be added to the messages.
         | 
| 12 | 
            -
             | 
| 13 | 
            -
                Returns:
         | 
| 14 | 
            -
                    pd.DataFrame: A new DataFrame with the sender identifier added to each message.
         | 
| 15 | 
            -
             | 
| 16 | 
            -
                Raises:
         | 
| 17 | 
            -
                    ValueError: If the 'sender' is None or 'None'.
         | 
| 18 | 
            -
                """
         | 
| 19 | 
            -
                if sender is None or strip_lower(sender) == 'none':
         | 
| 20 | 
            -
                    raise ValueError("sender cannot be None")
         | 
| 21 | 
            -
                df = messages.copy()
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                for i in df.index:
         | 
| 24 | 
            -
                    if not df.loc[i, 'content'].startswith('Sender'):
         | 
| 25 | 
            -
                        df.loc[i, 'content'] = f"Sender {sender}: {df.loc[i, 'content']}"
         | 
| 26 | 
            -
                    else:
         | 
| 27 | 
            -
                        content = df.loc[i, 'content'].split(':', 1)[1]
         | 
| 28 | 
            -
                        df.loc[i, 'content'] = f"Sender {sender}: {content}"
         | 
| 29 | 
            -
                return df
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
            def validate_messages(messages):
         | 
| 33 | 
            -
                """
         | 
| 34 | 
            -
                Validate the structure and content of a messages DataFrame.
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                Args:
         | 
| 37 | 
            -
                    messages (pd.DataFrame): A DataFrame containing messages with columns 'node_id', 'role', 'sender', 'timestamp', and 'content'.
         | 
| 38 | 
            -
             | 
| 39 | 
            -
                Returns:
         | 
| 40 | 
            -
                    bool: True if the messages DataFrame is valid; otherwise, raises a ValueError.
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                Raises:
         | 
| 43 | 
            -
                    ValueError: If the DataFrame structure is invalid or if it contains null values, roles other than ["system", "user", "assistant"],
         | 
| 44 | 
            -
                                or content that cannot be parsed as JSON strings.
         | 
| 45 | 
            -
                """
         | 
| 46 | 
            -
                if list(messages.columns) != ['node_id', 'role', 'sender', 'timestamp', 'content']:
         | 
| 47 | 
            -
                    raise ValueError('Invalid messages dataframe. Unmatched columns.')
         | 
| 48 | 
            -
                if messages.isnull().values.any():
         | 
| 49 | 
            -
                    raise ValueError('Invalid messages dataframe. Cannot have null.')
         | 
| 50 | 
            -
                if not all(role in ['system', 'user', 'assistant'] for role in messages['role'].unique()):
         | 
| 51 | 
            -
                    raise ValueError('Invalid messages dataframe. Cannot have role other than ["system", "user", "assistant"].')
         | 
| 52 | 
            -
                for cont in messages['content']:
         | 
| 53 | 
            -
                    if cont.startswith('Sender'):
         | 
| 54 | 
            -
                        cont = cont.split(':', 1)[1]
         | 
| 55 | 
            -
                    try:
         | 
| 56 | 
            -
                        json.loads(cont)
         | 
| 57 | 
            -
                    except:
         | 
| 58 | 
            -
                        raise ValueError('Invalid messages dataframe. Content expect json string.')
         | 
| 59 | 
            -
                return True
         | 
    
        lionagi/core/sessions/session.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 1 | 
             
            import pandas as pd
         | 
| 2 | 
            +
             | 
| 2 3 | 
             
            from typing import Any, List, Union, Dict, Optional, Callable, Tuple
         | 
| 3 4 | 
             
            from dotenv import load_dotenv
         | 
| 4 5 |  | 
| @@ -8,30 +9,26 @@ from ..messages.messages import System, Instruction | |
| 8 9 | 
             
            from ..branch.branch import Branch
         | 
| 9 10 | 
             
            from ..branch.branch_manager import BranchManager
         | 
| 10 11 |  | 
| 11 | 
            -
             | 
| 12 12 | 
             
            load_dotenv()
         | 
| 13 13 |  | 
| 14 14 |  | 
| 15 15 | 
             
            class Session:
         | 
| 16 16 | 
             
                """
         | 
| 17 | 
            -
                 | 
| 17 | 
            +
                Manages sessions with conversation branches, tool management, and interaction logging.
         | 
| 18 18 |  | 
| 19 | 
            -
                This class  | 
| 20 | 
            -
                messages, instruction sets, and tools. It also handles logging and interactions with an external service.
         | 
| 19 | 
            +
                This class orchestrates the handling of different conversation branches, enabling distinct conversational contexts to coexist within a single session. It facilitates the integration with external services for processing chat completions, tool management, and the logging of session activities.
         | 
| 21 20 |  | 
| 22 21 | 
             
                Attributes:
         | 
| 23 | 
            -
                    branches (Dict[str, Branch]):  | 
| 24 | 
            -
                    default_branch (Branch): The  | 
| 25 | 
            -
                    default_branch_name (str):  | 
| 26 | 
            -
                    llmconfig (Dict[str, Any]):  | 
| 27 | 
            -
                     | 
| 28 | 
            -
                    service (OpenAIService): Service used for handling chat completions and other operations.
         | 
| 22 | 
            +
                    branches (Dict[str, Branch]): Maps branch names to Branch instances.
         | 
| 23 | 
            +
                    default_branch (Branch): The primary branch for the session.
         | 
| 24 | 
            +
                    default_branch_name (str): Identifier for the default branch.
         | 
| 25 | 
            +
                    llmconfig (Dict[str, Any]): Configurations for language model interactions.
         | 
| 26 | 
            +
                    service (OpenAIService): Interface for external service interactions.
         | 
| 29 27 | 
             
                """
         | 
| 30 28 | 
             
                def __init__(
         | 
| 31 29 | 
             
                    self,
         | 
| 32 30 | 
             
                    system: Optional[Union[str, System]] = None,
         | 
| 33 31 | 
             
                    sender: Optional[str] = None,
         | 
| 34 | 
            -
                    dir: Optional[str] = None,
         | 
| 35 32 | 
             
                    llmconfig: Optional[Dict[str, Any]] = None,
         | 
| 36 33 | 
             
                    service: OpenAIService = None,
         | 
| 37 34 | 
             
                    branches: Optional[Dict[str, Branch]] = None,
         | 
| @@ -39,36 +36,26 @@ class Session: | |
| 39 36 | 
             
                    default_branch_name: str = 'main',
         | 
| 40 37 | 
             
                ):
         | 
| 41 38 | 
             
                    """
         | 
| 42 | 
            -
                     | 
| 39 | 
            +
                    Initializes a session with optional settings for branches, service, and language model configurations.
         | 
| 43 40 |  | 
| 44 41 | 
             
                    Args:
         | 
| 45 | 
            -
                        system (Union[str, System]): Initial system message or  | 
| 46 | 
            -
                         | 
| 47 | 
            -
                        llmconfig (Dict[str, Any] | 
| 48 | 
            -
                        service (OpenAIService | 
| 49 | 
            -
                        branches (Dict[str, Branch] | 
| 50 | 
            -
                        default_branch (Branch | 
| 51 | 
            -
                        default_branch_name (str | 
| 42 | 
            +
                        system (Optional[Union[str, System]]): Initial system message or configuration.
         | 
| 43 | 
            +
                        sender (Optional[str]): Identifier for the sender of the system message.
         | 
| 44 | 
            +
                        llmconfig (Optional[Dict[str, Any]]): Language model configuration settings.
         | 
| 45 | 
            +
                        service (OpenAIService): External service for chat completions and other operations.
         | 
| 46 | 
            +
                        branches (Optional[Dict[str, Branch]]): Predefined conversation branches.
         | 
| 47 | 
            +
                        default_branch (Optional[Branch]): Preselected default branch for the session.
         | 
| 48 | 
            +
                        default_branch_name (str): Name for the default branch, defaults to 'main'.
         | 
| 52 49 | 
             
                    """
         | 
| 53 50 |  | 
| 54 51 | 
             
                    self.branches = branches if isinstance(branches, dict) else {}
         | 
| 55 52 | 
             
                    if service is None:
         | 
| 56 53 | 
             
                        service = OpenAIService()
         | 
| 57 | 
            -
                    
         | 
| 58 | 
            -
                    self.default_branch = default_branch if default_branch else Branch(name=default_branch_name, service=service, llmconfig=llmconfig)
         | 
| 59 | 
            -
                    self.default_branch_name = default_branch_name
         | 
| 60 | 
            -
                    if system:
         | 
| 61 | 
            -
                        self.default_branch.add_message(system=system, sender=sender)
         | 
| 62 | 
            -
                    if self.branches:
         | 
| 63 | 
            -
                        if self.default_branch_name not in self.branches.keys():
         | 
| 64 | 
            -
                            raise ValueError('default branch name is not in imported branches')
         | 
| 65 | 
            -
                        if self.default_branch is not self.branches[self.default_branch_name]:
         | 
| 66 | 
            -
                            raise ValueError(f'default branch does not match Branch object under {self.default_branch_name}')
         | 
| 67 | 
            -
                    if not self.branches:
         | 
| 68 | 
            -
                        self.branches[self.default_branch_name] = self.default_branch
         | 
| 69 | 
            -
                    if dir:
         | 
| 70 | 
            -
                        self.default_branch.dir = dir
         | 
| 71 54 |  | 
| 55 | 
            +
                    self._setup_default_branch(
         | 
| 56 | 
            +
                        default_branch, default_branch_name, service, llmconfig, system, sender)
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    self._verify_default_branch()
         | 
| 72 59 | 
             
                    self.branch_manager = BranchManager(self.branches)
         | 
| 73 60 |  | 
| 74 61 | 
             
                def new_branch(
         | 
| @@ -83,20 +70,20 @@ class Session: | |
| 83 70 | 
             
                    llmconfig: Optional[Dict] = None,
         | 
| 84 71 | 
             
                ) -> None:
         | 
| 85 72 | 
             
                    """
         | 
| 86 | 
            -
                     | 
| 73 | 
            +
                    Creates a new branch within the session.
         | 
| 87 74 |  | 
| 88 75 | 
             
                    Args:
         | 
| 89 | 
            -
                        branch_name (str): Name  | 
| 90 | 
            -
                        dir (str | 
| 91 | 
            -
                        messages (Optional[pd.DataFrame]):  | 
| 92 | 
            -
                         | 
| 93 | 
            -
                         | 
| 94 | 
            -
                        sender (Optional[str] | 
| 95 | 
            -
                        service (OpenAIService | 
| 96 | 
            -
                        llmconfig (Dict[str, Any] | 
| 76 | 
            +
                        branch_name (str): Name for the new branch.
         | 
| 77 | 
            +
                        dir (Optional[str]): Path for storing branch-related logs.
         | 
| 78 | 
            +
                        messages (Optional[pd.DataFrame]): Initial set of messages for the branch.
         | 
| 79 | 
            +
                        tools (Optional[Union[Tool, List[Tool]]]): Tools to register in the new branch.
         | 
| 80 | 
            +
                        system (Optional[Union[str, System]]): System message or configuration for the branch.
         | 
| 81 | 
            +
                        sender (Optional[str]): Identifier for the sender of the initial message.
         | 
| 82 | 
            +
                        service (Optional[OpenAIService]): Service interface specific to the branch.
         | 
| 83 | 
            +
                        llmconfig (Optional[Dict[str, Any]]): Language model configurations for the branch.
         | 
| 97 84 |  | 
| 98 85 | 
             
                    Raises:
         | 
| 99 | 
            -
                        ValueError: If the branch name already exists  | 
| 86 | 
            +
                        ValueError: If the branch name already exists within the session.
         | 
| 100 87 | 
             
                    """
         | 
| 101 88 | 
             
                    if branch_name in self.branches.keys():
         | 
| 102 89 | 
             
                        raise ValueError(f'Invalid new branch name {branch_name}. Already existed.')
         | 
| @@ -116,18 +103,19 @@ class Session: | |
| 116 103 | 
             
                    get_name: bool = False
         | 
| 117 104 | 
             
                ) -> Union[Branch, Tuple[Branch, str]]:
         | 
| 118 105 | 
             
                    """
         | 
| 119 | 
            -
                     | 
| 106 | 
            +
                    Retrieves a branch from the session by name or as a Branch object.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    If no branch is specified, returns the default branch. Optionally, can also return the branch's name.
         | 
| 120 109 |  | 
| 121 110 | 
             
                    Args:
         | 
| 122 | 
            -
                        branch (Optional[Union[Branch, str]] | 
| 123 | 
            -
             | 
| 124 | 
            -
                        get_name (bool, optional): If True, returns the name of the branch along with the branch object.
         | 
| 111 | 
            +
                        branch (Optional[Union[Branch, str]]): The branch name or Branch object to retrieve. Defaults to None, which refers to the default branch.
         | 
| 112 | 
            +
                        get_name (bool): If True, also returns the name of the branch alongside the Branch object.
         | 
| 125 113 |  | 
| 126 114 | 
             
                    Returns:
         | 
| 127 | 
            -
                        Union[Branch, Tuple[Branch, str]]: The  | 
| 115 | 
            +
                        Union[Branch, Tuple[Branch, str]]: The requested Branch object, or a tuple of the Branch object and its name if `get_name` is True.
         | 
| 128 116 |  | 
| 129 117 | 
             
                    Raises:
         | 
| 130 | 
            -
                        ValueError: If the branch does not exist  | 
| 118 | 
            +
                        ValueError: If the specified branch does not exist within the session.
         | 
| 131 119 | 
             
                    """
         | 
| 132 120 | 
             
                    if isinstance(branch, str):
         | 
| 133 121 | 
             
                        if branch not in self.branches.keys():
         | 
| @@ -152,10 +140,10 @@ class Session: | |
| 152 140 |  | 
| 153 141 | 
             
                def change_default(self, branch: Union[str, Branch]) -> None:
         | 
| 154 142 | 
             
                    """
         | 
| 155 | 
            -
                     | 
| 143 | 
            +
                    Changes the default branch of the session.
         | 
| 156 144 |  | 
| 157 145 | 
             
                    Args:
         | 
| 158 | 
            -
                        branch (Union[str, Branch]): The branch or  | 
| 146 | 
            +
                        branch (Union[str, Branch]): The branch name or Branch object to set as the new default branch.
         | 
| 159 147 | 
             
                    """
         | 
| 160 148 | 
             
                    branch_, name_ = self.get_branch(branch, get_name=True)
         | 
| 161 149 | 
             
                    self.default_branch = branch_
         | 
| @@ -265,6 +253,45 @@ class Session: | |
| 265 253 | 
             
                        instruction=instruction, system=system, context=context,
         | 
| 266 254 | 
             
                        out=out, sender=sender, invoke=invoke, tools=tools, **kwargs)
         | 
| 267 255 |  | 
| 256 | 
            +
                async def ReAct(
         | 
| 257 | 
            +
                    self,
         | 
| 258 | 
            +
                    instruction: Union[Instruction, str],
         | 
| 259 | 
            +
                    context = None,
         | 
| 260 | 
            +
                    sender = None,
         | 
| 261 | 
            +
                    to_ = None,
         | 
| 262 | 
            +
                    system = None,
         | 
| 263 | 
            +
                    tools = None, 
         | 
| 264 | 
            +
                    num_rounds: int = 1,
         | 
| 265 | 
            +
                    fallback: Optional[Callable] = None,
         | 
| 266 | 
            +
                    fallback_kwargs: Optional[Dict] = None,
         | 
| 267 | 
            +
                    out=True,
         | 
| 268 | 
            +
                    **kwargs  
         | 
| 269 | 
            +
                ):
         | 
| 270 | 
            +
                    """
         | 
| 271 | 
            +
                    Performs a sequence of reasoning and action steps in a specified or default branch.
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    Args:
         | 
| 274 | 
            +
                        instruction (Union[Instruction, str]): Instruction to initiate the ReAct process.
         | 
| 275 | 
            +
                        context: Additional context for reasoning and action. Defaults to None.
         | 
| 276 | 
            +
                        sender: Identifier for the sender. Defaults to None.
         | 
| 277 | 
            +
                        to_: Target branch name or object for ReAct. Defaults to the default branch.
         | 
| 278 | 
            +
                        system: System message or configuration. Defaults to None.
         | 
| 279 | 
            +
                        tools: Tools to be used for actions. Defaults to None.
         | 
| 280 | 
            +
                        num_rounds (int): Number of reasoning-action cycles. Defaults to 1.
         | 
| 281 | 
            +
                        fallback (Optional[Callable]): Fallback function in case of an error. Defaults to None.
         | 
| 282 | 
            +
                        fallback_kwargs (Optional[Dict]): Arguments for the fallback function. Defaults to None.
         | 
| 283 | 
            +
                        out (bool): If True, outputs the result of the ReAct process. Defaults to True.
         | 
| 284 | 
            +
                        **kwargs: Arbitrary keyword arguments for additional customization.
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    Returns:
         | 
| 287 | 
            +
                        The outcome of the ReAct process, depending on the specified branch and instructions.
         | 
| 288 | 
            +
                    """
         | 
| 289 | 
            +
                    branch = self.get_branch(to_)
         | 
| 290 | 
            +
                    return await branch.ReAct(
         | 
| 291 | 
            +
                        instruction=instruction, context=context, sender=sender, system=system, tools=tools, 
         | 
| 292 | 
            +
                        num_rounds=num_rounds, fallback=fallback, fallback_kwargs=fallback_kwargs, 
         | 
| 293 | 
            +
                        out=out, **kwargs
         | 
| 294 | 
            +
                    )
         | 
| 268 295 |  | 
| 269 296 | 
             
                async def auto_followup(
         | 
| 270 297 | 
             
                    self,
         | 
| @@ -293,16 +320,8 @@ class Session: | |
| 293 320 | 
             
                    """
         | 
| 294 321 |  | 
| 295 322 | 
             
                    branch_ = self.get_branch(to_)
         | 
| 296 | 
            -
                    if fallback:
         | 
| 297 | 
            -
                        try:
         | 
| 298 | 
            -
                            return await branch_.auto_followup(
         | 
| 299 | 
            -
                                instruction=instruction, num=num, tools=tools,**kwargs
         | 
| 300 | 
            -
                            )
         | 
| 301 | 
            -
                        except:
         | 
| 302 | 
            -
                            return fallback(**fallback_kwargs)
         | 
| 303 | 
            -
                    
         | 
| 304 323 | 
             
                    return await branch_.auto_followup(
         | 
| 305 | 
            -
                        instruction=instruction, num=num, tools=tools | 
| 324 | 
            +
                        instruction=instruction, num=num, tools=tools, fallback=fallback, fallback_kwargs=fallback_kwargs, **kwargs
         | 
| 306 325 | 
             
                    )
         | 
| 307 326 |  | 
| 308 327 | 
             
                def change_first_system_message(self, system: Union[System, str]) -> None:
         | 
| @@ -369,7 +388,7 @@ class Session: | |
| 369 388 |  | 
| 370 389 | 
             
                def register_tools(self, tools: Union[Tool, List[Tool]]) -> None:
         | 
| 371 390 | 
             
                    """
         | 
| 372 | 
            -
                    Registers one or more tools to the current  | 
| 391 | 
            +
                    Registers one or more tools to the current default branch.
         | 
| 373 392 |  | 
| 374 393 | 
             
                    Args:
         | 
| 375 394 | 
             
                        tools (Union[Tool, List[Tool]]): The tool or list of tools to register.
         | 
| @@ -378,7 +397,7 @@ class Session: | |
| 378 397 |  | 
| 379 398 | 
             
                def delete_tool(self, name: str) -> bool:
         | 
| 380 399 | 
             
                    """
         | 
| 381 | 
            -
                    Deletes a tool from the current  | 
| 400 | 
            +
                    Deletes a tool from the current default branch.
         | 
| 382 401 |  | 
| 383 402 | 
             
                    Args:
         | 
| 384 403 | 
             
                        name (str): The name of the tool to delete.
         | 
| @@ -391,10 +410,10 @@ class Session: | |
| 391 410 | 
             
                @property
         | 
| 392 411 | 
             
                def describe(self) -> Dict[str, Any]:
         | 
| 393 412 | 
             
                    """
         | 
| 394 | 
            -
                    Generates a report of the current  | 
| 413 | 
            +
                    Generates a report of the current default branch.
         | 
| 395 414 |  | 
| 396 415 | 
             
                    Returns:
         | 
| 397 | 
            -
                        Dict[str, Any]: The report of the current  | 
| 416 | 
            +
                        Dict[str, Any]: The report of the current default branch.
         | 
| 398 417 | 
             
                    """
         | 
| 399 418 | 
             
                    return self.default_branch.describe
         | 
| 400 419 |  | 
| @@ -408,6 +427,60 @@ class Session: | |
| 408 427 | 
             
                    """
         | 
| 409 428 | 
             
                    return self.default_branch.messages
         | 
| 410 429 |  | 
| 430 | 
            +
             | 
| 431 | 
            +
                @property
         | 
| 432 | 
            +
                def first_system(self) -> pd.Series:
         | 
| 433 | 
            +
                    """
         | 
| 434 | 
            +
                    Get the first system message of the current default branch.
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    Returns:
         | 
| 437 | 
            +
                        System: The first system message of the current default branch.
         | 
| 438 | 
            +
                    """
         | 
| 439 | 
            +
                    return self.default_branch.first_system
         | 
| 440 | 
            +
                
         | 
| 441 | 
            +
                @property
         | 
| 442 | 
            +
                def last_response(self) -> pd.Series:
         | 
| 443 | 
            +
                    """
         | 
| 444 | 
            +
                    Get the last response message of the current default branch.
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    Returns:
         | 
| 447 | 
            +
                        str: The last response message of the current default branch.
         | 
| 448 | 
            +
                    """
         | 
| 449 | 
            +
                    return self.default_branch.last_response
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
                @property
         | 
| 453 | 
            +
                def last_response_content(self) -> Dict:
         | 
| 454 | 
            +
                    """
         | 
| 455 | 
            +
                    Get the last response content of the current default branch.
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    Returns:
         | 
| 458 | 
            +
                        Dict: The last response content of the current default branch.
         | 
| 459 | 
            +
                    """
         | 
| 460 | 
            +
                    return self.default_branch.last_response_content
         | 
| 461 | 
            +
             | 
| 462 | 
            +
             | 
| 463 | 
            +
                def _verify_default_branch(self):
         | 
| 464 | 
            +
                    if self.branches:
         | 
| 465 | 
            +
                        if self.default_branch_name not in self.branches.keys():
         | 
| 466 | 
            +
                            raise ValueError('default branch name is not in imported branches')
         | 
| 467 | 
            +
                        if self.default_branch is not self.branches[self.default_branch_name]:
         | 
| 468 | 
            +
                            raise ValueError(f'default branch does not match Branch object under {self.default_branch_name}')
         | 
| 469 | 
            +
                        
         | 
| 470 | 
            +
                    if not self.branches:
         | 
| 471 | 
            +
                        self.branches[self.default_branch_name] = self.default_branch
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                def _setup_default_branch(
         | 
| 474 | 
            +
                    self, default_branch, default_branch_name, service, llmconfig, system, sender
         | 
| 475 | 
            +
                ):
         | 
| 476 | 
            +
                    self.default_branch = default_branch if default_branch else Branch(
         | 
| 477 | 
            +
                        name=default_branch_name, service=service, llmconfig=llmconfig
         | 
| 478 | 
            +
                    )
         | 
| 479 | 
            +
                    self.default_branch_name = default_branch_name
         | 
| 480 | 
            +
                    if system:
         | 
| 481 | 
            +
                        self.default_branch.add_message(system=system, sender=sender)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    self.llmconfig = self.default_branch.llmconfig
         | 
| 411 484 | 
             
                # def add_instruction_set(self, name: str, instruction_set: InstructionSet) -> None:
         | 
| 412 485 | 
             
                #     """
         | 
| 413 486 | 
             
                #     Adds an instruction set to the current active branch.
         | 
    
        lionagi/tools/tool_manager.py
    CHANGED
    
    | @@ -1,50 +1,44 @@ | |
| 1 1 | 
             
            import json
         | 
| 2 2 | 
             
            import asyncio
         | 
| 3 3 | 
             
            from typing import Dict, Union, List, Tuple, Any
         | 
| 4 | 
            -
            from lionagi.utils.call_util import lcall, is_coroutine_func, _call_handler
         | 
| 4 | 
            +
            from lionagi.utils.call_util import lcall, is_coroutine_func, _call_handler, alcall
         | 
| 5 5 | 
             
            from lionagi.schema import BaseNode, Tool
         | 
| 6 6 |  | 
| 7 7 |  | 
| 8 8 | 
             
            class ToolManager(BaseNode):
         | 
| 9 9 | 
             
                """
         | 
| 10 | 
            -
                A manager class  | 
| 10 | 
            +
                A manager class for handling the registration and invocation of tools that are subclasses of Tool.
         | 
| 11 11 |  | 
| 12 | 
            +
                This class maintains a registry of tool instances, allowing for dynamic invocation based on
         | 
| 13 | 
            +
                tool name and provided arguments. It supports both synchronous and asynchronous tool function
         | 
| 14 | 
            +
                calls.
         | 
| 15 | 
            +
             | 
| 12 16 | 
             
                Attributes:
         | 
| 13 | 
            -
                    registry (Dict[str, Tool]): A dictionary to hold registered tools,  | 
| 17 | 
            +
                    registry (Dict[str, Tool]): A dictionary to hold registered tools, keyed by their names.
         | 
| 14 18 | 
             
                """
         | 
| 15 19 | 
             
                registry: Dict = {}
         | 
| 16 20 |  | 
| 17 21 | 
             
                def name_existed(self, name: str) -> bool:
         | 
| 18 22 | 
             
                    """
         | 
| 19 | 
            -
                     | 
| 23 | 
            +
                    Checks if a tool name already exists in the registry.
         | 
| 20 24 |  | 
| 21 | 
            -
                     | 
| 25 | 
            +
                    Args:
         | 
| 22 26 | 
             
                        name (str): The name of the tool to check.
         | 
| 23 27 |  | 
| 24 28 | 
             
                    Returns:
         | 
| 25 29 | 
             
                        bool: True if the name exists, False otherwise.
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                    Examples:
         | 
| 28 | 
            -
                        >>> tool_manager.name_existed('existing_tool')
         | 
| 29 | 
            -
                        True
         | 
| 30 | 
            -
                        >>> tool_manager.name_existed('nonexistent_tool')
         | 
| 31 | 
            -
                        False
         | 
| 32 30 | 
             
                    """
         | 
| 33 31 | 
             
                    return True if name in self.registry.keys() else False
         | 
| 34 32 |  | 
| 35 33 | 
             
                def _register_tool(self, tool: Tool) -> None:
         | 
| 36 34 | 
             
                    """
         | 
| 37 | 
            -
                     | 
| 35 | 
            +
                    Registers a tool in the registry. Raises a TypeError if the object is not an instance of Tool.
         | 
| 38 36 |  | 
| 39 | 
            -
                     | 
| 37 | 
            +
                    Args:
         | 
| 40 38 | 
             
                        tool (Tool): The tool instance to register.
         | 
| 41 39 |  | 
| 42 40 | 
             
                    Raises:
         | 
| 43 | 
            -
                        TypeError: If the provided  | 
| 44 | 
            -
             | 
| 45 | 
            -
                    Examples:
         | 
| 46 | 
            -
                        >>> tool_manager._register_tool(Tool())
         | 
| 47 | 
            -
                        # Tool is registered without any output
         | 
| 41 | 
            +
                        TypeError: If the provided object is not an instance of Tool.
         | 
| 48 42 | 
             
                    """
         | 
| 49 43 | 
             
                    if not isinstance(tool, Tool):
         | 
| 50 44 | 
             
                        raise TypeError('Please register a Tool object.')
         | 
| @@ -53,20 +47,16 @@ class ToolManager(BaseNode): | |
| 53 47 |  | 
| 54 48 | 
             
                async def invoke(self, func_call: Tuple[str, Dict[str, Any]]) -> Any:
         | 
| 55 49 | 
             
                    """
         | 
| 56 | 
            -
                     | 
| 50 | 
            +
                    Invokes a registered tool's function with the given arguments. Supports both coroutine and regular functions.
         | 
| 57 51 |  | 
| 58 52 | 
             
                    Args:
         | 
| 59 | 
            -
                        func_call (Tuple[str, Dict[str, Any]]): A tuple containing the  | 
| 53 | 
            +
                        func_call (Tuple[str, Dict[str, Any]]): A tuple containing the function name and a dictionary of keyword arguments.
         | 
| 60 54 |  | 
| 61 55 | 
             
                    Returns:
         | 
| 62 | 
            -
                        Any: The result of the  | 
| 56 | 
            +
                        Any: The result of the function call.
         | 
| 63 57 |  | 
| 64 58 | 
             
                    Raises:
         | 
| 65 | 
            -
                        ValueError: If the function is not registered or an error  | 
| 66 | 
            -
             | 
| 67 | 
            -
                    Examples:
         | 
| 68 | 
            -
                        >>> await tool_manager.invoke(('registered_function', {'arg1': 'value1'}))
         | 
| 69 | 
            -
                        # Result of the registered_function with given arguments
         | 
| 59 | 
            +
                        ValueError: If the function name is not registered or if there's an error during function invocation.
         | 
| 70 60 | 
             
                    """
         | 
| 71 61 | 
             
                    name, kwargs = func_call
         | 
| 72 62 | 
             
                    if self.name_existed(name):
         | 
| @@ -74,10 +64,13 @@ class ToolManager(BaseNode): | |
| 74 64 | 
             
                        func = tool.func
         | 
| 75 65 | 
             
                        parser = tool.parser
         | 
| 76 66 | 
             
                        try:
         | 
| 77 | 
            -
                             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 67 | 
            +
                            if is_coroutine_func(func):
         | 
| 68 | 
            +
                                tasks = [_call_handler(func, **kwargs)]
         | 
| 69 | 
            +
                                out = await asyncio.gather(*tasks)
         | 
| 70 | 
            +
                                return parser(out[0]) if parser else out[0]
         | 
| 71 | 
            +
                            else:
         | 
| 72 | 
            +
                                out = func(**kwargs)
         | 
| 73 | 
            +
                                return parser(out) if parser else out
         | 
| 81 74 | 
             
                        except Exception as e:
         | 
| 82 75 | 
             
                            raise ValueError(f"Error when invoking function {name} with arguments {kwargs} with error message {e}")
         | 
| 83 76 | 
             
                    else: 
         | 
| @@ -86,20 +79,16 @@ class ToolManager(BaseNode): | |
| 86 79 | 
             
                @staticmethod
         | 
| 87 80 | 
             
                def get_function_call(response: Dict) -> Tuple[str, Dict]:
         | 
| 88 81 | 
             
                    """
         | 
| 89 | 
            -
                     | 
| 82 | 
            +
                    Extracts a function call and arguments from a response dictionary.
         | 
| 90 83 |  | 
| 91 | 
            -
                     | 
| 92 | 
            -
                        response (Dict): The  | 
| 84 | 
            +
                    Args:
         | 
| 85 | 
            +
                        response (Dict): The response dictionary containing the function call information.
         | 
| 93 86 |  | 
| 94 87 | 
             
                    Returns:
         | 
| 95 | 
            -
                        Tuple[str, Dict]:  | 
| 88 | 
            +
                        Tuple[str, Dict]: A tuple containing the function name and a dictionary of arguments.
         | 
| 96 89 |  | 
| 97 90 | 
             
                    Raises:
         | 
| 98 | 
            -
                        ValueError: If the response  | 
| 99 | 
            -
             | 
| 100 | 
            -
                    Examples:
         | 
| 101 | 
            -
                        >>> ToolManager.get_function_call({"action": "execute_add", "arguments": '{"x":1, "y":2}'})
         | 
| 102 | 
            -
                        ('add', {'x': 1, 'y': 2})
         | 
| 91 | 
            +
                        ValueError: If the response does not contain valid function call information.
         | 
| 103 92 | 
             
                    """
         | 
| 104 93 | 
             
                    try:
         | 
| 105 94 | 
             
                        func = response['action'][7:]
         | 
| @@ -115,27 +104,20 @@ class ToolManager(BaseNode): | |
| 115 104 |  | 
| 116 105 | 
             
                def register_tools(self, tools: List[Tool]) -> None:
         | 
| 117 106 | 
             
                    """
         | 
| 118 | 
            -
                     | 
| 107 | 
            +
                    Registers multiple tools in the registry.
         | 
| 119 108 |  | 
| 120 | 
            -
                     | 
| 121 | 
            -
                        tools (List[Tool]): A list of  | 
| 122 | 
            -
             | 
| 123 | 
            -
                    Examples:
         | 
| 124 | 
            -
                        >>> tool_manager.register_tools([Tool(), Tool()])
         | 
| 125 | 
            -
                        # Multiple Tool instances registered
         | 
| 109 | 
            +
                    Args:
         | 
| 110 | 
            +
                        tools (List[Tool]): A list of tool instances to register.
         | 
| 126 111 | 
             
                    """
         | 
| 127 | 
            -
                    lcall(tools, self._register_tool)  | 
| 112 | 
            +
                    lcall(tools, self._register_tool) 
         | 
| 128 113 |  | 
| 129 114 | 
             
                def to_tool_schema_list(self) -> List[Dict[str, Any]]:
         | 
| 130 115 | 
             
                    """
         | 
| 131 | 
            -
                     | 
| 116 | 
            +
                    Generates a list of schemas for all registered tools.
         | 
| 132 117 |  | 
| 133 118 | 
             
                    Returns:
         | 
| 134 119 | 
             
                        List[Dict[str, Any]]: A list of tool schemas.
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                    Examples:
         | 
| 137 | 
            -
                        >>> tool_manager.to_tool_schema_list()
         | 
| 138 | 
            -
                        # Returns a list of registered tool schemas
         | 
| 120 | 
            +
                    
         | 
| 139 121 | 
             
                    """
         | 
| 140 122 | 
             
                    schema_list = []
         | 
| 141 123 | 
             
                    for tool in self.registry.values():
         | 
| @@ -144,21 +126,17 @@ class ToolManager(BaseNode): | |
| 144 126 |  | 
| 145 127 | 
             
                def _tool_parser(self, tools: Union[Dict, Tool, List[Tool], str, List[str], List[Dict]], **kwargs) -> Dict:
         | 
| 146 128 | 
             
                    """
         | 
| 147 | 
            -
                     | 
| 129 | 
            +
                    Parses tool information and generates a dictionary for tool invocation.
         | 
| 148 130 |  | 
| 149 | 
            -
                     | 
| 150 | 
            -
                        tools:  | 
| 131 | 
            +
                    Args:
         | 
| 132 | 
            +
                        tools: Tool information which can be a single Tool instance, a list of Tool instances, a tool name, or a list of tool names.
         | 
| 151 133 | 
             
                        **kwargs: Additional keyword arguments.
         | 
| 152 134 |  | 
| 153 135 | 
             
                    Returns:
         | 
| 154 | 
            -
                        Dict: A dictionary  | 
| 136 | 
            +
                        Dict: A dictionary containing tool schema information and any additional keyword arguments.
         | 
| 155 137 |  | 
| 156 138 | 
             
                    Raises:
         | 
| 157 | 
            -
                        ValueError: If a tool name  | 
| 158 | 
            -
             | 
| 159 | 
            -
                    Examples:
         | 
| 160 | 
            -
                        >>> tool_manager._tool_parser('registered_tool')
         | 
| 161 | 
            -
                        # Returns a dictionary containing the schema of the registered tool
         | 
| 139 | 
            +
                        ValueError: If a tool name is provided that is not registered.
         | 
| 162 140 | 
             
                    """
         | 
| 163 141 | 
             
                    def tool_check(tool):
         | 
| 164 142 | 
             
                        if isinstance(tool, dict):
         | 
| @@ -183,4 +161,3 @@ class ToolManager(BaseNode): | |
| 183 161 | 
             
                        kwargs = {**tool_kwarg, **kwargs}
         | 
| 184 162 |  | 
| 185 163 | 
             
                    return kwargs
         | 
| 186 | 
            -
                
         | 
    
        lionagi/utils/__init__.py
    CHANGED
    
    | @@ -1,7 +1,8 @@ | |
| 1 1 | 
             
            from .sys_util import (
         | 
| 2 2 | 
             
                get_timestamp, create_copy, create_path, split_path, 
         | 
| 3 3 | 
             
                get_bins, change_dict_key, str_to_num, create_id, 
         | 
| 4 | 
            -
                as_dict, is_package_installed, install_import | 
| 4 | 
            +
                as_dict, is_package_installed, install_import, to_df
         | 
| 5 | 
            +
                )
         | 
| 5 6 |  | 
| 6 7 | 
             
            from .nested_util import (
         | 
| 7 8 | 
             
                to_readable_dict, nfilter, nset, nget, 
         | 
| @@ -19,7 +20,7 @@ from .call_util import ( | |
| 19 20 |  | 
| 20 21 |  | 
| 21 22 | 
             
            __all__ = [
         | 
| 22 | 
            -
                "is_package_installed", "install_import",
         | 
| 23 | 
            +
                "is_package_installed", "install_import", "to_df",
         | 
| 23 24 | 
             
                'get_timestamp', 'create_copy', 'create_path', 'split_path',
         | 
| 24 25 | 
             
                'get_bins', 'change_dict_key', 'str_to_num', 'create_id',
         | 
| 25 26 | 
             
                'as_dict', 'to_list', 'to_readable_dict', 'nfilter', 'nset',
         | 
    
        lionagi/utils/call_util.py
    CHANGED
    
    | @@ -92,7 +92,7 @@ def is_coroutine_func(func: Callable) -> bool: | |
| 92 92 | 
             
                return asyncio.iscoroutinefunction(func)
         | 
| 93 93 |  | 
| 94 94 | 
             
            async def alcall(
         | 
| 95 | 
            -
                input: Any, func: Callable, flatten: bool = False, **kwargs
         | 
| 95 | 
            +
                input: Any = None, func: Callable = None, flatten: bool = False, **kwargs
         | 
| 96 96 | 
             
            )-> List[Any]:
         | 
| 97 97 | 
             
                """
         | 
| 98 98 | 
             
                Asynchronously apply a function to each element in the input.
         | 
| @@ -111,8 +111,12 @@ async def alcall( | |
| 111 111 | 
             
                    >>> asyncio.run(alcall([1, 2, 3], square))
         | 
| 112 112 | 
             
                    [1, 4, 9]
         | 
| 113 113 | 
             
                """
         | 
| 114 | 
            -
                 | 
| 115 | 
            -
             | 
| 114 | 
            +
                if input:
         | 
| 115 | 
            +
                    lst = to_list(input=input)
         | 
| 116 | 
            +
                    tasks = [func(i, **kwargs) for i in lst]
         | 
| 117 | 
            +
                else:
         | 
| 118 | 
            +
                    tasks = [func(**kwargs)]
         | 
| 119 | 
            +
                    
         | 
| 116 120 | 
             
                outs = await asyncio.gather(*tasks)
         | 
| 117 121 | 
             
                return to_list(outs, flatten=flatten)
         | 
| 118 122 |  | 
| @@ -833,16 +837,14 @@ async def _call_handler( | |
| 833 837 | 
             
                            loop = asyncio.get_running_loop()
         | 
| 834 838 | 
             
                        except RuntimeError:  # No running event loop
         | 
| 835 839 | 
             
                            loop = asyncio.new_event_loop()
         | 
| 836 | 
            -
                            asyncio.set_event_loop(loop)
         | 
| 837 | 
            -
                            # Running the coroutine in the new loop
         | 
| 838 840 | 
             
                            result = loop.run_until_complete(func(*args, **kwargs))
         | 
| 841 | 
            +
                            
         | 
| 839 842 | 
             
                            loop.close()
         | 
| 840 843 | 
             
                            return result
         | 
| 841 844 |  | 
| 842 845 | 
             
                        if loop.is_running():
         | 
| 843 | 
            -
                            return asyncio.ensure_future(func(*args, **kwargs))
         | 
| 844 | 
            -
                        else:
         | 
| 845 846 | 
             
                            return await func(*args, **kwargs)
         | 
| 847 | 
            +
             | 
| 846 848 | 
             
                    else:
         | 
| 847 849 | 
             
                        return func(*args, **kwargs)
         | 
| 848 850 |  |