lionagi 0.0.206__py3-none-any.whl → 0.0.208__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/_services/ollama.py +2 -2
- lionagi/core/branch/branch.py +517 -265
- lionagi/core/branch/branch_manager.py +0 -1
- lionagi/core/branch/conversation.py +640 -337
- lionagi/core/core_util.py +0 -59
- lionagi/core/sessions/session.py +137 -64
- lionagi/tools/tool_manager.py +39 -62
- lionagi/utils/__init__.py +3 -2
- lionagi/utils/call_util.py +9 -7
- lionagi/utils/sys_util.py +287 -255
- lionagi/version.py +1 -1
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/METADATA +1 -1
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/RECORD +16 -17
- lionagi/utils/pd_util.py +0 -57
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/LICENSE +0 -0
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/WHEEL +0 -0
- {lionagi-0.0.206.dist-info → lionagi-0.0.208.dist-info}/top_level.txt +0 -0
lionagi/core/core_util.py
CHANGED
@@ -1,59 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
from ..utils.sys_util import strip_lower
|
3
|
-
|
4
|
-
|
5
|
-
def sign_message(messages, sender: str):
|
6
|
-
"""
|
7
|
-
Sign messages with a sender identifier.
|
8
|
-
|
9
|
-
Args:
|
10
|
-
messages (pd.DataFrame): A DataFrame containing messages with columns 'node_id', 'role', 'sender', 'timestamp', and 'content'.
|
11
|
-
sender (str): The sender identifier to be added to the messages.
|
12
|
-
|
13
|
-
Returns:
|
14
|
-
pd.DataFrame: A new DataFrame with the sender identifier added to each message.
|
15
|
-
|
16
|
-
Raises:
|
17
|
-
ValueError: If the 'sender' is None or 'None'.
|
18
|
-
"""
|
19
|
-
if sender is None or strip_lower(sender) == 'none':
|
20
|
-
raise ValueError("sender cannot be None")
|
21
|
-
df = messages.copy()
|
22
|
-
|
23
|
-
for i in df.index:
|
24
|
-
if not df.loc[i, 'content'].startswith('Sender'):
|
25
|
-
df.loc[i, 'content'] = f"Sender {sender}: {df.loc[i, 'content']}"
|
26
|
-
else:
|
27
|
-
content = df.loc[i, 'content'].split(':', 1)[1]
|
28
|
-
df.loc[i, 'content'] = f"Sender {sender}: {content}"
|
29
|
-
return df
|
30
|
-
|
31
|
-
|
32
|
-
def validate_messages(messages):
|
33
|
-
"""
|
34
|
-
Validate the structure and content of a messages DataFrame.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
messages (pd.DataFrame): A DataFrame containing messages with columns 'node_id', 'role', 'sender', 'timestamp', and 'content'.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
bool: True if the messages DataFrame is valid; otherwise, raises a ValueError.
|
41
|
-
|
42
|
-
Raises:
|
43
|
-
ValueError: If the DataFrame structure is invalid or if it contains null values, roles other than ["system", "user", "assistant"],
|
44
|
-
or content that cannot be parsed as JSON strings.
|
45
|
-
"""
|
46
|
-
if list(messages.columns) != ['node_id', 'role', 'sender', 'timestamp', 'content']:
|
47
|
-
raise ValueError('Invalid messages dataframe. Unmatched columns.')
|
48
|
-
if messages.isnull().values.any():
|
49
|
-
raise ValueError('Invalid messages dataframe. Cannot have null.')
|
50
|
-
if not all(role in ['system', 'user', 'assistant'] for role in messages['role'].unique()):
|
51
|
-
raise ValueError('Invalid messages dataframe. Cannot have role other than ["system", "user", "assistant"].')
|
52
|
-
for cont in messages['content']:
|
53
|
-
if cont.startswith('Sender'):
|
54
|
-
cont = cont.split(':', 1)[1]
|
55
|
-
try:
|
56
|
-
json.loads(cont)
|
57
|
-
except:
|
58
|
-
raise ValueError('Invalid messages dataframe. Content expect json string.')
|
59
|
-
return True
|
lionagi/core/sessions/session.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import pandas as pd
|
2
|
+
|
2
3
|
from typing import Any, List, Union, Dict, Optional, Callable, Tuple
|
3
4
|
from dotenv import load_dotenv
|
4
5
|
|
@@ -8,30 +9,26 @@ from ..messages.messages import System, Instruction
|
|
8
9
|
from ..branch.branch import Branch
|
9
10
|
from ..branch.branch_manager import BranchManager
|
10
11
|
|
11
|
-
|
12
12
|
load_dotenv()
|
13
13
|
|
14
14
|
|
15
15
|
class Session:
|
16
16
|
"""
|
17
|
-
|
17
|
+
Manages sessions with conversation branches, tool management, and interaction logging.
|
18
18
|
|
19
|
-
This class
|
20
|
-
messages, instruction sets, and tools. It also handles logging and interactions with an external service.
|
19
|
+
This class orchestrates the handling of different conversation branches, enabling distinct conversational contexts to coexist within a single session. It facilitates the integration with external services for processing chat completions, tool management, and the logging of session activities.
|
21
20
|
|
22
21
|
Attributes:
|
23
|
-
branches (Dict[str, Branch]):
|
24
|
-
default_branch (Branch): The
|
25
|
-
default_branch_name (str):
|
26
|
-
llmconfig (Dict[str, Any]):
|
27
|
-
|
28
|
-
service (OpenAIService): Service used for handling chat completions and other operations.
|
22
|
+
branches (Dict[str, Branch]): Maps branch names to Branch instances.
|
23
|
+
default_branch (Branch): The primary branch for the session.
|
24
|
+
default_branch_name (str): Identifier for the default branch.
|
25
|
+
llmconfig (Dict[str, Any]): Configurations for language model interactions.
|
26
|
+
service (OpenAIService): Interface for external service interactions.
|
29
27
|
"""
|
30
28
|
def __init__(
|
31
29
|
self,
|
32
30
|
system: Optional[Union[str, System]] = None,
|
33
31
|
sender: Optional[str] = None,
|
34
|
-
dir: Optional[str] = None,
|
35
32
|
llmconfig: Optional[Dict[str, Any]] = None,
|
36
33
|
service: OpenAIService = None,
|
37
34
|
branches: Optional[Dict[str, Branch]] = None,
|
@@ -39,36 +36,26 @@ class Session:
|
|
39
36
|
default_branch_name: str = 'main',
|
40
37
|
):
|
41
38
|
"""
|
42
|
-
|
39
|
+
Initializes a session with optional settings for branches, service, and language model configurations.
|
43
40
|
|
44
41
|
Args:
|
45
|
-
system (Union[str, System]): Initial system message or
|
46
|
-
|
47
|
-
llmconfig (Dict[str, Any]
|
48
|
-
service (OpenAIService
|
49
|
-
branches (Dict[str, Branch]
|
50
|
-
default_branch (Branch
|
51
|
-
default_branch_name (str
|
42
|
+
system (Optional[Union[str, System]]): Initial system message or configuration.
|
43
|
+
sender (Optional[str]): Identifier for the sender of the system message.
|
44
|
+
llmconfig (Optional[Dict[str, Any]]): Language model configuration settings.
|
45
|
+
service (OpenAIService): External service for chat completions and other operations.
|
46
|
+
branches (Optional[Dict[str, Branch]]): Predefined conversation branches.
|
47
|
+
default_branch (Optional[Branch]): Preselected default branch for the session.
|
48
|
+
default_branch_name (str): Name for the default branch, defaults to 'main'.
|
52
49
|
"""
|
53
50
|
|
54
51
|
self.branches = branches if isinstance(branches, dict) else {}
|
55
52
|
if service is None:
|
56
53
|
service = OpenAIService()
|
57
|
-
|
58
|
-
self.default_branch = default_branch if default_branch else Branch(name=default_branch_name, service=service, llmconfig=llmconfig)
|
59
|
-
self.default_branch_name = default_branch_name
|
60
|
-
if system:
|
61
|
-
self.default_branch.add_message(system=system, sender=sender)
|
62
|
-
if self.branches:
|
63
|
-
if self.default_branch_name not in self.branches.keys():
|
64
|
-
raise ValueError('default branch name is not in imported branches')
|
65
|
-
if self.default_branch is not self.branches[self.default_branch_name]:
|
66
|
-
raise ValueError(f'default branch does not match Branch object under {self.default_branch_name}')
|
67
|
-
if not self.branches:
|
68
|
-
self.branches[self.default_branch_name] = self.default_branch
|
69
|
-
if dir:
|
70
|
-
self.default_branch.dir = dir
|
71
54
|
|
55
|
+
self._setup_default_branch(
|
56
|
+
default_branch, default_branch_name, service, llmconfig, system, sender)
|
57
|
+
|
58
|
+
self._verify_default_branch()
|
72
59
|
self.branch_manager = BranchManager(self.branches)
|
73
60
|
|
74
61
|
def new_branch(
|
@@ -83,20 +70,20 @@ class Session:
|
|
83
70
|
llmconfig: Optional[Dict] = None,
|
84
71
|
) -> None:
|
85
72
|
"""
|
86
|
-
|
73
|
+
Creates a new branch within the session.
|
87
74
|
|
88
75
|
Args:
|
89
|
-
branch_name (str): Name
|
90
|
-
dir (str
|
91
|
-
messages (Optional[pd.DataFrame]):
|
92
|
-
|
93
|
-
|
94
|
-
sender (Optional[str]
|
95
|
-
service (OpenAIService
|
96
|
-
llmconfig (Dict[str, Any]
|
76
|
+
branch_name (str): Name for the new branch.
|
77
|
+
dir (Optional[str]): Path for storing branch-related logs.
|
78
|
+
messages (Optional[pd.DataFrame]): Initial set of messages for the branch.
|
79
|
+
tools (Optional[Union[Tool, List[Tool]]]): Tools to register in the new branch.
|
80
|
+
system (Optional[Union[str, System]]): System message or configuration for the branch.
|
81
|
+
sender (Optional[str]): Identifier for the sender of the initial message.
|
82
|
+
service (Optional[OpenAIService]): Service interface specific to the branch.
|
83
|
+
llmconfig (Optional[Dict[str, Any]]): Language model configurations for the branch.
|
97
84
|
|
98
85
|
Raises:
|
99
|
-
ValueError: If the branch name already exists
|
86
|
+
ValueError: If the branch name already exists within the session.
|
100
87
|
"""
|
101
88
|
if branch_name in self.branches.keys():
|
102
89
|
raise ValueError(f'Invalid new branch name {branch_name}. Already existed.')
|
@@ -116,18 +103,19 @@ class Session:
|
|
116
103
|
get_name: bool = False
|
117
104
|
) -> Union[Branch, Tuple[Branch, str]]:
|
118
105
|
"""
|
119
|
-
|
106
|
+
Retrieves a branch from the session by name or as a Branch object.
|
107
|
+
|
108
|
+
If no branch is specified, returns the default branch. Optionally, can also return the branch's name.
|
120
109
|
|
121
110
|
Args:
|
122
|
-
branch (Optional[Union[Branch, str]]
|
123
|
-
|
124
|
-
get_name (bool, optional): If True, returns the name of the branch along with the branch object.
|
111
|
+
branch (Optional[Union[Branch, str]]): The branch name or Branch object to retrieve. Defaults to None, which refers to the default branch.
|
112
|
+
get_name (bool): If True, also returns the name of the branch alongside the Branch object.
|
125
113
|
|
126
114
|
Returns:
|
127
|
-
Union[Branch, Tuple[Branch, str]]: The
|
115
|
+
Union[Branch, Tuple[Branch, str]]: The requested Branch object, or a tuple of the Branch object and its name if `get_name` is True.
|
128
116
|
|
129
117
|
Raises:
|
130
|
-
ValueError: If the branch does not exist
|
118
|
+
ValueError: If the specified branch does not exist within the session.
|
131
119
|
"""
|
132
120
|
if isinstance(branch, str):
|
133
121
|
if branch not in self.branches.keys():
|
@@ -152,10 +140,10 @@ class Session:
|
|
152
140
|
|
153
141
|
def change_default(self, branch: Union[str, Branch]) -> None:
|
154
142
|
"""
|
155
|
-
|
143
|
+
Changes the default branch of the session.
|
156
144
|
|
157
145
|
Args:
|
158
|
-
branch (Union[str, Branch]): The branch or
|
146
|
+
branch (Union[str, Branch]): The branch name or Branch object to set as the new default branch.
|
159
147
|
"""
|
160
148
|
branch_, name_ = self.get_branch(branch, get_name=True)
|
161
149
|
self.default_branch = branch_
|
@@ -265,6 +253,45 @@ class Session:
|
|
265
253
|
instruction=instruction, system=system, context=context,
|
266
254
|
out=out, sender=sender, invoke=invoke, tools=tools, **kwargs)
|
267
255
|
|
256
|
+
async def ReAct(
|
257
|
+
self,
|
258
|
+
instruction: Union[Instruction, str],
|
259
|
+
context = None,
|
260
|
+
sender = None,
|
261
|
+
to_ = None,
|
262
|
+
system = None,
|
263
|
+
tools = None,
|
264
|
+
num_rounds: int = 1,
|
265
|
+
fallback: Optional[Callable] = None,
|
266
|
+
fallback_kwargs: Optional[Dict] = None,
|
267
|
+
out=True,
|
268
|
+
**kwargs
|
269
|
+
):
|
270
|
+
"""
|
271
|
+
Performs a sequence of reasoning and action steps in a specified or default branch.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
instruction (Union[Instruction, str]): Instruction to initiate the ReAct process.
|
275
|
+
context: Additional context for reasoning and action. Defaults to None.
|
276
|
+
sender: Identifier for the sender. Defaults to None.
|
277
|
+
to_: Target branch name or object for ReAct. Defaults to the default branch.
|
278
|
+
system: System message or configuration. Defaults to None.
|
279
|
+
tools: Tools to be used for actions. Defaults to None.
|
280
|
+
num_rounds (int): Number of reasoning-action cycles. Defaults to 1.
|
281
|
+
fallback (Optional[Callable]): Fallback function in case of an error. Defaults to None.
|
282
|
+
fallback_kwargs (Optional[Dict]): Arguments for the fallback function. Defaults to None.
|
283
|
+
out (bool): If True, outputs the result of the ReAct process. Defaults to True.
|
284
|
+
**kwargs: Arbitrary keyword arguments for additional customization.
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
The outcome of the ReAct process, depending on the specified branch and instructions.
|
288
|
+
"""
|
289
|
+
branch = self.get_branch(to_)
|
290
|
+
return await branch.ReAct(
|
291
|
+
instruction=instruction, context=context, sender=sender, system=system, tools=tools,
|
292
|
+
num_rounds=num_rounds, fallback=fallback, fallback_kwargs=fallback_kwargs,
|
293
|
+
out=out, **kwargs
|
294
|
+
)
|
268
295
|
|
269
296
|
async def auto_followup(
|
270
297
|
self,
|
@@ -293,16 +320,8 @@ class Session:
|
|
293
320
|
"""
|
294
321
|
|
295
322
|
branch_ = self.get_branch(to_)
|
296
|
-
if fallback:
|
297
|
-
try:
|
298
|
-
return await branch_.auto_followup(
|
299
|
-
instruction=instruction, num=num, tools=tools,**kwargs
|
300
|
-
)
|
301
|
-
except:
|
302
|
-
return fallback(**fallback_kwargs)
|
303
|
-
|
304
323
|
return await branch_.auto_followup(
|
305
|
-
instruction=instruction, num=num, tools=tools
|
324
|
+
instruction=instruction, num=num, tools=tools, fallback=fallback, fallback_kwargs=fallback_kwargs, **kwargs
|
306
325
|
)
|
307
326
|
|
308
327
|
def change_first_system_message(self, system: Union[System, str]) -> None:
|
@@ -369,7 +388,7 @@ class Session:
|
|
369
388
|
|
370
389
|
def register_tools(self, tools: Union[Tool, List[Tool]]) -> None:
|
371
390
|
"""
|
372
|
-
Registers one or more tools to the current
|
391
|
+
Registers one or more tools to the current default branch.
|
373
392
|
|
374
393
|
Args:
|
375
394
|
tools (Union[Tool, List[Tool]]): The tool or list of tools to register.
|
@@ -378,7 +397,7 @@ class Session:
|
|
378
397
|
|
379
398
|
def delete_tool(self, name: str) -> bool:
|
380
399
|
"""
|
381
|
-
Deletes a tool from the current
|
400
|
+
Deletes a tool from the current default branch.
|
382
401
|
|
383
402
|
Args:
|
384
403
|
name (str): The name of the tool to delete.
|
@@ -391,10 +410,10 @@ class Session:
|
|
391
410
|
@property
|
392
411
|
def describe(self) -> Dict[str, Any]:
|
393
412
|
"""
|
394
|
-
Generates a report of the current
|
413
|
+
Generates a report of the current default branch.
|
395
414
|
|
396
415
|
Returns:
|
397
|
-
Dict[str, Any]: The report of the current
|
416
|
+
Dict[str, Any]: The report of the current default branch.
|
398
417
|
"""
|
399
418
|
return self.default_branch.describe
|
400
419
|
|
@@ -408,6 +427,60 @@ class Session:
|
|
408
427
|
"""
|
409
428
|
return self.default_branch.messages
|
410
429
|
|
430
|
+
|
431
|
+
@property
|
432
|
+
def first_system(self) -> pd.Series:
|
433
|
+
"""
|
434
|
+
Get the first system message of the current default branch.
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
System: The first system message of the current default branch.
|
438
|
+
"""
|
439
|
+
return self.default_branch.first_system
|
440
|
+
|
441
|
+
@property
|
442
|
+
def last_response(self) -> pd.Series:
|
443
|
+
"""
|
444
|
+
Get the last response message of the current default branch.
|
445
|
+
|
446
|
+
Returns:
|
447
|
+
str: The last response message of the current default branch.
|
448
|
+
"""
|
449
|
+
return self.default_branch.last_response
|
450
|
+
|
451
|
+
|
452
|
+
@property
|
453
|
+
def last_response_content(self) -> Dict:
|
454
|
+
"""
|
455
|
+
Get the last response content of the current default branch.
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
Dict: The last response content of the current default branch.
|
459
|
+
"""
|
460
|
+
return self.default_branch.last_response_content
|
461
|
+
|
462
|
+
|
463
|
+
def _verify_default_branch(self):
|
464
|
+
if self.branches:
|
465
|
+
if self.default_branch_name not in self.branches.keys():
|
466
|
+
raise ValueError('default branch name is not in imported branches')
|
467
|
+
if self.default_branch is not self.branches[self.default_branch_name]:
|
468
|
+
raise ValueError(f'default branch does not match Branch object under {self.default_branch_name}')
|
469
|
+
|
470
|
+
if not self.branches:
|
471
|
+
self.branches[self.default_branch_name] = self.default_branch
|
472
|
+
|
473
|
+
def _setup_default_branch(
|
474
|
+
self, default_branch, default_branch_name, service, llmconfig, system, sender
|
475
|
+
):
|
476
|
+
self.default_branch = default_branch if default_branch else Branch(
|
477
|
+
name=default_branch_name, service=service, llmconfig=llmconfig
|
478
|
+
)
|
479
|
+
self.default_branch_name = default_branch_name
|
480
|
+
if system:
|
481
|
+
self.default_branch.add_message(system=system, sender=sender)
|
482
|
+
|
483
|
+
self.llmconfig = self.default_branch.llmconfig
|
411
484
|
# def add_instruction_set(self, name: str, instruction_set: InstructionSet) -> None:
|
412
485
|
# """
|
413
486
|
# Adds an instruction set to the current active branch.
|
lionagi/tools/tool_manager.py
CHANGED
@@ -1,50 +1,44 @@
|
|
1
1
|
import json
|
2
2
|
import asyncio
|
3
3
|
from typing import Dict, Union, List, Tuple, Any
|
4
|
-
from lionagi.utils.call_util import lcall, is_coroutine_func, _call_handler
|
4
|
+
from lionagi.utils.call_util import lcall, is_coroutine_func, _call_handler, alcall
|
5
5
|
from lionagi.schema import BaseNode, Tool
|
6
6
|
|
7
7
|
|
8
8
|
class ToolManager(BaseNode):
|
9
9
|
"""
|
10
|
-
A manager class
|
10
|
+
A manager class for handling the registration and invocation of tools that are subclasses of Tool.
|
11
11
|
|
12
|
+
This class maintains a registry of tool instances, allowing for dynamic invocation based on
|
13
|
+
tool name and provided arguments. It supports both synchronous and asynchronous tool function
|
14
|
+
calls.
|
15
|
+
|
12
16
|
Attributes:
|
13
|
-
registry (Dict[str, Tool]): A dictionary to hold registered tools,
|
17
|
+
registry (Dict[str, Tool]): A dictionary to hold registered tools, keyed by their names.
|
14
18
|
"""
|
15
19
|
registry: Dict = {}
|
16
20
|
|
17
21
|
def name_existed(self, name: str) -> bool:
|
18
22
|
"""
|
19
|
-
|
23
|
+
Checks if a tool name already exists in the registry.
|
20
24
|
|
21
|
-
|
25
|
+
Args:
|
22
26
|
name (str): The name of the tool to check.
|
23
27
|
|
24
28
|
Returns:
|
25
29
|
bool: True if the name exists, False otherwise.
|
26
|
-
|
27
|
-
Examples:
|
28
|
-
>>> tool_manager.name_existed('existing_tool')
|
29
|
-
True
|
30
|
-
>>> tool_manager.name_existed('nonexistent_tool')
|
31
|
-
False
|
32
30
|
"""
|
33
31
|
return True if name in self.registry.keys() else False
|
34
32
|
|
35
33
|
def _register_tool(self, tool: Tool) -> None:
|
36
34
|
"""
|
37
|
-
|
35
|
+
Registers a tool in the registry. Raises a TypeError if the object is not an instance of Tool.
|
38
36
|
|
39
|
-
|
37
|
+
Args:
|
40
38
|
tool (Tool): The tool instance to register.
|
41
39
|
|
42
40
|
Raises:
|
43
|
-
TypeError: If the provided
|
44
|
-
|
45
|
-
Examples:
|
46
|
-
>>> tool_manager._register_tool(Tool())
|
47
|
-
# Tool is registered without any output
|
41
|
+
TypeError: If the provided object is not an instance of Tool.
|
48
42
|
"""
|
49
43
|
if not isinstance(tool, Tool):
|
50
44
|
raise TypeError('Please register a Tool object.')
|
@@ -53,20 +47,16 @@ class ToolManager(BaseNode):
|
|
53
47
|
|
54
48
|
async def invoke(self, func_call: Tuple[str, Dict[str, Any]]) -> Any:
|
55
49
|
"""
|
56
|
-
|
50
|
+
Invokes a registered tool's function with the given arguments. Supports both coroutine and regular functions.
|
57
51
|
|
58
52
|
Args:
|
59
|
-
func_call (Tuple[str, Dict[str, Any]]): A tuple containing the
|
53
|
+
func_call (Tuple[str, Dict[str, Any]]): A tuple containing the function name and a dictionary of keyword arguments.
|
60
54
|
|
61
55
|
Returns:
|
62
|
-
Any: The result of the
|
56
|
+
Any: The result of the function call.
|
63
57
|
|
64
58
|
Raises:
|
65
|
-
ValueError: If the function is not registered or an error
|
66
|
-
|
67
|
-
Examples:
|
68
|
-
>>> await tool_manager.invoke(('registered_function', {'arg1': 'value1'}))
|
69
|
-
# Result of the registered_function with given arguments
|
59
|
+
ValueError: If the function name is not registered or if there's an error during function invocation.
|
70
60
|
"""
|
71
61
|
name, kwargs = func_call
|
72
62
|
if self.name_existed(name):
|
@@ -74,10 +64,13 @@ class ToolManager(BaseNode):
|
|
74
64
|
func = tool.func
|
75
65
|
parser = tool.parser
|
76
66
|
try:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
67
|
+
if is_coroutine_func(func):
|
68
|
+
tasks = [_call_handler(func, **kwargs)]
|
69
|
+
out = await asyncio.gather(*tasks)
|
70
|
+
return parser(out[0]) if parser else out[0]
|
71
|
+
else:
|
72
|
+
out = func(**kwargs)
|
73
|
+
return parser(out) if parser else out
|
81
74
|
except Exception as e:
|
82
75
|
raise ValueError(f"Error when invoking function {name} with arguments {kwargs} with error message {e}")
|
83
76
|
else:
|
@@ -86,20 +79,16 @@ class ToolManager(BaseNode):
|
|
86
79
|
@staticmethod
|
87
80
|
def get_function_call(response: Dict) -> Tuple[str, Dict]:
|
88
81
|
"""
|
89
|
-
|
82
|
+
Extracts a function call and arguments from a response dictionary.
|
90
83
|
|
91
|
-
|
92
|
-
response (Dict): The
|
84
|
+
Args:
|
85
|
+
response (Dict): The response dictionary containing the function call information.
|
93
86
|
|
94
87
|
Returns:
|
95
|
-
Tuple[str, Dict]:
|
88
|
+
Tuple[str, Dict]: A tuple containing the function name and a dictionary of arguments.
|
96
89
|
|
97
90
|
Raises:
|
98
|
-
ValueError: If the response
|
99
|
-
|
100
|
-
Examples:
|
101
|
-
>>> ToolManager.get_function_call({"action": "execute_add", "arguments": '{"x":1, "y":2}'})
|
102
|
-
('add', {'x': 1, 'y': 2})
|
91
|
+
ValueError: If the response does not contain valid function call information.
|
103
92
|
"""
|
104
93
|
try:
|
105
94
|
func = response['action'][7:]
|
@@ -115,27 +104,20 @@ class ToolManager(BaseNode):
|
|
115
104
|
|
116
105
|
def register_tools(self, tools: List[Tool]) -> None:
|
117
106
|
"""
|
118
|
-
|
107
|
+
Registers multiple tools in the registry.
|
119
108
|
|
120
|
-
|
121
|
-
tools (List[Tool]): A list of
|
122
|
-
|
123
|
-
Examples:
|
124
|
-
>>> tool_manager.register_tools([Tool(), Tool()])
|
125
|
-
# Multiple Tool instances registered
|
109
|
+
Args:
|
110
|
+
tools (List[Tool]): A list of tool instances to register.
|
126
111
|
"""
|
127
|
-
lcall(tools, self._register_tool)
|
112
|
+
lcall(tools, self._register_tool)
|
128
113
|
|
129
114
|
def to_tool_schema_list(self) -> List[Dict[str, Any]]:
|
130
115
|
"""
|
131
|
-
|
116
|
+
Generates a list of schemas for all registered tools.
|
132
117
|
|
133
118
|
Returns:
|
134
119
|
List[Dict[str, Any]]: A list of tool schemas.
|
135
|
-
|
136
|
-
Examples:
|
137
|
-
>>> tool_manager.to_tool_schema_list()
|
138
|
-
# Returns a list of registered tool schemas
|
120
|
+
|
139
121
|
"""
|
140
122
|
schema_list = []
|
141
123
|
for tool in self.registry.values():
|
@@ -144,21 +126,17 @@ class ToolManager(BaseNode):
|
|
144
126
|
|
145
127
|
def _tool_parser(self, tools: Union[Dict, Tool, List[Tool], str, List[str], List[Dict]], **kwargs) -> Dict:
|
146
128
|
"""
|
147
|
-
|
129
|
+
Parses tool information and generates a dictionary for tool invocation.
|
148
130
|
|
149
|
-
|
150
|
-
tools:
|
131
|
+
Args:
|
132
|
+
tools: Tool information which can be a single Tool instance, a list of Tool instances, a tool name, or a list of tool names.
|
151
133
|
**kwargs: Additional keyword arguments.
|
152
134
|
|
153
135
|
Returns:
|
154
|
-
Dict: A dictionary
|
136
|
+
Dict: A dictionary containing tool schema information and any additional keyword arguments.
|
155
137
|
|
156
138
|
Raises:
|
157
|
-
ValueError: If a tool name
|
158
|
-
|
159
|
-
Examples:
|
160
|
-
>>> tool_manager._tool_parser('registered_tool')
|
161
|
-
# Returns a dictionary containing the schema of the registered tool
|
139
|
+
ValueError: If a tool name is provided that is not registered.
|
162
140
|
"""
|
163
141
|
def tool_check(tool):
|
164
142
|
if isinstance(tool, dict):
|
@@ -183,4 +161,3 @@ class ToolManager(BaseNode):
|
|
183
161
|
kwargs = {**tool_kwarg, **kwargs}
|
184
162
|
|
185
163
|
return kwargs
|
186
|
-
|
lionagi/utils/__init__.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
from .sys_util import (
|
2
2
|
get_timestamp, create_copy, create_path, split_path,
|
3
3
|
get_bins, change_dict_key, str_to_num, create_id,
|
4
|
-
as_dict, is_package_installed, install_import
|
4
|
+
as_dict, is_package_installed, install_import, to_df
|
5
|
+
)
|
5
6
|
|
6
7
|
from .nested_util import (
|
7
8
|
to_readable_dict, nfilter, nset, nget,
|
@@ -19,7 +20,7 @@ from .call_util import (
|
|
19
20
|
|
20
21
|
|
21
22
|
__all__ = [
|
22
|
-
"is_package_installed", "install_import",
|
23
|
+
"is_package_installed", "install_import", "to_df",
|
23
24
|
'get_timestamp', 'create_copy', 'create_path', 'split_path',
|
24
25
|
'get_bins', 'change_dict_key', 'str_to_num', 'create_id',
|
25
26
|
'as_dict', 'to_list', 'to_readable_dict', 'nfilter', 'nset',
|
lionagi/utils/call_util.py
CHANGED
@@ -92,7 +92,7 @@ def is_coroutine_func(func: Callable) -> bool:
|
|
92
92
|
return asyncio.iscoroutinefunction(func)
|
93
93
|
|
94
94
|
async def alcall(
|
95
|
-
input: Any, func: Callable, flatten: bool = False, **kwargs
|
95
|
+
input: Any = None, func: Callable = None, flatten: bool = False, **kwargs
|
96
96
|
)-> List[Any]:
|
97
97
|
"""
|
98
98
|
Asynchronously apply a function to each element in the input.
|
@@ -111,8 +111,12 @@ async def alcall(
|
|
111
111
|
>>> asyncio.run(alcall([1, 2, 3], square))
|
112
112
|
[1, 4, 9]
|
113
113
|
"""
|
114
|
-
|
115
|
-
|
114
|
+
if input:
|
115
|
+
lst = to_list(input=input)
|
116
|
+
tasks = [func(i, **kwargs) for i in lst]
|
117
|
+
else:
|
118
|
+
tasks = [func(**kwargs)]
|
119
|
+
|
116
120
|
outs = await asyncio.gather(*tasks)
|
117
121
|
return to_list(outs, flatten=flatten)
|
118
122
|
|
@@ -833,16 +837,14 @@ async def _call_handler(
|
|
833
837
|
loop = asyncio.get_running_loop()
|
834
838
|
except RuntimeError: # No running event loop
|
835
839
|
loop = asyncio.new_event_loop()
|
836
|
-
asyncio.set_event_loop(loop)
|
837
|
-
# Running the coroutine in the new loop
|
838
840
|
result = loop.run_until_complete(func(*args, **kwargs))
|
841
|
+
|
839
842
|
loop.close()
|
840
843
|
return result
|
841
844
|
|
842
845
|
if loop.is_running():
|
843
|
-
return asyncio.ensure_future(func(*args, **kwargs))
|
844
|
-
else:
|
845
846
|
return await func(*args, **kwargs)
|
847
|
+
|
846
848
|
else:
|
847
849
|
return func(*args, **kwargs)
|
848
850
|
|