infinity-parser2 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_backends.py ADDED
@@ -0,0 +1,490 @@
1
+ """Unit tests for InfinityParser2 backend classes."""
2
+
3
+ import tempfile
4
+ import unittest
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ from PIL import Image
8
+
9
+ from infinity_parser2.backends import (
10
+ BaseBackend,
11
+ TransformersBackend,
12
+ VLLMEngineBackend,
13
+ VLLMServerBackend,
14
+ )
15
+
16
+
17
+ class TestBaseBackend(unittest.TestCase):
18
+ """Tests for BaseBackend abstract class."""
19
+
20
+ def test_base_backend_is_abstract(self):
21
+ """Test that BaseBackend cannot be instantiated directly."""
22
+ with self.assertRaises(TypeError):
23
+ BaseBackend()
24
+
25
+ def test_base_backend_has_abstract_methods(self):
26
+ """Test that BaseBackend has required abstract methods."""
27
+ self.assertTrue(hasattr(BaseBackend, 'init'))
28
+ self.assertTrue(hasattr(BaseBackend, 'parse_batch'))
29
+
30
+ def test_base_backend_subclass_interface(self):
31
+ """Test that subclasses implement required interface."""
32
+ class ConcreteBackend(BaseBackend):
33
+ def init(self):
34
+ pass
35
+
36
+ def parse_batch(self, input_data, prompt, batch_size=1, **kwargs):
37
+ return []
38
+
39
+ backend = ConcreteBackend(model_name="test/model", device="cuda")
40
+ self.assertEqual(backend.model_name, "test/model")
41
+ self.assertEqual(backend.device, "cuda")
42
+
43
+ def test_base_backend_init_parameters(self):
44
+ """Test BaseBackend initialization with parameters."""
45
+ class ConcreteBackend(BaseBackend):
46
+ def init(self):
47
+ pass
48
+
49
+ def parse_batch(self, input_data, prompt, batch_size=1, **kwargs):
50
+ return []
51
+
52
+ backend = ConcreteBackend(
53
+ model_name="custom/model",
54
+ device="cpu",
55
+ custom_arg="value"
56
+ )
57
+ self.assertEqual(backend.model_name, "custom/model")
58
+ self.assertEqual(backend.device, "cpu")
59
+ self.assertEqual(backend.kwargs.get("custom_arg"), "value")
60
+
61
+
62
+ class TestTransformersBackend(unittest.TestCase):
63
+ """Tests for TransformersBackend class."""
64
+
65
+ def test_transformers_backend_initialization_params(self):
66
+ """Test TransformersBackend initialization parameters."""
67
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
68
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
69
+ mock_model.from_pretrained.return_value = MagicMock()
70
+ mock_processor.from_pretrained.return_value = MagicMock()
71
+
72
+ backend = TransformersBackend(
73
+ model_name="test/model",
74
+ device="cuda",
75
+ torch_dtype="float16",
76
+ min_pixels=1024,
77
+ max_pixels=4096,
78
+ )
79
+
80
+ self.assertEqual(backend.model_name, "test/model")
81
+ self.assertEqual(backend.device, "cuda")
82
+ self.assertEqual(backend.min_pixels, 1024)
83
+ self.assertEqual(backend.max_pixels, 4096)
84
+
85
+ def test_transformers_backend_min_max_pixels_defaults(self):
86
+ """Test TransformersBackend default min_pixels and max_pixels."""
87
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
88
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
89
+ mock_model.from_pretrained.return_value = MagicMock()
90
+ mock_processor.from_pretrained.return_value = MagicMock()
91
+
92
+ backend = TransformersBackend(model_name="test/model")
93
+ self.assertEqual(backend.min_pixels, 2048)
94
+ self.assertEqual(backend.max_pixels, 16777216)
95
+
96
+ def test_transformers_backend_process_inputs(self):
97
+ """Test _process_inputs method."""
98
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
99
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
100
+ with patch("infinity_parser2.backends.transformers.process_vision_info") as mock_process_vision:
101
+ mock_model.from_pretrained.return_value = MagicMock()
102
+ mock_processor_instance = MagicMock()
103
+ mock_processor.from_pretrained.return_value = mock_processor_instance
104
+ mock_processor_instance.apply_chat_template.return_value = "processed"
105
+ mock_processor_instance.batch_decode.return_value = ["decoded"]
106
+ mock_processor_instance.return_value = {"input_ids": MagicMock()}
107
+ mock_process_vision.return_value = ([MagicMock()], None)
108
+
109
+ backend = TransformersBackend(model_name="test/model")
110
+
111
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
112
+ img = Image.new("RGB", (100, 100), color="red")
113
+ img.save(temp_file.name)
114
+ temp_file.close()
115
+
116
+ try:
117
+ result = backend._process_inputs(
118
+ [temp_file.name], "Test prompt"
119
+ )
120
+ self.assertIsInstance(result, dict)
121
+ self.assertIn("input_ids", result)
122
+ finally:
123
+ import os
124
+ os.unlink(temp_file.name)
125
+
126
+ def test_transformers_backend_process_multiple_inputs(self):
127
+ """Test processing multiple inputs."""
128
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
129
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
130
+ with patch("infinity_parser2.backends.transformers.process_vision_info") as mock_process_vision:
131
+ mock_model.from_pretrained.return_value = MagicMock()
132
+ mock_processor_instance = MagicMock()
133
+ mock_processor.from_pretrained.return_value = mock_processor_instance
134
+ mock_processor_instance.apply_chat_template.return_value = "processed"
135
+ mock_processor_instance.batch_decode.return_value = ["decoded"]
136
+ mock_processor_instance.return_value = {"input_ids": MagicMock()}
137
+ mock_process_vision.return_value = ([MagicMock()], None)
138
+
139
+ backend = TransformersBackend(model_name="test/model")
140
+
141
+ temp_files = []
142
+ for i in range(3):
143
+ f = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
144
+ img = Image.new("RGB", (100, 100), color="blue")
145
+ img.save(f.name)
146
+ temp_files.append(f.name)
147
+
148
+ try:
149
+ result = backend._process_inputs(
150
+ temp_files, "Test prompt"
151
+ )
152
+ self.assertIsInstance(result, dict)
153
+ self.assertIn("input_ids", result)
154
+ finally:
155
+ import os
156
+ for f in temp_files:
157
+ os.unlink(f)
158
+
159
+ def test_transformers_backend_generate_output_format(self):
160
+ """Test _generate method output format."""
161
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
162
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
163
+ mock_model_instance = MagicMock()
164
+ mock_model.from_pretrained.return_value = mock_model_instance
165
+ mock_model_instance.device = "cuda"
166
+
167
+ mock_processor_instance = MagicMock()
168
+ mock_processor.from_pretrained.return_value = mock_processor_instance
169
+ mock_processor_instance.apply_chat_template.return_value = "processed text"
170
+ mock_processor_instance.batch_decode.return_value = ["Generated text"]
171
+
172
+ backend = TransformersBackend(model_name="test/model")
173
+ backend._model = mock_model_instance
174
+
175
+ with patch.object(backend._processor, '__call__') as mock_call:
176
+ mock_input_ids = MagicMock()
177
+ mock_call.return_value = {
178
+ "input_ids": mock_input_ids
179
+ }
180
+
181
+ mock_output_ids = MagicMock()
182
+ mock_model_instance.generate.return_value = [mock_output_ids]
183
+ mock_output_ids.__getitem__ = MagicMock(return_value=[1, 2, 3])
184
+
185
+ results = backend._generate({"input_ids": mock_input_ids})
186
+ self.assertIsInstance(results, list)
187
+
188
+ def test_transformers_backend_parse_batch(self):
189
+ """Test parse_batch basic functionality."""
190
+ with patch("infinity_parser2.backends.transformers.AutoModelForImageTextToText") as mock_model:
191
+ with patch("infinity_parser2.backends.transformers.AutoProcessor") as mock_processor:
192
+ with patch("infinity_parser2.backends.transformers.process_vision_info") as mock_process_vision:
193
+ mock_model_instance = MagicMock()
194
+ mock_model.from_pretrained.return_value = mock_model_instance
195
+ mock_model_instance.device = "cuda"
196
+
197
+ mock_processor_instance = MagicMock()
198
+ mock_processor.from_pretrained.return_value = mock_processor_instance
199
+ mock_processor_instance.apply_chat_template.return_value = "processed"
200
+ mock_processor_instance.batch_decode.return_value = ["Result"]
201
+ mock_processor_instance.return_value = {"input_ids": MagicMock()}
202
+ mock_process_vision.return_value = ([MagicMock()], None)
203
+
204
+ backend = TransformersBackend(model_name="test/model")
205
+ backend._model = mock_model_instance
206
+
207
+ with patch.object(backend._processor, '__call__') as mock_call:
208
+ mock_input_ids = MagicMock()
209
+ mock_call.return_value = {
210
+ "input_ids": mock_input_ids
211
+ }
212
+
213
+ mock_output_ids = MagicMock()
214
+ mock_model_instance.generate.return_value = [mock_output_ids]
215
+ mock_output_ids.__getitem__ = MagicMock(return_value=[1, 2, 3])
216
+
217
+ results = backend.parse_batch(
218
+ [Image.new("RGB", (100, 100))],
219
+ "Test prompt"
220
+ )
221
+ self.assertIsInstance(results, list)
222
+ self.assertEqual(len(results), 1)
223
+
224
+
225
+ class TestVLLMEngineBackend(unittest.TestCase):
226
+ """Tests for VLLMEngineBackend class."""
227
+
228
+ def test_vllm_engine_backend_initialization(self):
229
+ """Test VLLMEngineBackend initialization parameters."""
230
+ with patch("infinity_parser2.backends.vllm_engine.LLM") as mock_llm:
231
+ mock_llm_instance = MagicMock()
232
+ mock_llm.return_value = mock_llm_instance
233
+
234
+ backend = VLLMEngineBackend(
235
+ model_name="test/model",
236
+ device="cuda",
237
+ tensor_parallel_size=2,
238
+ min_pixels=1024,
239
+ max_pixels=4096,
240
+ )
241
+
242
+ self.assertEqual(backend.model_name, "test/model")
243
+ self.assertEqual(backend.device, "cuda")
244
+ self.assertEqual(backend.tensor_parallel_size, 2)
245
+ self.assertEqual(backend.min_pixels, 1024)
246
+ self.assertEqual(backend.max_pixels, 4096)
247
+
248
+ def test_vllm_engine_backend_min_max_pixels_defaults(self):
249
+ """Test VLLMEngineBackend default min_pixels and max_pixels."""
250
+ with patch("infinity_parser2.backends.vllm_engine.LLM") as mock_llm:
251
+ mock_llm.return_value = MagicMock()
252
+
253
+ backend = VLLMEngineBackend(model_name="test/model")
254
+ self.assertEqual(backend.min_pixels, 2048)
255
+ self.assertEqual(backend.max_pixels, 16777216)
256
+
257
+ def test_vllm_engine_build_messages(self):
258
+ """Test _build_messages method."""
259
+ with patch("infinity_parser2.backends.vllm_engine.LLM") as mock_llm:
260
+ mock_llm.return_value = MagicMock()
261
+
262
+ backend = VLLMEngineBackend(model_name="test/model")
263
+ messages = backend._build_messages("base64data", "image/png", "Test prompt")
264
+
265
+ self.assertIsInstance(messages, list)
266
+ self.assertEqual(len(messages), 1)
267
+ self.assertEqual(messages[0]["role"], "user")
268
+ self.assertIsInstance(messages[0]["content"], list)
269
+ self.assertEqual(len(messages[0]["content"]), 2)
270
+
271
+ image_content = messages[0]["content"][0]
272
+ self.assertEqual(image_content["type"], "image_url")
273
+ self.assertIn("data:image/png;base64,base64data", image_content["image_url"]["url"])
274
+
275
+ text_content = messages[0]["content"][1]
276
+ self.assertEqual(text_content["type"], "text")
277
+ self.assertEqual(text_content["text"], "Test prompt")
278
+
279
+ def test_vllm_engine_parse_batch_returns_list(self):
280
+ """Test parse_batch returns a list."""
281
+ with patch("infinity_parser2.backends.vllm_engine.LLM") as mock_llm:
282
+ mock_llm_instance = MagicMock()
283
+ mock_llm.return_value = mock_llm_instance
284
+
285
+ mock_output = MagicMock()
286
+ mock_output.outputs = [MagicMock(text="Parsed result")]
287
+ mock_llm_instance.chat.return_value = [mock_output]
288
+
289
+ backend = VLLMEngineBackend(model_name="test/model")
290
+
291
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
292
+ img = Image.new("RGB", (100, 100), color="green")
293
+ img.save(temp_file.name)
294
+ temp_file.close()
295
+
296
+ try:
297
+ results = backend.parse_batch([temp_file.name], "Test prompt")
298
+ self.assertIsInstance(results, list)
299
+ self.assertEqual(len(results), 1)
300
+ finally:
301
+ import os
302
+ os.unlink(temp_file.name)
303
+
304
+ def test_vllm_engine_parse_batch(self):
305
+ """Test parse_batch basic functionality."""
306
+ with patch("infinity_parser2.backends.vllm_engine.LLM") as mock_llm:
307
+ mock_llm_instance = MagicMock()
308
+ mock_llm.return_value = mock_llm_instance
309
+
310
+ mock_output = MagicMock()
311
+ mock_output.outputs = [MagicMock(text="Parsed result")]
312
+ mock_llm_instance.chat.return_value = [mock_output]
313
+
314
+ backend = VLLMEngineBackend(model_name="test/model")
315
+
316
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
317
+ img = Image.new("RGB", (100, 100), color="green")
318
+ img.save(temp_file.name)
319
+ temp_file.close()
320
+
321
+ try:
322
+ results = backend.parse_batch(
323
+ [temp_file.name],
324
+ "Test prompt"
325
+ )
326
+ self.assertIsInstance(results, list)
327
+ self.assertEqual(len(results), 1)
328
+ mock_llm_instance.chat.assert_called()
329
+ finally:
330
+ import os
331
+ os.unlink(temp_file.name)
332
+
333
+
334
+ class TestVLLMServerBackend(unittest.TestCase):
335
+ """Tests for VLLMServerBackend class."""
336
+
337
+ def test_vllm_server_backend_initialization(self):
338
+ """Test VLLMServerBackend initialization parameters."""
339
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
340
+ mock_client_instance = MagicMock()
341
+ mock_openai.return_value = mock_client_instance
342
+
343
+ backend = VLLMServerBackend(
344
+ model_name="test/model",
345
+ api_url="http://localhost:8000/v1/chat/completions",
346
+ api_key="test-key",
347
+ timeout=60,
348
+ min_pixels=1024,
349
+ max_pixels=4096,
350
+ )
351
+
352
+ self.assertEqual(backend.model_name, "test/model")
353
+ self.assertEqual(backend.api_url, "http://localhost:8000/v1/chat/completions")
354
+ self.assertEqual(backend.api_key, "test-key")
355
+ self.assertEqual(backend.timeout, 60)
356
+ self.assertEqual(backend.min_pixels, 1024)
357
+ self.assertEqual(backend.max_pixels, 4096)
358
+ self.assertIsNotNone(backend.client)
359
+
360
+ def test_vllm_server_backend_min_max_pixels_defaults(self):
361
+ """Test VLLMServerBackend default min_pixels and max_pixels."""
362
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
363
+ mock_client_instance = MagicMock()
364
+ mock_openai.return_value = mock_client_instance
365
+
366
+ backend = VLLMServerBackend(api_url="http://localhost:8000/v1/chat/completions")
367
+ self.assertEqual(backend.min_pixels, 2048)
368
+ self.assertEqual(backend.max_pixels, 16777216)
369
+
370
+ def test_vllm_server_connection_check(self):
371
+ """Test server connection validation on init."""
372
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
373
+ mock_client_instance = MagicMock()
374
+ mock_openai.return_value = mock_client_instance
375
+
376
+ backend = VLLMServerBackend(api_url="http://localhost:8000/v1/chat/completions")
377
+ mock_client_instance.chat.completions.create.assert_called_once()
378
+
379
+ def test_vllm_server_connection_failure(self):
380
+ """Test RuntimeError on connection failure."""
381
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
382
+ mock_client_instance = MagicMock()
383
+ mock_client_instance.chat.completions.create.side_effect = Exception("Connection refused")
384
+ mock_openai.return_value = mock_client_instance
385
+
386
+ with self.assertRaises(RuntimeError) as context:
387
+ VLLMServerBackend(api_url="http://localhost:8000/v1/chat/completions")
388
+
389
+ self.assertIn("Cannot connect to vLLM server", str(context.exception))
390
+
391
+ def test_vllm_server_parse_batch_empty_input(self):
392
+ """Test parse_batch with empty input returns empty list."""
393
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
394
+ mock_client_instance = MagicMock()
395
+ mock_openai.return_value = mock_client_instance
396
+
397
+ backend = VLLMServerBackend(api_url="http://localhost:8000/v1/chat/completions")
398
+ results = backend.parse_batch([], "Test prompt")
399
+ self.assertEqual(results, [])
400
+
401
+ def test_vllm_server_parse_batch_success(self):
402
+ """Test successful parse_batch call."""
403
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
404
+ mock_client_instance = MagicMock()
405
+ mock_openai.return_value = mock_client_instance
406
+
407
+ mock_chat_response = MagicMock()
408
+ mock_message = MagicMock()
409
+ mock_message.content = "Parsed content"
410
+ mock_chat_response.choices = [MagicMock(message=mock_message)]
411
+ mock_client_instance.chat.completions.create.return_value = mock_chat_response
412
+
413
+ backend = VLLMServerBackend(api_url="http://localhost:8000/v1/chat/completions")
414
+
415
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
416
+ img = Image.new("RGB", (100, 100), color="yellow")
417
+ img.save(temp_file.name)
418
+ temp_file.close()
419
+
420
+ try:
421
+ results = backend.parse_batch([temp_file.name], "Test prompt")
422
+ self.assertIsInstance(results, list)
423
+ self.assertEqual(len(results), 1)
424
+ self.assertEqual(results[0], "Parsed content")
425
+ finally:
426
+ import os
427
+ os.unlink(temp_file.name)
428
+
429
+ def test_vllm_server_extra_body(self):
430
+ """Test that OpenAI client is called with correct parameters."""
431
+ with patch("infinity_parser2.backends.vllm_server.OpenAI") as mock_openai:
432
+ mock_client_instance = MagicMock()
433
+ mock_openai.return_value = mock_client_instance
434
+
435
+ mock_chat_response = MagicMock()
436
+ mock_message = MagicMock()
437
+ mock_message.content = "Result"
438
+ mock_chat_response.choices = [MagicMock(message=mock_message)]
439
+ mock_client_instance.chat.completions.create.return_value = mock_chat_response
440
+
441
+ backend = VLLMServerBackend(
442
+ api_url="http://localhost:8000/v1/chat/completions",
443
+ api_key="my-secret-key"
444
+ )
445
+
446
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
447
+ img = Image.new("RGB", (100, 100), color="cyan")
448
+ img.save(temp_file.name)
449
+ temp_file.close()
450
+
451
+ try:
452
+ backend.parse_batch([temp_file.name], "Test prompt")
453
+ call_kwargs = mock_client_instance.chat.completions.create.call_args[1]
454
+ self.assertEqual(call_kwargs["model"], "infly/Infinity-Parser2-Pro")
455
+ self.assertIn("messages", call_kwargs)
456
+ self.assertEqual(call_kwargs["max_tokens"], 32768)
457
+ self.assertEqual(call_kwargs["temperature"], 0.0)
458
+ self.assertEqual(call_kwargs["top_p"], 1.0)
459
+ self.assertEqual(
460
+ call_kwargs["extra_body"],
461
+ {"chat_template_kwargs": {"enable_thinking": False}}
462
+ )
463
+ finally:
464
+ import os
465
+ os.unlink(temp_file.name)
466
+
467
+
468
+ class TestBackendRegistry(unittest.TestCase):
469
+ """Tests for backend registry mapping."""
470
+
471
+ def test_backend_registry_keys(self):
472
+ """Test that BACKEND_REGISTRY contains expected keys."""
473
+ from infinity_parser2.backends import TransformersBackend, VLLMEngineBackend, VLLMServerBackend
474
+ from infinity_parser2.parser import BACKEND_REGISTRY
475
+
476
+ self.assertIn("transformers", BACKEND_REGISTRY)
477
+ self.assertIn("vllm-engine", BACKEND_REGISTRY)
478
+ self.assertIn("vllm-server", BACKEND_REGISTRY)
479
+
480
+ def test_backend_registry_values(self):
481
+ """Test that BACKEND_REGISTRY contains correct backend classes."""
482
+ from infinity_parser2.parser import BACKEND_REGISTRY
483
+
484
+ self.assertEqual(BACKEND_REGISTRY["transformers"], TransformersBackend)
485
+ self.assertEqual(BACKEND_REGISTRY["vllm-engine"], VLLMEngineBackend)
486
+ self.assertEqual(BACKEND_REGISTRY["vllm-server"], VLLMServerBackend)
487
+
488
+
489
+ if __name__ == "__main__":
490
+ unittest.main()