lionagi 0.0.115__py3-none-any.whl → 0.0.204__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (123) hide show
  1. lionagi/__init__.py +1 -2
  2. lionagi/_services/__init__.py +5 -0
  3. lionagi/_services/anthropic.py +79 -0
  4. lionagi/_services/base_service.py +414 -0
  5. lionagi/_services/oai.py +98 -0
  6. lionagi/_services/openrouter.py +44 -0
  7. lionagi/_services/services.py +91 -0
  8. lionagi/_services/transformers.py +46 -0
  9. lionagi/bridge/langchain.py +26 -16
  10. lionagi/bridge/llama_index.py +50 -20
  11. lionagi/configs/oai_configs.py +2 -14
  12. lionagi/configs/openrouter_configs.py +2 -2
  13. lionagi/core/__init__.py +7 -8
  14. lionagi/core/branch/branch.py +589 -0
  15. lionagi/core/branch/branch_manager.py +139 -0
  16. lionagi/core/branch/conversation.py +484 -0
  17. lionagi/core/core_util.py +59 -0
  18. lionagi/core/flow/flow.py +19 -0
  19. lionagi/core/flow/flow_util.py +62 -0
  20. lionagi/core/instruction_set/__init__.py +0 -5
  21. lionagi/core/instruction_set/instruction_set.py +343 -0
  22. lionagi/core/messages/messages.py +176 -0
  23. lionagi/core/sessions/__init__.py +0 -5
  24. lionagi/core/sessions/session.py +428 -0
  25. lionagi/loaders/chunker.py +51 -47
  26. lionagi/loaders/load_util.py +2 -2
  27. lionagi/loaders/reader.py +45 -39
  28. lionagi/models/imodel.py +53 -0
  29. lionagi/schema/async_queue.py +158 -0
  30. lionagi/schema/base_node.py +318 -147
  31. lionagi/schema/base_tool.py +31 -1
  32. lionagi/schema/data_logger.py +74 -38
  33. lionagi/schema/data_node.py +57 -6
  34. lionagi/structures/graph.py +132 -10
  35. lionagi/structures/relationship.py +58 -20
  36. lionagi/structures/structure.py +36 -25
  37. lionagi/tests/test_utils/test_api_util.py +219 -0
  38. lionagi/tests/test_utils/test_call_util.py +785 -0
  39. lionagi/tests/test_utils/test_encrypt_util.py +323 -0
  40. lionagi/tests/test_utils/test_io_util.py +238 -0
  41. lionagi/tests/test_utils/test_nested_util.py +338 -0
  42. lionagi/tests/test_utils/test_sys_util.py +358 -0
  43. lionagi/tools/tool_manager.py +186 -0
  44. lionagi/tools/tool_util.py +266 -3
  45. lionagi/utils/__init__.py +21 -61
  46. lionagi/utils/api_util.py +359 -71
  47. lionagi/utils/call_util.py +839 -264
  48. lionagi/utils/encrypt_util.py +283 -16
  49. lionagi/utils/io_util.py +178 -93
  50. lionagi/utils/nested_util.py +672 -0
  51. lionagi/utils/pd_util.py +57 -0
  52. lionagi/utils/sys_util.py +284 -156
  53. lionagi/utils/url_util.py +55 -0
  54. lionagi/version.py +1 -1
  55. {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/METADATA +21 -17
  56. lionagi-0.0.204.dist-info/RECORD +106 -0
  57. lionagi/core/conversations/__init__.py +0 -5
  58. lionagi/core/conversations/conversation.py +0 -107
  59. lionagi/core/flows/__init__.py +0 -8
  60. lionagi/core/flows/flow.py +0 -8
  61. lionagi/core/flows/flow_util.py +0 -62
  62. lionagi/core/instruction_set/instruction_sets.py +0 -7
  63. lionagi/core/sessions/sessions.py +0 -185
  64. lionagi/endpoints/__init__.py +0 -5
  65. lionagi/endpoints/audio.py +0 -17
  66. lionagi/endpoints/chatcompletion.py +0 -54
  67. lionagi/messages/__init__.py +0 -11
  68. lionagi/messages/instruction.py +0 -15
  69. lionagi/messages/message.py +0 -110
  70. lionagi/messages/response.py +0 -33
  71. lionagi/messages/system.py +0 -12
  72. lionagi/objs/__init__.py +0 -11
  73. lionagi/objs/abc_objs.py +0 -39
  74. lionagi/objs/async_queue.py +0 -135
  75. lionagi/objs/messenger.py +0 -85
  76. lionagi/objs/tool_manager.py +0 -253
  77. lionagi/services/__init__.py +0 -11
  78. lionagi/services/base_api_service.py +0 -230
  79. lionagi/services/oai.py +0 -34
  80. lionagi/services/openrouter.py +0 -31
  81. lionagi/tests/test_api_util.py +0 -46
  82. lionagi/tests/test_call_util.py +0 -115
  83. lionagi/tests/test_convert_util.py +0 -202
  84. lionagi/tests/test_encrypt_util.py +0 -33
  85. lionagi/tests/test_flat_util.py +0 -426
  86. lionagi/tests/test_sys_util.py +0 -0
  87. lionagi/utils/convert_util.py +0 -229
  88. lionagi/utils/flat_util.py +0 -599
  89. lionagi-0.0.115.dist-info/RECORD +0 -110
  90. /lionagi/{services → _services}/anyscale.py +0 -0
  91. /lionagi/{services → _services}/azure.py +0 -0
  92. /lionagi/{services → _services}/bedrock.py +0 -0
  93. /lionagi/{services → _services}/everlyai.py +0 -0
  94. /lionagi/{services → _services}/gemini.py +0 -0
  95. /lionagi/{services → _services}/gpt4all.py +0 -0
  96. /lionagi/{services → _services}/huggingface.py +0 -0
  97. /lionagi/{services → _services}/litellm.py +0 -0
  98. /lionagi/{services → _services}/localai.py +0 -0
  99. /lionagi/{services → _services}/mistralai.py +0 -0
  100. /lionagi/{services → _services}/ollama.py +0 -0
  101. /lionagi/{services → _services}/openllm.py +0 -0
  102. /lionagi/{services → _services}/perplexity.py +0 -0
  103. /lionagi/{services → _services}/predibase.py +0 -0
  104. /lionagi/{services → _services}/rungpt.py +0 -0
  105. /lionagi/{services → _services}/vllm.py +0 -0
  106. /lionagi/{services → _services}/xinference.py +0 -0
  107. /lionagi/{endpoints/assistants.py → agents/__init__.py} +0 -0
  108. /lionagi/{tools → agents}/planner.py +0 -0
  109. /lionagi/{tools → agents}/prompter.py +0 -0
  110. /lionagi/{tools → agents}/scorer.py +0 -0
  111. /lionagi/{tools → agents}/summarizer.py +0 -0
  112. /lionagi/{tools → agents}/validator.py +0 -0
  113. /lionagi/{endpoints/embeddings.py → core/branch/__init__.py} +0 -0
  114. /lionagi/{services/anthropic.py → core/branch/cluster.py} +0 -0
  115. /lionagi/{endpoints/finetune.py → core/flow/__init__.py} +0 -0
  116. /lionagi/{endpoints/image.py → core/messages/__init__.py} +0 -0
  117. /lionagi/{endpoints/moderation.py → models/__init__.py} +0 -0
  118. /lionagi/{endpoints/vision.py → models/base_model.py} +0 -0
  119. /lionagi/{objs → schema}/status_tracker.py +0 -0
  120. /lionagi/tests/{test_io_util.py → test_utils/__init__.py} +0 -0
  121. {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/LICENSE +0 -0
  122. {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/WHEEL +0 -0
  123. {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
1
+ import unittest
2
+ import os
3
+ import tempfile
4
+ import zipfile
5
+ import hashlib
6
+ from cryptography.fernet import InvalidToken
7
+ from lionagi.utils.encrypt_util import EncrytionUtil
8
+
9
+
10
+ class TestPasswordStrengthChecker(unittest.TestCase):
11
+
12
+ def test_short_passwords(self):
13
+ self.assertFalse(EncrytionUtil.password_strength_checker("Short1"))
14
+ self.assertFalse(EncrytionUtil.password_strength_checker("A1"))
15
+
16
+ def test_passwords_without_digits(self):
17
+ self.assertFalse(EncrytionUtil.password_strength_checker("NoDigitsHere"))
18
+ self.assertFalse(EncrytionUtil.password_strength_checker("Onlyletters!"))
19
+
20
+ def test_passwords_without_uppercase(self):
21
+ self.assertFalse(EncrytionUtil.password_strength_checker("alllowercase1"))
22
+ self.assertFalse(EncrytionUtil.password_strength_checker("nouppercase1!"))
23
+
24
+ def test_strong_passwords(self):
25
+ self.assertTrue(EncrytionUtil.password_strength_checker("ValidPass1"))
26
+ self.assertTrue(EncrytionUtil.password_strength_checker("AnotherGood1"))
27
+
28
+
29
+ class TestGenerateEncryptionKey(unittest.TestCase):
30
+
31
+ def setUp(self):
32
+ # Strong password and predefined salt for testing
33
+ self.strong_password = "StrongPass1"
34
+ self.salt = b'0123456789abcdef'
35
+
36
+ def test_with_strong_password_and_provided_salt(self):
37
+ key = EncrytionUtil.generate_encryption_key(password=self.strong_password, salt=self.salt)
38
+ self.assertIsInstance(key, str)
39
+
40
+ def test_with_strong_password_and_no_salt(self):
41
+ key = EncrytionUtil.generate_encryption_key(password=self.strong_password)
42
+ self.assertIsInstance(key, str)
43
+
44
+ def test_with_weak_password(self):
45
+ with self.assertRaises(ValueError):
46
+ EncrytionUtil.generate_encryption_key(password="weak")
47
+
48
+ def test_with_no_password(self):
49
+ key = EncrytionUtil.generate_encryption_key()
50
+ self.assertIsInstance(key, str)
51
+ self.assertEqual(len(key), 44) # Typical length of a Fernet key
52
+
53
+
54
+ class TestEncrypt(unittest.TestCase):
55
+
56
+ def setUp(self):
57
+ self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
58
+ self.invalid_key = "invalidkey"
59
+ self.test_data = "This is a test string."
60
+
61
+ def test_valid_encryption(self):
62
+ encrypted_data = EncrytionUtil.encrypt(self.test_data, self.valid_key)
63
+ self.assertIsInstance(encrypted_data, str)
64
+ self.assertNotEqual(encrypted_data, self.test_data)
65
+
66
+ def test_with_invalid_key(self):
67
+ with self.assertRaises(Exception):
68
+ EncrytionUtil.encrypt(self.test_data, self.invalid_key)
69
+
70
+ def test_with_non_string_data(self):
71
+ with self.assertRaises(AttributeError): # or whichever error is appropriate
72
+ EncrytionUtil.encrypt(12345, self.valid_key)
73
+
74
+ def test_with_empty_string(self):
75
+ encrypted_data = EncrytionUtil.encrypt("", self.valid_key)
76
+ self.assertIsInstance(encrypted_data, str)
77
+ self.assertNotEqual(encrypted_data, "")
78
+
79
+
80
+ class TestDecrypt(unittest.TestCase):
81
+
82
+ def setUp(self):
83
+ self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
84
+ self.invalid_key = "invalidkey"
85
+ self.test_data = "This is a test string."
86
+ self.encrypted_data = EncrytionUtil.encrypt(self.test_data, self.valid_key)
87
+
88
+ def test_valid_decryption(self):
89
+ decrypted_data = EncrytionUtil.decrypt(self.encrypted_data, self.valid_key)
90
+ self.assertIsInstance(decrypted_data, str)
91
+ self.assertEqual(decrypted_data, self.test_data)
92
+
93
+ def test_decryption_with_invalid_key(self):
94
+ with self.assertRaises(ValueError):
95
+ EncrytionUtil.decrypt(self.encrypted_data, self.invalid_key)
96
+
97
+ def test_with_non_string_data(self):
98
+ with self.assertRaises(AttributeError):
99
+ EncrytionUtil.decrypt(12345, self.valid_key)
100
+
101
+ def test_decryption_of_non_encrypted_string(self):
102
+ with self.assertRaises(InvalidToken):
103
+ EncrytionUtil.decrypt("plain text", self.valid_key)
104
+
105
+
106
+ class TestEncryptFile(unittest.TestCase):
107
+
108
+ def setUp(self):
109
+ self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
110
+ # Create a temporary file with some test data
111
+ self.temp_file = tempfile.NamedTemporaryFile(delete=False)
112
+ self.temp_file.write(b"This is a test file.")
113
+ self.temp_file.close()
114
+
115
+ def tearDown(self):
116
+ # Cleanup: Remove temporary files
117
+ os.remove(self.temp_file.name)
118
+ if os.path.exists(self.temp_file.name + '.enc'):
119
+ os.remove(self.temp_file.name + '.enc')
120
+
121
+ def test_encrypting_valid_file(self):
122
+ EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
123
+ self.assertTrue(os.path.exists(self.temp_file.name + '.enc'))
124
+
125
+ def test_with_non_existent_file_path(self):
126
+ with self.assertRaises(FileNotFoundError):
127
+ EncrytionUtil.encrypt_file("non_existent_file.txt", self.valid_key)
128
+
129
+
130
+ class TestDecryptFile(unittest.TestCase):
131
+
132
+ def setUp(self):
133
+ self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
134
+ # Create a temporary file and encrypt it
135
+ self.temp_file = tempfile.NamedTemporaryFile(delete=False)
136
+ self.temp_file.write(b"This is a test file.")
137
+ self.temp_file.close()
138
+ EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
139
+
140
+ def tearDown(self):
141
+ # Cleanup: Remove temporary files
142
+ os.remove(self.temp_file.name)
143
+ if os.path.exists(self.temp_file.name + '.enc'):
144
+ os.remove(self.temp_file.name + '.enc')
145
+ decrypted_file_path = self.temp_file.name.replace('.enc', '')
146
+ if os.path.exists(decrypted_file_path):
147
+ os.remove(decrypted_file_path)
148
+
149
+ def test_decrypting_valid_encrypted_file(self):
150
+ EncrytionUtil.decrypt_file(self.temp_file.name + '.enc', self.valid_key)
151
+ decrypted_file_path = self.temp_file.name.replace('.enc', '')
152
+ self.assertTrue(os.path.exists(decrypted_file_path))
153
+
154
+ def test_with_non_existent_encrypted_file_path(self):
155
+ with self.assertRaises(FileNotFoundError):
156
+ EncrytionUtil.decrypt_file("non_existent_file.txt.enc", self.valid_key)
157
+
158
+ def test_decrypting_with_invalid_key(self):
159
+ with self.assertRaises(ValueError):
160
+ EncrytionUtil.decrypt_file(self.temp_file.name + '.enc', 'invalidkey')
161
+
162
+
163
+ class TestIsEncrypted(unittest.TestCase):
164
+
165
+ def setUp(self):
166
+ self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
167
+ # Create a temporary file and encrypt it
168
+ self.temp_file = tempfile.NamedTemporaryFile(delete=False)
169
+ self.temp_file.write(b"This is a test file.")
170
+ self.temp_file.close()
171
+ EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
172
+
173
+ # Create a non-encrypted temporary file
174
+ self.non_encrypted_file = tempfile.NamedTemporaryFile(delete=False)
175
+ self.non_encrypted_file.write(b"This is a non-encrypted test file.")
176
+ self.non_encrypted_file.close()
177
+
178
+ def tearDown(self):
179
+ # Cleanup: Remove temporary files
180
+ os.remove(self.temp_file.name)
181
+ if os.path.exists(self.temp_file.name + '.enc'):
182
+ os.remove(self.temp_file.name + '.enc')
183
+ os.remove(self.non_encrypted_file.name)
184
+
185
+ def test_with_encrypted_file(self):
186
+ self.assertTrue(EncrytionUtil.is_encrypted(self.temp_file.name + '.enc', self.valid_key))
187
+
188
+ def test_with_non_encrypted_file(self):
189
+ self.assertFalse(EncrytionUtil.is_encrypted(self.non_encrypted_file.name, self.valid_key))
190
+
191
+
192
+ class TestDecompressFile(unittest.TestCase):
193
+
194
+ def setUp(self):
195
+ # Create a temporary directory
196
+ self.temp_dir = tempfile.mkdtemp()
197
+
198
+ # Create a temporary zip file with some content
199
+ self.temp_zip_file = os.path.join(self.temp_dir, 'test.zip')
200
+ with zipfile.ZipFile(self.temp_zip_file, 'w') as zipf:
201
+ zipf.writestr('test.txt', 'This is a test file.')
202
+
203
+ def tearDown(self):
204
+ # Cleanup: Remove temporary directory and its contents
205
+ for root, dirs, files in os.walk(self.temp_dir, topdown=False):
206
+ for name in files:
207
+ os.remove(os.path.join(root, name))
208
+ for name in dirs:
209
+ os.rmdir(os.path.join(root, name))
210
+ os.rmdir(self.temp_dir)
211
+
212
+ def test_decompressing_valid_zip_file(self):
213
+ EncrytionUtil.decompress_file(self.temp_zip_file, self.temp_dir)
214
+ self.assertTrue(os.path.exists(os.path.join(self.temp_dir, 'test.txt')))
215
+
216
+ def test_with_non_existent_zip_file_path(self):
217
+ with self.assertRaises(FileNotFoundError):
218
+ EncrytionUtil.decompress_file("non_existent_file.zip", self.temp_dir)
219
+
220
+ def test_with_invalid_zip_file(self):
221
+ invalid_zip_file = os.path.join(self.temp_dir, 'invalid.zip')
222
+ with open(invalid_zip_file, 'w') as f:
223
+ f.write("This is not a zip file.")
224
+ with self.assertRaises(zipfile.BadZipFile):
225
+ EncrytionUtil.decompress_file(invalid_zip_file, self.temp_dir)
226
+
227
+
228
+ class TestCompressFile(unittest.TestCase):
229
+
230
+ def setUp(self):
231
+ # Create a temporary file with some test data
232
+ self.temp_file = tempfile.NamedTemporaryFile(delete=False)
233
+ self.temp_file.write(b"This is a test file.")
234
+ self.temp_file.close()
235
+
236
+ def tearDown(self):
237
+ # Cleanup: Remove temporary files
238
+ os.remove(self.temp_file.name)
239
+ if os.path.exists(self.temp_file.name + '.zip'):
240
+ os.remove(self.temp_file.name + '.zip')
241
+
242
+ def test_compressing_valid_file(self):
243
+ EncrytionUtil.compress_file(self.temp_file.name)
244
+ self.assertTrue(os.path.exists(self.temp_file.name + '.zip'))
245
+
246
+ def test_with_non_existent_file_path(self):
247
+ with self.assertRaises(FileNotFoundError):
248
+ EncrytionUtil.compress_file("non_existent_file.txt")
249
+
250
+
251
+ class TestBinaryToHex(unittest.TestCase):
252
+
253
+ def test_with_valid_binary_data(self):
254
+ # Test a variety of binary data
255
+ self.assertEqual(EncrytionUtil.binary_to_hex(b'\x00\x0F'), '000f')
256
+ self.assertEqual(EncrytionUtil.binary_to_hex(b'hello'), '68656c6c6f')
257
+ self.assertEqual(EncrytionUtil.binary_to_hex(b'\xff\xfe\xfd\xfc'), 'fffefdfc')
258
+
259
+ def test_with_empty_bytes(self):
260
+ self.assertEqual(EncrytionUtil.binary_to_hex(b''), '')
261
+
262
+
263
+ class TestCreateHash(unittest.TestCase):
264
+
265
+ def test_hashing_with_default_algorithm(self):
266
+ data = "test"
267
+ expected_hash = hashlib.sha256(data.encode()).hexdigest()
268
+ self.assertEqual(EncrytionUtil.create_hash(data), expected_hash)
269
+
270
+ def test_hashing_with_different_algorithms(self):
271
+ data = "test"
272
+ algorithms = ['sha1', 'sha224', 'sha384', 'sha512']
273
+ for algo in algorithms:
274
+ with self.subTest(algorithm=algo):
275
+ expected_hash = hashlib.new(algo, data.encode()).hexdigest()
276
+ self.assertEqual(EncrytionUtil.create_hash(data, algo), expected_hash)
277
+
278
+ def test_with_unsupported_algorithm(self):
279
+ with self.assertRaises(ValueError):
280
+ EncrytionUtil.create_hash("test", "unsupported_algo")
281
+
282
+
283
+ class TestDecodeBase64(unittest.TestCase):
284
+
285
+ def test_with_valid_base64_encoded_strings(self):
286
+ # Test a variety of valid base64 encoded strings
287
+ test_cases = [
288
+ ("SGVsbG8sIFdvcmxkIQ==", "Hello, World!"),
289
+ ("VGhpcyBpcyBhIHRlc3Q=", "This is a test"),
290
+ ("c29tZSBieXRlcw==", "some bytes")
291
+ ]
292
+ for encoded, original in test_cases:
293
+ with self.subTest(encoded=encoded):
294
+ self.assertEqual(EncrytionUtil.decode_base64(encoded), original)
295
+
296
+ def test_with_invalid_base64_string(self):
297
+ invalid_data = "This is not base64!"
298
+ with self.assertRaises(Exception): # Replace Exception with the specific exception if known
299
+ EncrytionUtil.decode_base64(invalid_data)
300
+
301
+ def test_with_empty_string(self):
302
+ self.assertEqual(EncrytionUtil.decode_base64(""), "")
303
+
304
+
305
+ class TestEncodeBase64(unittest.TestCase):
306
+
307
+ def test_with_valid_strings(self):
308
+ # Test a variety of strings
309
+ test_cases = [
310
+ ("Hello, World!", "SGVsbG8sIFdvcmxkIQ=="),
311
+ ("This is a test", "VGhpcyBpcyBhIHRlc3Q="),
312
+ ("some bytes", "c29tZSBieXRlcw==")
313
+ ]
314
+ for original, encoded in test_cases:
315
+ with self.subTest(original=original):
316
+ self.assertEqual(EncrytionUtil.encode_base64(original), encoded)
317
+
318
+ def test_with_empty_string(self):
319
+ self.assertEqual(EncrytionUtil.encode_base64(""), "")
320
+
321
+
322
+ if __name__ == '__main__':
323
+ unittest.main()
@@ -0,0 +1,238 @@
1
+ import unittest
2
+ import json
3
+ import os
4
+ from unittest.mock import mock_open, patch
5
+ from io import StringIO
6
+ from lionagi.utils.io_util import IOUtil
7
+
8
+
9
+ class NonClosingStringIO(StringIO):
10
+ def close(self):
11
+ # Override close method to keep StringIO open
12
+ pass
13
+
14
+
15
+ class TestIOUtil(unittest.TestCase):
16
+
17
+ def setUp(self):
18
+ self.valid_data = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
19
+ self.empty_data = []
20
+ self.mock_csv_data = "name,age\nAlice,30\nBob,25"
21
+ self.expected_output = [{'name': 'Alice', 'age': '30'}, {'name': 'Bob', 'age': '25'}]
22
+ self.valid_jsonl_data = '{"name": "Alice", "age": 30}\n{"name": "Bob", "age": 25}\n'
23
+ self.expected_valid_output = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
24
+ self.invalid_json_data = '{"name": "Alice", "age": 30}\nInvalid JSON\n{"name": "Bob", "age": 25}\n'
25
+ self.valid_json_data = '{"name": "Alice", "age": 30}'
26
+ self.invalid_json_data = '{name: Alice, age: 30}'
27
+ self.csv_data1 = "name,age\nAlice,30\nBob,25"
28
+ self.csv_data2 = "name,score\nAlice,85\nBob,90"
29
+ self.merged_data = "name,age,score\nAlice,30,85\nBob,25,90"
30
+ self.empty_data = ""
31
+ self.expected_csv = "name,age\nAlice,30\nBob,25\n"
32
+
33
+ @patch("builtins.open", new_callable=mock_open, read_data="name,age\nAlice,30\nBob,25")
34
+ def test_read_csv_valid_file(self, mock_file):
35
+ result = IOUtil.read_csv("dummy.csv")
36
+ self.assertEqual(result, self.expected_output)
37
+
38
+ @patch("builtins.open", side_effect=FileNotFoundError)
39
+ def test_read_csv_nonexistent_file(self, mock_file):
40
+ with self.assertRaises(FileNotFoundError):
41
+ IOUtil.read_csv("nonexistent.csv")
42
+
43
+ @patch("builtins.open", new_callable=mock_open, read_data="")
44
+ def test_read_csv_empty_file(self, mock_file):
45
+ result = IOUtil.read_csv("empty.csv")
46
+ self.assertEqual(result, [])
47
+
48
+ @patch("builtins.open", new_callable=mock_open, read_data="name,age\nAlice,30\nBob,25.5\nCharlie")
49
+ def test_read_csv_inconsistent_columns(self, mock_file):
50
+ result = IOUtil.read_csv("inconsistent.csv")
51
+ self.assertEqual(len(result), 3)
52
+ self.assertIn('Charlie', result[-1].values())
53
+
54
+ # Additional test for different data types
55
+ @patch("builtins.open", new_callable=mock_open, read_data="name,age,score\nAlice,30,85.5\nBob,25,90")
56
+ def test_read_csv_varied_data_types(self, mock_file):
57
+ expected_output = [{'name': 'Alice', 'age': '30', 'score': '85.5'}, {'name': 'Bob', 'age': '25', 'score': '90'}]
58
+ result = IOUtil.read_csv("varied_types.csv")
59
+ self.assertEqual(result, expected_output)
60
+
61
+ @patch("builtins.open", new_callable=mock_open,
62
+ read_data='{"name": "Alice", "age": 30}\n{"name": "Bob", "age": 25}\n')
63
+ def test_read_jsonl_valid_file(self, mock_file):
64
+ result = IOUtil.read_jsonl("dummy.jsonl")
65
+ self.assertEqual(result, self.expected_valid_output)
66
+
67
+ @patch("builtins.open", side_effect=FileNotFoundError)
68
+ def test_read_jsonl_nonexistent_file(self, mock_file):
69
+ with self.assertRaises(FileNotFoundError):
70
+ IOUtil.read_jsonl("nonexistent.jsonl")
71
+
72
+ @patch("builtins.open", new_callable=mock_open, read_data="")
73
+ def test_read_jsonl_empty_file(self, mock_file):
74
+ result = IOUtil.read_jsonl("empty.jsonl")
75
+ self.assertEqual(result, [])
76
+
77
+ @patch("builtins.open", new_callable=mock_open, read_data='{"name": "Alice", "age": 30}\n{"name": "Bob"}\n')
78
+ def test_read_jsonl_mixed_data_types(self, mock_file):
79
+ expected_output = [{'name': 'Alice', 'age': 30}, {'name': 'Bob'}]
80
+ result = IOUtil.read_jsonl("mixed_types.jsonl")
81
+ self.assertEqual(result, expected_output)
82
+
83
+ @patch("builtins.open", new_callable=mock_open,
84
+ read_data='{"name": "Alice", "age": 30}\nInvalid JSON\n{"name": "Bob", "age": 25}\n')
85
+ def test_read_jsonl_invalid_json_format(self, mock_file):
86
+ with self.assertRaises(json.JSONDecodeError):
87
+ IOUtil.read_jsonl("invalid.jsonl")
88
+
89
+ def test_write_json_valid_data(self):
90
+ mock_file = NonClosingStringIO()
91
+ with patch("builtins.open", return_value=mock_file):
92
+ IOUtil.write_json(self.valid_data, "test.json")
93
+ # Since close is overridden, the file remains open
94
+ mock_file.seek(0)
95
+ written_data = mock_file.getvalue()
96
+ expected_json = json.dumps(self.valid_data, indent=4)
97
+ self.assertEqual(written_data, expected_json)
98
+
99
+ def test_write_json_empty_list(self):
100
+ with patch("builtins.open", mock_open()) as mocked_file:
101
+ IOUtil.write_json(self.empty_data, "test.json")
102
+ mocked_file.assert_called_once_with("test.json", 'w')
103
+ mocked_file().write.assert_called_once_with(json.dumps(self.empty_data, indent=4))
104
+
105
+ def test_write_json_non_serializable_data(self):
106
+ non_serializable_data = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': lambda x: x}]
107
+ with self.assertRaises(TypeError):
108
+ IOUtil.write_json(non_serializable_data, "test.json")
109
+
110
+ @patch("builtins.open", new_callable=mock_open, read_data='{"name": "Alice", "age": 30}')
111
+ def test_read_json_valid_file(self, mock_file):
112
+ expected_output = {"name": "Alice", "age": 30}
113
+ result = IOUtil.read_json("valid.json")
114
+ self.assertEqual(result, expected_output)
115
+
116
+ @patch("builtins.open", side_effect=FileNotFoundError)
117
+ def test_read_json_nonexistent_file(self, mock_file):
118
+ with self.assertRaises(FileNotFoundError):
119
+ IOUtil.read_json("nonexistent.json")
120
+
121
+ @patch("builtins.open", new_callable=mock_open, read_data='{name: Alice, age: 30}')
122
+ def test_read_json_invalid_format(self, mock_file):
123
+ with self.assertRaises(json.JSONDecodeError):
124
+ IOUtil.read_json("invalid.json")
125
+
126
+ def open_mock(self, file, mode='r', newline=None):
127
+ mock_files = {
128
+ 'file1.csv': mock_open(read_data=self.csv_data1).return_value,
129
+ 'file2.csv': mock_open(read_data=self.csv_data2).return_value,
130
+ 'empty1.csv': mock_open(read_data=self.empty_data).return_value,
131
+ 'empty2.csv': mock_open(read_data=self.empty_data).return_value
132
+ }
133
+ return mock_files.get(file, mock_open().return_value)
134
+
135
+ def test_merge_csv_files_valid(self):
136
+ with patch("builtins.open", side_effect=self.open_mock):
137
+ IOUtil.merge_csv_files(['file1.csv', 'file2.csv'], 'merged.csv')
138
+ # Additional assertions can be added to validate the content of the merged file.
139
+
140
+ def test_merge_csv_files_nonexistent(self):
141
+ with patch("builtins.open", side_effect=FileNotFoundError):
142
+ with self.assertRaises(FileNotFoundError):
143
+ IOUtil.merge_csv_files(['nonexistent.csv'], 'output.csv')
144
+
145
+ def test_merge_csv_files_different_columns(self):
146
+ with patch("builtins.open", side_effect=self.open_mock):
147
+ IOUtil.merge_csv_files(['file1.csv', 'file2.csv'], 'merged_diff_cols.csv')
148
+ # Additional assertions can be added to validate the content of the merged file.
149
+
150
+ def test_merge_csv_files_empty_files(self):
151
+ with patch("builtins.open", side_effect=self.open_mock):
152
+ IOUtil.merge_csv_files(['empty1.csv', 'empty2.csv'], 'merged_empty.csv')
153
+ # Assert that the output file is created and is empty
154
+
155
+ def test_to_csv_valid_data(self):
156
+ mock_file = mock_open()
157
+ with patch("builtins.open", mock_file):
158
+ IOUtil.to_csv(self.valid_data, "test.csv")
159
+
160
+ # Aggregate all calls to write and concatenate their arguments
161
+ write_calls = mock_file().write.call_args_list
162
+ written_data = ''.join(call_arg[0][0] for call_arg in write_calls)
163
+
164
+ # Normalize line endings to Unix style for comparison
165
+ written_data_normalized = written_data.replace('\r\n', '\n')
166
+ self.assertEqual(written_data_normalized, self.expected_csv)
167
+
168
+ @patch("os.path.exists", return_value=False)
169
+ def test_to_csv_nonexistent_dir_file_exist_ok_false(self, mock_exists):
170
+ with self.assertRaises(FileNotFoundError):
171
+ IOUtil.to_csv(self.valid_data, "nonexistent_dir/test.csv", file_exist_ok=False)
172
+
173
+ @patch("os.makedirs")
174
+ @patch("os.path.exists", return_value=False)
175
+ def test_to_csv_nonexistent_dir_file_exist_ok_true(self, mock_exists, mock_makedirs):
176
+ mock_file = mock_open()
177
+ with patch("builtins.open", mock_file):
178
+ IOUtil.to_csv(self.valid_data, "nonexistent_dir/test.csv", file_exist_ok=True)
179
+ mock_makedirs.assert_called_once_with(os.path.dirname("nonexistent_dir/test.csv"), exist_ok=True)
180
+
181
+ def test_to_csv_empty_list(self):
182
+ mock_file = mock_open()
183
+ with patch("builtins.open", mock_file):
184
+ IOUtil.to_csv([], "empty.csv")
185
+
186
+ # Assert that open was never called, as the method should return early for an empty list
187
+ mock_file.assert_not_called()
188
+
189
+ def test_append_to_jsonl_valid_data(self):
190
+ mock_file = mock_open()
191
+ data = {'name': 'Alice', 'age': 30}
192
+ with patch("builtins.open", mock_file, create=True):
193
+ IOUtil.append_to_jsonl(data, "test.jsonl")
194
+ mock_file().write.assert_called_once_with(json.dumps(data) + "\n")
195
+
196
+ def test_append_to_jsonl_new_file(self):
197
+ mock_file = mock_open()
198
+ data = {'name': 'Bob', 'age': 25}
199
+ with patch("builtins.open", mock_file, create=True):
200
+ IOUtil.append_to_jsonl(data, "new_file.jsonl")
201
+ mock_file.assert_called_once_with("new_file.jsonl", "a")
202
+ mock_file().write.assert_called_once_with(json.dumps(data) + "\n")
203
+
204
+ def test_append_to_jsonl_non_serializable_data(self):
205
+ class NonSerializable:
206
+ pass
207
+
208
+ non_serializable_data = NonSerializable()
209
+ with self.assertRaises(TypeError):
210
+ IOUtil.append_to_jsonl(non_serializable_data, "test.jsonl")
211
+
212
+ def test_to_temp_with_valid_string(self):
213
+ test_input = "Test String"
214
+ temp_file = IOUtil.to_temp(test_input)
215
+ with open(temp_file.name, 'r') as file:
216
+ content = json.load(file)
217
+ self.assertEqual(content, [test_input])
218
+ os.remove(temp_file.name) # Clean up the temporary file
219
+
220
+ def test_to_temp_with_valid_iterable(self):
221
+ test_input = ["Test", "String"]
222
+ temp_file = IOUtil.to_temp(test_input)
223
+ with open(temp_file.name, 'r') as file:
224
+ content = json.load(file)
225
+ self.assertEqual(content, test_input)
226
+ os.remove(temp_file.name) # Clean up the temporary file
227
+
228
+ def test_to_temp_with_non_serializable_data(self):
229
+ class NonSerializable:
230
+ pass
231
+
232
+ non_serializable_data = NonSerializable()
233
+ with self.assertRaises(TypeError):
234
+ IOUtil.to_temp(non_serializable_data)
235
+
236
+
237
+ if __name__ == '__main__':
238
+ unittest.main()