lionagi 0.0.305__py3-none-any.whl → 0.0.307__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. lionagi/__init__.py +2 -5
  2. lionagi/core/__init__.py +7 -4
  3. lionagi/core/agent/__init__.py +3 -0
  4. lionagi/core/agent/base_agent.py +46 -0
  5. lionagi/core/branch/__init__.py +4 -0
  6. lionagi/core/branch/base/__init__.py +0 -0
  7. lionagi/core/branch/base_branch.py +100 -78
  8. lionagi/core/branch/branch.py +22 -34
  9. lionagi/core/branch/branch_flow_mixin.py +3 -7
  10. lionagi/core/branch/executable_branch.py +192 -0
  11. lionagi/core/branch/util.py +77 -162
  12. lionagi/core/direct/__init__.py +13 -0
  13. lionagi/core/direct/parallel_predict.py +127 -0
  14. lionagi/core/direct/parallel_react.py +0 -0
  15. lionagi/core/direct/parallel_score.py +0 -0
  16. lionagi/core/direct/parallel_select.py +0 -0
  17. lionagi/core/direct/parallel_sentiment.py +0 -0
  18. lionagi/core/direct/predict.py +174 -0
  19. lionagi/core/direct/react.py +33 -0
  20. lionagi/core/direct/score.py +163 -0
  21. lionagi/core/direct/select.py +144 -0
  22. lionagi/core/direct/sentiment.py +51 -0
  23. lionagi/core/direct/utils.py +83 -0
  24. lionagi/core/flow/__init__.py +0 -3
  25. lionagi/core/flow/monoflow/{mono_react.py → ReAct.py} +52 -9
  26. lionagi/core/flow/monoflow/__init__.py +9 -0
  27. lionagi/core/flow/monoflow/{mono_chat.py → chat.py} +11 -11
  28. lionagi/core/flow/monoflow/{mono_chat_mixin.py → chat_mixin.py} +33 -27
  29. lionagi/core/flow/monoflow/{mono_followup.py → followup.py} +7 -6
  30. lionagi/core/flow/polyflow/__init__.py +1 -0
  31. lionagi/core/flow/polyflow/{polychat.py → chat.py} +15 -3
  32. lionagi/core/mail/__init__.py +8 -0
  33. lionagi/core/mail/mail_manager.py +88 -40
  34. lionagi/core/mail/schema.py +32 -6
  35. lionagi/core/messages/__init__.py +3 -0
  36. lionagi/core/messages/schema.py +56 -25
  37. lionagi/core/prompt/__init__.py +0 -0
  38. lionagi/core/prompt/prompt_template.py +0 -0
  39. lionagi/core/schema/__init__.py +7 -5
  40. lionagi/core/schema/action_node.py +29 -0
  41. lionagi/core/schema/base_mixin.py +56 -59
  42. lionagi/core/schema/base_node.py +35 -38
  43. lionagi/core/schema/condition.py +24 -0
  44. lionagi/core/schema/data_logger.py +98 -98
  45. lionagi/core/schema/data_node.py +19 -19
  46. lionagi/core/schema/prompt_template.py +0 -0
  47. lionagi/core/schema/structure.py +293 -190
  48. lionagi/core/session/__init__.py +1 -3
  49. lionagi/core/session/session.py +196 -214
  50. lionagi/core/tool/tool_manager.py +95 -103
  51. lionagi/integrations/__init__.py +1 -3
  52. lionagi/integrations/bridge/langchain_/documents.py +17 -18
  53. lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
  54. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
  55. lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
  56. lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
  57. lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
  58. lionagi/integrations/config/openrouter_configs.py +0 -1
  59. lionagi/integrations/provider/oai.py +26 -26
  60. lionagi/integrations/provider/services.py +38 -38
  61. lionagi/libs/__init__.py +34 -1
  62. lionagi/libs/ln_api.py +211 -221
  63. lionagi/libs/ln_async.py +53 -60
  64. lionagi/libs/ln_convert.py +118 -120
  65. lionagi/libs/ln_dataframe.py +32 -33
  66. lionagi/libs/ln_func_call.py +334 -342
  67. lionagi/libs/ln_nested.py +99 -107
  68. lionagi/libs/ln_parse.py +175 -158
  69. lionagi/libs/sys_util.py +52 -52
  70. lionagi/tests/test_core/test_base_branch.py +427 -427
  71. lionagi/tests/test_core/test_branch.py +292 -292
  72. lionagi/tests/test_core/test_mail_manager.py +57 -57
  73. lionagi/tests/test_core/test_session.py +254 -266
  74. lionagi/tests/test_core/test_session_base_util.py +299 -300
  75. lionagi/tests/test_core/test_tool_manager.py +70 -74
  76. lionagi/tests/test_libs/test_nested.py +2 -7
  77. lionagi/tests/test_libs/test_parse.py +2 -2
  78. lionagi/version.py +1 -1
  79. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
  80. lionagi-0.0.307.dist-info/RECORD +115 -0
  81. lionagi-0.0.305.dist-info/RECORD +0 -94
  82. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
  83. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
  84. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
@@ -1,427 +1,427 @@
1
- from lionagi.core.branch.base_branch import BaseBranch
2
- from lionagi.core.branch.util import MessageUtil
3
- from lionagi.core.messages.schema import System
4
- from lionagi.core.schema import DataLogger
5
-
6
-
7
- import unittest
8
- from unittest.mock import patch, MagicMock
9
- import pandas as pd
10
- from datetime import datetime
11
- import json
12
-
13
-
14
- class TestBaseBranch(unittest.TestCase):
15
-
16
- def setUp(self):
17
- # Patching DataLogger to avoid filesystem interactions during tests
18
- self.patcher = patch(
19
- "lionagi.core.branch.base_branch.DataLogger", autospec=True
20
- )
21
- self.MockDataLogger = self.patcher.start()
22
- self.addCleanup(self.patcher.stop)
23
-
24
- self.branch = BaseBranch()
25
- # Example messages to populate the DataFrame for testing
26
- self.test_messages = [
27
- {
28
- "node_id": "1",
29
- "timestamp": "2021-01-01 00:00:00",
30
- "role": "system",
31
- "sender": "system",
32
- "content": json.dumps({"system_info": "System message"}),
33
- },
34
- {
35
- "node_id": "2",
36
- "timestamp": "2021-01-01 00:01:00",
37
- "role": "user",
38
- "sender": "user1",
39
- "content": json.dumps({"instruction": "User message"}),
40
- },
41
- {
42
- "node_id": "3",
43
- "timestamp": "2021-01-01 00:02:00",
44
- "role": "assistant",
45
- "sender": "assistant",
46
- "content": json.dumps({"response": "Assistant response"}),
47
- },
48
- {
49
- "node_id": "4",
50
- "timestamp": "2021-01-01 00:03:00",
51
- "role": "assistant",
52
- "sender": "action_request",
53
- "content": json.dumps({"action_request": "Action request"}),
54
- },
55
- {
56
- "node_id": "5",
57
- "timestamp": "2021-01-01 00:04:00",
58
- "role": "assistant",
59
- "sender": "action_response",
60
- "content": json.dumps({"action_response": "Action response"}),
61
- },
62
- ]
63
- self.branch = BaseBranch(messages=pd.DataFrame(self.test_messages))
64
-
65
- def test_init_with_empty_messages(self):
66
- """Test __init__ method with no messages provided."""
67
- branch = BaseBranch()
68
- self.assertTrue(branch.messages.empty)
69
-
70
- def test_init_with_given_messages(self):
71
- """Test __init__ method with provided messages DataFrame."""
72
- messages = pd.DataFrame(
73
- [
74
- [
75
- "0",
76
- datetime(2024, 1, 1),
77
- "system",
78
- "system",
79
- json.dumps({"system_info": "Hi"}),
80
- ]
81
- ],
82
- columns=["node_id", "timestamp", "role", "sender", "content"],
83
- )
84
- branch = BaseBranch(messages=messages)
85
- self.assertFalse(branch.messages.empty)
86
-
87
- def test_add_message(self):
88
- """Test adding a message."""
89
- message_info = {"info": "Test Message"}
90
- with patch.object(
91
- MessageUtil, "create_message", return_value=System(system=message_info)
92
- ) as mocked_create_message:
93
- self.branch.add_message(system=message_info)
94
- mocked_create_message.assert_called_once()
95
- self.assertEqual(len(self.branch.messages), 6)
96
-
97
- def test_chat_messages_without_sender(self):
98
- """Test chat_messages property without including sender information."""
99
- chat_messages = self.branch.chat_messages
100
- self.assertEqual(len(chat_messages), 5)
101
- self.assertNotIn("Sender", chat_messages[0]["content"])
102
-
103
- def test_chat_messages_with_sender(self):
104
- """Test retrieving chat messages with sender information included."""
105
- chat_messages_with_sender = self.branch.chat_messages_with_sender
106
- expected_prefixes = [
107
- "Sender system: ",
108
- "Sender user1: ",
109
- "Sender assistant: ",
110
- "Sender action_request: ",
111
- "Sender action_response: ",
112
- ]
113
-
114
- # Verify the number of messages returned matches the number added
115
- self.assertEqual(len(chat_messages_with_sender), len(self.test_messages))
116
-
117
- # Check each message for correct sender prefix in content
118
- for i, message_dict in enumerate(chat_messages_with_sender):
119
- self.assertTrue(
120
- message_dict["content"].startswith(expected_prefixes[i]),
121
- msg=f"Message content does not start with expected sender prefix. Found: {message_dict['content']}",
122
- )
123
-
124
- # Optionally, verify the content matches after removing the prefix
125
- for i, message_dict in enumerate(chat_messages_with_sender):
126
- prefix, content = message_dict["content"].split(": ", 1)
127
- self.assertEqual(
128
- content,
129
- self.test_messages[i]["content"],
130
- msg=f"Message content does not match after removing sender prefix. Found: {content}, Expected: {self.test_messages[i]['content']}",
131
- )
132
-
133
- def test_last_message(self):
134
- """Test retrieving the last message."""
135
- last_message = self.branch.last_message
136
- self.assertEqual(
137
- json.loads(last_message.iloc[0]["content"]),
138
- {"action_response": "Action response"},
139
- )
140
-
141
- def test_last_message_content(self):
142
- """Test retrieving the content of the last message."""
143
- content = self.branch.last_message_content
144
- self.assertEqual(content, {"action_response": "Action response"})
145
-
146
- def test_first_system_message(self):
147
- """Test retrieving the first 'system' message."""
148
- first_system_message = self.branch.first_system
149
- self.assertEqual(
150
- json.loads(first_system_message.iloc[0]["content"]),
151
- {"system_info": "System message"},
152
- )
153
-
154
- def test_last_response(self):
155
- """Test retrieving the last 'assistant' message."""
156
- last_response = self.branch.last_response
157
- self.assertEqual(last_response.iloc[0]["sender"], "action_response")
158
-
159
- def test_last_response_content(self):
160
- """Test extracting content of the last 'assistant' message."""
161
- content = self.branch.last_response_content
162
- self.assertEqual(content, {"action_response": "Action response"})
163
-
164
- def test_action_request_messages(self):
165
- """Test filtering messages sent by 'action_request'."""
166
- action_requests = self.branch.action_request
167
- self.assertTrue(all(action_requests["sender"] == "action_request"))
168
-
169
- def test_action_response_messages(self):
170
- """Test filtering messages sent by 'action_response'."""
171
- action_responses = self.branch.action_response
172
- self.assertTrue(all(action_responses["sender"] == "action_response"))
173
-
174
- def test_responses(self):
175
- """Test filtering of 'assistant' role messages."""
176
- responses = self.branch.responses
177
- # Verify that all returned messages have the 'assistant' role
178
- self.assertTrue(all(responses["role"] == "assistant"))
179
- # Optionally, check the count matches the expected number of 'assistant' messages
180
- expected_count = sum(
181
- 1 for msg in self.test_messages if msg["role"] == "assistant"
182
- )
183
- self.assertEqual(len(responses), expected_count)
184
-
185
- def test_assistant_responses(self):
186
- """Test filtering of 'assistant' messages excluding action requests/responses."""
187
- assistant_responses = self.branch.assistant_responses
188
- # Verify that no returned messages are from 'action_request' or 'action_response' senders
189
- self.assertTrue(all(assistant_responses["sender"] != "action_request"))
190
- self.assertTrue(all(assistant_responses["sender"] != "action_response"))
191
- # Verify that all returned messages have the 'assistant' role
192
- self.assertTrue(all(assistant_responses["role"] == "assistant"))
193
-
194
- def test_info(self):
195
- """Test summarization of message counts by role."""
196
- info = self.branch.info
197
- # Verify that the dictionary contains keys for each role and 'total'
198
- self.assertIn("assistant", info)
199
- self.assertIn("user", info)
200
- self.assertIn("system", info)
201
- self.assertIn("total", info)
202
- # Optionally, verify the counts match expected values
203
- for role in ["assistant", "user", "system"]:
204
- expected_count = sum(1 for msg in self.test_messages if msg["role"] == role)
205
- self.assertEqual(info[role], expected_count)
206
- self.assertEqual(info["total"], len(self.test_messages))
207
-
208
- def test_sender_info(self):
209
- """Test summarization of message counts by sender."""
210
- sender_info = self.branch.sender_info
211
- # Verify that the dictionary contains keys for each sender and the counts match
212
- for sender in set(msg["sender"] for msg in self.test_messages):
213
- expected_count = sum(
214
- 1 for msg in self.test_messages if msg["sender"] == sender
215
- )
216
- self.assertEqual(sender_info.get(sender, 0), expected_count)
217
-
218
- def test_describe(self):
219
- """Test detailed description of the branch."""
220
- description = self.branch.describe
221
- # Verify that the description contains expected keys
222
- self.assertIn("total_messages", description)
223
- self.assertIn("summary_by_role", description)
224
- self.assertIn("messages", description)
225
- # Optionally, verify the accuracy of the values
226
- self.assertEqual(description["total_messages"], len(self.test_messages))
227
- self.assertEqual(len(description["messages"]), min(5, len(self.test_messages)))
228
-
229
- @patch("lionagi.libs.ln_dataframe.read_csv")
230
- def test_from_csv(cls, mock_read_csv):
231
- # Define a mock return value for read_csv
232
- mock_messages_df = pd.DataFrame(
233
- {
234
- "node_id": ["1", "2"],
235
- "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
236
- "role": ["system", "user"],
237
- "sender": ["system", "user1"],
238
- "content": [
239
- json.dumps({"system_info": "System message"}),
240
- json.dumps({"instruction": "User message"}),
241
- ],
242
- }
243
- )
244
- mock_read_csv.return_value = mock_messages_df
245
-
246
- # Call the from_csv class method
247
- branch = BaseBranch.from_csv(filename="dummy.csv")
248
-
249
- # Verify that read_csv was called correctly
250
- mock_read_csv.assert_called_once_with("dummy.csv")
251
-
252
- # Verify that the branch instance contains the correct messages
253
- pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
254
-
255
- @patch("lionagi.libs.ln_dataframe.read_json")
256
- def test_from_json(cls, mock_read_csv):
257
- # Define a mock return value for read_csv
258
- mock_messages_df = pd.DataFrame(
259
- {
260
- "node_id": ["1", "2"],
261
- "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
262
- "role": ["system", "user"],
263
- "sender": ["system", "user1"],
264
- "content": [
265
- json.dumps({"system_info": "System message"}),
266
- json.dumps({"instruction": "User message"}),
267
- ],
268
- }
269
- )
270
- mock_read_csv.return_value = mock_messages_df
271
-
272
- # Call the from_csv class method
273
- branch = BaseBranch.from_json_string(filename="dummy.json")
274
-
275
- # Verify that read_csv was called correctly
276
- mock_read_csv.assert_called_once_with("dummy.json")
277
-
278
- # Verify that the branch instance contains the correct messages
279
- pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
280
-
281
- @patch(
282
- "lionagi.libs.sys_util.SysUtil.create_path",
283
- return_value="path/to/messages.csv",
284
- )
285
- @patch.object(pd.DataFrame, "to_csv")
286
- def test_to_csv(self, mock_to_csv, mock_create_path):
287
- self.branch.datalogger = MagicMock()
288
- self.branch.datalogger.persist_path = "data/logs/"
289
-
290
- self.branch.to_csv_file("messages.csv", verbose=False, clear=False)
291
-
292
- mock_create_path.assert_called_once_with(
293
- self.branch.datalogger.persist_path,
294
- "messages.csv",
295
- timestamp=True,
296
- dir_exist_ok=True,
297
- time_prefix=False,
298
- )
299
- mock_to_csv.assert_called_once_with("path/to/messages.csv")
300
-
301
- # Verify that messages are not cleared after exporting
302
- assert not self.branch.messages.empty
303
-
304
- @patch(
305
- "lionagi.libs.sys_util.SysUtil.create_path",
306
- return_value="path/to/messages.json",
307
- )
308
- @patch.object(pd.DataFrame, "to_json")
309
- def test_to_json(self, mock_to_json, mock_create_path):
310
- self.branch.datalogger = MagicMock()
311
- self.branch.datalogger.persist_path = "data/logs/"
312
-
313
- self.branch.to_json_file("messages.json", verbose=False, clear=False)
314
-
315
- mock_create_path.assert_called_once_with(
316
- self.branch.datalogger.persist_path,
317
- "messages.json",
318
- timestamp=True,
319
- dir_exist_ok=True,
320
- time_prefix=False,
321
- )
322
- mock_to_json.assert_called_once_with(
323
- "path/to/messages.json", orient="records", lines=True, date_format="iso"
324
- )
325
-
326
- # Verify that messages are not cleared after exporting
327
- assert not self.branch.messages.empty
328
-
329
- # def test_log_to_csv(self):
330
- # self.branch.log_to_csv('log.csv', verbose=False, clear=False)
331
-
332
- # self.branch.datalogger.to_csv_file.assert_called_once_with(filename='log.csv', dir_exist_ok=True, timestamp=True,
333
- # time_prefix=False, verbose=False, clear=False)
334
-
335
- # def test_log_to_json(self):
336
- # branch = BaseBranch()
337
- # branch.log_to_json('log.json', verbose=False, clear=False)
338
-
339
- # self.branch.datalogger.to_json_file.assert_called_once_with(filename='log.json', dir_exist_ok=True, timestamp=True,
340
- # time_prefix=False, verbose=False, clear=False)
341
-
342
- def test_remove_message(self):
343
- """Test removing a message from the branch based on its node ID."""
344
- initial_length = len(self.branch.messages)
345
- node_id_to_remove = "2"
346
- self.branch.remove_message(node_id_to_remove)
347
- final_length = len(self.branch.messages)
348
- self.assertNotIn(node_id_to_remove, self.branch.messages["node_id"].tolist())
349
- self.assertEqual(final_length, initial_length - 1)
350
-
351
- def test_update_message(self):
352
- """Test updating a specific message's column identified by node_id with a new value."""
353
- node_id_to_update = "2"
354
- new_value = "Updated content"
355
- self.branch.update_message(node_id_to_update, "content", new_value)
356
- updated_value = self.branch.messages.loc[
357
- self.branch.messages["node_id"] == node_id_to_update, "content"
358
- ].values[0]
359
- self.assertEqual(updated_value, new_value)
360
-
361
- # def test_change_first_system_message(self):
362
- # """Test updating the first system message with new content and/or sender."""
363
- # new_system_content = {"system_info": "Updated system message"}
364
- # self.branch.change_first_system_message(new_system_content['system_info'])
365
- # first_system_message_content = self.branch.messages.loc[self.branch.messages['role'] == 'system', 'content'].iloc[0]
366
- # self.assertIn(json.dumps({"system_info": "Updated system message"}), first_system_message_content)
367
-
368
- def test_rollback(self):
369
- """Test removing the last 'n' messages from the branch."""
370
- steps_to_remove = 2
371
- initial_length = len(self.branch.messages)
372
- self.branch.rollback(steps_to_remove)
373
- final_length = len(self.branch.messages)
374
- self.assertEqual(final_length, initial_length - steps_to_remove)
375
-
376
- def test_clear_messages(self):
377
- """Test clearing all messages from the branch."""
378
- self.branch.clear_messages()
379
- self.assertTrue(self.branch.messages.empty)
380
-
381
- def test_replace_keyword(self):
382
- """Test replacing occurrences of a specified keyword with a replacement string."""
383
- keyword = "Assistant response"
384
- replacement = "Helper feedback"
385
- self.branch.replace_keyword(keyword, replacement)
386
- self.assertTrue(
387
- any(replacement in message for message in self.branch.messages["content"])
388
- )
389
-
390
- def test_search_keywords(self):
391
- """Test filtering messages by a specified keyword or list of keywords."""
392
- keyword_to_search = "Assistant"
393
- filtered_messages = self.branch.search_keywords(keyword_to_search)
394
- self.assertTrue(
395
- all(
396
- keyword_to_search in content for content in filtered_messages["content"]
397
- )
398
- )
399
-
400
- def test_extend(self):
401
- """Test extending branch messages with additional messages."""
402
- additional_messages = pd.DataFrame(
403
- [
404
- {
405
- "node_id": "6",
406
- "timestamp": datetime(2021, 1, 1),
407
- "role": "user",
408
- "sender": "user2",
409
- "content": json.dumps({"instruction": "Another user message"}),
410
- }
411
- ]
412
- )
413
- initial_length = len(self.branch.messages)
414
- self.branch.extend(additional_messages)
415
- self.assertEqual(len(self.branch.messages), initial_length + 1)
416
-
417
- def test_filter_by(self):
418
- """Test filtering messages by various criteria."""
419
- # Filtering by role
420
- filtered_messages = self.branch.filter_by(role="user")
421
- self.assertTrue(
422
- all(msg["role"] == "user" for _, msg in filtered_messages.iterrows())
423
- )
424
-
425
-
426
- if __name__ == "__main__":
427
- unittest.main()
1
+ # from lionagi.core.branch.base_branch import BaseBranch
2
+ # from lionagi.core.branch.util import MessageUtil
3
+ # from lionagi.core.messages.schema import System
4
+ # from lionagi.core.schema import DataLogger
5
+
6
+
7
+ # import unittest
8
+ # from unittest.mock import patch, MagicMock
9
+ # import pandas as pd
10
+ # from datetime import datetime
11
+ # import json
12
+
13
+
14
+ # class TestBaseBranch(unittest.TestCase):
15
+
16
+ # def setUp(self):
17
+ # # Patching DataLogger to avoid filesystem interactions during tests
18
+ # self.patcher = patch(
19
+ # "lionagi.core.branch.base_branch.DataLogger", autospec=True
20
+ # )
21
+ # self.MockDataLogger = self.patcher.start()
22
+ # self.addCleanup(self.patcher.stop)
23
+
24
+ # self.branch = BaseBranch()
25
+ # # Example messages to populate the DataFrame for testing
26
+ # self.test_messages = [
27
+ # {
28
+ # "node_id": "1",
29
+ # "timestamp": "2021-01-01 00:00:00",
30
+ # "role": "system",
31
+ # "sender": "system",
32
+ # "content": json.dumps({"system_info": "System message"}),
33
+ # },
34
+ # {
35
+ # "node_id": "2",
36
+ # "timestamp": "2021-01-01 00:01:00",
37
+ # "role": "user",
38
+ # "sender": "user1",
39
+ # "content": json.dumps({"instruction": "User message"}),
40
+ # },
41
+ # {
42
+ # "node_id": "3",
43
+ # "timestamp": "2021-01-01 00:02:00",
44
+ # "role": "assistant",
45
+ # "sender": "assistant",
46
+ # "content": json.dumps({"response": "Assistant response"}),
47
+ # },
48
+ # {
49
+ # "node_id": "4",
50
+ # "timestamp": "2021-01-01 00:03:00",
51
+ # "role": "assistant",
52
+ # "sender": "action_request",
53
+ # "content": json.dumps({"action_request": "Action request"}),
54
+ # },
55
+ # {
56
+ # "node_id": "5",
57
+ # "timestamp": "2021-01-01 00:04:00",
58
+ # "role": "assistant",
59
+ # "sender": "action_response",
60
+ # "content": json.dumps({"action_response": "Action response"}),
61
+ # },
62
+ # ]
63
+ # self.branch = BaseBranch(messages=pd.DataFrame(self.test_messages))
64
+
65
+ # def test_init_with_empty_messages(self):
66
+ # """Test __init__ method with no messages provided."""
67
+ # branch = BaseBranch()
68
+ # self.assertTrue(branch.messages.empty)
69
+
70
+ # def test_init_with_given_messages(self):
71
+ # """Test __init__ method with provided messages DataFrame."""
72
+ # messages = pd.DataFrame(
73
+ # [
74
+ # [
75
+ # "0",
76
+ # datetime(2024, 1, 1),
77
+ # "system",
78
+ # "system",
79
+ # json.dumps({"system_info": "Hi"}),
80
+ # ]
81
+ # ],
82
+ # columns=["node_id", "timestamp", "role", "sender", "content"],
83
+ # )
84
+ # branch = BaseBranch(messages=messages)
85
+ # self.assertFalse(branch.messages.empty)
86
+
87
+ # def test_add_message(self):
88
+ # """Test adding a message."""
89
+ # message_info = {"info": "Test Message"}
90
+ # with patch.object(
91
+ # MessageUtil, "create_message", return_value=System(system=message_info)
92
+ # ) as mocked_create_message:
93
+ # self.branch.add_message(system=message_info)
94
+ # mocked_create_message.assert_called_once()
95
+ # self.assertEqual(len(self.branch.messages), 6)
96
+
97
+ # def test_chat_messages_without_sender(self):
98
+ # """Test chat_messages property without including sender information."""
99
+ # chat_messages = self.branch.chat_messages
100
+ # self.assertEqual(len(chat_messages), 5)
101
+ # self.assertNotIn("Sender", chat_messages[0]["content"])
102
+
103
+ # def test_chat_messages_with_sender(self):
104
+ # """Test retrieving chat messages with sender information included."""
105
+ # chat_messages_with_sender = self.branch.chat_messages_with_sender
106
+ # expected_prefixes = [
107
+ # "Sender system: ",
108
+ # "Sender user1: ",
109
+ # "Sender assistant: ",
110
+ # "Sender action_request: ",
111
+ # "Sender action_response: ",
112
+ # ]
113
+
114
+ # # Verify the number of messages returned matches the number added
115
+ # self.assertEqual(len(chat_messages_with_sender), len(self.test_messages))
116
+
117
+ # # Check each message for correct sender prefix in content
118
+ # for i, message_dict in enumerate(chat_messages_with_sender):
119
+ # self.assertTrue(
120
+ # message_dict["content"].startswith(expected_prefixes[i]),
121
+ # msg=f"Message content does not start with expected sender prefix. Found: {message_dict['content']}",
122
+ # )
123
+
124
+ # # Optionally, verify the content matches after removing the prefix
125
+ # for i, message_dict in enumerate(chat_messages_with_sender):
126
+ # prefix, content = message_dict["content"].split(": ", 1)
127
+ # self.assertEqual(
128
+ # content,
129
+ # self.test_messages[i]["content"],
130
+ # msg=f"Message content does not match after removing sender prefix. Found: {content}, Expected: {self.test_messages[i]['content']}",
131
+ # )
132
+
133
+ # def test_last_message(self):
134
+ # """Test retrieving the last message."""
135
+ # last_message = self.branch.last_message
136
+ # self.assertEqual(
137
+ # json.loads(last_message.iloc[0]["content"]),
138
+ # {"action_response": "Action response"},
139
+ # )
140
+
141
+ # def test_last_message_content(self):
142
+ # """Test retrieving the content of the last message."""
143
+ # content = self.branch.last_message_content
144
+ # self.assertEqual(content, {"action_response": "Action response"})
145
+
146
+ # def test_first_system_message(self):
147
+ # """Test retrieving the first 'system' message."""
148
+ # first_system_message = self.branch.first_system
149
+ # self.assertEqual(
150
+ # json.loads(first_system_message.iloc[0]["content"]),
151
+ # {"system_info": "System message"},
152
+ # )
153
+
154
+ # def test_last_response(self):
155
+ # """Test retrieving the last 'assistant' message."""
156
+ # last_response = self.branch.last_response
157
+ # self.assertEqual(last_response.iloc[0]["sender"], "action_response")
158
+
159
+ # def test_last_response_content(self):
160
+ # """Test extracting content of the last 'assistant' message."""
161
+ # content = self.branch.last_response_content
162
+ # self.assertEqual(content, {"action_response": "Action response"})
163
+
164
+ # def test_action_request_messages(self):
165
+ # """Test filtering messages sent by 'action_request'."""
166
+ # action_requests = self.branch.action_request
167
+ # self.assertTrue(all(action_requests["sender"] == "action_request"))
168
+
169
+ # def test_action_response_messages(self):
170
+ # """Test filtering messages sent by 'action_response'."""
171
+ # action_responses = self.branch.action_response
172
+ # self.assertTrue(all(action_responses["sender"] == "action_response"))
173
+
174
+ # def test_responses(self):
175
+ # """Test filtering of 'assistant' role messages."""
176
+ # responses = self.branch.responses
177
+ # # Verify that all returned messages have the 'assistant' role
178
+ # self.assertTrue(all(responses["role"] == "assistant"))
179
+ # # Optionally, check the count matches the expected number of 'assistant' messages
180
+ # expected_count = sum(
181
+ # 1 for msg in self.test_messages if msg["role"] == "assistant"
182
+ # )
183
+ # self.assertEqual(len(responses), expected_count)
184
+
185
+ # def test_assistant_responses(self):
186
+ # """Test filtering of 'assistant' messages excluding action requests/responses."""
187
+ # assistant_responses = self.branch.assistant_responses
188
+ # # Verify that no returned messages are from 'action_request' or 'action_response' senders
189
+ # self.assertTrue(all(assistant_responses["sender"] != "action_request"))
190
+ # self.assertTrue(all(assistant_responses["sender"] != "action_response"))
191
+ # # Verify that all returned messages have the 'assistant' role
192
+ # self.assertTrue(all(assistant_responses["role"] == "assistant"))
193
+
194
+ # def test_info(self):
195
+ # """Test summarization of message counts by role."""
196
+ # info = self.branch.info
197
+ # # Verify that the dictionary contains keys for each role and 'total'
198
+ # self.assertIn("assistant", info)
199
+ # self.assertIn("user", info)
200
+ # self.assertIn("system", info)
201
+ # self.assertIn("total", info)
202
+ # # Optionally, verify the counts match expected values
203
+ # for role in ["assistant", "user", "system"]:
204
+ # expected_count = sum(1 for msg in self.test_messages if msg["role"] == role)
205
+ # self.assertEqual(info[role], expected_count)
206
+ # self.assertEqual(info["total"], len(self.test_messages))
207
+
208
+ # def test_sender_info(self):
209
+ # """Test summarization of message counts by sender."""
210
+ # sender_info = self.branch.sender_info
211
+ # # Verify that the dictionary contains keys for each sender and the counts match
212
+ # for sender in set(msg["sender"] for msg in self.test_messages):
213
+ # expected_count = sum(
214
+ # 1 for msg in self.test_messages if msg["sender"] == sender
215
+ # )
216
+ # self.assertEqual(sender_info.get(sender, 0), expected_count)
217
+
218
+ # def test_describe(self):
219
+ # """Test detailed description of the branch."""
220
+ # description = self.branch.describe
221
+ # # Verify that the description contains expected keys
222
+ # self.assertIn("total_messages", description)
223
+ # self.assertIn("summary_by_role", description)
224
+ # self.assertIn("messages", description)
225
+ # # Optionally, verify the accuracy of the values
226
+ # self.assertEqual(description["total_messages"], len(self.test_messages))
227
+ # self.assertEqual(len(description["messages"]), min(5, len(self.test_messages)))
228
+
229
+ # @patch("lionagi.libs.ln_dataframe.read_csv")
230
+ # def test_from_csv(cls, mock_read_csv):
231
+ # # Define a mock return value for read_csv
232
+ # mock_messages_df = pd.DataFrame(
233
+ # {
234
+ # "node_id": ["1", "2"],
235
+ # "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
236
+ # "role": ["system", "user"],
237
+ # "sender": ["system", "user1"],
238
+ # "content": [
239
+ # json.dumps({"system_info": "System message"}),
240
+ # json.dumps({"instruction": "User message"}),
241
+ # ],
242
+ # }
243
+ # )
244
+ # mock_read_csv.return_value = mock_messages_df
245
+
246
+ # # Call the from_csv class method
247
+ # branch = BaseBranch.from_csv(filename="dummy.csv")
248
+
249
+ # # Verify that read_csv was called correctly
250
+ # mock_read_csv.assert_called_once_with("dummy.csv")
251
+
252
+ # # Verify that the branch instance contains the correct messages
253
+ # pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
254
+
255
+ # @patch("lionagi.libs.ln_dataframe.read_json")
256
+ # def test_from_json(cls, mock_read_csv):
257
+ # # Define a mock return value for read_csv
258
+ # mock_messages_df = pd.DataFrame(
259
+ # {
260
+ # "node_id": ["1", "2"],
261
+ # "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
262
+ # "role": ["system", "user"],
263
+ # "sender": ["system", "user1"],
264
+ # "content": [
265
+ # json.dumps({"system_info": "System message"}),
266
+ # json.dumps({"instruction": "User message"}),
267
+ # ],
268
+ # }
269
+ # )
270
+ # mock_read_csv.return_value = mock_messages_df
271
+
272
+ # # Call the from_csv class method
273
+ # branch = BaseBranch.from_json_string(filename="dummy.json")
274
+
275
+ # # Verify that read_csv was called correctly
276
+ # mock_read_csv.assert_called_once_with("dummy.json")
277
+
278
+ # # Verify that the branch instance contains the correct messages
279
+ # pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
280
+
281
+ # @patch(
282
+ # "lionagi.libs.sys_util.SysUtil.create_path",
283
+ # return_value="path/to/messages.csv",
284
+ # )
285
+ # @patch.object(pd.DataFrame, "to_csv")
286
+ # def test_to_csv(self, mock_to_csv, mock_create_path):
287
+ # self.branch.datalogger = MagicMock()
288
+ # self.branch.datalogger.persist_path = "data/logs/"
289
+
290
+ # self.branch.to_csv_file("messages.csv", verbose=False, clear=False)
291
+
292
+ # mock_create_path.assert_called_once_with(
293
+ # self.branch.datalogger.persist_path,
294
+ # "messages.csv",
295
+ # timestamp=True,
296
+ # dir_exist_ok=True,
297
+ # time_prefix=False,
298
+ # )
299
+ # mock_to_csv.assert_called_once_with("path/to/messages.csv")
300
+
301
+ # # Verify that messages are not cleared after exporting
302
+ # assert not self.branch.messages.empty
303
+
304
+ # @patch(
305
+ # "lionagi.libs.sys_util.SysUtil.create_path",
306
+ # return_value="path/to/messages.json",
307
+ # )
308
+ # @patch.object(pd.DataFrame, "to_json")
309
+ # def test_to_json(self, mock_to_json, mock_create_path):
310
+ # self.branch.datalogger = MagicMock()
311
+ # self.branch.datalogger.persist_path = "data/logs/"
312
+
313
+ # self.branch.to_json_file("messages.json", verbose=False, clear=False)
314
+
315
+ # mock_create_path.assert_called_once_with(
316
+ # self.branch.datalogger.persist_path,
317
+ # "messages.json",
318
+ # timestamp=True,
319
+ # dir_exist_ok=True,
320
+ # time_prefix=False,
321
+ # )
322
+ # mock_to_json.assert_called_once_with(
323
+ # "path/to/messages.json", orient="records", lines=True, date_format="iso"
324
+ # )
325
+
326
+ # # Verify that messages are not cleared after exporting
327
+ # assert not self.branch.messages.empty
328
+
329
+ # # def test_log_to_csv(self):
330
+ # # self.branch.log_to_csv('log.csv', verbose=False, clear=False)
331
+
332
+ # # self.branch.datalogger.to_csv_file.assert_called_once_with(filename='log.csv', dir_exist_ok=True, timestamp=True,
333
+ # # time_prefix=False, verbose=False, clear=False)
334
+
335
+ # # def test_log_to_json(self):
336
+ # # branch = BaseBranch()
337
+ # # branch.log_to_json('log.json', verbose=False, clear=False)
338
+
339
+ # # self.branch.datalogger.to_json_file.assert_called_once_with(filename='log.json', dir_exist_ok=True, timestamp=True,
340
+ # # time_prefix=False, verbose=False, clear=False)
341
+
342
+ # def test_remove_message(self):
343
+ # """Test removing a message from the branch based on its node ID."""
344
+ # initial_length = len(self.branch.messages)
345
+ # node_id_to_remove = "2"
346
+ # self.branch.remove_message(node_id_to_remove)
347
+ # final_length = len(self.branch.messages)
348
+ # self.assertNotIn(node_id_to_remove, self.branch.messages["node_id"].tolist())
349
+ # self.assertEqual(final_length, initial_length - 1)
350
+
351
+ # def test_update_message(self):
352
+ # """Test updating a specific message's column identified by node_id with a new value."""
353
+ # node_id_to_update = "2"
354
+ # new_value = "Updated content"
355
+ # self.branch.update_message(node_id_to_update, "content", new_value)
356
+ # updated_value = self.branch.messages.loc[
357
+ # self.branch.messages["node_id"] == node_id_to_update, "content"
358
+ # ].values[0]
359
+ # self.assertEqual(updated_value, new_value)
360
+
361
+ # # def test_change_first_system_message(self):
362
+ # # """Test updating the first system message with new content and/or sender."""
363
+ # # new_system_content = {"system_info": "Updated system message"}
364
+ # # self.branch.change_first_system_message(new_system_content['system_info'])
365
+ # # first_system_message_content = self.branch.messages.loc[self.branch.messages['role'] == 'system', 'content'].iloc[0]
366
+ # # self.assertIn(json.dumps({"system_info": "Updated system message"}), first_system_message_content)
367
+
368
+ # def test_rollback(self):
369
+ # """Test removing the last 'n' messages from the branch."""
370
+ # steps_to_remove = 2
371
+ # initial_length = len(self.branch.messages)
372
+ # self.branch.rollback(steps_to_remove)
373
+ # final_length = len(self.branch.messages)
374
+ # self.assertEqual(final_length, initial_length - steps_to_remove)
375
+
376
+ # def test_clear_messages(self):
377
+ # """Test clearing all messages from the branch."""
378
+ # self.branch.clear_messages()
379
+ # self.assertTrue(self.branch.messages.empty)
380
+
381
+ # def test_replace_keyword(self):
382
+ # """Test replacing occurrences of a specified keyword with a replacement string."""
383
+ # keyword = "Assistant response"
384
+ # replacement = "Helper feedback"
385
+ # self.branch.replace_keyword(keyword, replacement)
386
+ # self.assertTrue(
387
+ # any(replacement in message for message in self.branch.messages["content"])
388
+ # )
389
+
390
+ # def test_search_keywords(self):
391
+ # """Test filtering messages by a specified keyword or list of keywords."""
392
+ # keyword_to_search = "Assistant"
393
+ # filtered_messages = self.branch.search_keywords(keyword_to_search)
394
+ # self.assertTrue(
395
+ # all(
396
+ # keyword_to_search in content for content in filtered_messages["content"]
397
+ # )
398
+ # )
399
+
400
+ # def test_extend(self):
401
+ # """Test extending branch messages with additional messages."""
402
+ # additional_messages = pd.DataFrame(
403
+ # [
404
+ # {
405
+ # "node_id": "6",
406
+ # "timestamp": datetime(2021, 1, 1),
407
+ # "role": "user",
408
+ # "sender": "user2",
409
+ # "content": json.dumps({"instruction": "Another user message"}),
410
+ # }
411
+ # ]
412
+ # )
413
+ # initial_length = len(self.branch.messages)
414
+ # self.branch.extend(additional_messages)
415
+ # self.assertEqual(len(self.branch.messages), initial_length + 1)
416
+
417
+ # def test_filter_by(self):
418
+ # """Test filtering messages by various criteria."""
419
+ # # Filtering by role
420
+ # filtered_messages = self.branch.filter_by(role="user")
421
+ # self.assertTrue(
422
+ # all(msg["role"] == "user" for _, msg in filtered_messages.iterrows())
423
+ # )
424
+
425
+
426
+ # if __name__ == "__main__":
427
+ # unittest.main()