structai 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
structai/llm_api.py ADDED
@@ -0,0 +1,713 @@
1
+ from openai import OpenAI
2
+ from typing import Union
3
+ import Levenshtein
4
+ import time
5
+ from json_repair import repair_json
6
+ from PIL import Image
7
+ import io
8
+ import re
9
+ import base64
10
+ import os
11
+ import string
12
+ import ipaddress
13
+ import ast
14
+ import json
15
+ from urllib.parse import urlparse
16
+ from .utils import run_with_timeout
17
+ from .mp import multi_thread
18
+
19
+ _ALLOWED_CHARS = set(
20
+ string.ascii_letters
21
+ + string.digits
22
+ + " .,:;!?+-*/=<>|@#$%&()[]{}_'\""
23
+ + "\n\t"
24
+ )
25
+
26
+ _ESCAPED_CTRL_RE = re.compile(
27
+ r"""
28
+ \\(
29
+ [btnrfv] | # \b \t \n \r \f \v
30
+ x[0-9a-fA-F]{2} | # \x08 \x1b
31
+ u[0-9a-fA-F]{4} | # \u0008
32
+ U[0-9a-fA-F]{8} | # \U00000008
33
+ x1b\[[0-9;]*[A-Za-z] # ANSI escaped
34
+ )
35
+ """,
36
+ re.VERBOSE,
37
+ )
38
+
39
+ def sanitize_text(text: str) -> str:
40
+ """
41
+ Sanitize subprocess / tqdm / CLI output for:
42
+ - JSON serialization
43
+ - LLM input
44
+ - Human-readable logs
45
+ """
46
+ if not text:
47
+ return text
48
+
49
+ text = _ESCAPED_CTRL_RE.sub("", text)
50
+
51
+ return "".join(ch for ch in text if ch in _ALLOWED_CHARS)
52
+
53
+ def filter_excessive_repeats(text, threshold=5):
54
+ """
55
+ Identifies sequences where a single character or a two-character substring repeats
56
+ at least the specified threshold times and removes them entirely from the string.
57
+
58
+ Args:
59
+ text (str): The input string to be processed.
60
+ threshold (int): The minimum number of consecutive repetitions to trigger removal.
61
+
62
+ Returns:
63
+ str: The processed string with excessive repetitions removed.
64
+ """
65
+ pattern1 = r'(.)\1{' + str(threshold - 1) + r',}'
66
+ text = re.sub(pattern1, '', text)
67
+
68
+ pattern2 = r'(.{2})\1{' + str(threshold - 1) + r',}'
69
+ text = re.sub(pattern2, '', text)
70
+ return text
71
+
72
+
73
+ def str2dict(s: str) -> dict:
74
+ start_index = s.find('{')
75
+ if start_index != -1:
76
+ end_index = s.rfind('}') + 1
77
+ s = s[start_index:end_index]
78
+ try:
79
+ d = ast.literal_eval(s)
80
+ except:
81
+ try:
82
+ d = json.loads(repair_json(s))
83
+ except:
84
+ d = json.loads(repair_json(sanitize_text(s)))
85
+ return d
86
+
87
+
88
+ def str2list(s: str) -> list:
89
+ start_index = s.find('[')
90
+ if start_index != -1:
91
+ end_index = s.rfind(']') + 1
92
+ s = s[start_index:end_index]
93
+ try:
94
+ l = ast.literal_eval(s)
95
+ except:
96
+ try:
97
+ l = json.loads(repair_json(s))
98
+ except:
99
+ l = json.loads(repair_json(sanitize_text(s)))
100
+ return l
101
+
102
+
103
+ def read_image(image_path: str) -> Image.Image:
104
+ return Image.open(image_path)
105
+
106
+
107
+ def encode_image(image_obj: Image.Image) -> str:
108
+ buffered = io.BytesIO()
109
+ image_obj.save(buffered, format="PNG")
110
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
111
+
112
+
113
+ def add_no_proxy_if_private(url: str):
114
+ if not url:
115
+ return
116
+ parsed = urlparse(url)
117
+ host = parsed.hostname
118
+ if not host:
119
+ return
120
+
121
+ # Only handle IP
122
+ try:
123
+ ip = ipaddress.ip_address(host)
124
+ except ValueError:
125
+ return
126
+
127
+ # Not effective for public IP
128
+ if ip.is_global:
129
+ return
130
+
131
+ for key in ("no_proxy", "NO_PROXY"):
132
+ old = os.environ.get(key, "")
133
+ entries = [x.strip() for x in old.split(",") if x.strip()]
134
+
135
+ if host not in entries:
136
+ entries.append(host)
137
+ os.environ[key] = ",".join(entries)
138
+ print(f"[no_proxy] added {host} to {key}")
139
+
140
+
141
+ def messages_to_responses_input(messages):
142
+ """
143
+ Convert Chat Completions messages format to Responses API input format.
144
+
145
+ Args:
146
+ messages (list): List of message dictionaries with 'role' and 'content'.
147
+
148
+ Returns:
149
+ tuple: (system_prompt_content, input_blocks)
150
+ - system_prompt_content (str or None): The system prompt content.
151
+ - input_blocks (list): List of input blocks for Responses API.
152
+ """
153
+ system_prompt_content = None
154
+ input_blocks = []
155
+
156
+ for msg in messages:
157
+ role = msg["role"]
158
+ content = msg["content"]
159
+
160
+ if role == "system":
161
+ # Responses API uses top-level system parameter
162
+ # If there are multiple system messages, concatenate them
163
+ if system_prompt_content is None:
164
+ system_prompt_content = content
165
+ else:
166
+ system_prompt_content += "\n" + content
167
+ else:
168
+ # Handle content which can be str or list (multimodal)
169
+ # Determine text type based on role: user -> input_text, assistant -> output_text
170
+ text_type = "input_text" if role == "user" else "output_text"
171
+
172
+ if isinstance(content, str):
173
+ input_blocks.append({
174
+ "role": role,
175
+ "content": [
176
+ {"type": text_type, "text": content}
177
+ ]
178
+ })
179
+ elif isinstance(content, list):
180
+ # Convert Chat Completion content list to Responses API content list
181
+ new_content_list = []
182
+ for item in content:
183
+ if item["type"] == "text":
184
+ new_content_list.append({"type": text_type, "text": item["text"]})
185
+ elif item["type"] == "image_url":
186
+ # Images are typically inputs
187
+ new_content_list.append({"type": "input_image", "image_url": item["image_url"]["url"]})
188
+ # Add other types if necessary
189
+
190
+ input_blocks.append({
191
+ "role": role,
192
+ "content": new_content_list
193
+ })
194
+
195
+ return system_prompt_content, input_blocks
196
+
197
+
198
+ def extract_text_outputs(result) -> list[str]:
199
+ """
200
+ Unified extractor for:
201
+ - Chat Completions API
202
+ - Responses API
203
+
204
+ Always returns: List[str]
205
+ """
206
+
207
+ # ---------- Chat Completions ----------
208
+ # response.choices[i].message.content
209
+ if hasattr(result, "choices"):
210
+ outputs = []
211
+ for choice in result.choices:
212
+ msg = getattr(choice, "message", None)
213
+ if msg and msg.content:
214
+ outputs.append(msg.content)
215
+ return outputs
216
+
217
+ # ---------- Responses API ----------
218
+ # response.output_text (recommended shortcut)
219
+ if hasattr(result, "output_text") and result.output_text:
220
+ return [result.output_text]
221
+
222
+ # ---------- Responses API (manual fallback) ----------
223
+ # Traverse response.output blocks
224
+ outputs = []
225
+ if hasattr(result, "output"):
226
+ current = []
227
+
228
+ for item in result.output:
229
+ # item might be an object or dict depending on SDK version
230
+ # Assuming object access based on previous code, but let's be safe with getattr/get
231
+ item_type = getattr(item, "type", None)
232
+ if not item_type and isinstance(item, dict):
233
+ item_type = item.get("type")
234
+
235
+ if item_type != "message":
236
+ continue
237
+
238
+ content = getattr(item, "content", [])
239
+ if not content and isinstance(item, dict):
240
+ content = item.get("content", [])
241
+
242
+ for block in content:
243
+ block_type = getattr(block, "type", None)
244
+ if not block_type and isinstance(block, dict):
245
+ block_type = block.get("type")
246
+
247
+ if block_type == "output_text":
248
+ text = getattr(block, "text", "")
249
+ if not text and isinstance(block, dict):
250
+ text = block.get("text", "")
251
+ current.append(text)
252
+
253
+ if current:
254
+ outputs.append("".join(current))
255
+ current = []
256
+
257
+ return outputs
258
+
259
+
260
+ class LLMAgent:
261
+ def __init__(self,
262
+ api_key = None,
263
+ api_base = None,
264
+ model_version = 'gpt-4.1-mini',
265
+ system_prompt = 'You are a helpful assistant.',
266
+ max_tokens = None,
267
+ temperature = 0,
268
+ http_client = None,
269
+ headers = None,
270
+ time_limit = 5*60,
271
+ max_try = 1,
272
+ use_responses_api = False
273
+ ):
274
+
275
+ # Load from environment if not provided
276
+ if api_key is None:
277
+ api_key = os.environ.get("LLM_API_KEY")
278
+ if api_base is None:
279
+ api_base = os.environ.get("LLM_BASE_URL")
280
+
281
+ add_no_proxy_if_private(api_base)
282
+ self.api_key = api_key
283
+ self.api_base = api_base
284
+ self.model_version = model_version
285
+ self.system_prompt = system_prompt
286
+ self.max_tokens = max_tokens
287
+ self.temperature = temperature
288
+ self.time_limit = time_limit
289
+ self.max_try = max_try
290
+ self.use_responses_api = use_responses_api
291
+
292
+ self.client = OpenAI(api_key=self.api_key, base_url=self.api_base, http_client=http_client, default_headers=headers)
293
+
294
+
295
+ def _llm_api_impl(self, query, system_prompt=None, **kwargs):
296
+ image_paths = kwargs.get('image_paths', None)
297
+ max_tokens = kwargs.get('max_tokens', self.max_tokens)
298
+ temperature = kwargs.get('temperature', self.temperature)
299
+ history = kwargs.get('history', None)
300
+ n = kwargs.get('n', 1)
301
+ if system_prompt is None:
302
+ system_prompt = self.system_prompt
303
+
304
+ if image_paths is None: # without image
305
+ content = query
306
+ else: # with image
307
+ content = [
308
+ {"type": "text", "text": query}
309
+ ]
310
+
311
+ for image_path in image_paths:
312
+ try:
313
+ img = read_image(image_path)
314
+ ima_str = encode_image(img)
315
+ except:
316
+ continue
317
+
318
+ content.append({
319
+ "type": "image_url",
320
+ "image_url": {
321
+ "url": f"data:image/jpeg;base64,{ima_str}",
322
+ }
323
+ })
324
+
325
+ if history is None:
326
+ messages=[
327
+ {"role": "system", "content": system_prompt},
328
+ {"role": "user", "content": content}
329
+ ]
330
+ else:
331
+ messages=[{"role": "system", "content": system_prompt}]+\
332
+ history+\
333
+ [{"role": "user", "content": content}]
334
+
335
+ use_responses_api = kwargs.get('use_responses_api', self.use_responses_api)
336
+ if use_responses_api:
337
+ system_prompt_content, input_blocks = messages_to_responses_input(messages)
338
+
339
+ # Prepare arguments for responses.create
340
+ create_kwargs = {
341
+ "model": self.model_version,
342
+ "input": input_blocks,
343
+ "max_output_tokens": max_tokens,
344
+ "temperature": temperature,
345
+ }
346
+
347
+ if system_prompt_content:
348
+ create_kwargs["instructions"] = system_prompt_content
349
+
350
+ # Handle n > 1 manually for Responses API
351
+ if n > 1:
352
+ # Use multi_thread for parallel execution
353
+ inp_list = [create_kwargs] * n
354
+ responses = multi_thread(inp_list, self.client.responses.create, max_workers=min(n, 20), use_tqdm=False)
355
+
356
+ assistant_responses = []
357
+ for response in responses:
358
+ if response:
359
+ assistant_responses.extend(extract_text_outputs(response))
360
+ else:
361
+ response = self.client.responses.create(**create_kwargs)
362
+ assistant_responses = extract_text_outputs(response)
363
+ else:
364
+ response = self.client.chat.completions.create(
365
+ model=self.model_version,
366
+ messages=messages,
367
+ max_tokens=max_tokens,
368
+ temperature=temperature,
369
+ n=n,
370
+ )
371
+ assistant_responses = extract_text_outputs(response)
372
+
373
+ return assistant_responses
374
+
375
+ def llm_api(self, query, system_prompt=None, **kwargs):
376
+ return run_with_timeout(
377
+ self._llm_api_impl,
378
+ args=(query, system_prompt),
379
+ kwargs=kwargs,
380
+ timeout=self.time_limit
381
+ )
382
+
383
+ def safe_api(self, query, system_prompt=None, return_example: Union[list, dict, str]=None, max_try=None, wait_time=0.0, **kwargs):
384
+ if max_try is None:
385
+ max_try = self.max_try
386
+
387
+ if return_example is not None:
388
+ assert isinstance(return_example, Union[list, dict, str]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] return_example should be list, dict or str: {type(return_example)}"
389
+
390
+ n = kwargs.get('n', 1)
391
+ response_list = []
392
+ for try_idx in range(max_try):
393
+ try:
394
+ responses = self.llm_api(query, system_prompt, **kwargs)
395
+
396
+ # str
397
+ if return_example is None or isinstance(return_example, str):
398
+ response_list = response_list + responses
399
+
400
+ # list
401
+ elif isinstance(return_example, list):
402
+ for response in responses:
403
+ result_list = str2list(response)
404
+
405
+ list_len = kwargs.get('list_len', None)
406
+ if list_len is not None:
407
+ assert len(result_list) == list_len, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the required length: {len(result_list)} != {list_len}\nResponse: {result_list}"
408
+
409
+ # type check
410
+ if len(return_example) > 0:
411
+ for result_item in result_list:
412
+ if isinstance(return_example[0], Union[float, int]):
413
+ assert isinstance(result_item, Union[float, int]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example list type: {type(result_item)} != {type(return_example[0])}\nItem: {result_item}"
414
+ else:
415
+ assert type(result_item) == type(return_example[0]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example list type: {type(result_item)} != {type(return_example[0])}\nItem: {result_item}"
416
+
417
+ # range check
418
+ list_min = kwargs.get('list_min', None)
419
+ list_max = kwargs.get('list_max', None)
420
+ if list_min is not None or list_max is not None:
421
+ for result_item in result_list:
422
+ if list_min is not None:
423
+ assert result_item >= list_min, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response {result_item} < list_min {list_min}"
424
+ if list_max is not None:
425
+ assert result_item <= list_max, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response {result_item} > list_max {list_max}"
426
+
427
+ response_list.append(result_list)
428
+
429
+ # dict
430
+ elif isinstance(return_example, dict):
431
+ for response in responses:
432
+ result_dict = str2dict(response)
433
+
434
+ if kwargs.get('check_keys', True):
435
+ result_dict_correct = {}
436
+ for k in return_example.keys():
437
+ if k in result_dict:
438
+ result_dict_correct[k] = result_dict[k]
439
+ else:
440
+ for out_k in result_dict.keys():
441
+ if len(k) > 5 and Levenshtein.distance(out_k.lower(), k.lower()) <= 2:
442
+ result_dict_correct[k] = result_dict[out_k]
443
+ break
444
+
445
+ assert k in result_dict_correct, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example dict: missing key {k}\nResponse: {result_dict}\n"
446
+ else:
447
+ result_dict_correct = result_dict
448
+
449
+ response_list.append(result_dict_correct)
450
+
451
+ if len(response_list) >= n:
452
+ response_list = response_list[:n]
453
+ break
454
+
455
+ except Exception as e:
456
+ print(f'[===ERROR===][safe_api][{e}]')
457
+ if try_idx < max_try - 1:
458
+ time.sleep(wait_time)
459
+
460
+ if len(response_list) == 0:
461
+ return None
462
+
463
+ if n == 1:
464
+ return response_list[0]
465
+ else:
466
+ return response_list
467
+
468
+
469
+ def __call__(self, query, *args, **kwargs):
470
+ return self.safe_api(query, *args, **kwargs)
471
+
472
+
473
+ def __str__(self) -> str:
474
+ return self.model_version.replace("/", "_")
475
+
476
+
477
+ if __name__ == '__main__':
478
+ # python -m structai.llm_api
479
+ print("Testing llm_api.py...")
480
+
481
+ # Test sanitize_text
482
+ print("Testing sanitize_text...")
483
+ assert sanitize_text("Hello World!") == "Hello World!", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
484
+ assert sanitize_text("Hello\nWorld") == "Hello\nWorld", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
485
+ assert sanitize_text("Hello 🌍") == "Hello ", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
486
+ print("sanitize_text passed")
487
+
488
+ # Test str2dict
489
+ print("Testing str2dict...")
490
+ assert str2dict('{"a": 1}') == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
491
+ assert str2dict(' {"a": 1} ') == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
492
+ assert str2dict("some text {'a': 1} more text") == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
493
+ assert str2dict("{'a': 1,}") == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
494
+ print("str2dict passed")
495
+
496
+ # Test str2list
497
+ print("Testing str2list...")
498
+ assert str2list('[1, 2, 3]') == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
499
+ assert str2list(' [1, 2, 3] ') == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
500
+ assert str2list("text [1, 2] text") == [1, 2], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
501
+ print("str2list passed")
502
+
503
+ # Test LLMAgent
504
+ print("\nTesting LLMAgent...")
505
+ if os.environ.get("LLM_API_KEY"):
506
+
507
+ def run_test(name, func, **kwargs):
508
+ print(f"\n[{name}]")
509
+ # Test with use_responses_api=False
510
+ print(" - use_responses_api=False:")
511
+ try:
512
+ func(use_responses_api=False, **kwargs)
513
+ except Exception as e:
514
+ print(f" Error: {e}")
515
+
516
+ # Test with use_responses_api=True
517
+ print(" - use_responses_api=True:")
518
+ try:
519
+ func(use_responses_api=True, **kwargs)
520
+ except Exception as e:
521
+ print(f" Error: {e}")
522
+
523
+ # 1. Test gpt-4.1-mini
524
+ def test_gpt_4_1_mini(use_responses_api):
525
+ agent = LLMAgent(model_version='gpt-4.1-mini', max_try=1, use_responses_api=use_responses_api)
526
+ res = agent("Say 'hello'", max_tokens=20)
527
+ print(f" Result: {res}")
528
+ run_test("Test 1: gpt-4.1-mini", test_gpt_4_1_mini)
529
+
530
+ # 2. Test gpt-5.2-pro
531
+ def test_gpt_5_2_pro(use_responses_api):
532
+ agent = LLMAgent(model_version='gpt-5.2-pro', max_try=1, use_responses_api=use_responses_api)
533
+ res = agent("Say 'hello'", max_tokens=20)
534
+ print(f" Result: {res}")
535
+ run_test("Test 2: gpt-5.2-pro", test_gpt_5_2_pro)
536
+
537
+ # 3. Test time_limit=1, max_try=3
538
+ def test_retry(use_responses_api):
539
+ start_time = time.time()
540
+ agent = LLMAgent(time_limit=1, max_try=3, use_responses_api=use_responses_api)
541
+ res = agent("Write a 500 word essay about AI.", max_tokens=500)
542
+ print(f" Result: {res}")
543
+ print(f" Time taken: {time.time() - start_time:.2f}s")
544
+ run_test("Test 3: time_limit=1, max_try=3 (expect retries)", test_retry)
545
+
546
+ # 4. Test image_paths
547
+ def test_image(use_responses_api):
548
+ # Create dummy image
549
+ img = Image.new('RGB', (100, 100), color='red')
550
+ img_path = "test_image.png"
551
+ img.save(img_path)
552
+
553
+ try:
554
+ agent = LLMAgent(model_version='gpt-4o', max_try=1, use_responses_api=use_responses_api) # Assuming gpt-4o for vision
555
+ res = agent("What color is this image?", image_paths=[img_path], max_tokens=20)
556
+ print(f" Result: {res}")
557
+ finally:
558
+ if os.path.exists(img_path):
559
+ os.remove(img_path)
560
+ run_test("Test 4: image_paths", test_image)
561
+
562
+ # 5. Test history
563
+ def test_history(use_responses_api):
564
+ agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
565
+ history = [
566
+ {"role": "user", "content": "My name is Bob."},
567
+ {"role": "assistant", "content": "Hello Bob."}
568
+ ]
569
+ res = agent("What is my name?", history=history, max_tokens=20)
570
+ print(f" Result: {res}")
571
+ run_test("Test 5: history", test_history)
572
+
573
+ # 6. Test n=3
574
+ def test_n_3(use_responses_api):
575
+ agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
576
+ res = agent("Generate a random number.", n=3, max_tokens=20)
577
+ print(f" Result (len={len(res) if isinstance(res, list) else 'N/A'}): {res}")
578
+ run_test("Test 6: n=3", test_n_3)
579
+
580
+ # 7. Test return_example (list and dict)
581
+ def test_return_example(use_responses_api):
582
+ agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
583
+ # List
584
+ print(" - List:")
585
+ res = agent("Return the list [1, 2, 3].", return_example=[1])
586
+ print(f" Result: {res}")
587
+ # Dict
588
+ print(" - Dict:")
589
+ res = agent("Return JSON {'a': 1}.", return_example={'a': 0})
590
+ print(f" Result: {res}")
591
+ run_test("Test 7: return_example", test_return_example)
592
+
593
+ # 8. Test list_min
594
+ def test_list_min(use_responses_api):
595
+ agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
596
+ res = agent("Return the list [10, 11, 12].", return_example=[1], list_min=5, list_max=20)
597
+ print(f" Result: {res}")
598
+
599
+ print(" - Testing failure case (list_min=20)...")
600
+ res = agent("Return the list [10, 11, 12].", return_example=[1], list_min=20, max_try=2)
601
+ print(f" Result (should be None): {res}")
602
+ run_test("Test 8: list_min", test_list_min)
603
+
604
+ else:
605
+ print("Skipping LLMAgent test: LLM_API_KEY not set")
606
+
607
+ # 9. Test Validation Logic with Mock
608
+ print("\nTesting Validation Logic with MockLLMAgent...")
609
+
610
+ class MockLLMAgent(LLMAgent):
611
+ def __init__(self, responses_map, **kwargs):
612
+ # responses_map: query -> response string or list of strings
613
+ super().__init__(api_key="dummy", **kwargs)
614
+ self.responses_map = responses_map
615
+
616
+ def llm_api(self, query, system_prompt=None, **kwargs):
617
+ # Simple mock: return predefined response based on query
618
+ res = self.responses_map.get(query, "{}")
619
+ if isinstance(res, str):
620
+ return [res]
621
+ return res
622
+
623
+ # Define test cases
624
+ mock_responses = {
625
+ "dict_exact": '{"long_name": "Alice", "years_old": 25}',
626
+ "dict_fuzzy": '{"long_nmae": "Alice", "years_oldd": 25}', # long_nmae->long_name (dist 1), years_oldd->years_old (dist 1)
627
+ "dict_bad": '{"wrong_name": "Alice", "wrong_age": 25}',
628
+ "dict_missing": '{"long_name": "Alice"}',
629
+ "list_int": '[10, 20]',
630
+ "list_str": '["10", "20"]',
631
+ "list_len_3": '[1, 2, 3]',
632
+ "list_len_2": '[1, 2]',
633
+ "list_range_ok": '[1, 5, 9]',
634
+ "list_range_bad_min": '[-1, 5]',
635
+ "list_range_bad_max": '[1, 11]',
636
+ }
637
+
638
+ # Use max_try=1 so that failures return None immediately
639
+ agent = MockLLMAgent(mock_responses, max_try=1)
640
+
641
+ # 9.1 Dict Key Tests
642
+ print(" - Dict Key Tests:")
643
+ # Use keys with len > 5 to trigger Levenshtein check
644
+ example_dict = {'long_name': 'John', 'years_old': 30}
645
+
646
+ # Exact match
647
+ res = agent("dict_exact", return_example=example_dict)
648
+ assert res == {"long_name": "Alice", "years_old": 25}, f"[===ERROR===][structai][llm_api.py][main] Dict exact match failed: {res}"
649
+ print(" [Pass] Exact match")
650
+
651
+ # Fuzzy match (Levenshtein)
652
+ res = agent("dict_fuzzy", return_example=example_dict)
653
+ # Should be corrected to match keys
654
+ assert res == {"long_name": "Alice", "years_old": 25}, f"[===ERROR===][structai][llm_api.py][main] Dict fuzzy match failed: {res}"
655
+ print(" [Pass] Fuzzy match (Levenshtein)")
656
+
657
+ # Bad keys
658
+ res = agent("dict_bad", return_example=example_dict)
659
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] Dict bad keys should fail: {res}"
660
+ print(" [Pass] Bad keys")
661
+
662
+ # Missing keys
663
+ res = agent("dict_missing", return_example=example_dict)
664
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] Dict missing keys should fail: {res}"
665
+ print(" [Pass] Missing keys")
666
+
667
+ # 9.2 List Type Tests
668
+ print(" - List Type Tests:")
669
+ example_list_int = [1]
670
+
671
+ # Correct type
672
+ res = agent("list_int", return_example=example_list_int)
673
+ assert res == [10, 20], f"[===ERROR===][structai][llm_api.py][main] List type ok failed: {res}"
674
+ print(" [Pass] Correct type")
675
+
676
+ # Incorrect type
677
+ res = agent("list_str", return_example=example_list_int)
678
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] List type bad should fail: {res}"
679
+ print(" [Pass] Incorrect type")
680
+
681
+ # 9.3 List Length Tests
682
+ print(" - List Length Tests:")
683
+
684
+ # Correct length
685
+ res = agent("list_len_3", return_example=[1], list_len=3)
686
+ assert res == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] List len ok failed: {res}"
687
+ print(" [Pass] Correct length")
688
+
689
+ # Incorrect length
690
+ res = agent("list_len_2", return_example=[1], list_len=3)
691
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] List len bad should fail: {res}"
692
+ print(" [Pass] Incorrect length")
693
+
694
+ # 9.4 List Range Tests
695
+ print(" - List Range Tests:")
696
+
697
+ # In range
698
+ res = agent("list_range_ok", return_example=[1], list_min=0, list_max=10)
699
+ assert res == [1, 5, 9], f"[===ERROR===][structai][llm_api.py][main] List range ok failed: {res}"
700
+ print(" [Pass] In range")
701
+
702
+ # Bad min
703
+ res = agent("list_range_bad_min", return_example=[1], list_min=0)
704
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] List range bad min should fail: {res}"
705
+ print(" [Pass] Bad min")
706
+
707
+ # Bad max
708
+ res = agent("list_range_bad_max", return_example=[1], list_max=10)
709
+ assert res is None, f"[===ERROR===][structai][llm_api.py][main] List range bad max should fail: {res}"
710
+ print(" [Pass] Bad max")
711
+
712
+ print("llm_api.py tests completed.")
713
+ print("--------------------------------------------")