lionagi 0.0.115__py3-none-any.whl → 0.0.204__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/__init__.py +1 -2
- lionagi/_services/__init__.py +5 -0
- lionagi/_services/anthropic.py +79 -0
- lionagi/_services/base_service.py +414 -0
- lionagi/_services/oai.py +98 -0
- lionagi/_services/openrouter.py +44 -0
- lionagi/_services/services.py +91 -0
- lionagi/_services/transformers.py +46 -0
- lionagi/bridge/langchain.py +26 -16
- lionagi/bridge/llama_index.py +50 -20
- lionagi/configs/oai_configs.py +2 -14
- lionagi/configs/openrouter_configs.py +2 -2
- lionagi/core/__init__.py +7 -8
- lionagi/core/branch/branch.py +589 -0
- lionagi/core/branch/branch_manager.py +139 -0
- lionagi/core/branch/conversation.py +484 -0
- lionagi/core/core_util.py +59 -0
- lionagi/core/flow/flow.py +19 -0
- lionagi/core/flow/flow_util.py +62 -0
- lionagi/core/instruction_set/__init__.py +0 -5
- lionagi/core/instruction_set/instruction_set.py +343 -0
- lionagi/core/messages/messages.py +176 -0
- lionagi/core/sessions/__init__.py +0 -5
- lionagi/core/sessions/session.py +428 -0
- lionagi/loaders/chunker.py +51 -47
- lionagi/loaders/load_util.py +2 -2
- lionagi/loaders/reader.py +45 -39
- lionagi/models/imodel.py +53 -0
- lionagi/schema/async_queue.py +158 -0
- lionagi/schema/base_node.py +318 -147
- lionagi/schema/base_tool.py +31 -1
- lionagi/schema/data_logger.py +74 -38
- lionagi/schema/data_node.py +57 -6
- lionagi/structures/graph.py +132 -10
- lionagi/structures/relationship.py +58 -20
- lionagi/structures/structure.py +36 -25
- lionagi/tests/test_utils/test_api_util.py +219 -0
- lionagi/tests/test_utils/test_call_util.py +785 -0
- lionagi/tests/test_utils/test_encrypt_util.py +323 -0
- lionagi/tests/test_utils/test_io_util.py +238 -0
- lionagi/tests/test_utils/test_nested_util.py +338 -0
- lionagi/tests/test_utils/test_sys_util.py +358 -0
- lionagi/tools/tool_manager.py +186 -0
- lionagi/tools/tool_util.py +266 -3
- lionagi/utils/__init__.py +21 -61
- lionagi/utils/api_util.py +359 -71
- lionagi/utils/call_util.py +839 -264
- lionagi/utils/encrypt_util.py +283 -16
- lionagi/utils/io_util.py +178 -93
- lionagi/utils/nested_util.py +672 -0
- lionagi/utils/pd_util.py +57 -0
- lionagi/utils/sys_util.py +284 -156
- lionagi/utils/url_util.py +55 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/METADATA +21 -17
- lionagi-0.0.204.dist-info/RECORD +106 -0
- lionagi/core/conversations/__init__.py +0 -5
- lionagi/core/conversations/conversation.py +0 -107
- lionagi/core/flows/__init__.py +0 -8
- lionagi/core/flows/flow.py +0 -8
- lionagi/core/flows/flow_util.py +0 -62
- lionagi/core/instruction_set/instruction_sets.py +0 -7
- lionagi/core/sessions/sessions.py +0 -185
- lionagi/endpoints/__init__.py +0 -5
- lionagi/endpoints/audio.py +0 -17
- lionagi/endpoints/chatcompletion.py +0 -54
- lionagi/messages/__init__.py +0 -11
- lionagi/messages/instruction.py +0 -15
- lionagi/messages/message.py +0 -110
- lionagi/messages/response.py +0 -33
- lionagi/messages/system.py +0 -12
- lionagi/objs/__init__.py +0 -11
- lionagi/objs/abc_objs.py +0 -39
- lionagi/objs/async_queue.py +0 -135
- lionagi/objs/messenger.py +0 -85
- lionagi/objs/tool_manager.py +0 -253
- lionagi/services/__init__.py +0 -11
- lionagi/services/base_api_service.py +0 -230
- lionagi/services/oai.py +0 -34
- lionagi/services/openrouter.py +0 -31
- lionagi/tests/test_api_util.py +0 -46
- lionagi/tests/test_call_util.py +0 -115
- lionagi/tests/test_convert_util.py +0 -202
- lionagi/tests/test_encrypt_util.py +0 -33
- lionagi/tests/test_flat_util.py +0 -426
- lionagi/tests/test_sys_util.py +0 -0
- lionagi/utils/convert_util.py +0 -229
- lionagi/utils/flat_util.py +0 -599
- lionagi-0.0.115.dist-info/RECORD +0 -110
- /lionagi/{services → _services}/anyscale.py +0 -0
- /lionagi/{services → _services}/azure.py +0 -0
- /lionagi/{services → _services}/bedrock.py +0 -0
- /lionagi/{services → _services}/everlyai.py +0 -0
- /lionagi/{services → _services}/gemini.py +0 -0
- /lionagi/{services → _services}/gpt4all.py +0 -0
- /lionagi/{services → _services}/huggingface.py +0 -0
- /lionagi/{services → _services}/litellm.py +0 -0
- /lionagi/{services → _services}/localai.py +0 -0
- /lionagi/{services → _services}/mistralai.py +0 -0
- /lionagi/{services → _services}/ollama.py +0 -0
- /lionagi/{services → _services}/openllm.py +0 -0
- /lionagi/{services → _services}/perplexity.py +0 -0
- /lionagi/{services → _services}/predibase.py +0 -0
- /lionagi/{services → _services}/rungpt.py +0 -0
- /lionagi/{services → _services}/vllm.py +0 -0
- /lionagi/{services → _services}/xinference.py +0 -0
- /lionagi/{endpoints/assistants.py → agents/__init__.py} +0 -0
- /lionagi/{tools → agents}/planner.py +0 -0
- /lionagi/{tools → agents}/prompter.py +0 -0
- /lionagi/{tools → agents}/scorer.py +0 -0
- /lionagi/{tools → agents}/summarizer.py +0 -0
- /lionagi/{tools → agents}/validator.py +0 -0
- /lionagi/{endpoints/embeddings.py → core/branch/__init__.py} +0 -0
- /lionagi/{services/anthropic.py → core/branch/cluster.py} +0 -0
- /lionagi/{endpoints/finetune.py → core/flow/__init__.py} +0 -0
- /lionagi/{endpoints/image.py → core/messages/__init__.py} +0 -0
- /lionagi/{endpoints/moderation.py → models/__init__.py} +0 -0
- /lionagi/{endpoints/vision.py → models/base_model.py} +0 -0
- /lionagi/{objs → schema}/status_tracker.py +0 -0
- /lionagi/tests/{test_io_util.py → test_utils/__init__.py} +0 -0
- {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/LICENSE +0 -0
- {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/WHEEL +0 -0
- {lionagi-0.0.115.dist-info → lionagi-0.0.204.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
|
|
1
|
+
import unittest
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
import zipfile
|
5
|
+
import hashlib
|
6
|
+
from cryptography.fernet import InvalidToken
|
7
|
+
from lionagi.utils.encrypt_util import EncrytionUtil
|
8
|
+
|
9
|
+
|
10
|
+
class TestPasswordStrengthChecker(unittest.TestCase):
|
11
|
+
|
12
|
+
def test_short_passwords(self):
|
13
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("Short1"))
|
14
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("A1"))
|
15
|
+
|
16
|
+
def test_passwords_without_digits(self):
|
17
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("NoDigitsHere"))
|
18
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("Onlyletters!"))
|
19
|
+
|
20
|
+
def test_passwords_without_uppercase(self):
|
21
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("alllowercase1"))
|
22
|
+
self.assertFalse(EncrytionUtil.password_strength_checker("nouppercase1!"))
|
23
|
+
|
24
|
+
def test_strong_passwords(self):
|
25
|
+
self.assertTrue(EncrytionUtil.password_strength_checker("ValidPass1"))
|
26
|
+
self.assertTrue(EncrytionUtil.password_strength_checker("AnotherGood1"))
|
27
|
+
|
28
|
+
|
29
|
+
class TestGenerateEncryptionKey(unittest.TestCase):
|
30
|
+
|
31
|
+
def setUp(self):
|
32
|
+
# Strong password and predefined salt for testing
|
33
|
+
self.strong_password = "StrongPass1"
|
34
|
+
self.salt = b'0123456789abcdef'
|
35
|
+
|
36
|
+
def test_with_strong_password_and_provided_salt(self):
|
37
|
+
key = EncrytionUtil.generate_encryption_key(password=self.strong_password, salt=self.salt)
|
38
|
+
self.assertIsInstance(key, str)
|
39
|
+
|
40
|
+
def test_with_strong_password_and_no_salt(self):
|
41
|
+
key = EncrytionUtil.generate_encryption_key(password=self.strong_password)
|
42
|
+
self.assertIsInstance(key, str)
|
43
|
+
|
44
|
+
def test_with_weak_password(self):
|
45
|
+
with self.assertRaises(ValueError):
|
46
|
+
EncrytionUtil.generate_encryption_key(password="weak")
|
47
|
+
|
48
|
+
def test_with_no_password(self):
|
49
|
+
key = EncrytionUtil.generate_encryption_key()
|
50
|
+
self.assertIsInstance(key, str)
|
51
|
+
self.assertEqual(len(key), 44) # Typical length of a Fernet key
|
52
|
+
|
53
|
+
|
54
|
+
class TestEncrypt(unittest.TestCase):
|
55
|
+
|
56
|
+
def setUp(self):
|
57
|
+
self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
|
58
|
+
self.invalid_key = "invalidkey"
|
59
|
+
self.test_data = "This is a test string."
|
60
|
+
|
61
|
+
def test_valid_encryption(self):
|
62
|
+
encrypted_data = EncrytionUtil.encrypt(self.test_data, self.valid_key)
|
63
|
+
self.assertIsInstance(encrypted_data, str)
|
64
|
+
self.assertNotEqual(encrypted_data, self.test_data)
|
65
|
+
|
66
|
+
def test_with_invalid_key(self):
|
67
|
+
with self.assertRaises(Exception):
|
68
|
+
EncrytionUtil.encrypt(self.test_data, self.invalid_key)
|
69
|
+
|
70
|
+
def test_with_non_string_data(self):
|
71
|
+
with self.assertRaises(AttributeError): # or whichever error is appropriate
|
72
|
+
EncrytionUtil.encrypt(12345, self.valid_key)
|
73
|
+
|
74
|
+
def test_with_empty_string(self):
|
75
|
+
encrypted_data = EncrytionUtil.encrypt("", self.valid_key)
|
76
|
+
self.assertIsInstance(encrypted_data, str)
|
77
|
+
self.assertNotEqual(encrypted_data, "")
|
78
|
+
|
79
|
+
|
80
|
+
class TestDecrypt(unittest.TestCase):
|
81
|
+
|
82
|
+
def setUp(self):
|
83
|
+
self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
|
84
|
+
self.invalid_key = "invalidkey"
|
85
|
+
self.test_data = "This is a test string."
|
86
|
+
self.encrypted_data = EncrytionUtil.encrypt(self.test_data, self.valid_key)
|
87
|
+
|
88
|
+
def test_valid_decryption(self):
|
89
|
+
decrypted_data = EncrytionUtil.decrypt(self.encrypted_data, self.valid_key)
|
90
|
+
self.assertIsInstance(decrypted_data, str)
|
91
|
+
self.assertEqual(decrypted_data, self.test_data)
|
92
|
+
|
93
|
+
def test_decryption_with_invalid_key(self):
|
94
|
+
with self.assertRaises(ValueError):
|
95
|
+
EncrytionUtil.decrypt(self.encrypted_data, self.invalid_key)
|
96
|
+
|
97
|
+
def test_with_non_string_data(self):
|
98
|
+
with self.assertRaises(AttributeError):
|
99
|
+
EncrytionUtil.decrypt(12345, self.valid_key)
|
100
|
+
|
101
|
+
def test_decryption_of_non_encrypted_string(self):
|
102
|
+
with self.assertRaises(InvalidToken):
|
103
|
+
EncrytionUtil.decrypt("plain text", self.valid_key)
|
104
|
+
|
105
|
+
|
106
|
+
class TestEncryptFile(unittest.TestCase):
|
107
|
+
|
108
|
+
def setUp(self):
|
109
|
+
self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
|
110
|
+
# Create a temporary file with some test data
|
111
|
+
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
|
112
|
+
self.temp_file.write(b"This is a test file.")
|
113
|
+
self.temp_file.close()
|
114
|
+
|
115
|
+
def tearDown(self):
|
116
|
+
# Cleanup: Remove temporary files
|
117
|
+
os.remove(self.temp_file.name)
|
118
|
+
if os.path.exists(self.temp_file.name + '.enc'):
|
119
|
+
os.remove(self.temp_file.name + '.enc')
|
120
|
+
|
121
|
+
def test_encrypting_valid_file(self):
|
122
|
+
EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
|
123
|
+
self.assertTrue(os.path.exists(self.temp_file.name + '.enc'))
|
124
|
+
|
125
|
+
def test_with_non_existent_file_path(self):
|
126
|
+
with self.assertRaises(FileNotFoundError):
|
127
|
+
EncrytionUtil.encrypt_file("non_existent_file.txt", self.valid_key)
|
128
|
+
|
129
|
+
|
130
|
+
class TestDecryptFile(unittest.TestCase):
|
131
|
+
|
132
|
+
def setUp(self):
|
133
|
+
self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
|
134
|
+
# Create a temporary file and encrypt it
|
135
|
+
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
|
136
|
+
self.temp_file.write(b"This is a test file.")
|
137
|
+
self.temp_file.close()
|
138
|
+
EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
|
139
|
+
|
140
|
+
def tearDown(self):
|
141
|
+
# Cleanup: Remove temporary files
|
142
|
+
os.remove(self.temp_file.name)
|
143
|
+
if os.path.exists(self.temp_file.name + '.enc'):
|
144
|
+
os.remove(self.temp_file.name + '.enc')
|
145
|
+
decrypted_file_path = self.temp_file.name.replace('.enc', '')
|
146
|
+
if os.path.exists(decrypted_file_path):
|
147
|
+
os.remove(decrypted_file_path)
|
148
|
+
|
149
|
+
def test_decrypting_valid_encrypted_file(self):
|
150
|
+
EncrytionUtil.decrypt_file(self.temp_file.name + '.enc', self.valid_key)
|
151
|
+
decrypted_file_path = self.temp_file.name.replace('.enc', '')
|
152
|
+
self.assertTrue(os.path.exists(decrypted_file_path))
|
153
|
+
|
154
|
+
def test_with_non_existent_encrypted_file_path(self):
|
155
|
+
with self.assertRaises(FileNotFoundError):
|
156
|
+
EncrytionUtil.decrypt_file("non_existent_file.txt.enc", self.valid_key)
|
157
|
+
|
158
|
+
def test_decrypting_with_invalid_key(self):
|
159
|
+
with self.assertRaises(ValueError):
|
160
|
+
EncrytionUtil.decrypt_file(self.temp_file.name + '.enc', 'invalidkey')
|
161
|
+
|
162
|
+
|
163
|
+
class TestIsEncrypted(unittest.TestCase):
|
164
|
+
|
165
|
+
def setUp(self):
|
166
|
+
self.valid_key = EncrytionUtil.generate_encryption_key("StrongPass1")
|
167
|
+
# Create a temporary file and encrypt it
|
168
|
+
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
|
169
|
+
self.temp_file.write(b"This is a test file.")
|
170
|
+
self.temp_file.close()
|
171
|
+
EncrytionUtil.encrypt_file(self.temp_file.name, self.valid_key)
|
172
|
+
|
173
|
+
# Create a non-encrypted temporary file
|
174
|
+
self.non_encrypted_file = tempfile.NamedTemporaryFile(delete=False)
|
175
|
+
self.non_encrypted_file.write(b"This is a non-encrypted test file.")
|
176
|
+
self.non_encrypted_file.close()
|
177
|
+
|
178
|
+
def tearDown(self):
|
179
|
+
# Cleanup: Remove temporary files
|
180
|
+
os.remove(self.temp_file.name)
|
181
|
+
if os.path.exists(self.temp_file.name + '.enc'):
|
182
|
+
os.remove(self.temp_file.name + '.enc')
|
183
|
+
os.remove(self.non_encrypted_file.name)
|
184
|
+
|
185
|
+
def test_with_encrypted_file(self):
|
186
|
+
self.assertTrue(EncrytionUtil.is_encrypted(self.temp_file.name + '.enc', self.valid_key))
|
187
|
+
|
188
|
+
def test_with_non_encrypted_file(self):
|
189
|
+
self.assertFalse(EncrytionUtil.is_encrypted(self.non_encrypted_file.name, self.valid_key))
|
190
|
+
|
191
|
+
|
192
|
+
class TestDecompressFile(unittest.TestCase):
|
193
|
+
|
194
|
+
def setUp(self):
|
195
|
+
# Create a temporary directory
|
196
|
+
self.temp_dir = tempfile.mkdtemp()
|
197
|
+
|
198
|
+
# Create a temporary zip file with some content
|
199
|
+
self.temp_zip_file = os.path.join(self.temp_dir, 'test.zip')
|
200
|
+
with zipfile.ZipFile(self.temp_zip_file, 'w') as zipf:
|
201
|
+
zipf.writestr('test.txt', 'This is a test file.')
|
202
|
+
|
203
|
+
def tearDown(self):
|
204
|
+
# Cleanup: Remove temporary directory and its contents
|
205
|
+
for root, dirs, files in os.walk(self.temp_dir, topdown=False):
|
206
|
+
for name in files:
|
207
|
+
os.remove(os.path.join(root, name))
|
208
|
+
for name in dirs:
|
209
|
+
os.rmdir(os.path.join(root, name))
|
210
|
+
os.rmdir(self.temp_dir)
|
211
|
+
|
212
|
+
def test_decompressing_valid_zip_file(self):
|
213
|
+
EncrytionUtil.decompress_file(self.temp_zip_file, self.temp_dir)
|
214
|
+
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, 'test.txt')))
|
215
|
+
|
216
|
+
def test_with_non_existent_zip_file_path(self):
|
217
|
+
with self.assertRaises(FileNotFoundError):
|
218
|
+
EncrytionUtil.decompress_file("non_existent_file.zip", self.temp_dir)
|
219
|
+
|
220
|
+
def test_with_invalid_zip_file(self):
|
221
|
+
invalid_zip_file = os.path.join(self.temp_dir, 'invalid.zip')
|
222
|
+
with open(invalid_zip_file, 'w') as f:
|
223
|
+
f.write("This is not a zip file.")
|
224
|
+
with self.assertRaises(zipfile.BadZipFile):
|
225
|
+
EncrytionUtil.decompress_file(invalid_zip_file, self.temp_dir)
|
226
|
+
|
227
|
+
|
228
|
+
class TestCompressFile(unittest.TestCase):
|
229
|
+
|
230
|
+
def setUp(self):
|
231
|
+
# Create a temporary file with some test data
|
232
|
+
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
|
233
|
+
self.temp_file.write(b"This is a test file.")
|
234
|
+
self.temp_file.close()
|
235
|
+
|
236
|
+
def tearDown(self):
|
237
|
+
# Cleanup: Remove temporary files
|
238
|
+
os.remove(self.temp_file.name)
|
239
|
+
if os.path.exists(self.temp_file.name + '.zip'):
|
240
|
+
os.remove(self.temp_file.name + '.zip')
|
241
|
+
|
242
|
+
def test_compressing_valid_file(self):
|
243
|
+
EncrytionUtil.compress_file(self.temp_file.name)
|
244
|
+
self.assertTrue(os.path.exists(self.temp_file.name + '.zip'))
|
245
|
+
|
246
|
+
def test_with_non_existent_file_path(self):
|
247
|
+
with self.assertRaises(FileNotFoundError):
|
248
|
+
EncrytionUtil.compress_file("non_existent_file.txt")
|
249
|
+
|
250
|
+
|
251
|
+
class TestBinaryToHex(unittest.TestCase):
|
252
|
+
|
253
|
+
def test_with_valid_binary_data(self):
|
254
|
+
# Test a variety of binary data
|
255
|
+
self.assertEqual(EncrytionUtil.binary_to_hex(b'\x00\x0F'), '000f')
|
256
|
+
self.assertEqual(EncrytionUtil.binary_to_hex(b'hello'), '68656c6c6f')
|
257
|
+
self.assertEqual(EncrytionUtil.binary_to_hex(b'\xff\xfe\xfd\xfc'), 'fffefdfc')
|
258
|
+
|
259
|
+
def test_with_empty_bytes(self):
|
260
|
+
self.assertEqual(EncrytionUtil.binary_to_hex(b''), '')
|
261
|
+
|
262
|
+
|
263
|
+
class TestCreateHash(unittest.TestCase):
|
264
|
+
|
265
|
+
def test_hashing_with_default_algorithm(self):
|
266
|
+
data = "test"
|
267
|
+
expected_hash = hashlib.sha256(data.encode()).hexdigest()
|
268
|
+
self.assertEqual(EncrytionUtil.create_hash(data), expected_hash)
|
269
|
+
|
270
|
+
def test_hashing_with_different_algorithms(self):
|
271
|
+
data = "test"
|
272
|
+
algorithms = ['sha1', 'sha224', 'sha384', 'sha512']
|
273
|
+
for algo in algorithms:
|
274
|
+
with self.subTest(algorithm=algo):
|
275
|
+
expected_hash = hashlib.new(algo, data.encode()).hexdigest()
|
276
|
+
self.assertEqual(EncrytionUtil.create_hash(data, algo), expected_hash)
|
277
|
+
|
278
|
+
def test_with_unsupported_algorithm(self):
|
279
|
+
with self.assertRaises(ValueError):
|
280
|
+
EncrytionUtil.create_hash("test", "unsupported_algo")
|
281
|
+
|
282
|
+
|
283
|
+
class TestDecodeBase64(unittest.TestCase):
|
284
|
+
|
285
|
+
def test_with_valid_base64_encoded_strings(self):
|
286
|
+
# Test a variety of valid base64 encoded strings
|
287
|
+
test_cases = [
|
288
|
+
("SGVsbG8sIFdvcmxkIQ==", "Hello, World!"),
|
289
|
+
("VGhpcyBpcyBhIHRlc3Q=", "This is a test"),
|
290
|
+
("c29tZSBieXRlcw==", "some bytes")
|
291
|
+
]
|
292
|
+
for encoded, original in test_cases:
|
293
|
+
with self.subTest(encoded=encoded):
|
294
|
+
self.assertEqual(EncrytionUtil.decode_base64(encoded), original)
|
295
|
+
|
296
|
+
def test_with_invalid_base64_string(self):
|
297
|
+
invalid_data = "This is not base64!"
|
298
|
+
with self.assertRaises(Exception): # Replace Exception with the specific exception if known
|
299
|
+
EncrytionUtil.decode_base64(invalid_data)
|
300
|
+
|
301
|
+
def test_with_empty_string(self):
|
302
|
+
self.assertEqual(EncrytionUtil.decode_base64(""), "")
|
303
|
+
|
304
|
+
|
305
|
+
class TestEncodeBase64(unittest.TestCase):
|
306
|
+
|
307
|
+
def test_with_valid_strings(self):
|
308
|
+
# Test a variety of strings
|
309
|
+
test_cases = [
|
310
|
+
("Hello, World!", "SGVsbG8sIFdvcmxkIQ=="),
|
311
|
+
("This is a test", "VGhpcyBpcyBhIHRlc3Q="),
|
312
|
+
("some bytes", "c29tZSBieXRlcw==")
|
313
|
+
]
|
314
|
+
for original, encoded in test_cases:
|
315
|
+
with self.subTest(original=original):
|
316
|
+
self.assertEqual(EncrytionUtil.encode_base64(original), encoded)
|
317
|
+
|
318
|
+
def test_with_empty_string(self):
|
319
|
+
self.assertEqual(EncrytionUtil.encode_base64(""), "")
|
320
|
+
|
321
|
+
|
322
|
+
if __name__ == '__main__':
|
323
|
+
unittest.main()
|
@@ -0,0 +1,238 @@
|
|
1
|
+
import unittest
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
from unittest.mock import mock_open, patch
|
5
|
+
from io import StringIO
|
6
|
+
from lionagi.utils.io_util import IOUtil
|
7
|
+
|
8
|
+
|
9
|
+
class NonClosingStringIO(StringIO):
|
10
|
+
def close(self):
|
11
|
+
# Override close method to keep StringIO open
|
12
|
+
pass
|
13
|
+
|
14
|
+
|
15
|
+
class TestIOUtil(unittest.TestCase):
|
16
|
+
|
17
|
+
def setUp(self):
|
18
|
+
self.valid_data = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
|
19
|
+
self.empty_data = []
|
20
|
+
self.mock_csv_data = "name,age\nAlice,30\nBob,25"
|
21
|
+
self.expected_output = [{'name': 'Alice', 'age': '30'}, {'name': 'Bob', 'age': '25'}]
|
22
|
+
self.valid_jsonl_data = '{"name": "Alice", "age": 30}\n{"name": "Bob", "age": 25}\n'
|
23
|
+
self.expected_valid_output = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
|
24
|
+
self.invalid_json_data = '{"name": "Alice", "age": 30}\nInvalid JSON\n{"name": "Bob", "age": 25}\n'
|
25
|
+
self.valid_json_data = '{"name": "Alice", "age": 30}'
|
26
|
+
self.invalid_json_data = '{name: Alice, age: 30}'
|
27
|
+
self.csv_data1 = "name,age\nAlice,30\nBob,25"
|
28
|
+
self.csv_data2 = "name,score\nAlice,85\nBob,90"
|
29
|
+
self.merged_data = "name,age,score\nAlice,30,85\nBob,25,90"
|
30
|
+
self.empty_data = ""
|
31
|
+
self.expected_csv = "name,age\nAlice,30\nBob,25\n"
|
32
|
+
|
33
|
+
@patch("builtins.open", new_callable=mock_open, read_data="name,age\nAlice,30\nBob,25")
|
34
|
+
def test_read_csv_valid_file(self, mock_file):
|
35
|
+
result = IOUtil.read_csv("dummy.csv")
|
36
|
+
self.assertEqual(result, self.expected_output)
|
37
|
+
|
38
|
+
@patch("builtins.open", side_effect=FileNotFoundError)
|
39
|
+
def test_read_csv_nonexistent_file(self, mock_file):
|
40
|
+
with self.assertRaises(FileNotFoundError):
|
41
|
+
IOUtil.read_csv("nonexistent.csv")
|
42
|
+
|
43
|
+
@patch("builtins.open", new_callable=mock_open, read_data="")
|
44
|
+
def test_read_csv_empty_file(self, mock_file):
|
45
|
+
result = IOUtil.read_csv("empty.csv")
|
46
|
+
self.assertEqual(result, [])
|
47
|
+
|
48
|
+
@patch("builtins.open", new_callable=mock_open, read_data="name,age\nAlice,30\nBob,25.5\nCharlie")
|
49
|
+
def test_read_csv_inconsistent_columns(self, mock_file):
|
50
|
+
result = IOUtil.read_csv("inconsistent.csv")
|
51
|
+
self.assertEqual(len(result), 3)
|
52
|
+
self.assertIn('Charlie', result[-1].values())
|
53
|
+
|
54
|
+
# Additional test for different data types
|
55
|
+
@patch("builtins.open", new_callable=mock_open, read_data="name,age,score\nAlice,30,85.5\nBob,25,90")
|
56
|
+
def test_read_csv_varied_data_types(self, mock_file):
|
57
|
+
expected_output = [{'name': 'Alice', 'age': '30', 'score': '85.5'}, {'name': 'Bob', 'age': '25', 'score': '90'}]
|
58
|
+
result = IOUtil.read_csv("varied_types.csv")
|
59
|
+
self.assertEqual(result, expected_output)
|
60
|
+
|
61
|
+
@patch("builtins.open", new_callable=mock_open,
|
62
|
+
read_data='{"name": "Alice", "age": 30}\n{"name": "Bob", "age": 25}\n')
|
63
|
+
def test_read_jsonl_valid_file(self, mock_file):
|
64
|
+
result = IOUtil.read_jsonl("dummy.jsonl")
|
65
|
+
self.assertEqual(result, self.expected_valid_output)
|
66
|
+
|
67
|
+
@patch("builtins.open", side_effect=FileNotFoundError)
|
68
|
+
def test_read_jsonl_nonexistent_file(self, mock_file):
|
69
|
+
with self.assertRaises(FileNotFoundError):
|
70
|
+
IOUtil.read_jsonl("nonexistent.jsonl")
|
71
|
+
|
72
|
+
@patch("builtins.open", new_callable=mock_open, read_data="")
|
73
|
+
def test_read_jsonl_empty_file(self, mock_file):
|
74
|
+
result = IOUtil.read_jsonl("empty.jsonl")
|
75
|
+
self.assertEqual(result, [])
|
76
|
+
|
77
|
+
@patch("builtins.open", new_callable=mock_open, read_data='{"name": "Alice", "age": 30}\n{"name": "Bob"}\n')
|
78
|
+
def test_read_jsonl_mixed_data_types(self, mock_file):
|
79
|
+
expected_output = [{'name': 'Alice', 'age': 30}, {'name': 'Bob'}]
|
80
|
+
result = IOUtil.read_jsonl("mixed_types.jsonl")
|
81
|
+
self.assertEqual(result, expected_output)
|
82
|
+
|
83
|
+
@patch("builtins.open", new_callable=mock_open,
|
84
|
+
read_data='{"name": "Alice", "age": 30}\nInvalid JSON\n{"name": "Bob", "age": 25}\n')
|
85
|
+
def test_read_jsonl_invalid_json_format(self, mock_file):
|
86
|
+
with self.assertRaises(json.JSONDecodeError):
|
87
|
+
IOUtil.read_jsonl("invalid.jsonl")
|
88
|
+
|
89
|
+
def test_write_json_valid_data(self):
|
90
|
+
mock_file = NonClosingStringIO()
|
91
|
+
with patch("builtins.open", return_value=mock_file):
|
92
|
+
IOUtil.write_json(self.valid_data, "test.json")
|
93
|
+
# Since close is overridden, the file remains open
|
94
|
+
mock_file.seek(0)
|
95
|
+
written_data = mock_file.getvalue()
|
96
|
+
expected_json = json.dumps(self.valid_data, indent=4)
|
97
|
+
self.assertEqual(written_data, expected_json)
|
98
|
+
|
99
|
+
def test_write_json_empty_list(self):
|
100
|
+
with patch("builtins.open", mock_open()) as mocked_file:
|
101
|
+
IOUtil.write_json(self.empty_data, "test.json")
|
102
|
+
mocked_file.assert_called_once_with("test.json", 'w')
|
103
|
+
mocked_file().write.assert_called_once_with(json.dumps(self.empty_data, indent=4))
|
104
|
+
|
105
|
+
def test_write_json_non_serializable_data(self):
|
106
|
+
non_serializable_data = [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': lambda x: x}]
|
107
|
+
with self.assertRaises(TypeError):
|
108
|
+
IOUtil.write_json(non_serializable_data, "test.json")
|
109
|
+
|
110
|
+
@patch("builtins.open", new_callable=mock_open, read_data='{"name": "Alice", "age": 30}')
|
111
|
+
def test_read_json_valid_file(self, mock_file):
|
112
|
+
expected_output = {"name": "Alice", "age": 30}
|
113
|
+
result = IOUtil.read_json("valid.json")
|
114
|
+
self.assertEqual(result, expected_output)
|
115
|
+
|
116
|
+
@patch("builtins.open", side_effect=FileNotFoundError)
|
117
|
+
def test_read_json_nonexistent_file(self, mock_file):
|
118
|
+
with self.assertRaises(FileNotFoundError):
|
119
|
+
IOUtil.read_json("nonexistent.json")
|
120
|
+
|
121
|
+
@patch("builtins.open", new_callable=mock_open, read_data='{name: Alice, age: 30}')
|
122
|
+
def test_read_json_invalid_format(self, mock_file):
|
123
|
+
with self.assertRaises(json.JSONDecodeError):
|
124
|
+
IOUtil.read_json("invalid.json")
|
125
|
+
|
126
|
+
def open_mock(self, file, mode='r', newline=None):
|
127
|
+
mock_files = {
|
128
|
+
'file1.csv': mock_open(read_data=self.csv_data1).return_value,
|
129
|
+
'file2.csv': mock_open(read_data=self.csv_data2).return_value,
|
130
|
+
'empty1.csv': mock_open(read_data=self.empty_data).return_value,
|
131
|
+
'empty2.csv': mock_open(read_data=self.empty_data).return_value
|
132
|
+
}
|
133
|
+
return mock_files.get(file, mock_open().return_value)
|
134
|
+
|
135
|
+
def test_merge_csv_files_valid(self):
|
136
|
+
with patch("builtins.open", side_effect=self.open_mock):
|
137
|
+
IOUtil.merge_csv_files(['file1.csv', 'file2.csv'], 'merged.csv')
|
138
|
+
# Additional assertions can be added to validate the content of the merged file.
|
139
|
+
|
140
|
+
def test_merge_csv_files_nonexistent(self):
|
141
|
+
with patch("builtins.open", side_effect=FileNotFoundError):
|
142
|
+
with self.assertRaises(FileNotFoundError):
|
143
|
+
IOUtil.merge_csv_files(['nonexistent.csv'], 'output.csv')
|
144
|
+
|
145
|
+
def test_merge_csv_files_different_columns(self):
|
146
|
+
with patch("builtins.open", side_effect=self.open_mock):
|
147
|
+
IOUtil.merge_csv_files(['file1.csv', 'file2.csv'], 'merged_diff_cols.csv')
|
148
|
+
# Additional assertions can be added to validate the content of the merged file.
|
149
|
+
|
150
|
+
def test_merge_csv_files_empty_files(self):
|
151
|
+
with patch("builtins.open", side_effect=self.open_mock):
|
152
|
+
IOUtil.merge_csv_files(['empty1.csv', 'empty2.csv'], 'merged_empty.csv')
|
153
|
+
# Assert that the output file is created and is empty
|
154
|
+
|
155
|
+
def test_to_csv_valid_data(self):
|
156
|
+
mock_file = mock_open()
|
157
|
+
with patch("builtins.open", mock_file):
|
158
|
+
IOUtil.to_csv(self.valid_data, "test.csv")
|
159
|
+
|
160
|
+
# Aggregate all calls to write and concatenate their arguments
|
161
|
+
write_calls = mock_file().write.call_args_list
|
162
|
+
written_data = ''.join(call_arg[0][0] for call_arg in write_calls)
|
163
|
+
|
164
|
+
# Normalize line endings to Unix style for comparison
|
165
|
+
written_data_normalized = written_data.replace('\r\n', '\n')
|
166
|
+
self.assertEqual(written_data_normalized, self.expected_csv)
|
167
|
+
|
168
|
+
@patch("os.path.exists", return_value=False)
|
169
|
+
def test_to_csv_nonexistent_dir_file_exist_ok_false(self, mock_exists):
|
170
|
+
with self.assertRaises(FileNotFoundError):
|
171
|
+
IOUtil.to_csv(self.valid_data, "nonexistent_dir/test.csv", file_exist_ok=False)
|
172
|
+
|
173
|
+
@patch("os.makedirs")
|
174
|
+
@patch("os.path.exists", return_value=False)
|
175
|
+
def test_to_csv_nonexistent_dir_file_exist_ok_true(self, mock_exists, mock_makedirs):
|
176
|
+
mock_file = mock_open()
|
177
|
+
with patch("builtins.open", mock_file):
|
178
|
+
IOUtil.to_csv(self.valid_data, "nonexistent_dir/test.csv", file_exist_ok=True)
|
179
|
+
mock_makedirs.assert_called_once_with(os.path.dirname("nonexistent_dir/test.csv"), exist_ok=True)
|
180
|
+
|
181
|
+
def test_to_csv_empty_list(self):
|
182
|
+
mock_file = mock_open()
|
183
|
+
with patch("builtins.open", mock_file):
|
184
|
+
IOUtil.to_csv([], "empty.csv")
|
185
|
+
|
186
|
+
# Assert that open was never called, as the method should return early for an empty list
|
187
|
+
mock_file.assert_not_called()
|
188
|
+
|
189
|
+
def test_append_to_jsonl_valid_data(self):
|
190
|
+
mock_file = mock_open()
|
191
|
+
data = {'name': 'Alice', 'age': 30}
|
192
|
+
with patch("builtins.open", mock_file, create=True):
|
193
|
+
IOUtil.append_to_jsonl(data, "test.jsonl")
|
194
|
+
mock_file().write.assert_called_once_with(json.dumps(data) + "\n")
|
195
|
+
|
196
|
+
def test_append_to_jsonl_new_file(self):
|
197
|
+
mock_file = mock_open()
|
198
|
+
data = {'name': 'Bob', 'age': 25}
|
199
|
+
with patch("builtins.open", mock_file, create=True):
|
200
|
+
IOUtil.append_to_jsonl(data, "new_file.jsonl")
|
201
|
+
mock_file.assert_called_once_with("new_file.jsonl", "a")
|
202
|
+
mock_file().write.assert_called_once_with(json.dumps(data) + "\n")
|
203
|
+
|
204
|
+
def test_append_to_jsonl_non_serializable_data(self):
|
205
|
+
class NonSerializable:
|
206
|
+
pass
|
207
|
+
|
208
|
+
non_serializable_data = NonSerializable()
|
209
|
+
with self.assertRaises(TypeError):
|
210
|
+
IOUtil.append_to_jsonl(non_serializable_data, "test.jsonl")
|
211
|
+
|
212
|
+
def test_to_temp_with_valid_string(self):
|
213
|
+
test_input = "Test String"
|
214
|
+
temp_file = IOUtil.to_temp(test_input)
|
215
|
+
with open(temp_file.name, 'r') as file:
|
216
|
+
content = json.load(file)
|
217
|
+
self.assertEqual(content, [test_input])
|
218
|
+
os.remove(temp_file.name) # Clean up the temporary file
|
219
|
+
|
220
|
+
def test_to_temp_with_valid_iterable(self):
|
221
|
+
test_input = ["Test", "String"]
|
222
|
+
temp_file = IOUtil.to_temp(test_input)
|
223
|
+
with open(temp_file.name, 'r') as file:
|
224
|
+
content = json.load(file)
|
225
|
+
self.assertEqual(content, test_input)
|
226
|
+
os.remove(temp_file.name) # Clean up the temporary file
|
227
|
+
|
228
|
+
def test_to_temp_with_non_serializable_data(self):
|
229
|
+
class NonSerializable:
|
230
|
+
pass
|
231
|
+
|
232
|
+
non_serializable_data = NonSerializable()
|
233
|
+
with self.assertRaises(TypeError):
|
234
|
+
IOUtil.to_temp(non_serializable_data)
|
235
|
+
|
236
|
+
|
237
|
+
if __name__ == '__main__':
|
238
|
+
unittest.main()
|