lionagi 0.0.305__py3-none-any.whl → 0.0.307__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +2 -5
- lionagi/core/__init__.py +7 -4
- lionagi/core/agent/__init__.py +3 -0
- lionagi/core/agent/base_agent.py +46 -0
- lionagi/core/branch/__init__.py +4 -0
- lionagi/core/branch/base/__init__.py +0 -0
- lionagi/core/branch/base_branch.py +100 -78
- lionagi/core/branch/branch.py +22 -34
- lionagi/core/branch/branch_flow_mixin.py +3 -7
- lionagi/core/branch/executable_branch.py +192 -0
- lionagi/core/branch/util.py +77 -162
- lionagi/core/direct/__init__.py +13 -0
- lionagi/core/direct/parallel_predict.py +127 -0
- lionagi/core/direct/parallel_react.py +0 -0
- lionagi/core/direct/parallel_score.py +0 -0
- lionagi/core/direct/parallel_select.py +0 -0
- lionagi/core/direct/parallel_sentiment.py +0 -0
- lionagi/core/direct/predict.py +174 -0
- lionagi/core/direct/react.py +33 -0
- lionagi/core/direct/score.py +163 -0
- lionagi/core/direct/select.py +144 -0
- lionagi/core/direct/sentiment.py +51 -0
- lionagi/core/direct/utils.py +83 -0
- lionagi/core/flow/__init__.py +0 -3
- lionagi/core/flow/monoflow/{mono_react.py → ReAct.py} +52 -9
- lionagi/core/flow/monoflow/__init__.py +9 -0
- lionagi/core/flow/monoflow/{mono_chat.py → chat.py} +11 -11
- lionagi/core/flow/monoflow/{mono_chat_mixin.py → chat_mixin.py} +33 -27
- lionagi/core/flow/monoflow/{mono_followup.py → followup.py} +7 -6
- lionagi/core/flow/polyflow/__init__.py +1 -0
- lionagi/core/flow/polyflow/{polychat.py → chat.py} +15 -3
- lionagi/core/mail/__init__.py +8 -0
- lionagi/core/mail/mail_manager.py +88 -40
- lionagi/core/mail/schema.py +32 -6
- lionagi/core/messages/__init__.py +3 -0
- lionagi/core/messages/schema.py +56 -25
- lionagi/core/prompt/__init__.py +0 -0
- lionagi/core/prompt/prompt_template.py +0 -0
- lionagi/core/schema/__init__.py +7 -5
- lionagi/core/schema/action_node.py +29 -0
- lionagi/core/schema/base_mixin.py +56 -59
- lionagi/core/schema/base_node.py +35 -38
- lionagi/core/schema/condition.py +24 -0
- lionagi/core/schema/data_logger.py +98 -98
- lionagi/core/schema/data_node.py +19 -19
- lionagi/core/schema/prompt_template.py +0 -0
- lionagi/core/schema/structure.py +293 -190
- lionagi/core/session/__init__.py +1 -3
- lionagi/core/session/session.py +196 -214
- lionagi/core/tool/tool_manager.py +95 -103
- lionagi/integrations/__init__.py +1 -3
- lionagi/integrations/bridge/langchain_/documents.py +17 -18
- lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
- lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
- lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
- lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
- lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
- lionagi/integrations/config/openrouter_configs.py +0 -1
- lionagi/integrations/provider/oai.py +26 -26
- lionagi/integrations/provider/services.py +38 -38
- lionagi/libs/__init__.py +34 -1
- lionagi/libs/ln_api.py +211 -221
- lionagi/libs/ln_async.py +53 -60
- lionagi/libs/ln_convert.py +118 -120
- lionagi/libs/ln_dataframe.py +32 -33
- lionagi/libs/ln_func_call.py +334 -342
- lionagi/libs/ln_nested.py +99 -107
- lionagi/libs/ln_parse.py +175 -158
- lionagi/libs/sys_util.py +52 -52
- lionagi/tests/test_core/test_base_branch.py +427 -427
- lionagi/tests/test_core/test_branch.py +292 -292
- lionagi/tests/test_core/test_mail_manager.py +57 -57
- lionagi/tests/test_core/test_session.py +254 -266
- lionagi/tests/test_core/test_session_base_util.py +299 -300
- lionagi/tests/test_core/test_tool_manager.py +70 -74
- lionagi/tests/test_libs/test_nested.py +2 -7
- lionagi/tests/test_libs/test_parse.py +2 -2
- lionagi/version.py +1 -1
- {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
- lionagi-0.0.307.dist-info/RECORD +115 -0
- lionagi-0.0.305.dist-info/RECORD +0 -94
- {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
- {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
- {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
@@ -1,427 +1,427 @@
|
|
1
|
-
from lionagi.core.branch.base_branch import BaseBranch
|
2
|
-
from lionagi.core.branch.util import MessageUtil
|
3
|
-
from lionagi.core.messages.schema import System
|
4
|
-
from lionagi.core.schema import DataLogger
|
5
|
-
|
6
|
-
|
7
|
-
import unittest
|
8
|
-
from unittest.mock import patch, MagicMock
|
9
|
-
import pandas as pd
|
10
|
-
from datetime import datetime
|
11
|
-
import json
|
12
|
-
|
13
|
-
|
14
|
-
class TestBaseBranch(unittest.TestCase):
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
if __name__ == "__main__":
|
427
|
-
|
1
|
+
# from lionagi.core.branch.base_branch import BaseBranch
|
2
|
+
# from lionagi.core.branch.util import MessageUtil
|
3
|
+
# from lionagi.core.messages.schema import System
|
4
|
+
# from lionagi.core.schema import DataLogger
|
5
|
+
|
6
|
+
|
7
|
+
# import unittest
|
8
|
+
# from unittest.mock import patch, MagicMock
|
9
|
+
# import pandas as pd
|
10
|
+
# from datetime import datetime
|
11
|
+
# import json
|
12
|
+
|
13
|
+
|
14
|
+
# class TestBaseBranch(unittest.TestCase):
|
15
|
+
|
16
|
+
# def setUp(self):
|
17
|
+
# # Patching DataLogger to avoid filesystem interactions during tests
|
18
|
+
# self.patcher = patch(
|
19
|
+
# "lionagi.core.branch.base_branch.DataLogger", autospec=True
|
20
|
+
# )
|
21
|
+
# self.MockDataLogger = self.patcher.start()
|
22
|
+
# self.addCleanup(self.patcher.stop)
|
23
|
+
|
24
|
+
# self.branch = BaseBranch()
|
25
|
+
# # Example messages to populate the DataFrame for testing
|
26
|
+
# self.test_messages = [
|
27
|
+
# {
|
28
|
+
# "node_id": "1",
|
29
|
+
# "timestamp": "2021-01-01 00:00:00",
|
30
|
+
# "role": "system",
|
31
|
+
# "sender": "system",
|
32
|
+
# "content": json.dumps({"system_info": "System message"}),
|
33
|
+
# },
|
34
|
+
# {
|
35
|
+
# "node_id": "2",
|
36
|
+
# "timestamp": "2021-01-01 00:01:00",
|
37
|
+
# "role": "user",
|
38
|
+
# "sender": "user1",
|
39
|
+
# "content": json.dumps({"instruction": "User message"}),
|
40
|
+
# },
|
41
|
+
# {
|
42
|
+
# "node_id": "3",
|
43
|
+
# "timestamp": "2021-01-01 00:02:00",
|
44
|
+
# "role": "assistant",
|
45
|
+
# "sender": "assistant",
|
46
|
+
# "content": json.dumps({"response": "Assistant response"}),
|
47
|
+
# },
|
48
|
+
# {
|
49
|
+
# "node_id": "4",
|
50
|
+
# "timestamp": "2021-01-01 00:03:00",
|
51
|
+
# "role": "assistant",
|
52
|
+
# "sender": "action_request",
|
53
|
+
# "content": json.dumps({"action_request": "Action request"}),
|
54
|
+
# },
|
55
|
+
# {
|
56
|
+
# "node_id": "5",
|
57
|
+
# "timestamp": "2021-01-01 00:04:00",
|
58
|
+
# "role": "assistant",
|
59
|
+
# "sender": "action_response",
|
60
|
+
# "content": json.dumps({"action_response": "Action response"}),
|
61
|
+
# },
|
62
|
+
# ]
|
63
|
+
# self.branch = BaseBranch(messages=pd.DataFrame(self.test_messages))
|
64
|
+
|
65
|
+
# def test_init_with_empty_messages(self):
|
66
|
+
# """Test __init__ method with no messages provided."""
|
67
|
+
# branch = BaseBranch()
|
68
|
+
# self.assertTrue(branch.messages.empty)
|
69
|
+
|
70
|
+
# def test_init_with_given_messages(self):
|
71
|
+
# """Test __init__ method with provided messages DataFrame."""
|
72
|
+
# messages = pd.DataFrame(
|
73
|
+
# [
|
74
|
+
# [
|
75
|
+
# "0",
|
76
|
+
# datetime(2024, 1, 1),
|
77
|
+
# "system",
|
78
|
+
# "system",
|
79
|
+
# json.dumps({"system_info": "Hi"}),
|
80
|
+
# ]
|
81
|
+
# ],
|
82
|
+
# columns=["node_id", "timestamp", "role", "sender", "content"],
|
83
|
+
# )
|
84
|
+
# branch = BaseBranch(messages=messages)
|
85
|
+
# self.assertFalse(branch.messages.empty)
|
86
|
+
|
87
|
+
# def test_add_message(self):
|
88
|
+
# """Test adding a message."""
|
89
|
+
# message_info = {"info": "Test Message"}
|
90
|
+
# with patch.object(
|
91
|
+
# MessageUtil, "create_message", return_value=System(system=message_info)
|
92
|
+
# ) as mocked_create_message:
|
93
|
+
# self.branch.add_message(system=message_info)
|
94
|
+
# mocked_create_message.assert_called_once()
|
95
|
+
# self.assertEqual(len(self.branch.messages), 6)
|
96
|
+
|
97
|
+
# def test_chat_messages_without_sender(self):
|
98
|
+
# """Test chat_messages property without including sender information."""
|
99
|
+
# chat_messages = self.branch.chat_messages
|
100
|
+
# self.assertEqual(len(chat_messages), 5)
|
101
|
+
# self.assertNotIn("Sender", chat_messages[0]["content"])
|
102
|
+
|
103
|
+
# def test_chat_messages_with_sender(self):
|
104
|
+
# """Test retrieving chat messages with sender information included."""
|
105
|
+
# chat_messages_with_sender = self.branch.chat_messages_with_sender
|
106
|
+
# expected_prefixes = [
|
107
|
+
# "Sender system: ",
|
108
|
+
# "Sender user1: ",
|
109
|
+
# "Sender assistant: ",
|
110
|
+
# "Sender action_request: ",
|
111
|
+
# "Sender action_response: ",
|
112
|
+
# ]
|
113
|
+
|
114
|
+
# # Verify the number of messages returned matches the number added
|
115
|
+
# self.assertEqual(len(chat_messages_with_sender), len(self.test_messages))
|
116
|
+
|
117
|
+
# # Check each message for correct sender prefix in content
|
118
|
+
# for i, message_dict in enumerate(chat_messages_with_sender):
|
119
|
+
# self.assertTrue(
|
120
|
+
# message_dict["content"].startswith(expected_prefixes[i]),
|
121
|
+
# msg=f"Message content does not start with expected sender prefix. Found: {message_dict['content']}",
|
122
|
+
# )
|
123
|
+
|
124
|
+
# # Optionally, verify the content matches after removing the prefix
|
125
|
+
# for i, message_dict in enumerate(chat_messages_with_sender):
|
126
|
+
# prefix, content = message_dict["content"].split(": ", 1)
|
127
|
+
# self.assertEqual(
|
128
|
+
# content,
|
129
|
+
# self.test_messages[i]["content"],
|
130
|
+
# msg=f"Message content does not match after removing sender prefix. Found: {content}, Expected: {self.test_messages[i]['content']}",
|
131
|
+
# )
|
132
|
+
|
133
|
+
# def test_last_message(self):
|
134
|
+
# """Test retrieving the last message."""
|
135
|
+
# last_message = self.branch.last_message
|
136
|
+
# self.assertEqual(
|
137
|
+
# json.loads(last_message.iloc[0]["content"]),
|
138
|
+
# {"action_response": "Action response"},
|
139
|
+
# )
|
140
|
+
|
141
|
+
# def test_last_message_content(self):
|
142
|
+
# """Test retrieving the content of the last message."""
|
143
|
+
# content = self.branch.last_message_content
|
144
|
+
# self.assertEqual(content, {"action_response": "Action response"})
|
145
|
+
|
146
|
+
# def test_first_system_message(self):
|
147
|
+
# """Test retrieving the first 'system' message."""
|
148
|
+
# first_system_message = self.branch.first_system
|
149
|
+
# self.assertEqual(
|
150
|
+
# json.loads(first_system_message.iloc[0]["content"]),
|
151
|
+
# {"system_info": "System message"},
|
152
|
+
# )
|
153
|
+
|
154
|
+
# def test_last_response(self):
|
155
|
+
# """Test retrieving the last 'assistant' message."""
|
156
|
+
# last_response = self.branch.last_response
|
157
|
+
# self.assertEqual(last_response.iloc[0]["sender"], "action_response")
|
158
|
+
|
159
|
+
# def test_last_response_content(self):
|
160
|
+
# """Test extracting content of the last 'assistant' message."""
|
161
|
+
# content = self.branch.last_response_content
|
162
|
+
# self.assertEqual(content, {"action_response": "Action response"})
|
163
|
+
|
164
|
+
# def test_action_request_messages(self):
|
165
|
+
# """Test filtering messages sent by 'action_request'."""
|
166
|
+
# action_requests = self.branch.action_request
|
167
|
+
# self.assertTrue(all(action_requests["sender"] == "action_request"))
|
168
|
+
|
169
|
+
# def test_action_response_messages(self):
|
170
|
+
# """Test filtering messages sent by 'action_response'."""
|
171
|
+
# action_responses = self.branch.action_response
|
172
|
+
# self.assertTrue(all(action_responses["sender"] == "action_response"))
|
173
|
+
|
174
|
+
# def test_responses(self):
|
175
|
+
# """Test filtering of 'assistant' role messages."""
|
176
|
+
# responses = self.branch.responses
|
177
|
+
# # Verify that all returned messages have the 'assistant' role
|
178
|
+
# self.assertTrue(all(responses["role"] == "assistant"))
|
179
|
+
# # Optionally, check the count matches the expected number of 'assistant' messages
|
180
|
+
# expected_count = sum(
|
181
|
+
# 1 for msg in self.test_messages if msg["role"] == "assistant"
|
182
|
+
# )
|
183
|
+
# self.assertEqual(len(responses), expected_count)
|
184
|
+
|
185
|
+
# def test_assistant_responses(self):
|
186
|
+
# """Test filtering of 'assistant' messages excluding action requests/responses."""
|
187
|
+
# assistant_responses = self.branch.assistant_responses
|
188
|
+
# # Verify that no returned messages are from 'action_request' or 'action_response' senders
|
189
|
+
# self.assertTrue(all(assistant_responses["sender"] != "action_request"))
|
190
|
+
# self.assertTrue(all(assistant_responses["sender"] != "action_response"))
|
191
|
+
# # Verify that all returned messages have the 'assistant' role
|
192
|
+
# self.assertTrue(all(assistant_responses["role"] == "assistant"))
|
193
|
+
|
194
|
+
# def test_info(self):
|
195
|
+
# """Test summarization of message counts by role."""
|
196
|
+
# info = self.branch.info
|
197
|
+
# # Verify that the dictionary contains keys for each role and 'total'
|
198
|
+
# self.assertIn("assistant", info)
|
199
|
+
# self.assertIn("user", info)
|
200
|
+
# self.assertIn("system", info)
|
201
|
+
# self.assertIn("total", info)
|
202
|
+
# # Optionally, verify the counts match expected values
|
203
|
+
# for role in ["assistant", "user", "system"]:
|
204
|
+
# expected_count = sum(1 for msg in self.test_messages if msg["role"] == role)
|
205
|
+
# self.assertEqual(info[role], expected_count)
|
206
|
+
# self.assertEqual(info["total"], len(self.test_messages))
|
207
|
+
|
208
|
+
# def test_sender_info(self):
|
209
|
+
# """Test summarization of message counts by sender."""
|
210
|
+
# sender_info = self.branch.sender_info
|
211
|
+
# # Verify that the dictionary contains keys for each sender and the counts match
|
212
|
+
# for sender in set(msg["sender"] for msg in self.test_messages):
|
213
|
+
# expected_count = sum(
|
214
|
+
# 1 for msg in self.test_messages if msg["sender"] == sender
|
215
|
+
# )
|
216
|
+
# self.assertEqual(sender_info.get(sender, 0), expected_count)
|
217
|
+
|
218
|
+
# def test_describe(self):
|
219
|
+
# """Test detailed description of the branch."""
|
220
|
+
# description = self.branch.describe
|
221
|
+
# # Verify that the description contains expected keys
|
222
|
+
# self.assertIn("total_messages", description)
|
223
|
+
# self.assertIn("summary_by_role", description)
|
224
|
+
# self.assertIn("messages", description)
|
225
|
+
# # Optionally, verify the accuracy of the values
|
226
|
+
# self.assertEqual(description["total_messages"], len(self.test_messages))
|
227
|
+
# self.assertEqual(len(description["messages"]), min(5, len(self.test_messages)))
|
228
|
+
|
229
|
+
# @patch("lionagi.libs.ln_dataframe.read_csv")
|
230
|
+
# def test_from_csv(cls, mock_read_csv):
|
231
|
+
# # Define a mock return value for read_csv
|
232
|
+
# mock_messages_df = pd.DataFrame(
|
233
|
+
# {
|
234
|
+
# "node_id": ["1", "2"],
|
235
|
+
# "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
|
236
|
+
# "role": ["system", "user"],
|
237
|
+
# "sender": ["system", "user1"],
|
238
|
+
# "content": [
|
239
|
+
# json.dumps({"system_info": "System message"}),
|
240
|
+
# json.dumps({"instruction": "User message"}),
|
241
|
+
# ],
|
242
|
+
# }
|
243
|
+
# )
|
244
|
+
# mock_read_csv.return_value = mock_messages_df
|
245
|
+
|
246
|
+
# # Call the from_csv class method
|
247
|
+
# branch = BaseBranch.from_csv(filename="dummy.csv")
|
248
|
+
|
249
|
+
# # Verify that read_csv was called correctly
|
250
|
+
# mock_read_csv.assert_called_once_with("dummy.csv")
|
251
|
+
|
252
|
+
# # Verify that the branch instance contains the correct messages
|
253
|
+
# pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
|
254
|
+
|
255
|
+
# @patch("lionagi.libs.ln_dataframe.read_json")
|
256
|
+
# def test_from_json(cls, mock_read_csv):
|
257
|
+
# # Define a mock return value for read_csv
|
258
|
+
# mock_messages_df = pd.DataFrame(
|
259
|
+
# {
|
260
|
+
# "node_id": ["1", "2"],
|
261
|
+
# "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
|
262
|
+
# "role": ["system", "user"],
|
263
|
+
# "sender": ["system", "user1"],
|
264
|
+
# "content": [
|
265
|
+
# json.dumps({"system_info": "System message"}),
|
266
|
+
# json.dumps({"instruction": "User message"}),
|
267
|
+
# ],
|
268
|
+
# }
|
269
|
+
# )
|
270
|
+
# mock_read_csv.return_value = mock_messages_df
|
271
|
+
|
272
|
+
# # Call the from_csv class method
|
273
|
+
# branch = BaseBranch.from_json_string(filename="dummy.json")
|
274
|
+
|
275
|
+
# # Verify that read_csv was called correctly
|
276
|
+
# mock_read_csv.assert_called_once_with("dummy.json")
|
277
|
+
|
278
|
+
# # Verify that the branch instance contains the correct messages
|
279
|
+
# pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
|
280
|
+
|
281
|
+
# @patch(
|
282
|
+
# "lionagi.libs.sys_util.SysUtil.create_path",
|
283
|
+
# return_value="path/to/messages.csv",
|
284
|
+
# )
|
285
|
+
# @patch.object(pd.DataFrame, "to_csv")
|
286
|
+
# def test_to_csv(self, mock_to_csv, mock_create_path):
|
287
|
+
# self.branch.datalogger = MagicMock()
|
288
|
+
# self.branch.datalogger.persist_path = "data/logs/"
|
289
|
+
|
290
|
+
# self.branch.to_csv_file("messages.csv", verbose=False, clear=False)
|
291
|
+
|
292
|
+
# mock_create_path.assert_called_once_with(
|
293
|
+
# self.branch.datalogger.persist_path,
|
294
|
+
# "messages.csv",
|
295
|
+
# timestamp=True,
|
296
|
+
# dir_exist_ok=True,
|
297
|
+
# time_prefix=False,
|
298
|
+
# )
|
299
|
+
# mock_to_csv.assert_called_once_with("path/to/messages.csv")
|
300
|
+
|
301
|
+
# # Verify that messages are not cleared after exporting
|
302
|
+
# assert not self.branch.messages.empty
|
303
|
+
|
304
|
+
# @patch(
|
305
|
+
# "lionagi.libs.sys_util.SysUtil.create_path",
|
306
|
+
# return_value="path/to/messages.json",
|
307
|
+
# )
|
308
|
+
# @patch.object(pd.DataFrame, "to_json")
|
309
|
+
# def test_to_json(self, mock_to_json, mock_create_path):
|
310
|
+
# self.branch.datalogger = MagicMock()
|
311
|
+
# self.branch.datalogger.persist_path = "data/logs/"
|
312
|
+
|
313
|
+
# self.branch.to_json_file("messages.json", verbose=False, clear=False)
|
314
|
+
|
315
|
+
# mock_create_path.assert_called_once_with(
|
316
|
+
# self.branch.datalogger.persist_path,
|
317
|
+
# "messages.json",
|
318
|
+
# timestamp=True,
|
319
|
+
# dir_exist_ok=True,
|
320
|
+
# time_prefix=False,
|
321
|
+
# )
|
322
|
+
# mock_to_json.assert_called_once_with(
|
323
|
+
# "path/to/messages.json", orient="records", lines=True, date_format="iso"
|
324
|
+
# )
|
325
|
+
|
326
|
+
# # Verify that messages are not cleared after exporting
|
327
|
+
# assert not self.branch.messages.empty
|
328
|
+
|
329
|
+
# # def test_log_to_csv(self):
|
330
|
+
# # self.branch.log_to_csv('log.csv', verbose=False, clear=False)
|
331
|
+
|
332
|
+
# # self.branch.datalogger.to_csv_file.assert_called_once_with(filename='log.csv', dir_exist_ok=True, timestamp=True,
|
333
|
+
# # time_prefix=False, verbose=False, clear=False)
|
334
|
+
|
335
|
+
# # def test_log_to_json(self):
|
336
|
+
# # branch = BaseBranch()
|
337
|
+
# # branch.log_to_json('log.json', verbose=False, clear=False)
|
338
|
+
|
339
|
+
# # self.branch.datalogger.to_json_file.assert_called_once_with(filename='log.json', dir_exist_ok=True, timestamp=True,
|
340
|
+
# # time_prefix=False, verbose=False, clear=False)
|
341
|
+
|
342
|
+
# def test_remove_message(self):
|
343
|
+
# """Test removing a message from the branch based on its node ID."""
|
344
|
+
# initial_length = len(self.branch.messages)
|
345
|
+
# node_id_to_remove = "2"
|
346
|
+
# self.branch.remove_message(node_id_to_remove)
|
347
|
+
# final_length = len(self.branch.messages)
|
348
|
+
# self.assertNotIn(node_id_to_remove, self.branch.messages["node_id"].tolist())
|
349
|
+
# self.assertEqual(final_length, initial_length - 1)
|
350
|
+
|
351
|
+
# def test_update_message(self):
|
352
|
+
# """Test updating a specific message's column identified by node_id with a new value."""
|
353
|
+
# node_id_to_update = "2"
|
354
|
+
# new_value = "Updated content"
|
355
|
+
# self.branch.update_message(node_id_to_update, "content", new_value)
|
356
|
+
# updated_value = self.branch.messages.loc[
|
357
|
+
# self.branch.messages["node_id"] == node_id_to_update, "content"
|
358
|
+
# ].values[0]
|
359
|
+
# self.assertEqual(updated_value, new_value)
|
360
|
+
|
361
|
+
# # def test_change_first_system_message(self):
|
362
|
+
# # """Test updating the first system message with new content and/or sender."""
|
363
|
+
# # new_system_content = {"system_info": "Updated system message"}
|
364
|
+
# # self.branch.change_first_system_message(new_system_content['system_info'])
|
365
|
+
# # first_system_message_content = self.branch.messages.loc[self.branch.messages['role'] == 'system', 'content'].iloc[0]
|
366
|
+
# # self.assertIn(json.dumps({"system_info": "Updated system message"}), first_system_message_content)
|
367
|
+
|
368
|
+
# def test_rollback(self):
|
369
|
+
# """Test removing the last 'n' messages from the branch."""
|
370
|
+
# steps_to_remove = 2
|
371
|
+
# initial_length = len(self.branch.messages)
|
372
|
+
# self.branch.rollback(steps_to_remove)
|
373
|
+
# final_length = len(self.branch.messages)
|
374
|
+
# self.assertEqual(final_length, initial_length - steps_to_remove)
|
375
|
+
|
376
|
+
# def test_clear_messages(self):
|
377
|
+
# """Test clearing all messages from the branch."""
|
378
|
+
# self.branch.clear_messages()
|
379
|
+
# self.assertTrue(self.branch.messages.empty)
|
380
|
+
|
381
|
+
# def test_replace_keyword(self):
|
382
|
+
# """Test replacing occurrences of a specified keyword with a replacement string."""
|
383
|
+
# keyword = "Assistant response"
|
384
|
+
# replacement = "Helper feedback"
|
385
|
+
# self.branch.replace_keyword(keyword, replacement)
|
386
|
+
# self.assertTrue(
|
387
|
+
# any(replacement in message for message in self.branch.messages["content"])
|
388
|
+
# )
|
389
|
+
|
390
|
+
# def test_search_keywords(self):
|
391
|
+
# """Test filtering messages by a specified keyword or list of keywords."""
|
392
|
+
# keyword_to_search = "Assistant"
|
393
|
+
# filtered_messages = self.branch.search_keywords(keyword_to_search)
|
394
|
+
# self.assertTrue(
|
395
|
+
# all(
|
396
|
+
# keyword_to_search in content for content in filtered_messages["content"]
|
397
|
+
# )
|
398
|
+
# )
|
399
|
+
|
400
|
+
# def test_extend(self):
|
401
|
+
# """Test extending branch messages with additional messages."""
|
402
|
+
# additional_messages = pd.DataFrame(
|
403
|
+
# [
|
404
|
+
# {
|
405
|
+
# "node_id": "6",
|
406
|
+
# "timestamp": datetime(2021, 1, 1),
|
407
|
+
# "role": "user",
|
408
|
+
# "sender": "user2",
|
409
|
+
# "content": json.dumps({"instruction": "Another user message"}),
|
410
|
+
# }
|
411
|
+
# ]
|
412
|
+
# )
|
413
|
+
# initial_length = len(self.branch.messages)
|
414
|
+
# self.branch.extend(additional_messages)
|
415
|
+
# self.assertEqual(len(self.branch.messages), initial_length + 1)
|
416
|
+
|
417
|
+
# def test_filter_by(self):
|
418
|
+
# """Test filtering messages by various criteria."""
|
419
|
+
# # Filtering by role
|
420
|
+
# filtered_messages = self.branch.filter_by(role="user")
|
421
|
+
# self.assertTrue(
|
422
|
+
# all(msg["role"] == "user" for _, msg in filtered_messages.iterrows())
|
423
|
+
# )
|
424
|
+
|
425
|
+
|
426
|
+
# if __name__ == "__main__":
|
427
|
+
# unittest.main()
|