lionagi 0.0.306__py3-none-any.whl → 0.0.307__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/__init__.py +2 -5
- lionagi/core/__init__.py +7 -5
- lionagi/core/agent/__init__.py +3 -0
- lionagi/core/agent/base_agent.py +10 -12
- lionagi/core/branch/__init__.py +4 -0
- lionagi/core/branch/base_branch.py +81 -81
- lionagi/core/branch/branch.py +16 -28
- lionagi/core/branch/branch_flow_mixin.py +3 -7
- lionagi/core/branch/executable_branch.py +86 -56
- lionagi/core/branch/util.py +77 -162
- lionagi/core/{flow/direct → direct}/__init__.py +1 -1
- lionagi/core/{flow/direct/predict.py → direct/parallel_predict.py} +39 -17
- lionagi/core/direct/parallel_react.py +0 -0
- lionagi/core/direct/parallel_score.py +0 -0
- lionagi/core/direct/parallel_select.py +0 -0
- lionagi/core/direct/parallel_sentiment.py +0 -0
- lionagi/core/direct/predict.py +174 -0
- lionagi/core/{flow/direct → direct}/react.py +2 -2
- lionagi/core/{flow/direct → direct}/score.py +28 -23
- lionagi/core/{flow/direct → direct}/select.py +48 -45
- lionagi/core/direct/utils.py +83 -0
- lionagi/core/flow/monoflow/ReAct.py +6 -5
- lionagi/core/flow/monoflow/__init__.py +9 -0
- lionagi/core/flow/monoflow/chat.py +10 -10
- lionagi/core/flow/monoflow/chat_mixin.py +11 -10
- lionagi/core/flow/monoflow/followup.py +6 -5
- lionagi/core/flow/polyflow/__init__.py +1 -0
- lionagi/core/flow/polyflow/chat.py +15 -3
- lionagi/core/mail/mail_manager.py +18 -19
- lionagi/core/mail/schema.py +5 -4
- lionagi/core/messages/schema.py +18 -20
- lionagi/core/prompt/__init__.py +0 -0
- lionagi/core/prompt/prompt_template.py +0 -0
- lionagi/core/schema/__init__.py +2 -2
- lionagi/core/schema/action_node.py +11 -3
- lionagi/core/schema/base_mixin.py +56 -59
- lionagi/core/schema/base_node.py +35 -38
- lionagi/core/schema/condition.py +24 -0
- lionagi/core/schema/data_logger.py +96 -99
- lionagi/core/schema/data_node.py +19 -19
- lionagi/core/schema/prompt_template.py +0 -0
- lionagi/core/schema/structure.py +171 -169
- lionagi/core/session/__init__.py +1 -3
- lionagi/core/session/session.py +196 -214
- lionagi/core/tool/tool_manager.py +95 -103
- lionagi/integrations/__init__.py +1 -3
- lionagi/integrations/bridge/langchain_/documents.py +17 -18
- lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
- lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
- lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
- lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
- lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
- lionagi/integrations/config/openrouter_configs.py +0 -1
- lionagi/integrations/provider/oai.py +26 -26
- lionagi/integrations/provider/services.py +38 -38
- lionagi/libs/__init__.py +34 -1
- lionagi/libs/ln_api.py +211 -221
- lionagi/libs/ln_async.py +53 -60
- lionagi/libs/ln_convert.py +118 -120
- lionagi/libs/ln_dataframe.py +32 -33
- lionagi/libs/ln_func_call.py +334 -342
- lionagi/libs/ln_nested.py +99 -107
- lionagi/libs/ln_parse.py +161 -165
- lionagi/libs/sys_util.py +52 -52
- lionagi/tests/test_core/test_session.py +254 -266
- lionagi/tests/test_core/test_session_base_util.py +299 -300
- lionagi/tests/test_core/test_tool_manager.py +70 -74
- lionagi/tests/test_libs/test_nested.py +2 -7
- lionagi/tests/test_libs/test_parse.py +2 -2
- lionagi/version.py +1 -1
- {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
- lionagi-0.0.307.dist-info/RECORD +115 -0
- lionagi/core/flow/direct/utils.py +0 -43
- lionagi-0.0.306.dist-info/RECORD +0 -106
- /lionagi/core/{flow/direct → direct}/sentiment.py +0 -0
- {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
- {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
- {lionagi-0.0.306.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
lionagi/__init__.py
CHANGED
@@ -6,11 +6,8 @@ import logging
|
|
6
6
|
from .version import __version__
|
7
7
|
from dotenv import load_dotenv
|
8
8
|
|
9
|
-
|
10
|
-
from .
|
11
|
-
from .core import *
|
12
|
-
from .integrations import *
|
13
|
-
|
9
|
+
from .core import direct, Branch, Session, Structure, Tool, BaseAgent
|
10
|
+
from .integrations.provider.services import Services
|
14
11
|
|
15
12
|
logger = logging.getLogger(__name__)
|
16
13
|
logger.setLevel(logging.INFO)
|
lionagi/core/__init__.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
-
from .
|
2
|
-
from .session import Branch, Session
|
3
|
-
|
4
|
-
from .flow import direct
|
1
|
+
from . import *
|
5
2
|
|
6
|
-
|
3
|
+
from .branch import Branch, ExecutableBranch
|
4
|
+
from .session import Session
|
5
|
+
from .schema import Tool, Structure, ActionNode, Relationship
|
6
|
+
from .agent import BaseAgent
|
7
|
+
from .messages import Instruction, System, Response
|
8
|
+
from .tool import func_to_tool
|
lionagi/core/agent/__init__.py
CHANGED
lionagi/core/agent/base_agent.py
CHANGED
@@ -1,27 +1,20 @@
|
|
1
|
-
from collections import deque
|
2
|
-
|
3
1
|
from lionagi.core.mail.schema import StartMail
|
4
2
|
from lionagi.core.schema.base_node import BaseRelatableNode
|
5
3
|
from lionagi.core.mail.mail_manager import MailManager
|
6
4
|
|
7
|
-
|
8
|
-
from lionagi.libs.ln_async import AsyncUtil
|
5
|
+
from lionagi.libs import func_call, AsyncUtil
|
9
6
|
|
10
7
|
|
11
8
|
class BaseAgent(BaseRelatableNode):
|
12
|
-
def __init__(
|
13
|
-
|
14
|
-
structure,
|
15
|
-
executable_class,
|
16
|
-
output_parser=None,
|
17
|
-
executable_class_kwargs={},
|
18
|
-
) -> None:
|
9
|
+
def __init__(self, structure, executable_obj, output_parser=None) -> None:
|
10
|
+
|
19
11
|
super().__init__()
|
20
12
|
self.structure = structure
|
21
|
-
self.executable =
|
13
|
+
self.executable = executable_obj
|
22
14
|
self.start = StartMail()
|
23
15
|
self.mailManager = MailManager([self.structure, self.executable, self.start])
|
24
16
|
self.output_parser = output_parser
|
17
|
+
self.start_context = None
|
25
18
|
|
26
19
|
async def mail_manager_control(self, refresh_time=1):
|
27
20
|
while not self.structure.execute_stop or not self.executable.execute_stop:
|
@@ -29,6 +22,7 @@ class BaseAgent(BaseRelatableNode):
|
|
29
22
|
self.mailManager.execute_stop = True
|
30
23
|
|
31
24
|
async def execute(self, context=None):
|
25
|
+
self.start_context = context
|
32
26
|
self.start.trigger(
|
33
27
|
context=context,
|
34
28
|
structure_id=self.structure.id_,
|
@@ -44,5 +38,9 @@ class BaseAgent(BaseRelatableNode):
|
|
44
38
|
],
|
45
39
|
)
|
46
40
|
|
41
|
+
self.structure.execute_stop = False
|
42
|
+
self.executable.execute_stop = False
|
43
|
+
self.mailManager.execute_stop = False
|
44
|
+
|
47
45
|
if self.output_parser:
|
48
46
|
return self.output_parser(self)
|
lionagi/core/branch/__init__.py
CHANGED
@@ -1,22 +1,19 @@
|
|
1
1
|
from abc import ABC
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
-
from lionagi.libs.sys_util import
|
4
|
+
from lionagi.libs.sys_util import PATH_TYPE
|
5
|
+
from lionagi.libs import convert, dataframe, SysUtil
|
5
6
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
from lionagi.core.schema.base_node import BaseRelatableNode
|
10
|
-
from lionagi.core.schema.data_logger import DataLogger, DLog
|
11
|
-
from lionagi.core.messages.schema import (
|
7
|
+
from ..schema.base_node import BaseRelatableNode
|
8
|
+
from ..schema.data_logger import DataLogger, DLog
|
9
|
+
from ..messages.schema import (
|
12
10
|
BranchColumns,
|
13
11
|
System,
|
14
12
|
Response,
|
15
13
|
Instruction,
|
16
14
|
BaseMessage,
|
17
15
|
)
|
18
|
-
from
|
19
|
-
from lionagi.libs.ln_parse import ParseUtil
|
16
|
+
from .util import MessageUtil
|
20
17
|
|
21
18
|
|
22
19
|
class BaseBranch(BaseRelatableNode, ABC):
|
@@ -25,9 +22,9 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
25
22
|
and logging functionality.
|
26
23
|
|
27
24
|
Attributes:
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
messages (dataframe.ln_DataFrame): Holds the messages in the branch.
|
26
|
+
datalogger (DataLogger): Logs data related to the branch's operation.
|
27
|
+
persist_path (PATH_TYPE): Filesystem path for data persistence.
|
31
28
|
"""
|
32
29
|
|
33
30
|
_columns: list[str] = BranchColumns.COLUMNS.value
|
@@ -49,9 +46,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
49
46
|
else:
|
50
47
|
self.messages = dataframe.ln_DataFrame(columns=self._columns)
|
51
48
|
|
52
|
-
self.datalogger = (
|
53
|
-
datalogger if datalogger else DataLogger(persist_path=persist_path)
|
54
|
-
)
|
49
|
+
self.datalogger = datalogger or DataLogger(persist_path=persist_path)
|
55
50
|
self.name = name
|
56
51
|
|
57
52
|
def add_message(
|
@@ -68,11 +63,11 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
68
63
|
Adds a message to the branch.
|
69
64
|
|
70
65
|
Args:
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
66
|
+
system: Information for creating a System message.
|
67
|
+
instruction: Information for creating an Instruction message.
|
68
|
+
context: Context information for the message.
|
69
|
+
response: Response data for creating a message.
|
70
|
+
**kwargs: Additional keyword arguments for message creation.
|
76
71
|
"""
|
77
72
|
_msg = MessageUtil.create_message(
|
78
73
|
system=system,
|
@@ -87,15 +82,20 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
87
82
|
if isinstance(_msg, System):
|
88
83
|
self.system_node = _msg
|
89
84
|
|
85
|
+
# sourcery skip: merge-nested-ifs
|
90
86
|
if isinstance(_msg, Instruction):
|
91
87
|
if recipient is None and self.name is not None:
|
92
88
|
_msg.recipient = self.name
|
93
89
|
|
94
90
|
if isinstance(_msg, Response):
|
95
91
|
if "action_response" in _msg.content.keys():
|
96
|
-
|
97
|
-
|
98
|
-
|
92
|
+
if recipient is None and self.name is not None:
|
93
|
+
_msg.recipient = self.name
|
94
|
+
if recipient is not None and self.name is None:
|
95
|
+
_msg.recipient = recipient
|
96
|
+
if "response" in _msg.content.keys():
|
97
|
+
if self.name is not None:
|
98
|
+
_msg.sender = self.name
|
99
99
|
|
100
100
|
_msg.content = _msg.msg_content
|
101
101
|
self.messages.loc[len(self.messages)] = _msg.to_pd_series()
|
@@ -108,11 +108,11 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
108
108
|
optionally including sender information.
|
109
109
|
|
110
110
|
Args:
|
111
|
-
|
111
|
+
with_sender: Flag to include sender information in the output.
|
112
112
|
|
113
113
|
Returns:
|
114
|
-
|
115
|
-
|
114
|
+
A list of message dictionaries, each with 'role' and 'content' keys,
|
115
|
+
and optionally prefixed by 'Sender' if with_sender is True.
|
116
116
|
"""
|
117
117
|
|
118
118
|
message = []
|
@@ -143,7 +143,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
143
143
|
Retrieves all chat messages without sender information.
|
144
144
|
|
145
145
|
Returns:
|
146
|
-
|
146
|
+
A list of dictionaries representing chat messages.
|
147
147
|
"""
|
148
148
|
|
149
149
|
return self._to_chatcompletion_message()
|
@@ -154,7 +154,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
154
154
|
Retrieves all chat messages, including sender information.
|
155
155
|
|
156
156
|
Returns:
|
157
|
-
|
157
|
+
A list of dictionaries representing chat messages, each prefixed with its sender.
|
158
158
|
"""
|
159
159
|
|
160
160
|
return self._to_chatcompletion_message(with_sender=True)
|
@@ -165,7 +165,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
165
165
|
Retrieves the last message from the branch as a pandas Series.
|
166
166
|
|
167
167
|
Returns:
|
168
|
-
|
168
|
+
A pandas Series representing the last message in the branch.
|
169
169
|
"""
|
170
170
|
|
171
171
|
return MessageUtil.get_message_rows(self.messages, n=1, from_="last")
|
@@ -176,7 +176,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
176
176
|
Extracts the content of the last message in the branch.
|
177
177
|
|
178
178
|
Returns:
|
179
|
-
|
179
|
+
A dictionary representing the content of the last message.
|
180
180
|
"""
|
181
181
|
|
182
182
|
return convert.to_dict(self.messages.content.iloc[-1])
|
@@ -187,7 +187,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
187
187
|
Retrieves the first message marked with the 'system' role.
|
188
188
|
|
189
189
|
Returns:
|
190
|
-
|
190
|
+
A pandas Series representing the first 'system' message in the branch.
|
191
191
|
"""
|
192
192
|
|
193
193
|
return MessageUtil.get_message_rows(
|
@@ -200,7 +200,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
200
200
|
Retrieves the last message marked with the 'assistant' role.
|
201
201
|
|
202
202
|
Returns:
|
203
|
-
|
203
|
+
A pandas Series representing the last 'assistant' (response) message in the branch.
|
204
204
|
"""
|
205
205
|
|
206
206
|
return MessageUtil.get_message_rows(
|
@@ -213,7 +213,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
213
213
|
Extracts the content of the last 'assistant' (response) message.
|
214
214
|
|
215
215
|
Returns:
|
216
|
-
|
216
|
+
A dictionary representing the content of the last 'assistant' message.
|
217
217
|
"""
|
218
218
|
|
219
219
|
return convert.to_dict(self.last_response.content.iloc[-1])
|
@@ -224,7 +224,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
224
224
|
Filters and retrieves all messages sent by 'action_request'.
|
225
225
|
|
226
226
|
Returns:
|
227
|
-
|
227
|
+
A pandas DataFrame containing all 'action_request' messages.
|
228
228
|
"""
|
229
229
|
|
230
230
|
return convert.to_df(self.messages[self.messages.sender == "action_request"])
|
@@ -235,7 +235,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
235
235
|
Filters and retrieves all messages sent by 'action_response'.
|
236
236
|
|
237
237
|
Returns:
|
238
|
-
|
238
|
+
A pandas DataFrame containing all 'action_response' messages.
|
239
239
|
"""
|
240
240
|
|
241
241
|
return convert.to_df(self.messages[self.messages.sender == "action_response"])
|
@@ -246,7 +246,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
246
246
|
Retrieves all messages marked with the 'assistant' role.
|
247
247
|
|
248
248
|
Returns:
|
249
|
-
|
249
|
+
A pandas DataFrame containing all messages with an 'assistant' role.
|
250
250
|
"""
|
251
251
|
|
252
252
|
return convert.to_df(self.messages[self.messages.role == "assistant"])
|
@@ -257,7 +257,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
257
257
|
Filters 'assistant' role messages excluding 'action_request' and 'action_response'.
|
258
258
|
|
259
259
|
Returns:
|
260
|
-
|
260
|
+
A pandas DataFrame of 'assistant' messages excluding action requests/responses.
|
261
261
|
"""
|
262
262
|
|
263
263
|
a_responses = self.responses[self.responses.sender != "action_response"]
|
@@ -274,7 +274,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
274
274
|
Summarizes branch information, including message counts by role.
|
275
275
|
|
276
276
|
Returns:
|
277
|
-
|
277
|
+
A dictionary containing counts of messages categorized by their role.
|
278
278
|
"""
|
279
279
|
|
280
280
|
return self._info()
|
@@ -285,7 +285,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
285
285
|
Provides a summary of message counts categorized by sender.
|
286
286
|
|
287
287
|
Returns:
|
288
|
-
|
288
|
+
A dictionary with senders as keys and counts of their messages as values.
|
289
289
|
"""
|
290
290
|
|
291
291
|
return self._info(use_sender=True)
|
@@ -296,8 +296,8 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
296
296
|
Provides a detailed description of the branch, including a summary of messages.
|
297
297
|
|
298
298
|
Returns:
|
299
|
-
|
300
|
-
|
299
|
+
A dictionary with a summary of total messages, a breakdown by role, and
|
300
|
+
a preview of the first five messages.
|
301
301
|
"""
|
302
302
|
|
303
303
|
return {
|
@@ -344,13 +344,13 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
344
344
|
Exports the branch messages to a CSV file.
|
345
345
|
|
346
346
|
Args:
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
347
|
+
filepath: Destination path for the CSV file. Defaults to 'messages.csv'.
|
348
|
+
dir_exist_ok: If False, an error is raised if the directory exists. Defaults to True.
|
349
|
+
timestamp: If True, appends a timestamp to the filename. Defaults to True.
|
350
|
+
time_prefix: If True, prefixes the filename with a timestamp. Defaults to False.
|
351
|
+
verbose: If True, prints a message upon successful export. Defaults to True.
|
352
|
+
clear: If True, clears the messages after exporting. Defaults to True.
|
353
|
+
**kwargs: Additional keyword arguments for pandas.DataFrame.to_csv().
|
354
354
|
"""
|
355
355
|
|
356
356
|
if not filename.endswith(".csv"):
|
@@ -371,7 +371,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
371
371
|
if clear:
|
372
372
|
self.clear_messages()
|
373
373
|
except Exception as e:
|
374
|
-
raise ValueError(f"Error in saving to csv: {e}")
|
374
|
+
raise ValueError(f"Error in saving to csv: {e}") from e
|
375
375
|
|
376
376
|
def to_json_file(
|
377
377
|
self,
|
@@ -387,13 +387,13 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
387
387
|
Exports the branch messages to a JSON file.
|
388
388
|
|
389
389
|
Args:
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
390
|
+
filename: Destination path for the JSON file. Defaults to 'messages.json'.
|
391
|
+
dir_exist_ok: If False, an error is raised if the dirctory exists. Defaults to True.
|
392
|
+
timestamp: If True, appends a timestamp to the filename. Defaults to True.
|
393
|
+
time_prefix: If True, prefixes the filename with a timestamp. Defaults to False.
|
394
|
+
verbose: If True, prints a message upon successful export. Defaults to True.
|
395
|
+
clear: If True, clears the messages after exporting. Defaults to True.
|
396
|
+
**kwargs: Additional keyword arguments for pandas.DataFrame.to_json().
|
397
397
|
"""
|
398
398
|
|
399
399
|
if not filename.endswith(".json"):
|
@@ -416,7 +416,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
416
416
|
if clear:
|
417
417
|
self.clear_messages()
|
418
418
|
except Exception as e:
|
419
|
-
raise ValueError(f"Error in saving to json: {e}")
|
419
|
+
raise ValueError(f"Error in saving to json: {e}") from e
|
420
420
|
|
421
421
|
def log_to_csv(
|
422
422
|
self,
|
@@ -434,13 +434,13 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
434
434
|
Exports the data logger contents to a CSV file.
|
435
435
|
|
436
436
|
Args:
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
437
|
+
filename: Destination path for the CSV file. Defaults to 'log.csv'.
|
438
|
+
dir_exist_ok: If False, an error is raised if the directory exists. Defaults to True.
|
439
|
+
timestamp: If True, appends a timestamp to the filename. Defaults to True.
|
440
|
+
time_prefix: If True, prefixes the filename with a timestamp. Defaults to False.
|
441
|
+
verbose: If True, prints a message upon successful export. Defaults to True.
|
442
|
+
clear: If True, clears the logger after exporting. Defaults to True.
|
443
|
+
**kwargs: Additional keyword arguments for pandas.DataFrame.to_csv().
|
444
444
|
"""
|
445
445
|
self.datalogger.to_csv_file(
|
446
446
|
filename=filename,
|
@@ -470,13 +470,13 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
470
470
|
Exports the data logger contents to a JSON file.
|
471
471
|
|
472
472
|
Args:
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
473
|
+
filename: Destination path for the JSON file. Defaults to 'log.json'.
|
474
|
+
dir_exist_ok: If False, an error is raised if the directory exists. Defaults to True.
|
475
|
+
timestamp: If True, appends a timestamp to the filename. Defaults to True.
|
476
|
+
time_prefix: If True, prefixes the filename with a timestamp. Defaults to False.
|
477
|
+
verbose: If True, prints a message upon successful export. Defaults to True.
|
478
|
+
clear: If True, clears the logger after exporting. Defaults to True.
|
479
|
+
**kwargs: Additional keyword arguments for pandas.DataFrame.to_json().
|
480
480
|
"""
|
481
481
|
|
482
482
|
self.datalogger.to_json_file(
|
@@ -513,14 +513,14 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
513
513
|
if verbose:
|
514
514
|
print(f"Loaded {len(df)} logs from {filename}")
|
515
515
|
except Exception as e:
|
516
|
-
raise ValueError(f"Error in loading log: {e}")
|
516
|
+
raise ValueError(f"Error in loading log: {e}") from e
|
517
517
|
|
518
518
|
def remove_message(self, node_id: str) -> None:
|
519
519
|
"""
|
520
520
|
Removes a message from the branch based on its node ID.
|
521
521
|
|
522
522
|
Args:
|
523
|
-
|
523
|
+
node_id: The unique identifier of the message to be removed.
|
524
524
|
"""
|
525
525
|
MessageUtil.remove_message(self.messages, node_id)
|
526
526
|
|
@@ -529,9 +529,9 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
529
529
|
Updates a specific column of a message identified by node_id with a new value.
|
530
530
|
|
531
531
|
Args:
|
532
|
-
|
533
|
-
|
534
|
-
|
532
|
+
value: The new value to update the message with.
|
533
|
+
node_id: The unique identifier of the message to update.
|
534
|
+
column: The column of the message to update.
|
535
535
|
"""
|
536
536
|
|
537
537
|
index = self.messages[self.messages["node_id"] == node_id].index[0]
|
@@ -547,8 +547,8 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
547
547
|
Updates the first system message with new content and/or sender.
|
548
548
|
|
549
549
|
Args:
|
550
|
-
|
551
|
-
|
550
|
+
system: The new system message content or a System object.
|
551
|
+
sender: The identifier of the sender for the system message.
|
552
552
|
"""
|
553
553
|
|
554
554
|
if len(self.messages[self.messages["role"] == "system"]) == 0:
|
@@ -570,7 +570,7 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
570
570
|
Removes the last 'n' messages from the branch.
|
571
571
|
|
572
572
|
Args:
|
573
|
-
|
573
|
+
steps: The number of messages to remove from the end.
|
574
574
|
"""
|
575
575
|
|
576
576
|
self.messages = dataframe.remove_last_n_rows(self.messages, steps)
|
@@ -642,10 +642,10 @@ class BaseBranch(BaseRelatableNode, ABC):
|
|
642
642
|
Helper method to generate summaries of messages either by role or sender.
|
643
643
|
|
644
644
|
Args:
|
645
|
-
|
645
|
+
use_sender: If True, summary is categorized by sender. Otherwise, by role.
|
646
646
|
|
647
647
|
Returns:
|
648
|
-
|
648
|
+
A dictionary summarizing the count of messages either by role or sender.
|
649
649
|
"""
|
650
650
|
|
651
651
|
messages = self.messages["sender"] if use_sender else self.messages["role"]
|
lionagi/core/branch/branch.py
CHANGED
@@ -2,27 +2,22 @@ from collections import deque
|
|
2
2
|
from typing import Any, Union, TypeVar, Callable
|
3
3
|
|
4
4
|
from lionagi.libs.sys_util import PATH_TYPE
|
5
|
-
from lionagi.libs
|
6
|
-
from lionagi.libs import ln_convert as convert
|
7
|
-
from lionagi.libs import ln_dataframe as dataframe
|
5
|
+
from lionagi.libs import StatusTracker, BaseService, convert, dataframe
|
8
6
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from lionagi.core.tool.tool_manager import ToolManager, func_to_tool
|
7
|
+
from ..schema import TOOL_TYPE, Tool, DataLogger
|
8
|
+
from ..tool import ToolManager, func_to_tool
|
12
9
|
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from lionagi.core.mail.schema import BaseMail
|
16
|
-
|
17
|
-
from lionagi.core.branch.util import MessageUtil
|
10
|
+
from ..messages import System
|
11
|
+
from ..mail import BaseMail
|
18
12
|
|
13
|
+
from .util import MessageUtil
|
14
|
+
from .base_branch import BaseBranch
|
19
15
|
from .branch_flow_mixin import BranchFlowMixin
|
20
16
|
|
21
17
|
from dotenv import load_dotenv
|
22
18
|
|
23
19
|
load_dotenv()
|
24
20
|
|
25
|
-
|
26
21
|
T = TypeVar("T", bound=Tool)
|
27
22
|
|
28
23
|
|
@@ -38,8 +33,7 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
38
33
|
llmconfig: dict[str, str | int | dict] | None = None,
|
39
34
|
tools: list[Callable | Tool] | None = None,
|
40
35
|
datalogger: None | DataLogger = None,
|
41
|
-
persist_path: PATH_TYPE | None = None,
|
42
|
-
# instruction_sets=None,
|
36
|
+
persist_path: PATH_TYPE | None = None, # instruction_sets=None,
|
43
37
|
tool_manager: ToolManager | None = None,
|
44
38
|
**kwargs,
|
45
39
|
):
|
@@ -57,7 +51,7 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
57
51
|
self.sender = sender or "system"
|
58
52
|
|
59
53
|
# add tool manager and register tools
|
60
|
-
self.tool_manager = tool_manager
|
54
|
+
self.tool_manager = tool_manager or ToolManager()
|
61
55
|
if tools:
|
62
56
|
try:
|
63
57
|
tools_ = []
|
@@ -70,7 +64,7 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
70
64
|
|
71
65
|
self.register_tools(tools_)
|
72
66
|
except Exception as e:
|
73
|
-
raise TypeError(f"Error in registering tools: {e}")
|
67
|
+
raise TypeError(f"Error in registering tools: {e}") from e
|
74
68
|
|
75
69
|
# add service and llmconfig
|
76
70
|
self.service, self.llmconfig = self._add_service(service, llmconfig)
|
@@ -96,14 +90,13 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
96
90
|
llmconfig: dict[str, str | int | dict] | None = None,
|
97
91
|
tools: TOOL_TYPE | None = None,
|
98
92
|
datalogger: None | DataLogger = None,
|
99
|
-
persist_path: PATH_TYPE | None = None,
|
100
|
-
# instruction_sets=None,
|
93
|
+
persist_path: PATH_TYPE | None = None, # instruction_sets=None,
|
101
94
|
tool_manager: ToolManager | None = None,
|
102
95
|
read_kwargs=None,
|
103
96
|
**kwargs,
|
104
97
|
):
|
105
98
|
|
106
|
-
|
99
|
+
return cls._from_csv(
|
107
100
|
filepath=filepath,
|
108
101
|
read_kwargs=read_kwargs,
|
109
102
|
name=name,
|
@@ -117,8 +110,6 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
117
110
|
**kwargs,
|
118
111
|
)
|
119
112
|
|
120
|
-
return self
|
121
|
-
|
122
113
|
@classmethod
|
123
114
|
def from_json_string(
|
124
115
|
cls,
|
@@ -128,14 +119,13 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
128
119
|
llmconfig: dict[str, str | int | dict] | None = None,
|
129
120
|
tools: TOOL_TYPE | None = None,
|
130
121
|
datalogger: None | DataLogger = None,
|
131
|
-
persist_path: PATH_TYPE | None = None,
|
132
|
-
# instruction_sets=None,
|
122
|
+
persist_path: PATH_TYPE | None = None, # instruction_sets=None,
|
133
123
|
tool_manager: ToolManager | None = None,
|
134
124
|
read_kwargs=None,
|
135
125
|
**kwargs,
|
136
126
|
):
|
137
127
|
|
138
|
-
|
128
|
+
return cls._from_json(
|
139
129
|
filepath=filepath,
|
140
130
|
read_kwargs=read_kwargs,
|
141
131
|
name=name,
|
@@ -149,8 +139,6 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
149
139
|
**kwargs,
|
150
140
|
)
|
151
141
|
|
152
|
-
return self
|
153
|
-
|
154
142
|
def messages_describe(self) -> dict[str, Any]:
|
155
143
|
|
156
144
|
return dict(
|
@@ -301,7 +289,7 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
301
289
|
Check if the conversation has been invoked with an action response.
|
302
290
|
|
303
291
|
Returns:
|
304
|
-
|
292
|
+
bool: True if the conversation has been invoked, False otherwise.
|
305
293
|
|
306
294
|
"""
|
307
295
|
content = self.messages.iloc[-1]["content"]
|
@@ -312,5 +300,5 @@ class Branch(BaseBranch, BranchFlowMixin):
|
|
312
300
|
"output",
|
313
301
|
}:
|
314
302
|
return True
|
315
|
-
except:
|
303
|
+
except Exception:
|
316
304
|
return False
|
@@ -1,12 +1,9 @@
|
|
1
1
|
from abc import ABC
|
2
2
|
from typing import Any, Optional, Union, TypeVar
|
3
3
|
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from lionagi.core.flow.monoflow.ReAct import MonoReAct
|
8
|
-
|
9
|
-
from lionagi.core.messages.schema import Instruction, System
|
4
|
+
from ..schema import TOOL_TYPE, Tool
|
5
|
+
from ..messages import Instruction, System
|
6
|
+
from ..flow.monoflow import MonoChat, MonoFollowup, MonoReAct
|
10
7
|
|
11
8
|
T = TypeVar("T", bound=Tool)
|
12
9
|
|
@@ -25,7 +22,6 @@ class BranchFlowMixin(ABC):
|
|
25
22
|
output_fields=None,
|
26
23
|
**kwargs,
|
27
24
|
) -> Any:
|
28
|
-
|
29
25
|
flow = MonoChat(self)
|
30
26
|
return await flow.chat(
|
31
27
|
instruction=instruction,
|