infinity-parser2 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- infinity_parser2/__init__.py +28 -0
- infinity_parser2/__main__.py +6 -0
- infinity_parser2/backends/__init__.py +13 -0
- infinity_parser2/backends/base.py +61 -0
- infinity_parser2/backends/transformers.py +159 -0
- infinity_parser2/backends/vllm_engine.py +117 -0
- infinity_parser2/backends/vllm_server.py +148 -0
- infinity_parser2/cli.py +207 -0
- infinity_parser2/parser.py +278 -0
- infinity_parser2/prompts.py +57 -0
- infinity_parser2/utils/__init__.py +43 -0
- infinity_parser2/utils/file.py +190 -0
- infinity_parser2/utils/image.py +99 -0
- infinity_parser2/utils/model.py +243 -0
- infinity_parser2/utils/pdf.py +46 -0
- infinity_parser2/utils/utils.py +159 -0
- infinity_parser2-0.1.0.dist-info/METADATA +310 -0
- infinity_parser2-0.1.0.dist-info/RECORD +25 -0
- infinity_parser2-0.1.0.dist-info/WHEEL +5 -0
- infinity_parser2-0.1.0.dist-info/entry_points.txt +2 -0
- infinity_parser2-0.1.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_backends.py +490 -0
- tests/test_parser.py +464 -0
- tests/test_utils.py +689 -0
tests/test_parser.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""Unit tests for InfinityParser2 main parser class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
import tempfile
|
|
6
|
+
import unittest
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from unittest.mock import MagicMock, patch
|
|
9
|
+
|
|
10
|
+
from PIL import Image
|
|
11
|
+
|
|
12
|
+
from infinity_parser2 import InfinityParser2, SUPPORTED_TASK_TYPES
|
|
13
|
+
from infinity_parser2.backends import TransformersBackend
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestInfinityParser2Initialization(unittest.TestCase):
|
|
17
|
+
"""Tests for InfinityParser2 initialization and configuration."""
|
|
18
|
+
|
|
19
|
+
@patch("infinity_parser2.parser.get_model_cache")
|
|
20
|
+
@patch("infinity_parser2.backends.vllm_engine.VLLMEngineBackend.__init__", return_value=None)
|
|
21
|
+
def test_default_initialization(self, mock_backend_init, mock_get_cache):
|
|
22
|
+
"""Test default initialization with all default parameters."""
|
|
23
|
+
mock_cache = MagicMock()
|
|
24
|
+
mock_cache.resolve_model_path.return_value = "/cached/path/infly_Infinity-Parser2-Pro"
|
|
25
|
+
mock_get_cache.return_value = mock_cache
|
|
26
|
+
parser = InfinityParser2()
|
|
27
|
+
self.assertEqual(parser.model_name, "infly/Infinity-Parser2-Pro")
|
|
28
|
+
self.assertEqual(parser.backend_name, "vllm-engine")
|
|
29
|
+
self.assertEqual(parser.device, "cuda")
|
|
30
|
+
self.assertEqual(parser.api_url, "http://localhost:8000/v1/chat/completions")
|
|
31
|
+
self.assertEqual(parser.api_key, "EMPTY")
|
|
32
|
+
self.assertEqual(parser.min_pixels, 2048)
|
|
33
|
+
self.assertEqual(parser.max_pixels, 16777216)
|
|
34
|
+
self.assertIsInstance(parser.tensor_parallel_size, int)
|
|
35
|
+
|
|
36
|
+
@patch("infinity_parser2.parser.get_model_cache")
|
|
37
|
+
@patch("infinity_parser2.backends.transformers.TransformersBackend.__init__", return_value=None)
|
|
38
|
+
def test_custom_initialization(self, mock_backend_init, mock_get_cache):
|
|
39
|
+
"""Test initialization with custom parameters."""
|
|
40
|
+
mock_cache = MagicMock()
|
|
41
|
+
mock_cache.resolve_model_path.return_value = "/cached/path/custom/model"
|
|
42
|
+
mock_get_cache.return_value = mock_cache
|
|
43
|
+
parser = InfinityParser2(
|
|
44
|
+
model_name="custom/model",
|
|
45
|
+
backend="transformers",
|
|
46
|
+
tensor_parallel_size=2,
|
|
47
|
+
api_url="http://custom:8000/v1",
|
|
48
|
+
api_key="test-key",
|
|
49
|
+
min_pixels=1024,
|
|
50
|
+
max_pixels=8192,
|
|
51
|
+
)
|
|
52
|
+
self.assertEqual(parser.model_name, "custom/model")
|
|
53
|
+
self.assertEqual(parser.backend_name, "transformers")
|
|
54
|
+
self.assertEqual(parser.tensor_parallel_size, 2)
|
|
55
|
+
self.assertEqual(parser.device, "cuda")
|
|
56
|
+
self.assertEqual(parser.api_url, "http://custom:8000/v1")
|
|
57
|
+
self.assertEqual(parser.api_key, "test-key")
|
|
58
|
+
self.assertEqual(parser.min_pixels, 1024)
|
|
59
|
+
self.assertEqual(parser.max_pixels, 8192)
|
|
60
|
+
|
|
61
|
+
def test_device_must_be_cuda(self):
|
|
62
|
+
"""Test that device must be 'cuda', raising ValueError otherwise."""
|
|
63
|
+
with self.assertRaises(ValueError) as context:
|
|
64
|
+
InfinityParser2(device="cpu")
|
|
65
|
+
self.assertIn("cuda", str(context.exception))
|
|
66
|
+
|
|
67
|
+
@patch("infinity_parser2.parser.get_model_cache")
|
|
68
|
+
@patch("infinity_parser2.backends.vllm_engine.VLLMEngineBackend.__init__", return_value=None)
|
|
69
|
+
def test_backend_case_insensitive(self, mock_backend_init, mock_get_cache):
|
|
70
|
+
"""Test that backend name is case-insensitive."""
|
|
71
|
+
mock_cache = MagicMock()
|
|
72
|
+
mock_cache.resolve_model_path.return_value = "/cached/path/model"
|
|
73
|
+
mock_get_cache.return_value = mock_cache
|
|
74
|
+
parser1 = InfinityParser2(backend="VLLM-ENGINE")
|
|
75
|
+
parser2 = InfinityParser2(backend="vllm-engine")
|
|
76
|
+
self.assertEqual(parser1.backend_name, "vllm-engine")
|
|
77
|
+
self.assertEqual(parser2.backend_name, "vllm-engine")
|
|
78
|
+
with self.assertRaises(ValueError) as context:
|
|
79
|
+
InfinityParser2(backend="VLLMEngine")
|
|
80
|
+
self.assertIn("Unsupported backend", str(context.exception))
|
|
81
|
+
|
|
82
|
+
def test_unsupported_backend_raises_error(self):
|
|
83
|
+
"""Test that unsupported backend raises ValueError."""
|
|
84
|
+
with self.assertRaises(ValueError) as context:
|
|
85
|
+
InfinityParser2(backend="unsupported-backend")
|
|
86
|
+
self.assertIn("Unsupported backend", str(context.exception))
|
|
87
|
+
self.assertIn("unsupported-backend", str(context.exception))
|
|
88
|
+
|
|
89
|
+
def test_backend_registry_contains_all_backends(self):
|
|
90
|
+
"""Test that BACKEND_REGISTRY contains all expected backends."""
|
|
91
|
+
from infinity_parser2.parser import BACKEND_REGISTRY
|
|
92
|
+
expected_backends = {"transformers", "vllm-engine", "vllm-server"}
|
|
93
|
+
self.assertEqual(set(BACKEND_REGISTRY.keys()), expected_backends)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestInfinityParser2BackendProperty(unittest.TestCase):
|
|
97
|
+
"""Tests for backend property."""
|
|
98
|
+
|
|
99
|
+
@patch("infinity_parser2.backends.vllm_engine.LLM")
|
|
100
|
+
def test_backend_is_initialized_on_init(self, mock_llm):
|
|
101
|
+
"""Test that backend is initialized during __init__."""
|
|
102
|
+
mock_llm.return_value = MagicMock()
|
|
103
|
+
parser = InfinityParser2(backend="vllm-engine")
|
|
104
|
+
self.assertIsNotNone(parser._backend)
|
|
105
|
+
|
|
106
|
+
@patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText")
|
|
107
|
+
@patch("infinity_parser2.backends.transformers.AutoProcessor")
|
|
108
|
+
def test_backend_returns_correct_type(self, mock_processor, mock_model):
|
|
109
|
+
"""Test that backend returns correct backend instance."""
|
|
110
|
+
mock_model.from_pretrained.return_value = MagicMock()
|
|
111
|
+
mock_processor.from_pretrained.return_value = MagicMock()
|
|
112
|
+
parser = InfinityParser2(backend="transformers")
|
|
113
|
+
self.assertIsInstance(parser._backend, TransformersBackend)
|
|
114
|
+
|
|
115
|
+
@patch("infinity_parser2.backends.vllm_engine.LLM")
|
|
116
|
+
def test_backend_same_instance_on_multiple_accesses(self, mock_llm):
|
|
117
|
+
"""Test that backend returns the same instance."""
|
|
118
|
+
mock_llm.return_value = MagicMock()
|
|
119
|
+
parser = InfinityParser2(backend="vllm-engine")
|
|
120
|
+
backend1 = parser._backend
|
|
121
|
+
backend2 = parser._backend
|
|
122
|
+
self.assertIs(backend1, backend2)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TestInfinityParser2ParseInputValidation(unittest.TestCase):
|
|
126
|
+
"""Tests for parse method input validation."""
|
|
127
|
+
|
|
128
|
+
def setUp(self):
|
|
129
|
+
"""Set up test fixtures."""
|
|
130
|
+
self.temp_dir = tempfile.mkdtemp()
|
|
131
|
+
self.temp_file = os.path.join(self.temp_dir, "test.png")
|
|
132
|
+
Image.new("RGB", (100, 100), color="white").save(self.temp_file)
|
|
133
|
+
|
|
134
|
+
def tearDown(self):
|
|
135
|
+
"""Clean up temporary files."""
|
|
136
|
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
137
|
+
|
|
138
|
+
def _make_parser(self):
|
|
139
|
+
"""Create parser with mocked backend."""
|
|
140
|
+
parser = InfinityParser2(backend="vllm-engine")
|
|
141
|
+
parser._backend = MagicMock()
|
|
142
|
+
return parser
|
|
143
|
+
|
|
144
|
+
def test_parse_nonexistent_file_raises_error(self):
|
|
145
|
+
"""Test that parsing nonexistent file raises FileNotFoundError."""
|
|
146
|
+
parser = self._make_parser()
|
|
147
|
+
with self.assertRaises(FileNotFoundError):
|
|
148
|
+
parser.parse("nonexistent_file.pdf")
|
|
149
|
+
|
|
150
|
+
def test_parse_unsupported_file_raises_error(self):
|
|
151
|
+
"""Test that parsing unsupported file type raises ValueError."""
|
|
152
|
+
txt_file = os.path.join(self.temp_dir, "test.txt")
|
|
153
|
+
Path(txt_file).touch()
|
|
154
|
+
parser = self._make_parser()
|
|
155
|
+
with self.assertRaises(ValueError) as context:
|
|
156
|
+
parser.parse(txt_file)
|
|
157
|
+
self.assertIn("Unsupported file type", str(context.exception))
|
|
158
|
+
|
|
159
|
+
def test_parse_list_with_invalid_item_raises_error(self):
|
|
160
|
+
"""Test that parsing list with invalid item raises TypeError."""
|
|
161
|
+
parser = self._make_parser()
|
|
162
|
+
with self.assertRaises(TypeError):
|
|
163
|
+
parser.parse([123, "not_a_string"])
|
|
164
|
+
|
|
165
|
+
def test_parse_list_with_nonexistent_file_raises_error(self):
|
|
166
|
+
"""Test that parsing list with nonexistent file raises FileNotFoundError."""
|
|
167
|
+
parser = self._make_parser()
|
|
168
|
+
with self.assertRaises(FileNotFoundError):
|
|
169
|
+
parser.parse([self.temp_file, "nonexistent.pdf"])
|
|
170
|
+
|
|
171
|
+
def test_parse_list_with_unsupported_file_raises_error(self):
|
|
172
|
+
"""Test that parsing list with unsupported file raises ValueError."""
|
|
173
|
+
txt_file = os.path.join(self.temp_dir, "test.txt")
|
|
174
|
+
Path(txt_file).touch()
|
|
175
|
+
parser = self._make_parser()
|
|
176
|
+
with self.assertRaises(ValueError):
|
|
177
|
+
parser.parse([self.temp_file, txt_file])
|
|
178
|
+
|
|
179
|
+
def test_parse_unsupported_input_type_raises_error(self):
|
|
180
|
+
"""Test that parsing unsupported input type raises TypeError."""
|
|
181
|
+
parser = self._make_parser()
|
|
182
|
+
with self.assertRaises(TypeError) as context:
|
|
183
|
+
parser.parse(12345)
|
|
184
|
+
self.assertIn("Unsupported input type", str(context.exception))
|
|
185
|
+
|
|
186
|
+
def test_parse_directory_with_no_supported_files_raises_error(self):
|
|
187
|
+
"""Test that parsing directory with no supported files raises ValueError."""
|
|
188
|
+
empty_dir = tempfile.mkdtemp()
|
|
189
|
+
txt_file = os.path.join(empty_dir, "test.txt")
|
|
190
|
+
Path(txt_file).touch()
|
|
191
|
+
try:
|
|
192
|
+
parser = self._make_parser()
|
|
193
|
+
with self.assertRaises(ValueError) as context:
|
|
194
|
+
parser.parse(empty_dir)
|
|
195
|
+
self.assertIn("No supported files found", str(context.exception))
|
|
196
|
+
finally:
|
|
197
|
+
shutil.rmtree(empty_dir, ignore_errors=True)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class TestInfinityParser2MockedParse(unittest.TestCase):
|
|
201
|
+
"""Tests for parse method with mocked backend."""
|
|
202
|
+
|
|
203
|
+
def setUp(self):
|
|
204
|
+
"""Set up test fixtures with mocked backend."""
|
|
205
|
+
self.temp_dir = tempfile.mkdtemp()
|
|
206
|
+
|
|
207
|
+
def tearDown(self):
|
|
208
|
+
"""Clean up temporary files."""
|
|
209
|
+
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
210
|
+
|
|
211
|
+
def _make_parser(self):
|
|
212
|
+
"""Create parser with mocked backend."""
|
|
213
|
+
parser = InfinityParser2(backend="vllm-engine")
|
|
214
|
+
parser._backend = MagicMock()
|
|
215
|
+
return parser
|
|
216
|
+
|
|
217
|
+
def test_parse_single_file_returns_string(self):
|
|
218
|
+
"""Test that parsing single file returns string."""
|
|
219
|
+
parser = self._make_parser()
|
|
220
|
+
parser._backend.parse_batch.return_value = ["Parsed content"]
|
|
221
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
222
|
+
try:
|
|
223
|
+
result = parser.parse(temp_file.name)
|
|
224
|
+
self.assertIsInstance(result, str)
|
|
225
|
+
self.assertEqual(result, "Parsed content")
|
|
226
|
+
finally:
|
|
227
|
+
os.unlink(temp_file.name)
|
|
228
|
+
|
|
229
|
+
def test_parse_single_image_doc2json_converts_to_markdown(self):
|
|
230
|
+
"""Test that DOC2JSON mode converts JSON to Markdown for single image."""
|
|
231
|
+
parser = self._make_parser()
|
|
232
|
+
parser._backend.parse_batch.return_value = ['[{"category": "title", "text": "# Document Title"}]']
|
|
233
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
234
|
+
try:
|
|
235
|
+
result = parser.parse(temp_file.name)
|
|
236
|
+
self.assertIsInstance(result, str)
|
|
237
|
+
self.assertIn("# Document Title", result)
|
|
238
|
+
self.assertNotIn("[{", result)
|
|
239
|
+
finally:
|
|
240
|
+
os.unlink(temp_file.name)
|
|
241
|
+
|
|
242
|
+
def test_parse_list_returns_list(self):
|
|
243
|
+
"""Test that parsing list returns list of strings."""
|
|
244
|
+
parser = self._make_parser()
|
|
245
|
+
parser._backend.parse_batch.return_value = ["Result 1", "Result 2"]
|
|
246
|
+
temp_files = []
|
|
247
|
+
for i in range(2):
|
|
248
|
+
f = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
249
|
+
temp_files.append(f.name)
|
|
250
|
+
try:
|
|
251
|
+
result = parser.parse(temp_files)
|
|
252
|
+
self.assertIsInstance(result, list)
|
|
253
|
+
self.assertEqual(len(result), 2)
|
|
254
|
+
finally:
|
|
255
|
+
for f in temp_files:
|
|
256
|
+
os.unlink(f)
|
|
257
|
+
|
|
258
|
+
def test_parse_pil_image_doc2json(self):
|
|
259
|
+
"""Test parsing PIL Image object in DOC2JSON mode converts JSON to Markdown."""
|
|
260
|
+
parser = self._make_parser()
|
|
261
|
+
parser._backend.parse_batch.return_value = ['[{"category": "text", "text": "Image content"}]']
|
|
262
|
+
img = Image.new("RGB", (100, 100), color="white")
|
|
263
|
+
result = parser.parse(img)
|
|
264
|
+
self.assertIsInstance(result, str)
|
|
265
|
+
self.assertIn("Image content", result)
|
|
266
|
+
self.assertNotIn("[{", result)
|
|
267
|
+
|
|
268
|
+
def test_parse_with_output_dir_creates_subdirectories(self):
|
|
269
|
+
"""Test that parsing with output_dir creates subdirectories."""
|
|
270
|
+
parser = self._make_parser()
|
|
271
|
+
parser._backend.parse_batch.return_value = ["Result content"]
|
|
272
|
+
temp_file = os.path.join(self.temp_dir, "test.png")
|
|
273
|
+
Image.new("RGB", (100, 100), color="white").save(temp_file)
|
|
274
|
+
output_dir = tempfile.mkdtemp()
|
|
275
|
+
try:
|
|
276
|
+
parser.parse(temp_file, output_dir=output_dir)
|
|
277
|
+
subdir = os.path.join(output_dir, "test.png")
|
|
278
|
+
self.assertTrue(os.path.exists(subdir))
|
|
279
|
+
result_file = os.path.join(subdir, "result.md")
|
|
280
|
+
self.assertTrue(os.path.exists(result_file))
|
|
281
|
+
finally:
|
|
282
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
283
|
+
|
|
284
|
+
def test_parse_batch_size_passed_to_backend(self):
|
|
285
|
+
"""Test that batch_size is passed to backend correctly."""
|
|
286
|
+
parser = self._make_parser()
|
|
287
|
+
parser._backend.parse_batch.return_value = ["Result"]
|
|
288
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
289
|
+
try:
|
|
290
|
+
parser.parse(temp_file.name, batch_size=4)
|
|
291
|
+
parser._backend.parse_batch.assert_called_once()
|
|
292
|
+
call_kwargs = parser._backend.parse_batch.call_args[1]
|
|
293
|
+
self.assertEqual(call_kwargs["batch_size"], 4)
|
|
294
|
+
finally:
|
|
295
|
+
os.unlink(temp_file.name)
|
|
296
|
+
|
|
297
|
+
def test_parse_with_task_type_doc2md(self):
|
|
298
|
+
"""Test parsing with doc2md task type."""
|
|
299
|
+
parser = self._make_parser()
|
|
300
|
+
parser._backend.parse_batch.return_value = ["# Title\n\nParagraph text"]
|
|
301
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
302
|
+
try:
|
|
303
|
+
result = parser.parse(temp_file.name, task_type="doc2md")
|
|
304
|
+
self.assertIsInstance(result, str)
|
|
305
|
+
self.assertIn("# Title", result)
|
|
306
|
+
finally:
|
|
307
|
+
os.unlink(temp_file.name)
|
|
308
|
+
|
|
309
|
+
def test_parse_with_output_format_json(self):
|
|
310
|
+
"""Test parsing with output_format='json' returns raw JSON."""
|
|
311
|
+
parser = self._make_parser()
|
|
312
|
+
parser._backend.parse_batch.return_value = ['[{"bbox": [0,0,100,100], "category": "text", "text": "Hello"}]']
|
|
313
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
314
|
+
try:
|
|
315
|
+
result = parser.parse(temp_file.name, output_format="json")
|
|
316
|
+
self.assertIsInstance(result, str)
|
|
317
|
+
self.assertIn('"category": "text"', result)
|
|
318
|
+
self.assertIn('"Hello"', result)
|
|
319
|
+
finally:
|
|
320
|
+
os.unlink(temp_file.name)
|
|
321
|
+
|
|
322
|
+
def test_parse_with_output_format_invalid(self):
|
|
323
|
+
"""Test that invalid output_format raises ValueError."""
|
|
324
|
+
parser = self._make_parser()
|
|
325
|
+
parser._backend.parse_batch.return_value = ["Result"]
|
|
326
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
327
|
+
try:
|
|
328
|
+
with self.assertRaises(ValueError) as context:
|
|
329
|
+
parser.parse(temp_file.name, output_format="xml")
|
|
330
|
+
self.assertIn("output_format must be one of", str(context.exception))
|
|
331
|
+
finally:
|
|
332
|
+
os.unlink(temp_file.name)
|
|
333
|
+
|
|
334
|
+
def test_parse_doc2md_cannot_use_output_format_json(self):
|
|
335
|
+
"""Test that DOC2MD mode cannot use output_format='json'."""
|
|
336
|
+
parser = self._make_parser()
|
|
337
|
+
parser._backend.parse_batch.return_value = ["# Title"]
|
|
338
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
339
|
+
try:
|
|
340
|
+
with self.assertRaises(ValueError) as context:
|
|
341
|
+
parser.parse(temp_file.name, task_type="doc2md", output_format="json")
|
|
342
|
+
self.assertIn("output_format='json' is only supported for doc2json tasks", str(context.exception))
|
|
343
|
+
finally:
|
|
344
|
+
os.unlink(temp_file.name)
|
|
345
|
+
|
|
346
|
+
def test_parse_custom_prompt_cannot_use_output_format_json(self):
|
|
347
|
+
"""Test that custom prompt cannot use output_format='json'."""
|
|
348
|
+
parser = self._make_parser()
|
|
349
|
+
parser._backend.parse_batch.return_value = ["Custom result"]
|
|
350
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
351
|
+
try:
|
|
352
|
+
with self.assertRaises(ValueError) as context:
|
|
353
|
+
parser.parse(temp_file.name, task_type="custom", custom_prompt="Custom instruction", output_format="json")
|
|
354
|
+
self.assertIn("output_format='json' is only supported for doc2json tasks", str(context.exception))
|
|
355
|
+
finally:
|
|
356
|
+
os.unlink(temp_file.name)
|
|
357
|
+
|
|
358
|
+
def test_parse_with_output_dir_and_output_format_json(self):
|
|
359
|
+
"""Test parsing with output_dir and output_format='json' saves only JSON."""
|
|
360
|
+
parser = self._make_parser()
|
|
361
|
+
parser._backend.parse_batch.return_value = ['[{"bbox": [0,0,100,100], "category": "text", "text": "Hello"}]']
|
|
362
|
+
temp_file = os.path.join(self.temp_dir, "test.png")
|
|
363
|
+
Image.new("RGB", (100, 100), color="white").save(temp_file)
|
|
364
|
+
output_dir = tempfile.mkdtemp()
|
|
365
|
+
try:
|
|
366
|
+
parser.parse(temp_file, output_dir=output_dir, output_format="json")
|
|
367
|
+
subdir = os.path.join(output_dir, "test.png")
|
|
368
|
+
self.assertTrue(os.path.exists(subdir))
|
|
369
|
+
json_file = os.path.join(subdir, "result.json")
|
|
370
|
+
self.assertTrue(os.path.exists(json_file))
|
|
371
|
+
md_file = os.path.join(subdir, "result.md")
|
|
372
|
+
self.assertFalse(os.path.exists(md_file))
|
|
373
|
+
with open(json_file, "r") as f:
|
|
374
|
+
content = f.read()
|
|
375
|
+
self.assertIn('"category": "text"', content)
|
|
376
|
+
finally:
|
|
377
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
378
|
+
|
|
379
|
+
def test_parse_with_output_dir_and_output_format_md(self):
|
|
380
|
+
"""Test parsing with output_dir and output_format='md' (default) saves only Markdown."""
|
|
381
|
+
parser = self._make_parser()
|
|
382
|
+
parser._backend.parse_batch.return_value = ['[{"bbox": [0,0,100,100], "category": "text", "text": "Hello"}]']
|
|
383
|
+
temp_file = os.path.join(self.temp_dir, "test.png")
|
|
384
|
+
Image.new("RGB", (100, 100), color="white").save(temp_file)
|
|
385
|
+
output_dir = tempfile.mkdtemp()
|
|
386
|
+
try:
|
|
387
|
+
parser.parse(temp_file, output_dir=output_dir, output_format="md")
|
|
388
|
+
subdir = os.path.join(output_dir, "test.png")
|
|
389
|
+
self.assertTrue(os.path.exists(subdir))
|
|
390
|
+
md_file = os.path.join(subdir, "result.md")
|
|
391
|
+
self.assertTrue(os.path.exists(md_file))
|
|
392
|
+
json_file = os.path.join(subdir, "result.json")
|
|
393
|
+
self.assertFalse(os.path.exists(json_file))
|
|
394
|
+
finally:
|
|
395
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
396
|
+
|
|
397
|
+
def test_parse_with_custom_prompt(self):
|
|
398
|
+
"""Test parsing with custom_prompt (task_type='custom')."""
|
|
399
|
+
parser = self._make_parser()
|
|
400
|
+
parser._backend.parse_batch.return_value = ["Custom result"]
|
|
401
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
402
|
+
try:
|
|
403
|
+
result = parser.parse(temp_file.name, task_type="custom", custom_prompt="Custom instruction")
|
|
404
|
+
self.assertIsInstance(result, str)
|
|
405
|
+
self.assertEqual(result, "Custom result")
|
|
406
|
+
call_args = parser._backend.parse_batch.call_args
|
|
407
|
+
self.assertEqual(call_args[0][1], "Custom instruction")
|
|
408
|
+
finally:
|
|
409
|
+
os.unlink(temp_file.name)
|
|
410
|
+
|
|
411
|
+
def test_parse_directory(self):
|
|
412
|
+
"""Test parsing a directory of files."""
|
|
413
|
+
parser = self._make_parser()
|
|
414
|
+
parser._backend.parse_batch.return_value = ["Result1", "Result2"]
|
|
415
|
+
dir_path = tempfile.mkdtemp()
|
|
416
|
+
try:
|
|
417
|
+
file1 = os.path.join(dir_path, "file1.png")
|
|
418
|
+
file2 = os.path.join(dir_path, "file2.png")
|
|
419
|
+
Image.new("RGB", (100, 100), color="red").save(file1)
|
|
420
|
+
Image.new("RGB", (100, 100), color="blue").save(file2)
|
|
421
|
+
result = parser.parse(dir_path)
|
|
422
|
+
self.assertIsInstance(result, dict)
|
|
423
|
+
self.assertEqual(len(result), 2)
|
|
424
|
+
for path, content in result.items():
|
|
425
|
+
self.assertIsInstance(path, str)
|
|
426
|
+
self.assertIsInstance(content, str)
|
|
427
|
+
finally:
|
|
428
|
+
shutil.rmtree(dir_path)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class TestTaskType(unittest.TestCase):
|
|
432
|
+
"""Tests for task_type parameter and SUPPORTED_TASK_TYPES."""
|
|
433
|
+
|
|
434
|
+
def test_supported_task_types(self):
|
|
435
|
+
"""Test SUPPORTED_TASK_TYPES contains expected values."""
|
|
436
|
+
self.assertEqual(SUPPORTED_TASK_TYPES, ["doc2json", "doc2md", "custom"])
|
|
437
|
+
|
|
438
|
+
def test_task_type_doc2json(self):
|
|
439
|
+
"""Test that 'doc2json' is a supported task type."""
|
|
440
|
+
self.assertIn("doc2json", SUPPORTED_TASK_TYPES)
|
|
441
|
+
|
|
442
|
+
def test_task_type_doc2md(self):
|
|
443
|
+
"""Test that 'doc2md' is a supported task type."""
|
|
444
|
+
self.assertIn("doc2md", SUPPORTED_TASK_TYPES)
|
|
445
|
+
|
|
446
|
+
def test_task_type_custom(self):
|
|
447
|
+
"""Test that 'custom' is a supported task type."""
|
|
448
|
+
self.assertIn("custom", SUPPORTED_TASK_TYPES)
|
|
449
|
+
|
|
450
|
+
def test_prompt_doc2json_defined(self):
|
|
451
|
+
"""Test that PROMPT_DOC2JSON is defined."""
|
|
452
|
+
from infinity_parser2 import PROMPT_DOC2JSON
|
|
453
|
+
self.assertIsInstance(PROMPT_DOC2JSON, str)
|
|
454
|
+
self.assertIn("layout", PROMPT_DOC2JSON.lower())
|
|
455
|
+
|
|
456
|
+
def test_prompt_doc2md_defined(self):
|
|
457
|
+
"""Test that PROMPT_DOC2MD is defined."""
|
|
458
|
+
from infinity_parser2 import PROMPT_DOC2MD
|
|
459
|
+
self.assertIsInstance(PROMPT_DOC2MD, str)
|
|
460
|
+
self.assertIn("markdown", PROMPT_DOC2MD.lower())
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
if __name__ == "__main__":
|
|
464
|
+
unittest.main()
|