lionagi 0.0.305__py3-none-any.whl → 0.0.307__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (84) hide show
  1. lionagi/__init__.py +2 -5
  2. lionagi/core/__init__.py +7 -4
  3. lionagi/core/agent/__init__.py +3 -0
  4. lionagi/core/agent/base_agent.py +46 -0
  5. lionagi/core/branch/__init__.py +4 -0
  6. lionagi/core/branch/base/__init__.py +0 -0
  7. lionagi/core/branch/base_branch.py +100 -78
  8. lionagi/core/branch/branch.py +22 -34
  9. lionagi/core/branch/branch_flow_mixin.py +3 -7
  10. lionagi/core/branch/executable_branch.py +192 -0
  11. lionagi/core/branch/util.py +77 -162
  12. lionagi/core/direct/__init__.py +13 -0
  13. lionagi/core/direct/parallel_predict.py +127 -0
  14. lionagi/core/direct/parallel_react.py +0 -0
  15. lionagi/core/direct/parallel_score.py +0 -0
  16. lionagi/core/direct/parallel_select.py +0 -0
  17. lionagi/core/direct/parallel_sentiment.py +0 -0
  18. lionagi/core/direct/predict.py +174 -0
  19. lionagi/core/direct/react.py +33 -0
  20. lionagi/core/direct/score.py +163 -0
  21. lionagi/core/direct/select.py +144 -0
  22. lionagi/core/direct/sentiment.py +51 -0
  23. lionagi/core/direct/utils.py +83 -0
  24. lionagi/core/flow/__init__.py +0 -3
  25. lionagi/core/flow/monoflow/{mono_react.py → ReAct.py} +52 -9
  26. lionagi/core/flow/monoflow/__init__.py +9 -0
  27. lionagi/core/flow/monoflow/{mono_chat.py → chat.py} +11 -11
  28. lionagi/core/flow/monoflow/{mono_chat_mixin.py → chat_mixin.py} +33 -27
  29. lionagi/core/flow/monoflow/{mono_followup.py → followup.py} +7 -6
  30. lionagi/core/flow/polyflow/__init__.py +1 -0
  31. lionagi/core/flow/polyflow/{polychat.py → chat.py} +15 -3
  32. lionagi/core/mail/__init__.py +8 -0
  33. lionagi/core/mail/mail_manager.py +88 -40
  34. lionagi/core/mail/schema.py +32 -6
  35. lionagi/core/messages/__init__.py +3 -0
  36. lionagi/core/messages/schema.py +56 -25
  37. lionagi/core/prompt/__init__.py +0 -0
  38. lionagi/core/prompt/prompt_template.py +0 -0
  39. lionagi/core/schema/__init__.py +7 -5
  40. lionagi/core/schema/action_node.py +29 -0
  41. lionagi/core/schema/base_mixin.py +56 -59
  42. lionagi/core/schema/base_node.py +35 -38
  43. lionagi/core/schema/condition.py +24 -0
  44. lionagi/core/schema/data_logger.py +98 -98
  45. lionagi/core/schema/data_node.py +19 -19
  46. lionagi/core/schema/prompt_template.py +0 -0
  47. lionagi/core/schema/structure.py +293 -190
  48. lionagi/core/session/__init__.py +1 -3
  49. lionagi/core/session/session.py +196 -214
  50. lionagi/core/tool/tool_manager.py +95 -103
  51. lionagi/integrations/__init__.py +1 -3
  52. lionagi/integrations/bridge/langchain_/documents.py +17 -18
  53. lionagi/integrations/bridge/langchain_/langchain_bridge.py +14 -14
  54. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +22 -22
  55. lionagi/integrations/bridge/llamaindex_/node_parser.py +12 -12
  56. lionagi/integrations/bridge/llamaindex_/reader.py +11 -11
  57. lionagi/integrations/bridge/llamaindex_/textnode.py +7 -7
  58. lionagi/integrations/config/openrouter_configs.py +0 -1
  59. lionagi/integrations/provider/oai.py +26 -26
  60. lionagi/integrations/provider/services.py +38 -38
  61. lionagi/libs/__init__.py +34 -1
  62. lionagi/libs/ln_api.py +211 -221
  63. lionagi/libs/ln_async.py +53 -60
  64. lionagi/libs/ln_convert.py +118 -120
  65. lionagi/libs/ln_dataframe.py +32 -33
  66. lionagi/libs/ln_func_call.py +334 -342
  67. lionagi/libs/ln_nested.py +99 -107
  68. lionagi/libs/ln_parse.py +175 -158
  69. lionagi/libs/sys_util.py +52 -52
  70. lionagi/tests/test_core/test_base_branch.py +427 -427
  71. lionagi/tests/test_core/test_branch.py +292 -292
  72. lionagi/tests/test_core/test_mail_manager.py +57 -57
  73. lionagi/tests/test_core/test_session.py +254 -266
  74. lionagi/tests/test_core/test_session_base_util.py +299 -300
  75. lionagi/tests/test_core/test_tool_manager.py +70 -74
  76. lionagi/tests/test_libs/test_nested.py +2 -7
  77. lionagi/tests/test_libs/test_parse.py +2 -2
  78. lionagi/version.py +1 -1
  79. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/METADATA +4 -2
  80. lionagi-0.0.307.dist-info/RECORD +115 -0
  81. lionagi-0.0.305.dist-info/RECORD +0 -94
  82. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/LICENSE +0 -0
  83. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/WHEEL +0 -0
  84. {lionagi-0.0.305.dist-info → lionagi-0.0.307.dist-info}/top_level.txt +0 -0
@@ -1,427 +1,427 @@
1
- from lionagi.core.branch.base_branch import BaseBranch
2
- from lionagi.core.branch.util import MessageUtil
3
- from lionagi.core.messages.schema import System
4
- from lionagi.core.schema import DataLogger
5
-
6
-
7
- import unittest
8
- from unittest.mock import patch, MagicMock
9
- import pandas as pd
10
- from datetime import datetime
11
- import json
12
-
13
-
14
- class TestBaseBranch(unittest.TestCase):
15
-
16
- def setUp(self):
17
- # Patching DataLogger to avoid filesystem interactions during tests
18
- self.patcher = patch(
19
- "lionagi.core.branch.base_branch.DataLogger", autospec=True
20
- )
21
- self.MockDataLogger = self.patcher.start()
22
- self.addCleanup(self.patcher.stop)
23
-
24
- self.branch = BaseBranch()
25
- # Example messages to populate the DataFrame for testing
26
- self.test_messages = [
27
- {
28
- "node_id": "1",
29
- "timestamp": "2021-01-01 00:00:00",
30
- "role": "system",
31
- "sender": "system",
32
- "content": json.dumps({"system_info": "System message"}),
33
- },
34
- {
35
- "node_id": "2",
36
- "timestamp": "2021-01-01 00:01:00",
37
- "role": "user",
38
- "sender": "user1",
39
- "content": json.dumps({"instruction": "User message"}),
40
- },
41
- {
42
- "node_id": "3",
43
- "timestamp": "2021-01-01 00:02:00",
44
- "role": "assistant",
45
- "sender": "assistant",
46
- "content": json.dumps({"response": "Assistant response"}),
47
- },
48
- {
49
- "node_id": "4",
50
- "timestamp": "2021-01-01 00:03:00",
51
- "role": "assistant",
52
- "sender": "action_request",
53
- "content": json.dumps({"action_request": "Action request"}),
54
- },
55
- {
56
- "node_id": "5",
57
- "timestamp": "2021-01-01 00:04:00",
58
- "role": "assistant",
59
- "sender": "action_response",
60
- "content": json.dumps({"action_response": "Action response"}),
61
- },
62
- ]
63
- self.branch = BaseBranch(messages=pd.DataFrame(self.test_messages))
64
-
65
- def test_init_with_empty_messages(self):
66
- """Test __init__ method with no messages provided."""
67
- branch = BaseBranch()
68
- self.assertTrue(branch.messages.empty)
69
-
70
- def test_init_with_given_messages(self):
71
- """Test __init__ method with provided messages DataFrame."""
72
- messages = pd.DataFrame(
73
- [
74
- [
75
- "0",
76
- datetime(2024, 1, 1),
77
- "system",
78
- "system",
79
- json.dumps({"system_info": "Hi"}),
80
- ]
81
- ],
82
- columns=["node_id", "timestamp", "role", "sender", "content"],
83
- )
84
- branch = BaseBranch(messages=messages)
85
- self.assertFalse(branch.messages.empty)
86
-
87
- def test_add_message(self):
88
- """Test adding a message."""
89
- message_info = {"info": "Test Message"}
90
- with patch.object(
91
- MessageUtil, "create_message", return_value=System(system=message_info)
92
- ) as mocked_create_message:
93
- self.branch.add_message(system=message_info)
94
- mocked_create_message.assert_called_once()
95
- self.assertEqual(len(self.branch.messages), 6)
96
-
97
- def test_chat_messages_without_sender(self):
98
- """Test chat_messages property without including sender information."""
99
- chat_messages = self.branch.chat_messages
100
- self.assertEqual(len(chat_messages), 5)
101
- self.assertNotIn("Sender", chat_messages[0]["content"])
102
-
103
- def test_chat_messages_with_sender(self):
104
- """Test retrieving chat messages with sender information included."""
105
- chat_messages_with_sender = self.branch.chat_messages_with_sender
106
- expected_prefixes = [
107
- "Sender system: ",
108
- "Sender user1: ",
109
- "Sender assistant: ",
110
- "Sender action_request: ",
111
- "Sender action_response: ",
112
- ]
113
-
114
- # Verify the number of messages returned matches the number added
115
- self.assertEqual(len(chat_messages_with_sender), len(self.test_messages))
116
-
117
- # Check each message for correct sender prefix in content
118
- for i, message_dict in enumerate(chat_messages_with_sender):
119
- self.assertTrue(
120
- message_dict["content"].startswith(expected_prefixes[i]),
121
- msg=f"Message content does not start with expected sender prefix. Found: {message_dict['content']}",
122
- )
123
-
124
- # Optionally, verify the content matches after removing the prefix
125
- for i, message_dict in enumerate(chat_messages_with_sender):
126
- prefix, content = message_dict["content"].split(": ", 1)
127
- self.assertEqual(
128
- content,
129
- self.test_messages[i]["content"],
130
- msg=f"Message content does not match after removing sender prefix. Found: {content}, Expected: {self.test_messages[i]['content']}",
131
- )
132
-
133
- def test_last_message(self):
134
- """Test retrieving the last message."""
135
- last_message = self.branch.last_message
136
- self.assertEqual(
137
- json.loads(last_message.iloc[0]["content"]),
138
- {"action_response": "Action response"},
139
- )
140
-
141
- def test_last_message_content(self):
142
- """Test retrieving the content of the last message."""
143
- content = self.branch.last_message_content
144
- self.assertEqual(content, {"action_response": "Action response"})
145
-
146
- def test_first_system_message(self):
147
- """Test retrieving the first 'system' message."""
148
- first_system_message = self.branch.first_system
149
- self.assertEqual(
150
- json.loads(first_system_message.iloc[0]["content"]),
151
- {"system_info": "System message"},
152
- )
153
-
154
- def test_last_response(self):
155
- """Test retrieving the last 'assistant' message."""
156
- last_response = self.branch.last_response
157
- self.assertEqual(last_response.iloc[0]["sender"], "action_response")
158
-
159
- def test_last_response_content(self):
160
- """Test extracting content of the last 'assistant' message."""
161
- content = self.branch.last_response_content
162
- self.assertEqual(content, {"action_response": "Action response"})
163
-
164
- def test_action_request_messages(self):
165
- """Test filtering messages sent by 'action_request'."""
166
- action_requests = self.branch.action_request
167
- self.assertTrue(all(action_requests["sender"] == "action_request"))
168
-
169
- def test_action_response_messages(self):
170
- """Test filtering messages sent by 'action_response'."""
171
- action_responses = self.branch.action_response
172
- self.assertTrue(all(action_responses["sender"] == "action_response"))
173
-
174
- def test_responses(self):
175
- """Test filtering of 'assistant' role messages."""
176
- responses = self.branch.responses
177
- # Verify that all returned messages have the 'assistant' role
178
- self.assertTrue(all(responses["role"] == "assistant"))
179
- # Optionally, check the count matches the expected number of 'assistant' messages
180
- expected_count = sum(
181
- 1 for msg in self.test_messages if msg["role"] == "assistant"
182
- )
183
- self.assertEqual(len(responses), expected_count)
184
-
185
- def test_assistant_responses(self):
186
- """Test filtering of 'assistant' messages excluding action requests/responses."""
187
- assistant_responses = self.branch.assistant_responses
188
- # Verify that no returned messages are from 'action_request' or 'action_response' senders
189
- self.assertTrue(all(assistant_responses["sender"] != "action_request"))
190
- self.assertTrue(all(assistant_responses["sender"] != "action_response"))
191
- # Verify that all returned messages have the 'assistant' role
192
- self.assertTrue(all(assistant_responses["role"] == "assistant"))
193
-
194
- def test_info(self):
195
- """Test summarization of message counts by role."""
196
- info = self.branch.info
197
- # Verify that the dictionary contains keys for each role and 'total'
198
- self.assertIn("assistant", info)
199
- self.assertIn("user", info)
200
- self.assertIn("system", info)
201
- self.assertIn("total", info)
202
- # Optionally, verify the counts match expected values
203
- for role in ["assistant", "user", "system"]:
204
- expected_count = sum(1 for msg in self.test_messages if msg["role"] == role)
205
- self.assertEqual(info[role], expected_count)
206
- self.assertEqual(info["total"], len(self.test_messages))
207
-
208
- def test_sender_info(self):
209
- """Test summarization of message counts by sender."""
210
- sender_info = self.branch.sender_info
211
- # Verify that the dictionary contains keys for each sender and the counts match
212
- for sender in set(msg["sender"] for msg in self.test_messages):
213
- expected_count = sum(
214
- 1 for msg in self.test_messages if msg["sender"] == sender
215
- )
216
- self.assertEqual(sender_info.get(sender, 0), expected_count)
217
-
218
- def test_describe(self):
219
- """Test detailed description of the branch."""
220
- description = self.branch.describe
221
- # Verify that the description contains expected keys
222
- self.assertIn("total_messages", description)
223
- self.assertIn("summary_by_role", description)
224
- self.assertIn("messages", description)
225
- # Optionally, verify the accuracy of the values
226
- self.assertEqual(description["total_messages"], len(self.test_messages))
227
- self.assertEqual(len(description["messages"]), min(5, len(self.test_messages)))
228
-
229
- @patch("lionagi.libs.ln_dataframe.read_csv")
230
- def test_from_csv(cls, mock_read_csv):
231
- # Define a mock return value for read_csv
232
- mock_messages_df = pd.DataFrame(
233
- {
234
- "node_id": ["1", "2"],
235
- "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
236
- "role": ["system", "user"],
237
- "sender": ["system", "user1"],
238
- "content": [
239
- json.dumps({"system_info": "System message"}),
240
- json.dumps({"instruction": "User message"}),
241
- ],
242
- }
243
- )
244
- mock_read_csv.return_value = mock_messages_df
245
-
246
- # Call the from_csv class method
247
- branch = BaseBranch.from_csv(filename="dummy.csv")
248
-
249
- # Verify that read_csv was called correctly
250
- mock_read_csv.assert_called_once_with("dummy.csv")
251
-
252
- # Verify that the branch instance contains the correct messages
253
- pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
254
-
255
- @patch("lionagi.libs.ln_dataframe.read_json")
256
- def test_from_json(cls, mock_read_csv):
257
- # Define a mock return value for read_csv
258
- mock_messages_df = pd.DataFrame(
259
- {
260
- "node_id": ["1", "2"],
261
- "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
262
- "role": ["system", "user"],
263
- "sender": ["system", "user1"],
264
- "content": [
265
- json.dumps({"system_info": "System message"}),
266
- json.dumps({"instruction": "User message"}),
267
- ],
268
- }
269
- )
270
- mock_read_csv.return_value = mock_messages_df
271
-
272
- # Call the from_csv class method
273
- branch = BaseBranch.from_json_string(filename="dummy.json")
274
-
275
- # Verify that read_csv was called correctly
276
- mock_read_csv.assert_called_once_with("dummy.json")
277
-
278
- # Verify that the branch instance contains the correct messages
279
- pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
280
-
281
- @patch(
282
- "lionagi.libs.sys_util.SysUtil.create_path",
283
- return_value="path/to/messages.csv",
284
- )
285
- @patch.object(pd.DataFrame, "to_csv")
286
- def test_to_csv(self, mock_to_csv, mock_create_path):
287
- self.branch.datalogger = MagicMock()
288
- self.branch.datalogger.persist_path = "data/logs/"
289
-
290
- self.branch.to_csv_file("messages.csv", verbose=False, clear=False)
291
-
292
- mock_create_path.assert_called_once_with(
293
- self.branch.datalogger.persist_path,
294
- "messages.csv",
295
- timestamp=True,
296
- dir_exist_ok=True,
297
- time_prefix=False,
298
- )
299
- mock_to_csv.assert_called_once_with("path/to/messages.csv")
300
-
301
- # Verify that messages are not cleared after exporting
302
- assert not self.branch.messages.empty
303
-
304
- @patch(
305
- "lionagi.libs.sys_util.SysUtil.create_path",
306
- return_value="path/to/messages.json",
307
- )
308
- @patch.object(pd.DataFrame, "to_json")
309
- def test_to_json(self, mock_to_json, mock_create_path):
310
- self.branch.datalogger = MagicMock()
311
- self.branch.datalogger.persist_path = "data/logs/"
312
-
313
- self.branch.to_json_file("messages.json", verbose=False, clear=False)
314
-
315
- mock_create_path.assert_called_once_with(
316
- self.branch.datalogger.persist_path,
317
- "messages.json",
318
- timestamp=True,
319
- dir_exist_ok=True,
320
- time_prefix=False,
321
- )
322
- mock_to_json.assert_called_once_with(
323
- "path/to/messages.json", orient="records", lines=True, date_format="iso"
324
- )
325
-
326
- # Verify that messages are not cleared after exporting
327
- assert not self.branch.messages.empty
328
-
329
- # def test_log_to_csv(self):
330
- # self.branch.log_to_csv('log.csv', verbose=False, clear=False)
331
-
332
- # self.branch.datalogger.to_csv_file.assert_called_once_with(filename='log.csv', dir_exist_ok=True, timestamp=True,
333
- # time_prefix=False, verbose=False, clear=False)
334
-
335
- # def test_log_to_json(self):
336
- # branch = BaseBranch()
337
- # branch.log_to_json('log.json', verbose=False, clear=False)
338
-
339
- # self.branch.datalogger.to_json_file.assert_called_once_with(filename='log.json', dir_exist_ok=True, timestamp=True,
340
- # time_prefix=False, verbose=False, clear=False)
341
-
342
- def test_remove_message(self):
343
- """Test removing a message from the branch based on its node ID."""
344
- initial_length = len(self.branch.messages)
345
- node_id_to_remove = "2"
346
- self.branch.remove_message(node_id_to_remove)
347
- final_length = len(self.branch.messages)
348
- self.assertNotIn(node_id_to_remove, self.branch.messages["node_id"].tolist())
349
- self.assertEqual(final_length, initial_length - 1)
350
-
351
- def test_update_message(self):
352
- """Test updating a specific message's column identified by node_id with a new value."""
353
- node_id_to_update = "2"
354
- new_value = "Updated content"
355
- self.branch.update_message(node_id_to_update, "content", new_value)
356
- updated_value = self.branch.messages.loc[
357
- self.branch.messages["node_id"] == node_id_to_update, "content"
358
- ].values[0]
359
- self.assertEqual(updated_value, new_value)
360
-
361
- # def test_change_first_system_message(self):
362
- # """Test updating the first system message with new content and/or sender."""
363
- # new_system_content = {"system_info": "Updated system message"}
364
- # self.branch.change_first_system_message(new_system_content['system_info'])
365
- # first_system_message_content = self.branch.messages.loc[self.branch.messages['role'] == 'system', 'content'].iloc[0]
366
- # self.assertIn(json.dumps({"system_info": "Updated system message"}), first_system_message_content)
367
-
368
- def test_rollback(self):
369
- """Test removing the last 'n' messages from the branch."""
370
- steps_to_remove = 2
371
- initial_length = len(self.branch.messages)
372
- self.branch.rollback(steps_to_remove)
373
- final_length = len(self.branch.messages)
374
- self.assertEqual(final_length, initial_length - steps_to_remove)
375
-
376
- def test_clear_messages(self):
377
- """Test clearing all messages from the branch."""
378
- self.branch.clear_messages()
379
- self.assertTrue(self.branch.messages.empty)
380
-
381
- def test_replace_keyword(self):
382
- """Test replacing occurrences of a specified keyword with a replacement string."""
383
- keyword = "Assistant response"
384
- replacement = "Helper feedback"
385
- self.branch.replace_keyword(keyword, replacement)
386
- self.assertTrue(
387
- any(replacement in message for message in self.branch.messages["content"])
388
- )
389
-
390
- def test_search_keywords(self):
391
- """Test filtering messages by a specified keyword or list of keywords."""
392
- keyword_to_search = "Assistant"
393
- filtered_messages = self.branch.search_keywords(keyword_to_search)
394
- self.assertTrue(
395
- all(
396
- keyword_to_search in content for content in filtered_messages["content"]
397
- )
398
- )
399
-
400
- def test_extend(self):
401
- """Test extending branch messages with additional messages."""
402
- additional_messages = pd.DataFrame(
403
- [
404
- {
405
- "node_id": "6",
406
- "timestamp": datetime(2021, 1, 1),
407
- "role": "user",
408
- "sender": "user2",
409
- "content": json.dumps({"instruction": "Another user message"}),
410
- }
411
- ]
412
- )
413
- initial_length = len(self.branch.messages)
414
- self.branch.extend(additional_messages)
415
- self.assertEqual(len(self.branch.messages), initial_length + 1)
416
-
417
- def test_filter_by(self):
418
- """Test filtering messages by various criteria."""
419
- # Filtering by role
420
- filtered_messages = self.branch.filter_by(role="user")
421
- self.assertTrue(
422
- all(msg["role"] == "user" for _, msg in filtered_messages.iterrows())
423
- )
424
-
425
-
426
- if __name__ == "__main__":
427
- unittest.main()
1
+ # from lionagi.core.branch.base_branch import BaseBranch
2
+ # from lionagi.core.branch.util import MessageUtil
3
+ # from lionagi.core.messages.schema import System
4
+ # from lionagi.core.schema import DataLogger
5
+
6
+
7
+ # import unittest
8
+ # from unittest.mock import patch, MagicMock
9
+ # import pandas as pd
10
+ # from datetime import datetime
11
+ # import json
12
+
13
+
14
+ # class TestBaseBranch(unittest.TestCase):
15
+
16
+ # def setUp(self):
17
+ # # Patching DataLogger to avoid filesystem interactions during tests
18
+ # self.patcher = patch(
19
+ # "lionagi.core.branch.base_branch.DataLogger", autospec=True
20
+ # )
21
+ # self.MockDataLogger = self.patcher.start()
22
+ # self.addCleanup(self.patcher.stop)
23
+
24
+ # self.branch = BaseBranch()
25
+ # # Example messages to populate the DataFrame for testing
26
+ # self.test_messages = [
27
+ # {
28
+ # "node_id": "1",
29
+ # "timestamp": "2021-01-01 00:00:00",
30
+ # "role": "system",
31
+ # "sender": "system",
32
+ # "content": json.dumps({"system_info": "System message"}),
33
+ # },
34
+ # {
35
+ # "node_id": "2",
36
+ # "timestamp": "2021-01-01 00:01:00",
37
+ # "role": "user",
38
+ # "sender": "user1",
39
+ # "content": json.dumps({"instruction": "User message"}),
40
+ # },
41
+ # {
42
+ # "node_id": "3",
43
+ # "timestamp": "2021-01-01 00:02:00",
44
+ # "role": "assistant",
45
+ # "sender": "assistant",
46
+ # "content": json.dumps({"response": "Assistant response"}),
47
+ # },
48
+ # {
49
+ # "node_id": "4",
50
+ # "timestamp": "2021-01-01 00:03:00",
51
+ # "role": "assistant",
52
+ # "sender": "action_request",
53
+ # "content": json.dumps({"action_request": "Action request"}),
54
+ # },
55
+ # {
56
+ # "node_id": "5",
57
+ # "timestamp": "2021-01-01 00:04:00",
58
+ # "role": "assistant",
59
+ # "sender": "action_response",
60
+ # "content": json.dumps({"action_response": "Action response"}),
61
+ # },
62
+ # ]
63
+ # self.branch = BaseBranch(messages=pd.DataFrame(self.test_messages))
64
+
65
+ # def test_init_with_empty_messages(self):
66
+ # """Test __init__ method with no messages provided."""
67
+ # branch = BaseBranch()
68
+ # self.assertTrue(branch.messages.empty)
69
+
70
+ # def test_init_with_given_messages(self):
71
+ # """Test __init__ method with provided messages DataFrame."""
72
+ # messages = pd.DataFrame(
73
+ # [
74
+ # [
75
+ # "0",
76
+ # datetime(2024, 1, 1),
77
+ # "system",
78
+ # "system",
79
+ # json.dumps({"system_info": "Hi"}),
80
+ # ]
81
+ # ],
82
+ # columns=["node_id", "timestamp", "role", "sender", "content"],
83
+ # )
84
+ # branch = BaseBranch(messages=messages)
85
+ # self.assertFalse(branch.messages.empty)
86
+
87
+ # def test_add_message(self):
88
+ # """Test adding a message."""
89
+ # message_info = {"info": "Test Message"}
90
+ # with patch.object(
91
+ # MessageUtil, "create_message", return_value=System(system=message_info)
92
+ # ) as mocked_create_message:
93
+ # self.branch.add_message(system=message_info)
94
+ # mocked_create_message.assert_called_once()
95
+ # self.assertEqual(len(self.branch.messages), 6)
96
+
97
+ # def test_chat_messages_without_sender(self):
98
+ # """Test chat_messages property without including sender information."""
99
+ # chat_messages = self.branch.chat_messages
100
+ # self.assertEqual(len(chat_messages), 5)
101
+ # self.assertNotIn("Sender", chat_messages[0]["content"])
102
+
103
+ # def test_chat_messages_with_sender(self):
104
+ # """Test retrieving chat messages with sender information included."""
105
+ # chat_messages_with_sender = self.branch.chat_messages_with_sender
106
+ # expected_prefixes = [
107
+ # "Sender system: ",
108
+ # "Sender user1: ",
109
+ # "Sender assistant: ",
110
+ # "Sender action_request: ",
111
+ # "Sender action_response: ",
112
+ # ]
113
+
114
+ # # Verify the number of messages returned matches the number added
115
+ # self.assertEqual(len(chat_messages_with_sender), len(self.test_messages))
116
+
117
+ # # Check each message for correct sender prefix in content
118
+ # for i, message_dict in enumerate(chat_messages_with_sender):
119
+ # self.assertTrue(
120
+ # message_dict["content"].startswith(expected_prefixes[i]),
121
+ # msg=f"Message content does not start with expected sender prefix. Found: {message_dict['content']}",
122
+ # )
123
+
124
+ # # Optionally, verify the content matches after removing the prefix
125
+ # for i, message_dict in enumerate(chat_messages_with_sender):
126
+ # prefix, content = message_dict["content"].split(": ", 1)
127
+ # self.assertEqual(
128
+ # content,
129
+ # self.test_messages[i]["content"],
130
+ # msg=f"Message content does not match after removing sender prefix. Found: {content}, Expected: {self.test_messages[i]['content']}",
131
+ # )
132
+
133
+ # def test_last_message(self):
134
+ # """Test retrieving the last message."""
135
+ # last_message = self.branch.last_message
136
+ # self.assertEqual(
137
+ # json.loads(last_message.iloc[0]["content"]),
138
+ # {"action_response": "Action response"},
139
+ # )
140
+
141
+ # def test_last_message_content(self):
142
+ # """Test retrieving the content of the last message."""
143
+ # content = self.branch.last_message_content
144
+ # self.assertEqual(content, {"action_response": "Action response"})
145
+
146
+ # def test_first_system_message(self):
147
+ # """Test retrieving the first 'system' message."""
148
+ # first_system_message = self.branch.first_system
149
+ # self.assertEqual(
150
+ # json.loads(first_system_message.iloc[0]["content"]),
151
+ # {"system_info": "System message"},
152
+ # )
153
+
154
+ # def test_last_response(self):
155
+ # """Test retrieving the last 'assistant' message."""
156
+ # last_response = self.branch.last_response
157
+ # self.assertEqual(last_response.iloc[0]["sender"], "action_response")
158
+
159
+ # def test_last_response_content(self):
160
+ # """Test extracting content of the last 'assistant' message."""
161
+ # content = self.branch.last_response_content
162
+ # self.assertEqual(content, {"action_response": "Action response"})
163
+
164
+ # def test_action_request_messages(self):
165
+ # """Test filtering messages sent by 'action_request'."""
166
+ # action_requests = self.branch.action_request
167
+ # self.assertTrue(all(action_requests["sender"] == "action_request"))
168
+
169
+ # def test_action_response_messages(self):
170
+ # """Test filtering messages sent by 'action_response'."""
171
+ # action_responses = self.branch.action_response
172
+ # self.assertTrue(all(action_responses["sender"] == "action_response"))
173
+
174
+ # def test_responses(self):
175
+ # """Test filtering of 'assistant' role messages."""
176
+ # responses = self.branch.responses
177
+ # # Verify that all returned messages have the 'assistant' role
178
+ # self.assertTrue(all(responses["role"] == "assistant"))
179
+ # # Optionally, check the count matches the expected number of 'assistant' messages
180
+ # expected_count = sum(
181
+ # 1 for msg in self.test_messages if msg["role"] == "assistant"
182
+ # )
183
+ # self.assertEqual(len(responses), expected_count)
184
+
185
+ # def test_assistant_responses(self):
186
+ # """Test filtering of 'assistant' messages excluding action requests/responses."""
187
+ # assistant_responses = self.branch.assistant_responses
188
+ # # Verify that no returned messages are from 'action_request' or 'action_response' senders
189
+ # self.assertTrue(all(assistant_responses["sender"] != "action_request"))
190
+ # self.assertTrue(all(assistant_responses["sender"] != "action_response"))
191
+ # # Verify that all returned messages have the 'assistant' role
192
+ # self.assertTrue(all(assistant_responses["role"] == "assistant"))
193
+
194
+ # def test_info(self):
195
+ # """Test summarization of message counts by role."""
196
+ # info = self.branch.info
197
+ # # Verify that the dictionary contains keys for each role and 'total'
198
+ # self.assertIn("assistant", info)
199
+ # self.assertIn("user", info)
200
+ # self.assertIn("system", info)
201
+ # self.assertIn("total", info)
202
+ # # Optionally, verify the counts match expected values
203
+ # for role in ["assistant", "user", "system"]:
204
+ # expected_count = sum(1 for msg in self.test_messages if msg["role"] == role)
205
+ # self.assertEqual(info[role], expected_count)
206
+ # self.assertEqual(info["total"], len(self.test_messages))
207
+
208
+ # def test_sender_info(self):
209
+ # """Test summarization of message counts by sender."""
210
+ # sender_info = self.branch.sender_info
211
+ # # Verify that the dictionary contains keys for each sender and the counts match
212
+ # for sender in set(msg["sender"] for msg in self.test_messages):
213
+ # expected_count = sum(
214
+ # 1 for msg in self.test_messages if msg["sender"] == sender
215
+ # )
216
+ # self.assertEqual(sender_info.get(sender, 0), expected_count)
217
+
218
+ # def test_describe(self):
219
+ # """Test detailed description of the branch."""
220
+ # description = self.branch.describe
221
+ # # Verify that the description contains expected keys
222
+ # self.assertIn("total_messages", description)
223
+ # self.assertIn("summary_by_role", description)
224
+ # self.assertIn("messages", description)
225
+ # # Optionally, verify the accuracy of the values
226
+ # self.assertEqual(description["total_messages"], len(self.test_messages))
227
+ # self.assertEqual(len(description["messages"]), min(5, len(self.test_messages)))
228
+
229
+ # @patch("lionagi.libs.ln_dataframe.read_csv")
230
+ # def test_from_csv(cls, mock_read_csv):
231
+ # # Define a mock return value for read_csv
232
+ # mock_messages_df = pd.DataFrame(
233
+ # {
234
+ # "node_id": ["1", "2"],
235
+ # "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
236
+ # "role": ["system", "user"],
237
+ # "sender": ["system", "user1"],
238
+ # "content": [
239
+ # json.dumps({"system_info": "System message"}),
240
+ # json.dumps({"instruction": "User message"}),
241
+ # ],
242
+ # }
243
+ # )
244
+ # mock_read_csv.return_value = mock_messages_df
245
+
246
+ # # Call the from_csv class method
247
+ # branch = BaseBranch.from_csv(filename="dummy.csv")
248
+
249
+ # # Verify that read_csv was called correctly
250
+ # mock_read_csv.assert_called_once_with("dummy.csv")
251
+
252
+ # # Verify that the branch instance contains the correct messages
253
+ # pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
254
+
255
+ # @patch("lionagi.libs.ln_dataframe.read_json")
256
+ # def test_from_json(cls, mock_read_csv):
257
+ # # Define a mock return value for read_csv
258
+ # mock_messages_df = pd.DataFrame(
259
+ # {
260
+ # "node_id": ["1", "2"],
261
+ # "timestamp": [datetime(2021, 1, 1), datetime(2021, 1, 1)],
262
+ # "role": ["system", "user"],
263
+ # "sender": ["system", "user1"],
264
+ # "content": [
265
+ # json.dumps({"system_info": "System message"}),
266
+ # json.dumps({"instruction": "User message"}),
267
+ # ],
268
+ # }
269
+ # )
270
+ # mock_read_csv.return_value = mock_messages_df
271
+
272
+ # # Call the from_csv class method
273
+ # branch = BaseBranch.from_json_string(filename="dummy.json")
274
+
275
+ # # Verify that read_csv was called correctly
276
+ # mock_read_csv.assert_called_once_with("dummy.json")
277
+
278
+ # # Verify that the branch instance contains the correct messages
279
+ # pd.testing.assert_frame_equal(branch.messages, mock_messages_df)
280
+
281
+ # @patch(
282
+ # "lionagi.libs.sys_util.SysUtil.create_path",
283
+ # return_value="path/to/messages.csv",
284
+ # )
285
+ # @patch.object(pd.DataFrame, "to_csv")
286
+ # def test_to_csv(self, mock_to_csv, mock_create_path):
287
+ # self.branch.datalogger = MagicMock()
288
+ # self.branch.datalogger.persist_path = "data/logs/"
289
+
290
+ # self.branch.to_csv_file("messages.csv", verbose=False, clear=False)
291
+
292
+ # mock_create_path.assert_called_once_with(
293
+ # self.branch.datalogger.persist_path,
294
+ # "messages.csv",
295
+ # timestamp=True,
296
+ # dir_exist_ok=True,
297
+ # time_prefix=False,
298
+ # )
299
+ # mock_to_csv.assert_called_once_with("path/to/messages.csv")
300
+
301
+ # # Verify that messages are not cleared after exporting
302
+ # assert not self.branch.messages.empty
303
+
304
+ # @patch(
305
+ # "lionagi.libs.sys_util.SysUtil.create_path",
306
+ # return_value="path/to/messages.json",
307
+ # )
308
+ # @patch.object(pd.DataFrame, "to_json")
309
+ # def test_to_json(self, mock_to_json, mock_create_path):
310
+ # self.branch.datalogger = MagicMock()
311
+ # self.branch.datalogger.persist_path = "data/logs/"
312
+
313
+ # self.branch.to_json_file("messages.json", verbose=False, clear=False)
314
+
315
+ # mock_create_path.assert_called_once_with(
316
+ # self.branch.datalogger.persist_path,
317
+ # "messages.json",
318
+ # timestamp=True,
319
+ # dir_exist_ok=True,
320
+ # time_prefix=False,
321
+ # )
322
+ # mock_to_json.assert_called_once_with(
323
+ # "path/to/messages.json", orient="records", lines=True, date_format="iso"
324
+ # )
325
+
326
+ # # Verify that messages are not cleared after exporting
327
+ # assert not self.branch.messages.empty
328
+
329
+ # # def test_log_to_csv(self):
330
+ # # self.branch.log_to_csv('log.csv', verbose=False, clear=False)
331
+
332
+ # # self.branch.datalogger.to_csv_file.assert_called_once_with(filename='log.csv', dir_exist_ok=True, timestamp=True,
333
+ # # time_prefix=False, verbose=False, clear=False)
334
+
335
+ # # def test_log_to_json(self):
336
+ # # branch = BaseBranch()
337
+ # # branch.log_to_json('log.json', verbose=False, clear=False)
338
+
339
+ # # self.branch.datalogger.to_json_file.assert_called_once_with(filename='log.json', dir_exist_ok=True, timestamp=True,
340
+ # # time_prefix=False, verbose=False, clear=False)
341
+
342
+ # def test_remove_message(self):
343
+ # """Test removing a message from the branch based on its node ID."""
344
+ # initial_length = len(self.branch.messages)
345
+ # node_id_to_remove = "2"
346
+ # self.branch.remove_message(node_id_to_remove)
347
+ # final_length = len(self.branch.messages)
348
+ # self.assertNotIn(node_id_to_remove, self.branch.messages["node_id"].tolist())
349
+ # self.assertEqual(final_length, initial_length - 1)
350
+
351
+ # def test_update_message(self):
352
+ # """Test updating a specific message's column identified by node_id with a new value."""
353
+ # node_id_to_update = "2"
354
+ # new_value = "Updated content"
355
+ # self.branch.update_message(node_id_to_update, "content", new_value)
356
+ # updated_value = self.branch.messages.loc[
357
+ # self.branch.messages["node_id"] == node_id_to_update, "content"
358
+ # ].values[0]
359
+ # self.assertEqual(updated_value, new_value)
360
+
361
+ # # def test_change_first_system_message(self):
362
+ # # """Test updating the first system message with new content and/or sender."""
363
+ # # new_system_content = {"system_info": "Updated system message"}
364
+ # # self.branch.change_first_system_message(new_system_content['system_info'])
365
+ # # first_system_message_content = self.branch.messages.loc[self.branch.messages['role'] == 'system', 'content'].iloc[0]
366
+ # # self.assertIn(json.dumps({"system_info": "Updated system message"}), first_system_message_content)
367
+
368
+ # def test_rollback(self):
369
+ # """Test removing the last 'n' messages from the branch."""
370
+ # steps_to_remove = 2
371
+ # initial_length = len(self.branch.messages)
372
+ # self.branch.rollback(steps_to_remove)
373
+ # final_length = len(self.branch.messages)
374
+ # self.assertEqual(final_length, initial_length - steps_to_remove)
375
+
376
+ # def test_clear_messages(self):
377
+ # """Test clearing all messages from the branch."""
378
+ # self.branch.clear_messages()
379
+ # self.assertTrue(self.branch.messages.empty)
380
+
381
+ # def test_replace_keyword(self):
382
+ # """Test replacing occurrences of a specified keyword with a replacement string."""
383
+ # keyword = "Assistant response"
384
+ # replacement = "Helper feedback"
385
+ # self.branch.replace_keyword(keyword, replacement)
386
+ # self.assertTrue(
387
+ # any(replacement in message for message in self.branch.messages["content"])
388
+ # )
389
+
390
+ # def test_search_keywords(self):
391
+ # """Test filtering messages by a specified keyword or list of keywords."""
392
+ # keyword_to_search = "Assistant"
393
+ # filtered_messages = self.branch.search_keywords(keyword_to_search)
394
+ # self.assertTrue(
395
+ # all(
396
+ # keyword_to_search in content for content in filtered_messages["content"]
397
+ # )
398
+ # )
399
+
400
+ # def test_extend(self):
401
+ # """Test extending branch messages with additional messages."""
402
+ # additional_messages = pd.DataFrame(
403
+ # [
404
+ # {
405
+ # "node_id": "6",
406
+ # "timestamp": datetime(2021, 1, 1),
407
+ # "role": "user",
408
+ # "sender": "user2",
409
+ # "content": json.dumps({"instruction": "Another user message"}),
410
+ # }
411
+ # ]
412
+ # )
413
+ # initial_length = len(self.branch.messages)
414
+ # self.branch.extend(additional_messages)
415
+ # self.assertEqual(len(self.branch.messages), initial_length + 1)
416
+
417
+ # def test_filter_by(self):
418
+ # """Test filtering messages by various criteria."""
419
+ # # Filtering by role
420
+ # filtered_messages = self.branch.filter_by(role="user")
421
+ # self.assertTrue(
422
+ # all(msg["role"] == "user" for _, msg in filtered_messages.iterrows())
423
+ # )
424
+
425
+
426
+ # if __name__ == "__main__":
427
+ # unittest.main()