structai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- structai/__init__.py +30 -0
- structai/io.py +227 -0
- structai/llm_api.py +682 -0
- structai/mp.py +96 -0
- structai/openai_server.py +79 -0
- structai/utils.py +209 -0
- structai-0.1.0.dist-info/METADATA +365 -0
- structai-0.1.0.dist-info/RECORD +11 -0
- structai-0.1.0.dist-info/WHEEL +5 -0
- structai-0.1.0.dist-info/licenses/LICENSE +21 -0
- structai-0.1.0.dist-info/top_level.txt +1 -0
structai/llm_api.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
from openai import OpenAI
|
|
2
|
+
from typing import Union
|
|
3
|
+
import Levenshtein
|
|
4
|
+
import time
|
|
5
|
+
from json_repair import repair_json
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import io
|
|
8
|
+
import base64
|
|
9
|
+
import os
|
|
10
|
+
import string
|
|
11
|
+
import ipaddress
|
|
12
|
+
import ast
|
|
13
|
+
import json
|
|
14
|
+
from urllib.parse import urlparse
|
|
15
|
+
from .utils import run_with_timeout
|
|
16
|
+
from .mp import multi_thread
|
|
17
|
+
|
|
18
|
+
_ALLOWED_CHARS = set(
|
|
19
|
+
string.ascii_letters
|
|
20
|
+
+ string.digits
|
|
21
|
+
+ " .,:;!?+-*/=<>|@#$%&()[]{}_'\""
|
|
22
|
+
+ "\n\t"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def sanitize_text(text: str) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Sanitize subprocess / tqdm / CLI output for:
|
|
28
|
+
- JSON serialization
|
|
29
|
+
- LLM input
|
|
30
|
+
- Human-readable logs
|
|
31
|
+
|
|
32
|
+
Strategy:
|
|
33
|
+
- Whitelist ASCII English characters only
|
|
34
|
+
- Remove all control chars, ANSI effects, unicode noise
|
|
35
|
+
"""
|
|
36
|
+
if not text:
|
|
37
|
+
return text
|
|
38
|
+
|
|
39
|
+
return "".join(ch for ch in text if ch in _ALLOWED_CHARS)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def str2dict(s: str) -> dict:
|
|
43
|
+
start_index = s.find('{')
|
|
44
|
+
if start_index != -1:
|
|
45
|
+
end_index = s.rfind('}') + 1
|
|
46
|
+
s = s[start_index:end_index]
|
|
47
|
+
try:
|
|
48
|
+
d = ast.literal_eval(s)
|
|
49
|
+
except:
|
|
50
|
+
try:
|
|
51
|
+
d = json.loads(repair_json(s))
|
|
52
|
+
except:
|
|
53
|
+
d = json.loads(repair_json(sanitize_text(s)))
|
|
54
|
+
return d
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def str2list(s: str) -> list:
|
|
58
|
+
start_index = s.find('[')
|
|
59
|
+
if start_index != -1:
|
|
60
|
+
end_index = s.rfind(']') + 1
|
|
61
|
+
s = s[start_index:end_index]
|
|
62
|
+
try:
|
|
63
|
+
l = ast.literal_eval(s)
|
|
64
|
+
except:
|
|
65
|
+
try:
|
|
66
|
+
l = json.loads(repair_json(s))
|
|
67
|
+
except:
|
|
68
|
+
l = json.loads(repair_json(sanitize_text(s)))
|
|
69
|
+
return l
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def read_image(image_path: str) -> Image.Image:
|
|
73
|
+
return Image.open(image_path)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def encode_image(image_obj: Image.Image) -> str:
|
|
77
|
+
buffered = io.BytesIO()
|
|
78
|
+
image_obj.save(buffered, format="PNG")
|
|
79
|
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def add_no_proxy_if_private(url: str):
|
|
83
|
+
if not url:
|
|
84
|
+
return
|
|
85
|
+
parsed = urlparse(url)
|
|
86
|
+
host = parsed.hostname
|
|
87
|
+
if not host:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# Only handle IP
|
|
91
|
+
try:
|
|
92
|
+
ip = ipaddress.ip_address(host)
|
|
93
|
+
except ValueError:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
# Not effective for public IP
|
|
97
|
+
if ip.is_global:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
for key in ("no_proxy", "NO_PROXY"):
|
|
101
|
+
old = os.environ.get(key, "")
|
|
102
|
+
entries = [x.strip() for x in old.split(",") if x.strip()]
|
|
103
|
+
|
|
104
|
+
if host not in entries:
|
|
105
|
+
entries.append(host)
|
|
106
|
+
os.environ[key] = ",".join(entries)
|
|
107
|
+
print(f"[no_proxy] added {host} to {key}")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def messages_to_responses_input(messages):
|
|
111
|
+
"""
|
|
112
|
+
Convert Chat Completions messages format to Responses API input format.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
messages (list): List of message dictionaries with 'role' and 'content'.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
tuple: (system_prompt_content, input_blocks)
|
|
119
|
+
- system_prompt_content (str or None): The system prompt content.
|
|
120
|
+
- input_blocks (list): List of input blocks for Responses API.
|
|
121
|
+
"""
|
|
122
|
+
system_prompt_content = None
|
|
123
|
+
input_blocks = []
|
|
124
|
+
|
|
125
|
+
for msg in messages:
|
|
126
|
+
role = msg["role"]
|
|
127
|
+
content = msg["content"]
|
|
128
|
+
|
|
129
|
+
if role == "system":
|
|
130
|
+
# Responses API uses top-level system parameter
|
|
131
|
+
# If there are multiple system messages, concatenate them
|
|
132
|
+
if system_prompt_content is None:
|
|
133
|
+
system_prompt_content = content
|
|
134
|
+
else:
|
|
135
|
+
system_prompt_content += "\n" + content
|
|
136
|
+
else:
|
|
137
|
+
# Handle content which can be str or list (multimodal)
|
|
138
|
+
# Determine text type based on role: user -> input_text, assistant -> output_text
|
|
139
|
+
text_type = "input_text" if role == "user" else "output_text"
|
|
140
|
+
|
|
141
|
+
if isinstance(content, str):
|
|
142
|
+
input_blocks.append({
|
|
143
|
+
"role": role,
|
|
144
|
+
"content": [
|
|
145
|
+
{"type": text_type, "text": content}
|
|
146
|
+
]
|
|
147
|
+
})
|
|
148
|
+
elif isinstance(content, list):
|
|
149
|
+
# Convert Chat Completion content list to Responses API content list
|
|
150
|
+
new_content_list = []
|
|
151
|
+
for item in content:
|
|
152
|
+
if item["type"] == "text":
|
|
153
|
+
new_content_list.append({"type": text_type, "text": item["text"]})
|
|
154
|
+
elif item["type"] == "image_url":
|
|
155
|
+
# Images are typically inputs
|
|
156
|
+
new_content_list.append({"type": "input_image", "image_url": item["image_url"]["url"]})
|
|
157
|
+
# Add other types if necessary
|
|
158
|
+
|
|
159
|
+
input_blocks.append({
|
|
160
|
+
"role": role,
|
|
161
|
+
"content": new_content_list
|
|
162
|
+
})
|
|
163
|
+
|
|
164
|
+
return system_prompt_content, input_blocks
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def extract_text_outputs(result) -> list[str]:
|
|
168
|
+
"""
|
|
169
|
+
Unified extractor for:
|
|
170
|
+
- Chat Completions API
|
|
171
|
+
- Responses API
|
|
172
|
+
|
|
173
|
+
Always returns: List[str]
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
# ---------- Chat Completions ----------
|
|
177
|
+
# response.choices[i].message.content
|
|
178
|
+
if hasattr(result, "choices"):
|
|
179
|
+
outputs = []
|
|
180
|
+
for choice in result.choices:
|
|
181
|
+
msg = getattr(choice, "message", None)
|
|
182
|
+
if msg and msg.content:
|
|
183
|
+
outputs.append(msg.content)
|
|
184
|
+
return outputs
|
|
185
|
+
|
|
186
|
+
# ---------- Responses API ----------
|
|
187
|
+
# response.output_text (recommended shortcut)
|
|
188
|
+
if hasattr(result, "output_text") and result.output_text:
|
|
189
|
+
return [result.output_text]
|
|
190
|
+
|
|
191
|
+
# ---------- Responses API (manual fallback) ----------
|
|
192
|
+
# Traverse response.output blocks
|
|
193
|
+
outputs = []
|
|
194
|
+
if hasattr(result, "output"):
|
|
195
|
+
current = []
|
|
196
|
+
|
|
197
|
+
for item in result.output:
|
|
198
|
+
# item might be an object or dict depending on SDK version
|
|
199
|
+
# Assuming object access based on previous code, but let's be safe with getattr/get
|
|
200
|
+
item_type = getattr(item, "type", None)
|
|
201
|
+
if not item_type and isinstance(item, dict):
|
|
202
|
+
item_type = item.get("type")
|
|
203
|
+
|
|
204
|
+
if item_type != "message":
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
content = getattr(item, "content", [])
|
|
208
|
+
if not content and isinstance(item, dict):
|
|
209
|
+
content = item.get("content", [])
|
|
210
|
+
|
|
211
|
+
for block in content:
|
|
212
|
+
block_type = getattr(block, "type", None)
|
|
213
|
+
if not block_type and isinstance(block, dict):
|
|
214
|
+
block_type = block.get("type")
|
|
215
|
+
|
|
216
|
+
if block_type == "output_text":
|
|
217
|
+
text = getattr(block, "text", "")
|
|
218
|
+
if not text and isinstance(block, dict):
|
|
219
|
+
text = block.get("text", "")
|
|
220
|
+
current.append(text)
|
|
221
|
+
|
|
222
|
+
if current:
|
|
223
|
+
outputs.append("".join(current))
|
|
224
|
+
current = []
|
|
225
|
+
|
|
226
|
+
return outputs
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class LLMAgent:
|
|
230
|
+
def __init__(self,
|
|
231
|
+
api_key = None,
|
|
232
|
+
api_base = None,
|
|
233
|
+
model_version = 'gpt-4.1-mini',
|
|
234
|
+
system_prompt = 'You are a helpful assistant.',
|
|
235
|
+
max_tokens = None,
|
|
236
|
+
temperature = 0,
|
|
237
|
+
http_client = None,
|
|
238
|
+
headers = None,
|
|
239
|
+
time_limit = 5*60,
|
|
240
|
+
max_try = 1,
|
|
241
|
+
use_responses_api = False
|
|
242
|
+
):
|
|
243
|
+
|
|
244
|
+
# Load from environment if not provided
|
|
245
|
+
if api_key is None:
|
|
246
|
+
api_key = os.environ.get("LLM_API_KEY")
|
|
247
|
+
if api_base is None:
|
|
248
|
+
api_base = os.environ.get("LLM_BASE_URL")
|
|
249
|
+
|
|
250
|
+
add_no_proxy_if_private(api_base)
|
|
251
|
+
self.api_key = api_key
|
|
252
|
+
self.api_base = api_base
|
|
253
|
+
self.model_version = model_version
|
|
254
|
+
self.system_prompt = system_prompt
|
|
255
|
+
self.max_tokens = max_tokens
|
|
256
|
+
self.temperature = temperature
|
|
257
|
+
self.time_limit = time_limit
|
|
258
|
+
self.max_try = max_try
|
|
259
|
+
self.use_responses_api = use_responses_api
|
|
260
|
+
|
|
261
|
+
self.client = OpenAI(api_key=self.api_key, base_url=self.api_base, http_client=http_client, default_headers=headers)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _llm_api_impl(self, query, system_prompt=None, **kwargs):
|
|
265
|
+
image_paths = kwargs.get('image_paths', None)
|
|
266
|
+
max_tokens = kwargs.get('max_tokens', self.max_tokens)
|
|
267
|
+
temperature = kwargs.get('temperature', self.temperature)
|
|
268
|
+
history = kwargs.get('history', None)
|
|
269
|
+
n = kwargs.get('n', 1)
|
|
270
|
+
if system_prompt is None:
|
|
271
|
+
system_prompt = self.system_prompt
|
|
272
|
+
|
|
273
|
+
if image_paths is None: # without image
|
|
274
|
+
content = query
|
|
275
|
+
else: # with image
|
|
276
|
+
content = [
|
|
277
|
+
{"type": "text", "text": query}
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
for image_path in image_paths:
|
|
281
|
+
try:
|
|
282
|
+
img = read_image(image_path)
|
|
283
|
+
ima_str = encode_image(img)
|
|
284
|
+
except:
|
|
285
|
+
continue
|
|
286
|
+
|
|
287
|
+
content.append({
|
|
288
|
+
"type": "image_url",
|
|
289
|
+
"image_url": {
|
|
290
|
+
"url": f"data:image/jpeg;base64,{ima_str}",
|
|
291
|
+
}
|
|
292
|
+
})
|
|
293
|
+
|
|
294
|
+
if history is None:
|
|
295
|
+
messages=[
|
|
296
|
+
{"role": "system", "content": system_prompt},
|
|
297
|
+
{"role": "user", "content": content}
|
|
298
|
+
]
|
|
299
|
+
else:
|
|
300
|
+
messages=[{"role": "system", "content": system_prompt}]+\
|
|
301
|
+
history+\
|
|
302
|
+
[{"role": "user", "content": content}]
|
|
303
|
+
|
|
304
|
+
use_responses_api = kwargs.get('use_responses_api', self.use_responses_api)
|
|
305
|
+
if use_responses_api:
|
|
306
|
+
system_prompt_content, input_blocks = messages_to_responses_input(messages)
|
|
307
|
+
|
|
308
|
+
# Prepare arguments for responses.create
|
|
309
|
+
create_kwargs = {
|
|
310
|
+
"model": self.model_version,
|
|
311
|
+
"input": input_blocks,
|
|
312
|
+
"max_output_tokens": max_tokens,
|
|
313
|
+
"temperature": temperature,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
if system_prompt_content:
|
|
317
|
+
create_kwargs["instructions"] = system_prompt_content
|
|
318
|
+
|
|
319
|
+
# Handle n > 1 manually for Responses API
|
|
320
|
+
if n > 1:
|
|
321
|
+
# Use multi_thread for parallel execution
|
|
322
|
+
inp_list = [create_kwargs] * n
|
|
323
|
+
responses = multi_thread(inp_list, self.client.responses.create, max_workers=min(n, 20), use_tqdm=False)
|
|
324
|
+
|
|
325
|
+
assistant_responses = []
|
|
326
|
+
for response in responses:
|
|
327
|
+
if response:
|
|
328
|
+
assistant_responses.extend(extract_text_outputs(response))
|
|
329
|
+
else:
|
|
330
|
+
response = self.client.responses.create(**create_kwargs)
|
|
331
|
+
assistant_responses = extract_text_outputs(response)
|
|
332
|
+
else:
|
|
333
|
+
response = self.client.chat.completions.create(
|
|
334
|
+
model=self.model_version,
|
|
335
|
+
messages=messages,
|
|
336
|
+
max_tokens=max_tokens,
|
|
337
|
+
temperature=temperature,
|
|
338
|
+
n=n,
|
|
339
|
+
)
|
|
340
|
+
assistant_responses = extract_text_outputs(response)
|
|
341
|
+
|
|
342
|
+
return assistant_responses
|
|
343
|
+
|
|
344
|
+
def llm_api(self, query, system_prompt=None, **kwargs):
|
|
345
|
+
return run_with_timeout(
|
|
346
|
+
self._llm_api_impl,
|
|
347
|
+
args=(query, system_prompt),
|
|
348
|
+
kwargs=kwargs,
|
|
349
|
+
timeout=self.time_limit
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def safe_api(self, query, system_prompt=None, return_example: Union[list, dict, str]=None, max_try=None, wait_time=0.0, **kwargs):
|
|
353
|
+
if max_try is None:
|
|
354
|
+
max_try = self.max_try
|
|
355
|
+
|
|
356
|
+
if return_example is not None:
|
|
357
|
+
assert isinstance(return_example, Union[list, dict, str]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] return_example should be list, dict or str: {type(return_example)}"
|
|
358
|
+
|
|
359
|
+
n = kwargs.get('n', 1)
|
|
360
|
+
response_list = []
|
|
361
|
+
for try_idx in range(max_try):
|
|
362
|
+
try:
|
|
363
|
+
responses = self.llm_api(query, system_prompt, **kwargs)
|
|
364
|
+
|
|
365
|
+
# str
|
|
366
|
+
if return_example is None or isinstance(return_example, str):
|
|
367
|
+
response_list = response_list + responses
|
|
368
|
+
|
|
369
|
+
# list
|
|
370
|
+
elif isinstance(return_example, list):
|
|
371
|
+
for response in responses:
|
|
372
|
+
result_list = str2list(response)
|
|
373
|
+
|
|
374
|
+
list_len = kwargs.get('list_len', None)
|
|
375
|
+
if list_len is not None:
|
|
376
|
+
assert len(result_list) == list_len, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the required length: {len(result_list)} != {list_len}\nResponse: {result_list}"
|
|
377
|
+
|
|
378
|
+
# type check
|
|
379
|
+
if len(return_example) > 0:
|
|
380
|
+
for result_item in result_list:
|
|
381
|
+
if isinstance(return_example[0], Union[float, int]):
|
|
382
|
+
assert isinstance(result_item, Union[float, int]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example list type: {type(result_item)} != {type(return_example[0])}\nItem: {result_item}"
|
|
383
|
+
else:
|
|
384
|
+
assert type(result_item) == type(return_example[0]), f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example list type: {type(result_item)} != {type(return_example[0])}\nItem: {result_item}"
|
|
385
|
+
|
|
386
|
+
# range check
|
|
387
|
+
list_min = kwargs.get('list_min', None)
|
|
388
|
+
list_max = kwargs.get('list_max', None)
|
|
389
|
+
if list_min is not None or list_max is not None:
|
|
390
|
+
for result_item in result_list:
|
|
391
|
+
if list_min is not None:
|
|
392
|
+
assert result_item >= list_min, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response {result_item} < list_min {list_min}"
|
|
393
|
+
if list_max is not None:
|
|
394
|
+
assert result_item <= list_max, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response {result_item} > list_max {list_max}"
|
|
395
|
+
|
|
396
|
+
response_list.append(result_list)
|
|
397
|
+
|
|
398
|
+
# dict
|
|
399
|
+
elif isinstance(return_example, dict):
|
|
400
|
+
for response in responses:
|
|
401
|
+
result_dict = str2dict(response)
|
|
402
|
+
|
|
403
|
+
if kwargs.get('check_keys', True):
|
|
404
|
+
result_dict_correct = {}
|
|
405
|
+
for k in return_example.keys():
|
|
406
|
+
if k in result_dict:
|
|
407
|
+
result_dict_correct[k] = result_dict[k]
|
|
408
|
+
else:
|
|
409
|
+
for out_k in result_dict.keys():
|
|
410
|
+
if len(k) > 5 and Levenshtein.distance(out_k.lower(), k.lower()) <= 2:
|
|
411
|
+
result_dict_correct[k] = result_dict[out_k]
|
|
412
|
+
break
|
|
413
|
+
|
|
414
|
+
assert k in result_dict_correct, f"[===ERROR===][structai][llm_api.py][LLMAgent.safe_api] LLM response does not match the example dict: missing key {k}\nResponse: {result_dict}\n"
|
|
415
|
+
else:
|
|
416
|
+
result_dict_correct = result_dict
|
|
417
|
+
|
|
418
|
+
response_list.append(result_dict_correct)
|
|
419
|
+
|
|
420
|
+
if len(response_list) >= n:
|
|
421
|
+
response_list = response_list[:n]
|
|
422
|
+
break
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
print(f'[===ERROR===][safe_api][{e}]')
|
|
426
|
+
if try_idx < max_try - 1:
|
|
427
|
+
time.sleep(wait_time)
|
|
428
|
+
|
|
429
|
+
if len(response_list) == 0:
|
|
430
|
+
return None
|
|
431
|
+
|
|
432
|
+
if n == 1:
|
|
433
|
+
return response_list[0]
|
|
434
|
+
else:
|
|
435
|
+
return response_list
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def __call__(self, query, *args, **kwargs):
|
|
439
|
+
return self.safe_api(query, *args, **kwargs)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def __str__(self) -> str:
|
|
443
|
+
return self.model_version.replace("/", "_")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
if __name__ == '__main__':
|
|
447
|
+
# python -m structai.llm_api
|
|
448
|
+
print("Testing llm_api.py...")
|
|
449
|
+
|
|
450
|
+
# Test sanitize_text
|
|
451
|
+
print("Testing sanitize_text...")
|
|
452
|
+
assert sanitize_text("Hello World!") == "Hello World!", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
|
|
453
|
+
assert sanitize_text("Hello\nWorld") == "Hello\nWorld", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
|
|
454
|
+
assert sanitize_text("Hello 🌍") == "Hello ", f"[===ERROR===][structai][llm_api.py][main] sanitize_text failed"
|
|
455
|
+
print("sanitize_text passed")
|
|
456
|
+
|
|
457
|
+
# Test str2dict
|
|
458
|
+
print("Testing str2dict...")
|
|
459
|
+
assert str2dict('{"a": 1}') == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
|
|
460
|
+
assert str2dict(' {"a": 1} ') == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
|
|
461
|
+
assert str2dict("some text {'a': 1} more text") == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
|
|
462
|
+
assert str2dict("{'a': 1,}") == {"a": 1}, f"[===ERROR===][structai][llm_api.py][main] str2dict failed"
|
|
463
|
+
print("str2dict passed")
|
|
464
|
+
|
|
465
|
+
# Test str2list
|
|
466
|
+
print("Testing str2list...")
|
|
467
|
+
assert str2list('[1, 2, 3]') == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
|
|
468
|
+
assert str2list(' [1, 2, 3] ') == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
|
|
469
|
+
assert str2list("text [1, 2] text") == [1, 2], f"[===ERROR===][structai][llm_api.py][main] str2list failed"
|
|
470
|
+
print("str2list passed")
|
|
471
|
+
|
|
472
|
+
# Test LLMAgent
|
|
473
|
+
print("\nTesting LLMAgent...")
|
|
474
|
+
if os.environ.get("LLM_API_KEY"):
|
|
475
|
+
|
|
476
|
+
def run_test(name, func, **kwargs):
|
|
477
|
+
print(f"\n[{name}]")
|
|
478
|
+
# Test with use_responses_api=False
|
|
479
|
+
print(" - use_responses_api=False:")
|
|
480
|
+
try:
|
|
481
|
+
func(use_responses_api=False, **kwargs)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
print(f" Error: {e}")
|
|
484
|
+
|
|
485
|
+
# Test with use_responses_api=True
|
|
486
|
+
print(" - use_responses_api=True:")
|
|
487
|
+
try:
|
|
488
|
+
func(use_responses_api=True, **kwargs)
|
|
489
|
+
except Exception as e:
|
|
490
|
+
print(f" Error: {e}")
|
|
491
|
+
|
|
492
|
+
# 1. Test gpt-4.1-mini
|
|
493
|
+
def test_gpt_4_1_mini(use_responses_api):
|
|
494
|
+
agent = LLMAgent(model_version='gpt-4.1-mini', max_try=1, use_responses_api=use_responses_api)
|
|
495
|
+
res = agent("Say 'hello'", max_tokens=20)
|
|
496
|
+
print(f" Result: {res}")
|
|
497
|
+
run_test("Test 1: gpt-4.1-mini", test_gpt_4_1_mini)
|
|
498
|
+
|
|
499
|
+
# 2. Test gpt-5.2-pro
|
|
500
|
+
def test_gpt_5_2_pro(use_responses_api):
|
|
501
|
+
agent = LLMAgent(model_version='gpt-5.2-pro', max_try=1, use_responses_api=use_responses_api)
|
|
502
|
+
res = agent("Say 'hello'", max_tokens=20)
|
|
503
|
+
print(f" Result: {res}")
|
|
504
|
+
run_test("Test 2: gpt-5.2-pro", test_gpt_5_2_pro)
|
|
505
|
+
|
|
506
|
+
# 3. Test time_limit=1, max_try=3
|
|
507
|
+
def test_retry(use_responses_api):
|
|
508
|
+
start_time = time.time()
|
|
509
|
+
agent = LLMAgent(time_limit=1, max_try=3, use_responses_api=use_responses_api)
|
|
510
|
+
res = agent("Write a 500 word essay about AI.", max_tokens=500)
|
|
511
|
+
print(f" Result: {res}")
|
|
512
|
+
print(f" Time taken: {time.time() - start_time:.2f}s")
|
|
513
|
+
run_test("Test 3: time_limit=1, max_try=3 (expect retries)", test_retry)
|
|
514
|
+
|
|
515
|
+
# 4. Test image_paths
|
|
516
|
+
def test_image(use_responses_api):
|
|
517
|
+
# Create dummy image
|
|
518
|
+
img = Image.new('RGB', (100, 100), color='red')
|
|
519
|
+
img_path = "test_image.png"
|
|
520
|
+
img.save(img_path)
|
|
521
|
+
|
|
522
|
+
try:
|
|
523
|
+
agent = LLMAgent(model_version='gpt-4o', max_try=1, use_responses_api=use_responses_api) # Assuming gpt-4o for vision
|
|
524
|
+
res = agent("What color is this image?", image_paths=[img_path], max_tokens=20)
|
|
525
|
+
print(f" Result: {res}")
|
|
526
|
+
finally:
|
|
527
|
+
if os.path.exists(img_path):
|
|
528
|
+
os.remove(img_path)
|
|
529
|
+
run_test("Test 4: image_paths", test_image)
|
|
530
|
+
|
|
531
|
+
# 5. Test history
|
|
532
|
+
def test_history(use_responses_api):
|
|
533
|
+
agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
|
|
534
|
+
history = [
|
|
535
|
+
{"role": "user", "content": "My name is Bob."},
|
|
536
|
+
{"role": "assistant", "content": "Hello Bob."}
|
|
537
|
+
]
|
|
538
|
+
res = agent("What is my name?", history=history, max_tokens=20)
|
|
539
|
+
print(f" Result: {res}")
|
|
540
|
+
run_test("Test 5: history", test_history)
|
|
541
|
+
|
|
542
|
+
# 6. Test n=3
|
|
543
|
+
def test_n_3(use_responses_api):
|
|
544
|
+
agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
|
|
545
|
+
res = agent("Generate a random number.", n=3, max_tokens=20)
|
|
546
|
+
print(f" Result (len={len(res) if isinstance(res, list) else 'N/A'}): {res}")
|
|
547
|
+
run_test("Test 6: n=3", test_n_3)
|
|
548
|
+
|
|
549
|
+
# 7. Test return_example (list and dict)
|
|
550
|
+
def test_return_example(use_responses_api):
|
|
551
|
+
agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
|
|
552
|
+
# List
|
|
553
|
+
print(" - List:")
|
|
554
|
+
res = agent("Return the list [1, 2, 3].", return_example=[1])
|
|
555
|
+
print(f" Result: {res}")
|
|
556
|
+
# Dict
|
|
557
|
+
print(" - Dict:")
|
|
558
|
+
res = agent("Return JSON {'a': 1}.", return_example={'a': 0})
|
|
559
|
+
print(f" Result: {res}")
|
|
560
|
+
run_test("Test 7: return_example", test_return_example)
|
|
561
|
+
|
|
562
|
+
# 8. Test list_min
|
|
563
|
+
def test_list_min(use_responses_api):
|
|
564
|
+
agent = LLMAgent(max_try=1, use_responses_api=use_responses_api)
|
|
565
|
+
res = agent("Return the list [10, 11, 12].", return_example=[1], list_min=5, list_max=20)
|
|
566
|
+
print(f" Result: {res}")
|
|
567
|
+
|
|
568
|
+
print(" - Testing failure case (list_min=20)...")
|
|
569
|
+
res = agent("Return the list [10, 11, 12].", return_example=[1], list_min=20, max_try=2)
|
|
570
|
+
print(f" Result (should be None): {res}")
|
|
571
|
+
run_test("Test 8: list_min", test_list_min)
|
|
572
|
+
|
|
573
|
+
else:
|
|
574
|
+
print("Skipping LLMAgent test: LLM_API_KEY not set")
|
|
575
|
+
|
|
576
|
+
# 9. Test Validation Logic with Mock
|
|
577
|
+
print("\nTesting Validation Logic with MockLLMAgent...")
|
|
578
|
+
|
|
579
|
+
class MockLLMAgent(LLMAgent):
|
|
580
|
+
def __init__(self, responses_map, **kwargs):
|
|
581
|
+
# responses_map: query -> response string or list of strings
|
|
582
|
+
super().__init__(api_key="dummy", **kwargs)
|
|
583
|
+
self.responses_map = responses_map
|
|
584
|
+
|
|
585
|
+
def llm_api(self, query, system_prompt=None, **kwargs):
|
|
586
|
+
# Simple mock: return predefined response based on query
|
|
587
|
+
res = self.responses_map.get(query, "{}")
|
|
588
|
+
if isinstance(res, str):
|
|
589
|
+
return [res]
|
|
590
|
+
return res
|
|
591
|
+
|
|
592
|
+
# Define test cases
|
|
593
|
+
mock_responses = {
|
|
594
|
+
"dict_exact": '{"long_name": "Alice", "years_old": 25}',
|
|
595
|
+
"dict_fuzzy": '{"long_nmae": "Alice", "years_oldd": 25}', # long_nmae->long_name (dist 1), years_oldd->years_old (dist 1)
|
|
596
|
+
"dict_bad": '{"wrong_name": "Alice", "wrong_age": 25}',
|
|
597
|
+
"dict_missing": '{"long_name": "Alice"}',
|
|
598
|
+
"list_int": '[10, 20]',
|
|
599
|
+
"list_str": '["10", "20"]',
|
|
600
|
+
"list_len_3": '[1, 2, 3]',
|
|
601
|
+
"list_len_2": '[1, 2]',
|
|
602
|
+
"list_range_ok": '[1, 5, 9]',
|
|
603
|
+
"list_range_bad_min": '[-1, 5]',
|
|
604
|
+
"list_range_bad_max": '[1, 11]',
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
# Use max_try=1 so that failures return None immediately
|
|
608
|
+
agent = MockLLMAgent(mock_responses, max_try=1)
|
|
609
|
+
|
|
610
|
+
# 9.1 Dict Key Tests
|
|
611
|
+
print(" - Dict Key Tests:")
|
|
612
|
+
# Use keys with len > 5 to trigger Levenshtein check
|
|
613
|
+
example_dict = {'long_name': 'John', 'years_old': 30}
|
|
614
|
+
|
|
615
|
+
# Exact match
|
|
616
|
+
res = agent("dict_exact", return_example=example_dict)
|
|
617
|
+
assert res == {"long_name": "Alice", "years_old": 25}, f"[===ERROR===][structai][llm_api.py][main] Dict exact match failed: {res}"
|
|
618
|
+
print(" [Pass] Exact match")
|
|
619
|
+
|
|
620
|
+
# Fuzzy match (Levenshtein)
|
|
621
|
+
res = agent("dict_fuzzy", return_example=example_dict)
|
|
622
|
+
# Should be corrected to match keys
|
|
623
|
+
assert res == {"long_name": "Alice", "years_old": 25}, f"[===ERROR===][structai][llm_api.py][main] Dict fuzzy match failed: {res}"
|
|
624
|
+
print(" [Pass] Fuzzy match (Levenshtein)")
|
|
625
|
+
|
|
626
|
+
# Bad keys
|
|
627
|
+
res = agent("dict_bad", return_example=example_dict)
|
|
628
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] Dict bad keys should fail: {res}"
|
|
629
|
+
print(" [Pass] Bad keys")
|
|
630
|
+
|
|
631
|
+
# Missing keys
|
|
632
|
+
res = agent("dict_missing", return_example=example_dict)
|
|
633
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] Dict missing keys should fail: {res}"
|
|
634
|
+
print(" [Pass] Missing keys")
|
|
635
|
+
|
|
636
|
+
# 9.2 List Type Tests
|
|
637
|
+
print(" - List Type Tests:")
|
|
638
|
+
example_list_int = [1]
|
|
639
|
+
|
|
640
|
+
# Correct type
|
|
641
|
+
res = agent("list_int", return_example=example_list_int)
|
|
642
|
+
assert res == [10, 20], f"[===ERROR===][structai][llm_api.py][main] List type ok failed: {res}"
|
|
643
|
+
print(" [Pass] Correct type")
|
|
644
|
+
|
|
645
|
+
# Incorrect type
|
|
646
|
+
res = agent("list_str", return_example=example_list_int)
|
|
647
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] List type bad should fail: {res}"
|
|
648
|
+
print(" [Pass] Incorrect type")
|
|
649
|
+
|
|
650
|
+
# 9.3 List Length Tests
|
|
651
|
+
print(" - List Length Tests:")
|
|
652
|
+
|
|
653
|
+
# Correct length
|
|
654
|
+
res = agent("list_len_3", return_example=[1], list_len=3)
|
|
655
|
+
assert res == [1, 2, 3], f"[===ERROR===][structai][llm_api.py][main] List len ok failed: {res}"
|
|
656
|
+
print(" [Pass] Correct length")
|
|
657
|
+
|
|
658
|
+
# Incorrect length
|
|
659
|
+
res = agent("list_len_2", return_example=[1], list_len=3)
|
|
660
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] List len bad should fail: {res}"
|
|
661
|
+
print(" [Pass] Incorrect length")
|
|
662
|
+
|
|
663
|
+
# 9.4 List Range Tests
|
|
664
|
+
print(" - List Range Tests:")
|
|
665
|
+
|
|
666
|
+
# In range
|
|
667
|
+
res = agent("list_range_ok", return_example=[1], list_min=0, list_max=10)
|
|
668
|
+
assert res == [1, 5, 9], f"[===ERROR===][structai][llm_api.py][main] List range ok failed: {res}"
|
|
669
|
+
print(" [Pass] In range")
|
|
670
|
+
|
|
671
|
+
# Bad min
|
|
672
|
+
res = agent("list_range_bad_min", return_example=[1], list_min=0)
|
|
673
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] List range bad min should fail: {res}"
|
|
674
|
+
print(" [Pass] Bad min")
|
|
675
|
+
|
|
676
|
+
# Bad max
|
|
677
|
+
res = agent("list_range_bad_max", return_example=[1], list_max=10)
|
|
678
|
+
assert res is None, f"[===ERROR===][structai][llm_api.py][main] List range bad max should fail: {res}"
|
|
679
|
+
print(" [Pass] Bad max")
|
|
680
|
+
|
|
681
|
+
print("llm_api.py tests completed.")
|
|
682
|
+
print("--------------------------------------------")
|