infinity-parser2 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_utils.py ADDED
@@ -0,0 +1,689 @@
1
+ """Unit tests for InfinityParser2 utility functions."""
2
+
3
+ import io
4
+ import json
5
+ import os
6
+ import tempfile
7
+ import unittest
8
+ from pathlib import Path
9
+ from unittest.mock import MagicMock, patch
10
+
11
+ from PIL import Image
12
+
13
+ from infinity_parser2.utils import (
14
+ convert_pdf_to_images,
15
+ encode_file_to_base64,
16
+ extract_json_content,
17
+ load_image,
18
+ truncate_last_incomplete_element,
19
+ obtain_origin_hw,
20
+ restore_abs_bbox_coordinates,
21
+ convert_json_to_markdown,
22
+ postprocess_doc2json_result,
23
+ save_results,
24
+ )
25
+ from infinity_parser2.utils.model import ModelCache, _resolve_hf_endpoint
26
+
27
+
28
+ class TestLoadImage(unittest.TestCase):
29
+ """Tests for load_image utility function."""
30
+
31
+ def setUp(self):
32
+ """Set up test fixtures."""
33
+ self.temp_dir = tempfile.mkdtemp()
34
+ self.test_image_path = os.path.join(self.temp_dir, "test.png")
35
+ self.test_image = Image.new("RGB", (100, 100), color="red")
36
+ self.test_image.save(self.test_image_path)
37
+
38
+ def tearDown(self):
39
+ """Clean up temporary files."""
40
+ import shutil
41
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
42
+
43
+ def test_load_image_from_path(self):
44
+ """Test loading image from file path."""
45
+ loaded = load_image(self.test_image_path)
46
+ self.assertIsInstance(loaded, Image.Image)
47
+ self.assertEqual(loaded.size, (100, 100))
48
+ self.assertEqual(loaded.mode, "RGB")
49
+
50
+ def test_load_image_from_pil_image(self):
51
+ """Test loading image from PIL Image object."""
52
+ original = Image.new("RGB", (50, 50), color="blue")
53
+ loaded = load_image(original)
54
+ self.assertIsInstance(loaded, Image.Image)
55
+ self.assertEqual(loaded.size, (50, 50))
56
+ self.assertEqual(loaded.mode, "RGB")
57
+
58
+ def test_load_image_converts_to_rgb(self):
59
+ """Test that loaded image is always in RGB mode."""
60
+ rgba_image = Image.new("RGBA", (100, 100), color=(255, 0, 0, 128))
61
+ loaded = load_image(rgba_image)
62
+ self.assertEqual(loaded.mode, "RGB")
63
+
64
+ def test_load_image_unsupported_type_raises_error(self):
65
+ """Test that unsupported input type raises TypeError."""
66
+ with self.assertRaises(TypeError):
67
+ load_image(12345)
68
+ with self.assertRaises(TypeError):
69
+ load_image([1, 2, 3])
70
+
71
+
72
+ class TestEncodeFileToBase64(unittest.TestCase):
73
+ """Tests for encode_file_to_base64 utility function."""
74
+
75
+ def setUp(self):
76
+ """Set up test fixtures."""
77
+ self.temp_dir = tempfile.mkdtemp()
78
+
79
+ def tearDown(self):
80
+ """Clean up temporary files."""
81
+ import shutil
82
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
83
+
84
+ def test_encode_png_file(self):
85
+ """Test encoding PNG file to base64."""
86
+ png_path = os.path.join(self.temp_dir, "test.png")
87
+ Image.new("RGB", (100, 100), color="green").save(png_path)
88
+ base64_str, mime_type = encode_file_to_base64(png_path)
89
+ self.assertIsInstance(base64_str, str)
90
+ self.assertTrue(len(base64_str) > 0)
91
+ self.assertEqual(mime_type, "image/png")
92
+
93
+ def test_encode_jpg_file(self):
94
+ """Test encoding JPG file to base64."""
95
+ jpg_path = os.path.join(self.temp_dir, "test.jpg")
96
+ Image.new("RGB", (100, 100), color="yellow").save(jpg_path, "JPEG")
97
+ base64_str, mime_type = encode_file_to_base64(jpg_path)
98
+ self.assertIsInstance(base64_str, str)
99
+ self.assertTrue(len(base64_str) > 0)
100
+ self.assertEqual(mime_type, "image/jpeg")
101
+
102
+ def test_encode_jpeg_file(self):
103
+ """Test encoding JPEG file with .jpeg extension."""
104
+ jpeg_path = os.path.join(self.temp_dir, "test.jpeg")
105
+ Image.new("RGB", (100, 100), color="orange").save(jpeg_path, "JPEG")
106
+ base64_str, mime_type = encode_file_to_base64(jpeg_path)
107
+ self.assertEqual(mime_type, "image/jpeg")
108
+
109
+ def test_encode_webp_file(self):
110
+ """Test encoding WebP file."""
111
+ webp_path = os.path.join(self.temp_dir, "test.webp")
112
+ Image.new("RGB", (100, 100), color="purple").save(webp_path, "WEBP")
113
+ base64_str, mime_type = encode_file_to_base64(webp_path)
114
+ self.assertEqual(mime_type, "image/webp")
115
+
116
+ def test_encode_bmp_file(self):
117
+ """Test encoding BMP file."""
118
+ bmp_path = os.path.join(self.temp_dir, "test.bmp")
119
+ Image.new("RGB", (100, 100), color="cyan").save(bmp_path, "BMP")
120
+ base64_str, mime_type = encode_file_to_base64(bmp_path)
121
+ self.assertEqual(mime_type, "image/bmp")
122
+
123
+ def test_encode_pil_image(self):
124
+ """Test encoding PIL Image object."""
125
+ img = Image.new("RGB", (100, 100), color="magenta")
126
+ base64_str, mime_type = encode_file_to_base64(img)
127
+ self.assertIsInstance(base64_str, str)
128
+ self.assertTrue(len(base64_str) > 0)
129
+ self.assertEqual(mime_type, "image/jpeg") # Default for PIL without format
130
+
131
+ def test_encode_pil_image_with_format(self):
132
+ """Test encoding PIL Image with explicit format."""
133
+ img = Image.new("RGB", (100, 100), color="white")
134
+ self.assertIsNone(img.format)
135
+ _, mime_type = encode_file_to_base64(img)
136
+ self.assertEqual(mime_type, "image/jpeg")
137
+
138
+ png_path = os.path.join(self.temp_dir, "test.png")
139
+ Image.new("RGB", (100, 100), color="white").save(png_path)
140
+ with Image.open(png_path) as loaded_img:
141
+ self.assertEqual(loaded_img.format, "PNG")
142
+ _, mime_type = encode_file_to_base64(loaded_img)
143
+ self.assertEqual(mime_type, "image/png")
144
+
145
+ def test_encode_with_custom_min_max_pixels(self):
146
+ """Test encoding with custom min_pixels and max_pixels parameters."""
147
+ large_path = os.path.join(self.temp_dir, "large.png")
148
+ Image.new("RGB", (1000, 1000), color="blue").save(large_path)
149
+ base64_str, _ = encode_file_to_base64(large_path, min_pixels=100, max_pixels=50000)
150
+ self.assertIsInstance(base64_str, str)
151
+ self.assertTrue(len(base64_str) > 0)
152
+
153
+ def test_encode_unknown_extension_defaults_to_jpeg(self):
154
+ """Test that unknown extension defaults to image/jpeg."""
155
+ unknown_path = os.path.join(self.temp_dir, "test.unknown")
156
+ Image.new("RGB", (100, 100), color="gray").save(unknown_path, format="PNG")
157
+ _, mime_type = encode_file_to_base64(unknown_path)
158
+ self.assertEqual(mime_type, "image/jpeg")
159
+
160
+ def test_base64_decoding(self):
161
+ """Test that base64 output can be decoded correctly."""
162
+ import base64
163
+ png_path = os.path.join(self.temp_dir, "test.png")
164
+ original = Image.new("RGB", (100, 100), color="red")
165
+ original.save(png_path)
166
+ base64_str, _ = encode_file_to_base64(png_path)
167
+ decoded_bytes = base64.b64decode(base64_str)
168
+ decoded_image = Image.open(io.BytesIO(decoded_bytes))
169
+ self.assertIsInstance(decoded_image, Image.Image)
170
+
171
+
172
+ class TestConvertPdfToImages(unittest.TestCase):
173
+ """Tests for convert_pdf_to_images utility function."""
174
+
175
+ def setUp(self):
176
+ """Set up test fixtures."""
177
+ self.temp_dir = tempfile.mkdtemp()
178
+
179
+ def tearDown(self):
180
+ """Clean up temporary files."""
181
+ import shutil
182
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
183
+
184
+ def test_convert_pdf_returns_list(self):
185
+ """Test that convert_pdf_to_images returns a list."""
186
+ pdf_path = self._create_simple_pdf()
187
+ result = convert_pdf_to_images(pdf_path)
188
+ self.assertIsInstance(result, list)
189
+
190
+ def test_convert_pdf_returns_pil_images(self):
191
+ """Test that result contains PIL Image objects."""
192
+ pdf_path = self._create_simple_pdf()
193
+ result = convert_pdf_to_images(pdf_path)
194
+ for img in result:
195
+ self.assertIsInstance(img, Image.Image)
196
+
197
+ def test_convert_pdf_returns_rgb_images(self):
198
+ """Test that returned images are in RGB mode."""
199
+ pdf_path = self._create_simple_pdf()
200
+ result = convert_pdf_to_images(pdf_path)
201
+ for img in result:
202
+ self.assertEqual(img.mode, "RGB")
203
+
204
+ def _create_simple_pdf(self):
205
+ """Helper to create a simple PDF for testing."""
206
+ try:
207
+ import fitz
208
+ pdf_path = os.path.join(self.temp_dir, "test.pdf")
209
+ doc = fitz.open()
210
+ page = doc.new_page(width=595, height=842)
211
+ page.insert_text((100, 100), "Test PDF", fontsize=12)
212
+ doc.save(pdf_path)
213
+ doc.close()
214
+ return pdf_path
215
+ except ImportError:
216
+ self.skipTest("PyMuPDF not available for creating test PDF")
217
+
218
+
219
+ class TestExtractJsonContent(unittest.TestCase):
220
+ """Tests for extract_json_content utility function."""
221
+
222
+ def test_extract_from_markdown_block(self):
223
+ """Test extracting JSON from markdown code block."""
224
+ text = '```json\n{"key": "value"}\n```'
225
+ result = extract_json_content(text)
226
+ self.assertEqual(result, '{"key": "value"}')
227
+
228
+ def test_extract_partial_markdown_block(self):
229
+ """Test extracting from partial markdown block."""
230
+ text = '```json\n{"key": "value"'
231
+ result = extract_json_content(text)
232
+ self.assertEqual(result, '{"key": "value"')
233
+
234
+ def test_return_plain_text_if_no_block(self):
235
+ """Test returning text as-is if no markdown block."""
236
+ text = '{"key": "value"}'
237
+ result = extract_json_content(text)
238
+ self.assertEqual(result, text)
239
+
240
+
241
+ class TestTruncateLastIncompleteElement(unittest.TestCase):
242
+ """Tests for truncate_last_incomplete_element utility function."""
243
+
244
+ def test_no_truncation_for_short_text(self):
245
+ """Test that short text is not truncated."""
246
+ text = '[{"bbox": [0,0,100,100], "text": "hello"}]'
247
+ result, was_truncated = truncate_last_incomplete_element(text)
248
+ self.assertEqual(result, text)
249
+ self.assertFalse(was_truncated)
250
+
251
+ def test_truncate_incomplete_element(self):
252
+ """Test truncating incomplete element when text does not end with ]."""
253
+ # Text must NOT end with "]" to trigger truncation. The first bbox dict is
254
+ # complete, the second is truncated at the comma after its opening brace.
255
+ text = '[{"bbox": [0,0,100,100], "text": "hello"}, {"bbox": [0,0,200,200], "text": "incomplete'
256
+ result, was_truncated = truncate_last_incomplete_element(text)
257
+ self.assertTrue(was_truncated)
258
+ self.assertIn('[{"bbox": [0,0,100,100], "text": "hello"}]', result)
259
+
260
+ def test_no_truncation_for_single_element(self):
261
+ """Test that single element is not truncated."""
262
+ text = '[{"bbox": [0,0,100,100]'
263
+ result, was_truncated = truncate_last_incomplete_element(text)
264
+ self.assertFalse(was_truncated)
265
+
266
+
267
+ class TestObtainOriginHw(unittest.TestCase):
268
+ """Tests for obtain_origin_hw utility function."""
269
+
270
+ def test_from_pil_image(self):
271
+ """Test getting dimensions from PIL Image."""
272
+ img = Image.new("RGB", (800, 600))
273
+ h, w = obtain_origin_hw(img)
274
+ self.assertEqual(h, 600) # height
275
+ self.assertEqual(w, 800) # width
276
+
277
+ def test_from_file_path(self):
278
+ """Test getting dimensions from file path."""
279
+ temp_dir = tempfile.mkdtemp()
280
+ try:
281
+ img_path = os.path.join(temp_dir, "test.png")
282
+ Image.new("RGB", (1024, 768)).save(img_path)
283
+ h, w = obtain_origin_hw(img_path)
284
+ self.assertEqual(h, 768)
285
+ self.assertEqual(w, 1024)
286
+ finally:
287
+ import shutil
288
+ shutil.rmtree(temp_dir)
289
+
290
+ def test_fallback_on_error(self):
291
+ """Test fallback dimensions on error."""
292
+ h, w = obtain_origin_hw("/nonexistent/file.png")
293
+ self.assertEqual(h, 1000)
294
+ self.assertEqual(w, 1000)
295
+
296
+
297
+ class TestRestoreAbsBboxCoordinates(unittest.TestCase):
298
+ """Tests for restore_abs_bbox_coordinates utility function."""
299
+
300
+ def test_convert_normalized_to_absolute(self):
301
+ """Test converting normalized [0-1000] bboxes to pixel coordinates."""
302
+ ans = '[{"bbox": [0, 0, 500, 500], "text": "hello"}]'
303
+ result = restore_abs_bbox_coordinates(ans, 1000, 1000)
304
+ data = json.loads(result)
305
+ self.assertEqual(data[0]["bbox"], [0, 0, 500, 500])
306
+
307
+ def test_convert_with_actual_dimensions(self):
308
+ """Test converting with actual image dimensions."""
309
+ ans = '[{"bbox": [0, 0, 500, 500], "text": "hello"}]'
310
+ result = restore_abs_bbox_coordinates(ans, 2000, 3000)
311
+ data = json.loads(result)
312
+ self.assertEqual(data[0]["bbox"], [0, 0, 1500, 1000])
313
+
314
+ def test_invalid_json_unchanged(self):
315
+ """Test that invalid JSON is returned unchanged."""
316
+ ans = "not valid json"
317
+ result = restore_abs_bbox_coordinates(ans, 1000, 1000)
318
+ self.assertEqual(result, ans)
319
+
320
+
321
+ class TestConvertJsonToMarkdown(unittest.TestCase):
322
+ """Tests for convert_json_to_markdown utility function."""
323
+
324
+ def test_convert_layout_to_markdown(self):
325
+ """Test converting layout JSON to markdown."""
326
+ ans = json.dumps([
327
+ {"category": "title", "text": "# Document Title"},
328
+ {"category": "text", "text": "Paragraph content"},
329
+ {"category": "figure", "text": ""},
330
+ ])
331
+ result = convert_json_to_markdown(ans)
332
+ self.assertIn("# Document Title", result)
333
+ self.assertIn("Paragraph content", result)
334
+
335
+ def test_strip_header_footer_by_default(self):
336
+ """Test that headers and footers are stripped by default."""
337
+ ans = json.dumps([
338
+ {"category": "header", "text": "Header content"},
339
+ {"category": "text", "text": "Main content"},
340
+ {"category": "footer", "text": "Footer content"},
341
+ ])
342
+ result = convert_json_to_markdown(ans)
343
+ self.assertNotIn("Header content", result)
344
+ self.assertIn("Main content", result)
345
+ self.assertNotIn("Footer content", result)
346
+
347
+ def test_keep_header_footer_when_requested(self):
348
+ """Test keeping header and footer when keep_header_footer=True."""
349
+ ans = json.dumps([
350
+ {"category": "header", "text": "Header content"},
351
+ {"category": "text", "text": "Main content"},
352
+ {"category": "footer", "text": "Footer content"},
353
+ ])
354
+ result = convert_json_to_markdown(ans, keep_header_footer=True)
355
+ self.assertIn("Header content", result)
356
+ self.assertIn("Main content", result)
357
+ self.assertIn("Footer content", result)
358
+
359
+ def test_invalid_json_returned_unchanged(self):
360
+ """Test that invalid JSON is returned unchanged."""
361
+ ans = "not valid json"
362
+ result = convert_json_to_markdown(ans)
363
+ self.assertEqual(result, ans)
364
+
365
+
366
+ class TestPostprocessDoc2JsonResult(unittest.TestCase):
367
+ """Tests for postprocess_doc2json_result utility function."""
368
+
369
+ def test_full_postprocess_pipeline(self):
370
+ """Test the full postprocessing pipeline."""
371
+ raw_text = '```json\n[{"bbox": [0, 0, 500, 500], "category": "text", "text": "Test"}]\n```'
372
+ img = Image.new("RGB", (1000, 1000))
373
+ result = postprocess_doc2json_result(raw_text, img)
374
+ data = json.loads(result)
375
+ self.assertEqual(len(data), 1)
376
+ self.assertEqual(data[0]["category"], "text")
377
+
378
+
379
+ class TestImageMimeTypes(unittest.TestCase):
380
+ """Tests for IMAGE_MIME_TYPES constant."""
381
+
382
+ def test_mime_types_defined(self):
383
+ """Test that all expected MIME types are defined."""
384
+ from infinity_parser2.utils.image import IMAGE_MIME_TYPES
385
+ expected_types = {
386
+ ".jpg": "image/jpeg",
387
+ ".jpeg": "image/jpeg",
388
+ ".png": "image/png",
389
+ ".webp": "image/webp",
390
+ ".bmp": "image/bmp",
391
+ ".gif": "image/gif",
392
+ ".tiff": "image/tiff",
393
+ ".tif": "image/tiff",
394
+ }
395
+ self.assertEqual(IMAGE_MIME_TYPES, expected_types)
396
+
397
+
398
+ class TestIsSupportedFile(unittest.TestCase):
399
+ """Tests for is_supported_file utility function."""
400
+
401
+ def test_supported_image_extensions(self):
402
+ """Test that all expected image extensions are supported."""
403
+ from infinity_parser2.utils import is_supported_file
404
+ expected_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"}
405
+ for ext in expected_extensions:
406
+ self.assertTrue(is_supported_file(f"test/file{ext}"))
407
+
408
+ def test_supported_pdf(self):
409
+ """Test that PDF extension is supported."""
410
+ from infinity_parser2.utils import is_supported_file
411
+ self.assertTrue(is_supported_file("test/file.pdf"))
412
+ self.assertTrue(is_supported_file("test/file.PDF"))
413
+
414
+ def test_case_insensitive(self):
415
+ """Test that file extension check is case-insensitive."""
416
+ from infinity_parser2.utils import is_supported_file
417
+ self.assertTrue(is_supported_file("test/file.PNG"))
418
+ self.assertTrue(is_supported_file("test/file.JPEG"))
419
+ self.assertTrue(is_supported_file("test/file.TIFF"))
420
+
421
+ def test_unsupported_files(self):
422
+ """Test that unsupported files return False."""
423
+ from infinity_parser2.utils import is_supported_file
424
+ self.assertFalse(is_supported_file("test/file.txt"))
425
+ self.assertFalse(is_supported_file("test/file.doc"))
426
+ self.assertFalse(is_supported_file("test/file.xlsx"))
427
+
428
+
429
+ class TestGetFilesFromDirectory(unittest.TestCase):
430
+ """Tests for get_files_from_directory utility function."""
431
+
432
+ def setUp(self):
433
+ """Set up temporary test directory with test files."""
434
+ import shutil
435
+ self.temp_dir = tempfile.mkdtemp()
436
+ self.test_files = []
437
+ for ext in [".pdf", ".png", ".jpg", ".txt"]:
438
+ filepath = os.path.join(self.temp_dir, f"test{ext}")
439
+ Path(filepath).touch()
440
+ self.test_files.append(filepath)
441
+
442
+ def tearDown(self):
443
+ """Clean up temporary test directory."""
444
+ import shutil
445
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
446
+
447
+ def test_get_files_from_directory(self):
448
+ """Test getting supported files from directory."""
449
+ from infinity_parser2.utils import get_files_from_directory
450
+ files = get_files_from_directory(self.temp_dir)
451
+ self.assertEqual(len(files), 3) # .pdf, .png, .jpg (not .txt)
452
+ for f in files:
453
+ self.assertTrue(f.endswith((".pdf", ".png", ".jpg")))
454
+
455
+ def test_files_sorted(self):
456
+ """Test that files are returned in sorted order."""
457
+ from infinity_parser2.utils import get_files_from_directory
458
+ files = get_files_from_directory(self.temp_dir)
459
+ self.assertEqual(files, sorted(files))
460
+
461
+ def test_empty_directory(self):
462
+ """Test getting files from empty directory."""
463
+ from infinity_parser2.utils import get_files_from_directory
464
+ empty_dir = tempfile.mkdtemp()
465
+ try:
466
+ files = get_files_from_directory(empty_dir)
467
+ self.assertEqual(len(files), 0)
468
+ finally:
469
+ import shutil
470
+ shutil.rmtree(empty_dir, ignore_errors=True)
471
+
472
+ def test_nested_directory(self):
473
+ """Test getting files from nested directories."""
474
+ from infinity_parser2.utils import get_files_from_directory
475
+ nested_dir = os.path.join(self.temp_dir, "nested")
476
+ os.makedirs(nested_dir)
477
+ nested_file = os.path.join(nested_dir, "nested.pdf")
478
+ Path(nested_file).touch()
479
+
480
+ files = get_files_from_directory(self.temp_dir)
481
+ self.assertTrue(any(f.endswith("nested.pdf") for f in files))
482
+
483
+
484
+ class TestSaveResults(unittest.TestCase):
485
+ """Tests for save_results utility function.
486
+
487
+ New signature: save_results(inputs, results, output_dir, task_type="doc2json", output_format="md")
488
+ Returns None (writes files to disk).
489
+ output_format controls what file is saved: 'md' saves result.md, 'json' saves result.json (doc2json only).
490
+ """
491
+
492
+ def setUp(self):
493
+ """Set up test fixtures."""
494
+ self.temp_dir = tempfile.mkdtemp()
495
+
496
+ def tearDown(self):
497
+ """Clean up temporary files."""
498
+ import shutil
499
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
500
+
501
+ def test_save_results_returns_none(self):
502
+ """Test that save_results returns None."""
503
+ keys = ["test_key"]
504
+ results = ["Test result content"]
505
+ result = save_results(keys, results, self.temp_dir, task_type="doc2md")
506
+ self.assertIsNone(result)
507
+
508
+ def test_save_results_creates_directory(self):
509
+ """Test that save_results creates output directory."""
510
+ keys = ["test_key"]
511
+ results = ["Test result content"]
512
+ save_results(keys, results, self.temp_dir, task_type="doc2md")
513
+ self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "test_key")))
514
+
515
+ def test_save_results_writes_md_file(self):
516
+ """Test that save_results writes result.md for non-doc2json mode."""
517
+ keys = ["test_key"]
518
+ results = ["Test result content"]
519
+ save_results(keys, results, self.temp_dir, task_type="doc2md")
520
+ result_path = os.path.join(self.temp_dir, "test_key", "result.md")
521
+ self.assertTrue(os.path.exists(result_path))
522
+ with open(result_path, "r") as f:
523
+ content = f.read()
524
+ self.assertEqual(content, "Test result content")
525
+
526
+ def test_save_results_doc2json_mode_md(self):
527
+ """Test that save_results creates result.md for doc2json mode with output_format='md'."""
528
+ keys = ["test_key"]
529
+ json_result = json.dumps([{"bbox": [0, 0, 100, 100], "category": "text", "text": "Hello"}])
530
+ results = [json_result]
531
+ save_results(keys, results, self.temp_dir, task_type="doc2json", output_format="md")
532
+
533
+ json_path = os.path.join(self.temp_dir, "test_key", "result.json")
534
+ md_path = os.path.join(self.temp_dir, "test_key", "result.md")
535
+ self.assertFalse(os.path.exists(json_path))
536
+ self.assertTrue(os.path.exists(md_path))
537
+
538
+ def test_save_results_doc2json_mode_json(self):
539
+ """Test that save_results creates result.json for doc2json mode with output_format='json'."""
540
+ keys = ["test_key"]
541
+ json_result = json.dumps([{"bbox": [0, 0, 100, 100], "category": "text", "text": "Hello"}])
542
+ results = [json_result]
543
+ save_results(keys, results, self.temp_dir, task_type="doc2json", output_format="json")
544
+
545
+ json_path = os.path.join(self.temp_dir, "test_key", "result.json")
546
+ md_path = os.path.join(self.temp_dir, "test_key", "result.md")
547
+ self.assertTrue(os.path.exists(json_path))
548
+ self.assertFalse(os.path.exists(md_path))
549
+ with open(json_path, "r") as f:
550
+ self.assertEqual(f.read(), json_result)
551
+
552
+ def test_save_results_handles_multiple_keys(self):
553
+ """Test saving multiple results."""
554
+ keys = ["key1", "key2", "key3"]
555
+ results = ["Result 1", "Result 2", "Result 3"]
556
+ save_results(keys, results, self.temp_dir, task_type="doc2md")
557
+ for key in keys:
558
+ result_path = os.path.join(self.temp_dir, key, "result.md")
559
+ self.assertTrue(os.path.exists(result_path))
560
+
561
+ def test_save_results_output_dir_already_exists(self):
562
+ """Test that save_results works when output dir already exists."""
563
+ os.makedirs(self.temp_dir, exist_ok=True)
564
+ keys = ["test_key"]
565
+ results = ["Test content"]
566
+ save_results(keys, results, self.temp_dir, task_type="doc2md")
567
+ result_path = os.path.join(self.temp_dir, "test_key", "result.md")
568
+ self.assertTrue(os.path.exists(result_path))
569
+
570
+
571
+ class TestModelCache(unittest.TestCase):
572
+ """Tests for ModelCache utility class."""
573
+
574
+ def setUp(self):
575
+ """Set up test fixtures."""
576
+ self.temp_dir = tempfile.mkdtemp()
577
+ self.model_dir = os.path.join(self.temp_dir, "test_model")
578
+ os.makedirs(self.model_dir)
579
+ self.cache = ModelCache(cache_dir=self.temp_dir)
580
+
581
+ def tearDown(self):
582
+ """Clean up temporary files."""
583
+ import shutil
584
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
585
+
586
+ def test_cache_dir_creation(self):
587
+ """Test that cache directory is created if it doesn't exist."""
588
+ self.assertTrue(os.path.exists(self.temp_dir))
589
+ # models_file is created lazily on first save
590
+ self.assertIsInstance(self.cache.models_file, str)
591
+
592
+ def test_is_cached_returns_false_initially(self):
593
+ """Test that is_cached returns False for uncached models."""
594
+ self.assertFalse(self.cache.is_cached("uncached/model"))
595
+
596
+ def test_cache_model(self):
597
+ """Test caching a model."""
598
+ self.cache.cache_model("test/model", self.model_dir)
599
+ self.assertTrue(self.cache.is_cached("test/model"))
600
+
601
+ def test_get_cached_path(self):
602
+ """Test getting cached model path."""
603
+ self.cache.cache_model("test/model", self.model_dir)
604
+ self.assertEqual(self.cache.get_cached_path("test/model"), self.model_dir)
605
+
606
+ def test_get_cached_path_returns_none_for_uncached(self):
607
+ """Test that get_cached_path returns None for uncached models."""
608
+ self.assertIsNone(self.cache.get_cached_path("uncached/model"))
609
+
610
+ def test_resolve_local_path(self):
611
+ """Test resolving a local path directly."""
612
+ result = self.cache.resolve_model_path(self.model_dir)
613
+ self.assertEqual(result, self.model_dir)
614
+
615
+ def test_resolve_cached_model(self):
616
+ """Test resolving a cached model returns cached path."""
617
+ self.cache.cache_model("cached/model", self.model_dir)
618
+ result = self.cache.resolve_model_path("cached/model")
619
+ self.assertEqual(result, self.model_dir)
620
+
621
+ def test_resolve_nonexistent_model_triggers_download(self):
622
+ """Test that resolving nonexistent model attempts download."""
623
+ with patch.object(self.cache, 'download_and_cache') as mock_download:
624
+ mock_download.return_value = "/downloaded/path"
625
+ result = self.cache.resolve_model_path("nonexistent/model")
626
+ mock_download.assert_called_once_with("nonexistent/model")
627
+ self.assertEqual(result, "/downloaded/path")
628
+
629
+ def test_cache_persistence(self):
630
+ """Test that cache persists after reinitialization."""
631
+ self.cache.cache_model("test/model", self.model_dir)
632
+ new_cache = ModelCache(cache_dir=self.temp_dir)
633
+ self.assertTrue(new_cache.is_cached("test/model"))
634
+ self.assertEqual(new_cache.get_cached_path("test/model"), self.model_dir)
635
+
636
+ def test_invalid_json_cache_file(self):
637
+ """Test handling of invalid JSON in cache file."""
638
+ cache_file = os.path.join(self.temp_dir, "models_cache.json")
639
+ with open(cache_file, "w") as f:
640
+ f.write("invalid json")
641
+ new_cache = ModelCache(cache_dir=self.temp_dir)
642
+ self.assertFalse(new_cache.is_cached("any/model"))
643
+
644
+
645
+ class TestResolveHFEndpoint(unittest.TestCase):
646
+ """Tests for _resolve_hf_endpoint function."""
647
+
648
+ @patch('infinity_parser2.utils.model._check_endpoint_reachable')
649
+ def test_resolves_default_endpoint(self, mock_check):
650
+ """Test that default endpoint is returned when reachable."""
651
+ mock_check.return_value = True
652
+ result = _resolve_hf_endpoint()
653
+ self.assertEqual(result, "https://huggingface.co")
654
+
655
+ @patch('infinity_parser2.utils.model._check_endpoint_reachable')
656
+ def test_falls_back_to_mirror(self, mock_check):
657
+ """Test that mirror is used when default is not reachable."""
658
+ mock_check.return_value = False
659
+ result = _resolve_hf_endpoint()
660
+ self.assertEqual(result, "https://hf-mirror.com")
661
+
662
+
663
+ class TestCheckEndpointReachable(unittest.TestCase):
664
+ """Tests for _check_endpoint_reachable function."""
665
+
666
+ @patch('urllib.request.urlopen')
667
+ def test_returns_true_on_200(self, mock_urlopen):
668
+ """Test that True is returned on HTTP 200."""
669
+ from infinity_parser2.utils.model import _check_endpoint_reachable
670
+ mock_response = MagicMock()
671
+ mock_response.status = 200
672
+ mock_response.__enter__ = MagicMock(return_value=mock_response)
673
+ mock_response.__exit__ = MagicMock(return_value=False)
674
+ mock_urlopen.return_value = mock_response
675
+ result = _check_endpoint_reachable("https://example.com")
676
+ self.assertTrue(result)
677
+
678
+ @patch('urllib.request.urlopen')
679
+ def test_returns_false_on_error(self, mock_urlopen):
680
+ """Test that False is returned on connection error."""
681
+ from infinity_parser2.utils.model import _check_endpoint_reachable
682
+ import socket
683
+ mock_urlopen.side_effect = socket.timeout()
684
+ result = _check_endpoint_reachable("https://example.com")
685
+ self.assertFalse(result)
686
+
687
+
688
+ if __name__ == "__main__":
689
+ unittest.main()