paddleocr-skills 1.0.0 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +220 -220
- package/bin/paddleocr-skills.js +33 -20
- package/lib/copy.js +39 -39
- package/lib/installer.js +76 -70
- package/lib/prompts.js +67 -67
- package/lib/python.js +75 -75
- package/lib/verify.js +121 -121
- package/package.json +42 -42
- package/templates/.env.example +12 -12
- package/templates/{paddleocr-vl/references/paddleocr-vl → paddleocr-vl-1.5/references/paddleocr-vl-1.5}/layout_schema.md +64 -64
- package/templates/{paddleocr-vl/references/paddleocr-vl → paddleocr-vl-1.5/references/paddleocr-vl-1.5}/output_format.md +154 -154
- package/templates/{paddleocr-vl/references/paddleocr-vl → paddleocr-vl-1.5/references/paddleocr-vl-1.5}/vl_model_spec.md +157 -157
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/_lib.py +780 -780
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/configure.py +270 -270
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/optimize_file.py +226 -226
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/requirements-optimize.txt +8 -8
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/requirements.txt +7 -7
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/smoke_test.py +199 -199
- package/templates/{paddleocr-vl/scripts/paddleocr-vl → paddleocr-vl-1.5/scripts/paddleocr-vl-1.5}/vl_caller.py +232 -232
- package/templates/{paddleocr-vl/skills/paddleocr-vl → paddleocr-vl-1.5/skills/paddleocr-vl-1.5}/SKILL.md +481 -481
- package/templates/ppocrv5/references/ppocrv5/agent_policy.md +258 -258
- package/templates/ppocrv5/references/ppocrv5/normalized_schema.md +257 -257
- package/templates/ppocrv5/references/ppocrv5/provider_api.md +140 -140
- package/templates/ppocrv5/scripts/ppocrv5/_lib.py +635 -635
- package/templates/ppocrv5/scripts/ppocrv5/configure.py +346 -346
- package/templates/ppocrv5/scripts/ppocrv5/ocr_caller.py +684 -684
- package/templates/ppocrv5/scripts/ppocrv5/requirements.txt +4 -4
- package/templates/ppocrv5/scripts/ppocrv5/smoke_test.py +139 -139
- package/templates/ppocrv5/skills/ppocrv5/SKILL.md +272 -272
|
@@ -1,635 +1,635 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Core library for PP-OCRv5 API Skill
|
|
3
|
-
- ProviderClient: HTTP client for Paddle AI Studio API
|
|
4
|
-
- Mapper: snake_case <-> camelCase conversion
|
|
5
|
-
- Normalizer: Provider response -> normalized output
|
|
6
|
-
- QualityEvaluator: Quality scoring for auto mode
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
import hashlib
|
|
10
|
-
import json
|
|
11
|
-
import logging
|
|
12
|
-
import math
|
|
13
|
-
import os
|
|
14
|
-
import re
|
|
15
|
-
import sys
|
|
16
|
-
import time
|
|
17
|
-
from pathlib import Path
|
|
18
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
-
|
|
20
|
-
import httpx
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger(__name__)
|
|
23
|
-
|
|
24
|
-
# =============================================================================
|
|
25
|
-
# Constants
|
|
26
|
-
# =============================================================================
|
|
27
|
-
|
|
28
|
-
# Quality scoring weights
|
|
29
|
-
QUALITY_TEXT_COUNT_WEIGHT = 0.6
|
|
30
|
-
QUALITY_CONFIDENCE_WEIGHT = 0.4
|
|
31
|
-
QUALITY_THRESHOLD_DEFAULT = 0.72
|
|
32
|
-
|
|
33
|
-
# Retry and timeout
|
|
34
|
-
DEFAULT_TIMEOUT_MS = 25000
|
|
35
|
-
DEFAULT_MAX_RETRY = 2
|
|
36
|
-
DEFAULT_CACHE_TTL_SEC = 600
|
|
37
|
-
|
|
38
|
-
# Normalization constants
|
|
39
|
-
NORM_REFERENCE_COUNT = 50 # Reference count for normalization
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# =============================================================================
|
|
43
|
-
# Environment & Config
|
|
44
|
-
# =============================================================================
|
|
45
|
-
|
|
46
|
-
class Config:
|
|
47
|
-
"""
|
|
48
|
-
Configuration manager - reads from .env file and environment variables
|
|
49
|
-
|
|
50
|
-
Priority order:
|
|
51
|
-
1. Environment variables (already set in shell)
|
|
52
|
-
2. .env file in project root
|
|
53
|
-
3. Raise error if not found
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
_env_loaded = False
|
|
57
|
-
|
|
58
|
-
@staticmethod
|
|
59
|
-
def load_env():
|
|
60
|
-
"""Load .env file using python-dotenv"""
|
|
61
|
-
if Config._env_loaded:
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
try:
|
|
65
|
-
from dotenv import load_dotenv
|
|
66
|
-
|
|
67
|
-
# Find .env file (in project root, which is parent of scripts/)
|
|
68
|
-
project_root = Path(__file__).parent.parent
|
|
69
|
-
env_file = project_root / ".env"
|
|
70
|
-
|
|
71
|
-
if env_file.exists():
|
|
72
|
-
load_dotenv(env_file)
|
|
73
|
-
logger.debug(f"Loaded .env from {env_file}")
|
|
74
|
-
else:
|
|
75
|
-
logger.debug(f".env file not found at {env_file}")
|
|
76
|
-
|
|
77
|
-
Config._env_loaded = True
|
|
78
|
-
|
|
79
|
-
except ImportError:
|
|
80
|
-
logger.warning("python-dotenv not installed, skipping .env file loading")
|
|
81
|
-
logger.warning("Install with: pip install python-dotenv")
|
|
82
|
-
Config._env_loaded = True
|
|
83
|
-
|
|
84
|
-
@staticmethod
|
|
85
|
-
def get_api_url() -> str:
|
|
86
|
-
"""
|
|
87
|
-
Get API URL from environment.
|
|
88
|
-
|
|
89
|
-
Priority:
|
|
90
|
-
1. API_URL environment variable
|
|
91
|
-
2. AISTUDIO_HOST environment variable (legacy)
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Full API URL (https://...com/ocr)
|
|
95
|
-
"""
|
|
96
|
-
Config.load_env()
|
|
97
|
-
|
|
98
|
-
# Priority 1: Direct API_URL
|
|
99
|
-
api_url = os.getenv("API_URL", "").strip()
|
|
100
|
-
if api_url:
|
|
101
|
-
# Normalize: ensure it starts with https:// and ends with /ocr
|
|
102
|
-
api_url = re.sub(r'^https?://', '', api_url) # Remove protocol
|
|
103
|
-
api_url = re.sub(r'/ocr$', '', api_url) # Remove /ocr if exists
|
|
104
|
-
return f"https://{api_url}/ocr"
|
|
105
|
-
|
|
106
|
-
# Priority 2: Legacy AISTUDIO_HOST
|
|
107
|
-
host = os.getenv("AISTUDIO_HOST", "").strip()
|
|
108
|
-
if host:
|
|
109
|
-
host = Config.normalize_host(host)
|
|
110
|
-
return f"https://{host}/ocr"
|
|
111
|
-
|
|
112
|
-
# Not found
|
|
113
|
-
raise ValueError(
|
|
114
|
-
"API not configured. Get your API at: https://aistudio.baidu.com/paddleocr/task"
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
@staticmethod
|
|
118
|
-
def normalize_host(host: str) -> str:
|
|
119
|
-
"""
|
|
120
|
-
Normalize host to bare hostname without protocol or path.
|
|
121
|
-
Examples:
|
|
122
|
-
- your-subdomain.aistudio-app.com -> your-subdomain.aistudio-app.com
|
|
123
|
-
- https://your-subdomain.aistudio-app.com -> your-subdomain.aistudio-app.com
|
|
124
|
-
- https://your-subdomain.aistudio-app.com/ocr -> your-subdomain.aistudio-app.com
|
|
125
|
-
"""
|
|
126
|
-
# Remove http:// or https://
|
|
127
|
-
host = re.sub(r'^https?://', '', host)
|
|
128
|
-
# Remove trailing path (e.g., /ocr or /)
|
|
129
|
-
host = re.sub(r'/.*$', '', host)
|
|
130
|
-
return host.strip()
|
|
131
|
-
|
|
132
|
-
@staticmethod
|
|
133
|
-
def get_token() -> str:
|
|
134
|
-
"""
|
|
135
|
-
Get token from environment.
|
|
136
|
-
|
|
137
|
-
Priority:
|
|
138
|
-
1. PADDLE_OCR_TOKEN environment variable
|
|
139
|
-
2. PADDLE_OCR_TOKEN_FALLBACK key
|
|
140
|
-
3. COZE_PP_OCRV5_* prefix scan
|
|
141
|
-
"""
|
|
142
|
-
Config.load_env()
|
|
143
|
-
|
|
144
|
-
# Priority 1: Direct token
|
|
145
|
-
token = os.getenv("PADDLE_OCR_TOKEN", "").strip()
|
|
146
|
-
if token:
|
|
147
|
-
return token
|
|
148
|
-
|
|
149
|
-
# Priority 2: Fallback env key
|
|
150
|
-
fallback_key = os.getenv("PADDLE_OCR_TOKEN_FALLBACK", "").strip()
|
|
151
|
-
if fallback_key:
|
|
152
|
-
token = os.getenv(fallback_key, "").strip()
|
|
153
|
-
if token:
|
|
154
|
-
logger.info(f"Using token from fallback key: {fallback_key}")
|
|
155
|
-
return token
|
|
156
|
-
|
|
157
|
-
# Priority 3: Scan for COZE_PP_OCRV5_ prefix
|
|
158
|
-
for key, value in os.environ.items():
|
|
159
|
-
if key.startswith("COZE_PP_OCRV5_"):
|
|
160
|
-
logger.info(f"Using token from auto-detected key: {key}")
|
|
161
|
-
return value.strip()
|
|
162
|
-
|
|
163
|
-
# Not found
|
|
164
|
-
raise ValueError(
|
|
165
|
-
"TOKEN not configured. Get your API at: https://aistudio.baidu.com/paddleocr/task"
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
@staticmethod
|
|
169
|
-
def get_timeout_ms() -> int:
|
|
170
|
-
return int(os.getenv("PADDLE_OCR_TIMEOUT_MS", str(DEFAULT_TIMEOUT_MS)))
|
|
171
|
-
|
|
172
|
-
@staticmethod
|
|
173
|
-
def get_max_retry() -> int:
|
|
174
|
-
return int(os.getenv("PADDLE_OCR_MAX_RETRY", str(DEFAULT_MAX_RETRY)))
|
|
175
|
-
|
|
176
|
-
@staticmethod
|
|
177
|
-
def get_cache_ttl_sec() -> int:
|
|
178
|
-
return int(os.getenv("PADDLE_OCR_CACHE_TTL_SEC", str(DEFAULT_CACHE_TTL_SEC)))
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
# =============================================================================
|
|
182
|
-
# Mapper: snake_case <-> camelCase
|
|
183
|
-
# =============================================================================
|
|
184
|
-
|
|
185
|
-
class Mapper:
|
|
186
|
-
"""Convert between snake_case (Python) and camelCase (Provider API)"""
|
|
187
|
-
|
|
188
|
-
@staticmethod
|
|
189
|
-
def snake_to_camel(name: str) -> str:
|
|
190
|
-
"""Convert snake_case to camelCase"""
|
|
191
|
-
components = name.split('_')
|
|
192
|
-
return components[0] + ''.join(x.title() for x in components[1:])
|
|
193
|
-
|
|
194
|
-
@staticmethod
|
|
195
|
-
def camel_to_snake(name: str) -> str:
|
|
196
|
-
"""Convert camelCase to snake_case"""
|
|
197
|
-
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
198
|
-
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
199
|
-
|
|
200
|
-
@staticmethod
|
|
201
|
-
def dict_to_camel(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
202
|
-
"""Convert dict keys from snake_case to camelCase, drop None values"""
|
|
203
|
-
result = {}
|
|
204
|
-
for k, v in data.items():
|
|
205
|
-
if v is None:
|
|
206
|
-
continue # Drop None values
|
|
207
|
-
camel_key = Mapper.snake_to_camel(k)
|
|
208
|
-
result[camel_key] = v
|
|
209
|
-
return result
|
|
210
|
-
|
|
211
|
-
@staticmethod
|
|
212
|
-
def dict_to_snake(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
213
|
-
"""Convert dict keys from camelCase to snake_case"""
|
|
214
|
-
result = {}
|
|
215
|
-
for k, v in data.items():
|
|
216
|
-
snake_key = Mapper.camel_to_snake(k)
|
|
217
|
-
result[snake_key] = v
|
|
218
|
-
return result
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
# =============================================================================
|
|
222
|
-
# Quality Evaluator
|
|
223
|
-
# =============================================================================
|
|
224
|
-
|
|
225
|
-
class QualityEvaluator:
|
|
226
|
-
"""Evaluate OCR quality based on text items and scores"""
|
|
227
|
-
|
|
228
|
-
@staticmethod
|
|
229
|
-
def norm(n: int, max_n: int = NORM_REFERENCE_COUNT) -> float:
|
|
230
|
-
"""Normalize text count: min(1, log(1+n)/log(1+max_n))"""
|
|
231
|
-
if n <= 0:
|
|
232
|
-
return 0.0
|
|
233
|
-
return min(1.0, math.log(1 + n) / math.log(1 + max_n))
|
|
234
|
-
|
|
235
|
-
@staticmethod
|
|
236
|
-
def evaluate(rec_texts: List[str], rec_scores: Optional[List[float]] = None) -> Dict[str, Any]:
|
|
237
|
-
"""
|
|
238
|
-
Evaluate quality from provider's prunedResult
|
|
239
|
-
Returns: {
|
|
240
|
-
"quality_score": float,
|
|
241
|
-
"avg_rec_score": float,
|
|
242
|
-
"text_items": int,
|
|
243
|
-
"warnings": List[str]
|
|
244
|
-
}
|
|
245
|
-
"""
|
|
246
|
-
text_items = len(rec_texts) if rec_texts else 0
|
|
247
|
-
warnings = []
|
|
248
|
-
|
|
249
|
-
# Average recognition score
|
|
250
|
-
if rec_scores and len(rec_scores) > 0:
|
|
251
|
-
avg_rec_score = sum(rec_scores) / len(rec_scores)
|
|
252
|
-
else:
|
|
253
|
-
avg_rec_score = 0.5 # Default if missing
|
|
254
|
-
if text_items > 0:
|
|
255
|
-
warnings.append("rec_scores missing, using default 0.5")
|
|
256
|
-
|
|
257
|
-
# Quality score: weighted combination of text count and confidence
|
|
258
|
-
if text_items == 0:
|
|
259
|
-
quality_score = 0.0
|
|
260
|
-
warnings.append("No text items detected")
|
|
261
|
-
else:
|
|
262
|
-
norm_count = QualityEvaluator.norm(text_items)
|
|
263
|
-
quality_score = QUALITY_TEXT_COUNT_WEIGHT * norm_count + QUALITY_CONFIDENCE_WEIGHT * avg_rec_score
|
|
264
|
-
|
|
265
|
-
return {
|
|
266
|
-
"quality_score": round(quality_score, 4),
|
|
267
|
-
"avg_rec_score": round(avg_rec_score, 4),
|
|
268
|
-
"text_items": text_items,
|
|
269
|
-
"warnings": warnings
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
# =============================================================================
|
|
274
|
-
# Normalizer
|
|
275
|
-
# =============================================================================
|
|
276
|
-
|
|
277
|
-
class Normalizer:
|
|
278
|
-
"""Normalize provider response to unified output schema"""
|
|
279
|
-
|
|
280
|
-
@staticmethod
|
|
281
|
-
def normalize_response(
|
|
282
|
-
provider_response: Dict[str, Any],
|
|
283
|
-
request_id: str,
|
|
284
|
-
api_url: str,
|
|
285
|
-
status_code: int,
|
|
286
|
-
mode: str,
|
|
287
|
-
selected_attempt: int,
|
|
288
|
-
attempts_history: List[Dict[str, Any]],
|
|
289
|
-
return_raw: bool = False
|
|
290
|
-
) -> Dict[str, Any]:
|
|
291
|
-
"""
|
|
292
|
-
Convert provider response to normalized output.
|
|
293
|
-
|
|
294
|
-
Args:
|
|
295
|
-
provider_response: Raw response from provider
|
|
296
|
-
request_id: Unique request ID
|
|
297
|
-
api_url: API endpoint used
|
|
298
|
-
status_code: HTTP status code
|
|
299
|
-
mode: fast/quality/auto
|
|
300
|
-
selected_attempt: Which attempt was selected (1-indexed)
|
|
301
|
-
attempts_history: List of attempt details
|
|
302
|
-
return_raw: Whether to include raw_provider in output
|
|
303
|
-
|
|
304
|
-
Returns:
|
|
305
|
-
Normalized output dict
|
|
306
|
-
"""
|
|
307
|
-
error_code = provider_response.get("errorCode", -1)
|
|
308
|
-
|
|
309
|
-
if error_code != 0:
|
|
310
|
-
# Error response
|
|
311
|
-
error_msg = provider_response.get("errorMsg", "Unknown error")
|
|
312
|
-
return {
|
|
313
|
-
"ok": False,
|
|
314
|
-
"request_id": request_id,
|
|
315
|
-
"provider": {
|
|
316
|
-
"api_url": api_url,
|
|
317
|
-
"status_code": status_code,
|
|
318
|
-
"log_id": provider_response.get("logId")
|
|
319
|
-
},
|
|
320
|
-
"result": None,
|
|
321
|
-
"quality": None,
|
|
322
|
-
"agent_trace": {
|
|
323
|
-
"mode": mode,
|
|
324
|
-
"selected_attempt": selected_attempt,
|
|
325
|
-
"attempts": attempts_history
|
|
326
|
-
},
|
|
327
|
-
"raw_provider": provider_response if return_raw else None,
|
|
328
|
-
"error": {
|
|
329
|
-
"code": Normalizer._map_error_code(error_code, status_code),
|
|
330
|
-
"message": error_msg,
|
|
331
|
-
"details": {
|
|
332
|
-
"error_code": error_code,
|
|
333
|
-
"status_code": status_code
|
|
334
|
-
}
|
|
335
|
-
}
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
# Success response
|
|
339
|
-
result = provider_response.get("result", {})
|
|
340
|
-
ocr_results = result.get("ocrResults", [])
|
|
341
|
-
|
|
342
|
-
pages = []
|
|
343
|
-
all_texts = []
|
|
344
|
-
total_items = 0
|
|
345
|
-
total_scores_sum = 0.0
|
|
346
|
-
total_scores_count = 0
|
|
347
|
-
|
|
348
|
-
for page_idx, ocr_res in enumerate(ocr_results):
|
|
349
|
-
pruned = ocr_res.get("prunedResult", {})
|
|
350
|
-
rec_texts = pruned.get("rec_texts", [])
|
|
351
|
-
rec_scores = pruned.get("rec_scores", [])
|
|
352
|
-
rec_boxes = pruned.get("rec_boxes", [])
|
|
353
|
-
rec_polys = pruned.get("rec_polys", [])
|
|
354
|
-
|
|
355
|
-
items = []
|
|
356
|
-
page_text_lines = []
|
|
357
|
-
|
|
358
|
-
for i, text in enumerate(rec_texts):
|
|
359
|
-
score = rec_scores[i] if i < len(rec_scores) else None
|
|
360
|
-
box = None
|
|
361
|
-
if i < len(rec_boxes):
|
|
362
|
-
box = rec_boxes[i]
|
|
363
|
-
elif i < len(rec_polys):
|
|
364
|
-
# Flatten polygon to box (simplified)
|
|
365
|
-
box = rec_polys[i]
|
|
366
|
-
|
|
367
|
-
item = {"text": text}
|
|
368
|
-
if score is not None:
|
|
369
|
-
item["score"] = round(score, 4)
|
|
370
|
-
total_scores_sum += score
|
|
371
|
-
total_scores_count += 1
|
|
372
|
-
if box is not None:
|
|
373
|
-
item["box"] = box
|
|
374
|
-
|
|
375
|
-
items.append(item)
|
|
376
|
-
page_text_lines.append(text)
|
|
377
|
-
total_items += 1
|
|
378
|
-
|
|
379
|
-
page_text = "\n".join(page_text_lines)
|
|
380
|
-
all_texts.append(page_text)
|
|
381
|
-
|
|
382
|
-
page_avg_conf = 0.0
|
|
383
|
-
if total_scores_count > 0:
|
|
384
|
-
page_avg_conf = total_scores_sum / total_scores_count
|
|
385
|
-
|
|
386
|
-
pages.append({
|
|
387
|
-
"page_index": page_idx,
|
|
388
|
-
"text": page_text,
|
|
389
|
-
"avg_confidence": round(page_avg_conf, 4) if items else 0.0,
|
|
390
|
-
"items": items
|
|
391
|
-
})
|
|
392
|
-
|
|
393
|
-
full_text = "\n\n".join(all_texts)
|
|
394
|
-
|
|
395
|
-
# Get quality from last attempt (selected one)
|
|
396
|
-
quality_info = None
|
|
397
|
-
if attempts_history and selected_attempt <= len(attempts_history):
|
|
398
|
-
last_attempt = attempts_history[selected_attempt - 1]
|
|
399
|
-
quality_info = {
|
|
400
|
-
"quality_score": last_attempt.get("quality_score", 0.0),
|
|
401
|
-
"avg_rec_score": last_attempt.get("avg_rec_score", 0.0),
|
|
402
|
-
"text_items": total_items,
|
|
403
|
-
"warnings": last_attempt.get("warnings", [])
|
|
404
|
-
}
|
|
405
|
-
else:
|
|
406
|
-
# Fallback: compute quality on the fly
|
|
407
|
-
quality_info = QualityEvaluator.evaluate(
|
|
408
|
-
rec_texts=[item["text"] for page in pages for item in page["items"]],
|
|
409
|
-
rec_scores=[item.get("score") for page in pages for item in page["items"] if "score" in item]
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
return {
|
|
413
|
-
"ok": True,
|
|
414
|
-
"request_id": request_id,
|
|
415
|
-
"provider": {
|
|
416
|
-
"api_url": api_url,
|
|
417
|
-
"status_code": status_code,
|
|
418
|
-
"log_id": provider_response.get("logId")
|
|
419
|
-
},
|
|
420
|
-
"result": {
|
|
421
|
-
"pages": pages,
|
|
422
|
-
"full_text": full_text
|
|
423
|
-
},
|
|
424
|
-
"quality": quality_info,
|
|
425
|
-
"agent_trace": {
|
|
426
|
-
"mode": mode,
|
|
427
|
-
"selected_attempt": selected_attempt,
|
|
428
|
-
"attempts": attempts_history
|
|
429
|
-
},
|
|
430
|
-
"raw_provider": provider_response if return_raw else None,
|
|
431
|
-
"error": None
|
|
432
|
-
}
|
|
433
|
-
|
|
434
|
-
@staticmethod
|
|
435
|
-
def _map_error_code(error_code: int, status_code: int) -> str:
|
|
436
|
-
"""Map provider error code to unified error code"""
|
|
437
|
-
if status_code == 403:
|
|
438
|
-
return "PROVIDER_AUTH_ERROR"
|
|
439
|
-
elif status_code == 429:
|
|
440
|
-
return "PROVIDER_QUOTA_EXCEEDED"
|
|
441
|
-
elif status_code == 503:
|
|
442
|
-
return "PROVIDER_OVERLOADED"
|
|
443
|
-
elif status_code == 504:
|
|
444
|
-
return "PROVIDER_TIMEOUT"
|
|
445
|
-
elif error_code == 500:
|
|
446
|
-
return "PROVIDER_BAD_REQUEST"
|
|
447
|
-
else:
|
|
448
|
-
return "PROVIDER_ERROR"
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
# =============================================================================
|
|
452
|
-
# Provider Client
|
|
453
|
-
# =============================================================================
|
|
454
|
-
|
|
455
|
-
class ProviderClient:
|
|
456
|
-
"""HTTP client for Paddle AI Studio PP-OCRv5 API"""
|
|
457
|
-
|
|
458
|
-
def __init__(
|
|
459
|
-
self,
|
|
460
|
-
api_url: str,
|
|
461
|
-
token: str,
|
|
462
|
-
timeout_ms: int = 25000,
|
|
463
|
-
max_retry: int = 2
|
|
464
|
-
):
|
|
465
|
-
self.api_url = api_url
|
|
466
|
-
self.token = token
|
|
467
|
-
self.timeout_ms = timeout_ms
|
|
468
|
-
self.max_retry = max_retry
|
|
469
|
-
self.client = httpx.Client(timeout=timeout_ms / 1000.0)
|
|
470
|
-
|
|
471
|
-
def call(self, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int, float]:
|
|
472
|
-
"""
|
|
473
|
-
Call provider API with retry on 503/504.
|
|
474
|
-
|
|
475
|
-
Returns:
|
|
476
|
-
(response_json, status_code, elapsed_ms)
|
|
477
|
-
"""
|
|
478
|
-
headers = {
|
|
479
|
-
"Authorization": f"token {self.token}",
|
|
480
|
-
"Content-Type": "application/json"
|
|
481
|
-
}
|
|
482
|
-
|
|
483
|
-
attempt = 0
|
|
484
|
-
while attempt <= self.max_retry:
|
|
485
|
-
start_time = time.time()
|
|
486
|
-
try:
|
|
487
|
-
resp = self.client.post(self.api_url, json=payload, headers=headers)
|
|
488
|
-
elapsed_ms = (time.time() - start_time) * 1000
|
|
489
|
-
|
|
490
|
-
# Parse response
|
|
491
|
-
try:
|
|
492
|
-
resp_json = resp.json()
|
|
493
|
-
except (json.JSONDecodeError, ValueError) as e:
|
|
494
|
-
logger.warning(f"Failed to parse JSON response: {e}")
|
|
495
|
-
resp_json = {"errorCode": -1, "errorMsg": "Invalid JSON response"}
|
|
496
|
-
|
|
497
|
-
# Retry on 503/504
|
|
498
|
-
if resp.status_code in [503, 504] and attempt < self.max_retry:
|
|
499
|
-
logger.warning(f"Attempt {attempt + 1} failed with {resp.status_code}, retrying...")
|
|
500
|
-
backoff_ms = 200 * (4 ** attempt) + (hash(str(time.time())) % 100)
|
|
501
|
-
time.sleep(backoff_ms / 1000.0)
|
|
502
|
-
attempt += 1
|
|
503
|
-
continue
|
|
504
|
-
|
|
505
|
-
return resp_json, resp.status_code, elapsed_ms
|
|
506
|
-
|
|
507
|
-
except httpx.TimeoutException:
|
|
508
|
-
elapsed_ms = (time.time() - start_time) * 1000
|
|
509
|
-
if attempt < self.max_retry:
|
|
510
|
-
logger.warning(f"Attempt {attempt + 1} timed out, retrying...")
|
|
511
|
-
backoff_ms = 200 * (4 ** attempt) + (hash(str(time.time())) % 100)
|
|
512
|
-
time.sleep(backoff_ms / 1000.0)
|
|
513
|
-
attempt += 1
|
|
514
|
-
continue
|
|
515
|
-
else:
|
|
516
|
-
return {
|
|
517
|
-
"errorCode": 504,
|
|
518
|
-
"errorMsg": "Request timed out"
|
|
519
|
-
}, 504, elapsed_ms
|
|
520
|
-
|
|
521
|
-
except Exception as e:
|
|
522
|
-
elapsed_ms = (time.time() - start_time) * 1000
|
|
523
|
-
logger.error(f"Request failed: {e}")
|
|
524
|
-
return {
|
|
525
|
-
"errorCode": -1,
|
|
526
|
-
"errorMsg": f"Request failed: {str(e)}"
|
|
527
|
-
}, 500, elapsed_ms
|
|
528
|
-
|
|
529
|
-
# Should not reach here
|
|
530
|
-
return {"errorCode": -1, "errorMsg": "Max retries exceeded"}, 500, 0.0
|
|
531
|
-
|
|
532
|
-
def close(self):
|
|
533
|
-
"""Close HTTP client"""
|
|
534
|
-
self.client.close()
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
# =============================================================================
|
|
538
|
-
# Cache
|
|
539
|
-
# =============================================================================
|
|
540
|
-
|
|
541
|
-
class SimpleCache:
|
|
542
|
-
"""In-memory TTL cache for normalized results"""
|
|
543
|
-
|
|
544
|
-
def __init__(self, ttl_sec: int = 600):
|
|
545
|
-
self.ttl_sec = ttl_sec
|
|
546
|
-
self._cache: Dict[str, Tuple[Any, float]] = {}
|
|
547
|
-
|
|
548
|
-
def get(self, key: str) -> Optional[Any]:
|
|
549
|
-
"""Get cached value if not expired"""
|
|
550
|
-
if key in self._cache:
|
|
551
|
-
value, expiry = self._cache[key]
|
|
552
|
-
if time.time() < expiry:
|
|
553
|
-
return value
|
|
554
|
-
else:
|
|
555
|
-
del self._cache[key]
|
|
556
|
-
return None
|
|
557
|
-
|
|
558
|
-
def set(self, key: str, value: Any):
|
|
559
|
-
"""Set cache value with TTL"""
|
|
560
|
-
expiry = time.time() + self.ttl_sec
|
|
561
|
-
self._cache[key] = (value, expiry)
|
|
562
|
-
|
|
563
|
-
@staticmethod
|
|
564
|
-
def make_key(file_input: str, options: Dict[str, Any]) -> str:
|
|
565
|
-
"""
|
|
566
|
-
Generate cache key from file and options.
|
|
567
|
-
For performance, only hash first 1KB of large inputs.
|
|
568
|
-
"""
|
|
569
|
-
# For large inputs (base64 encoded files), only hash first 1KB
|
|
570
|
-
input_sample = file_input[:1024] if len(file_input) > 1024 else file_input
|
|
571
|
-
file_hash = hashlib.sha256(input_sample.encode()).hexdigest()[:16]
|
|
572
|
-
|
|
573
|
-
options_str = json.dumps(options, sort_keys=True)
|
|
574
|
-
options_hash = hashlib.sha256(options_str.encode()).hexdigest()[:16]
|
|
575
|
-
return f"{file_hash}_{options_hash}"
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
# =============================================================================
|
|
579
|
-
# Agent Policy
|
|
580
|
-
# =============================================================================
|
|
581
|
-
|
|
582
|
-
class AgentPolicy:
|
|
583
|
-
"""Generate attempt strategies for auto mode"""
|
|
584
|
-
|
|
585
|
-
@staticmethod
|
|
586
|
-
def get_attempts_config(mode: str, max_attempts: int = 3) -> List[Dict[str, Any]]:
|
|
587
|
-
"""
|
|
588
|
-
Get list of attempt configurations based on mode.
|
|
589
|
-
|
|
590
|
-
Args:
|
|
591
|
-
mode: 'fast', 'quality', or 'auto'
|
|
592
|
-
max_attempts: Max attempts for auto mode
|
|
593
|
-
|
|
594
|
-
Returns:
|
|
595
|
-
List of option dicts
|
|
596
|
-
"""
|
|
597
|
-
if mode == "fast":
|
|
598
|
-
return [{
|
|
599
|
-
"use_doc_orientation_classify": False,
|
|
600
|
-
"use_doc_unwarping": False,
|
|
601
|
-
"use_textline_orientation": False
|
|
602
|
-
}]
|
|
603
|
-
|
|
604
|
-
elif mode == "quality":
|
|
605
|
-
return [{
|
|
606
|
-
"use_doc_orientation_classify": True,
|
|
607
|
-
"use_doc_unwarping": True,
|
|
608
|
-
"use_textline_orientation": False
|
|
609
|
-
}]
|
|
610
|
-
|
|
611
|
-
elif mode == "auto":
|
|
612
|
-
attempts = [
|
|
613
|
-
# Attempt 1: fast path
|
|
614
|
-
{
|
|
615
|
-
"use_doc_orientation_classify": False,
|
|
616
|
-
"use_doc_unwarping": False,
|
|
617
|
-
"use_textline_orientation": False
|
|
618
|
-
},
|
|
619
|
-
# Attempt 2: orientation fix
|
|
620
|
-
{
|
|
621
|
-
"use_doc_orientation_classify": True,
|
|
622
|
-
"use_doc_unwarping": False,
|
|
623
|
-
"use_textline_orientation": False
|
|
624
|
-
},
|
|
625
|
-
# Attempt 3: unwarping fix
|
|
626
|
-
{
|
|
627
|
-
"use_doc_orientation_classify": True,
|
|
628
|
-
"use_doc_unwarping": True,
|
|
629
|
-
"use_textline_orientation": False
|
|
630
|
-
}
|
|
631
|
-
]
|
|
632
|
-
return attempts[:max_attempts]
|
|
633
|
-
|
|
634
|
-
else:
|
|
635
|
-
raise ValueError(f"Unknown mode: {mode}")
|
|
1
|
+
"""
|
|
2
|
+
Core library for PP-OCRv5 API Skill
|
|
3
|
+
- ProviderClient: HTTP client for Paddle AI Studio API
|
|
4
|
+
- Mapper: snake_case <-> camelCase conversion
|
|
5
|
+
- Normalizer: Provider response -> normalized output
|
|
6
|
+
- QualityEvaluator: Quality scoring for auto mode
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
import os
|
|
14
|
+
import re
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# =============================================================================
|
|
25
|
+
# Constants
|
|
26
|
+
# =============================================================================
|
|
27
|
+
|
|
28
|
+
# Quality scoring weights
|
|
29
|
+
QUALITY_TEXT_COUNT_WEIGHT = 0.6
|
|
30
|
+
QUALITY_CONFIDENCE_WEIGHT = 0.4
|
|
31
|
+
QUALITY_THRESHOLD_DEFAULT = 0.72
|
|
32
|
+
|
|
33
|
+
# Retry and timeout
|
|
34
|
+
DEFAULT_TIMEOUT_MS = 25000
|
|
35
|
+
DEFAULT_MAX_RETRY = 2
|
|
36
|
+
DEFAULT_CACHE_TTL_SEC = 600
|
|
37
|
+
|
|
38
|
+
# Normalization constants
|
|
39
|
+
NORM_REFERENCE_COUNT = 50 # Reference count for normalization
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# =============================================================================
|
|
43
|
+
# Environment & Config
|
|
44
|
+
# =============================================================================
|
|
45
|
+
|
|
46
|
+
class Config:
|
|
47
|
+
"""
|
|
48
|
+
Configuration manager - reads from .env file and environment variables
|
|
49
|
+
|
|
50
|
+
Priority order:
|
|
51
|
+
1. Environment variables (already set in shell)
|
|
52
|
+
2. .env file in project root
|
|
53
|
+
3. Raise error if not found
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
_env_loaded = False
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def load_env():
|
|
60
|
+
"""Load .env file using python-dotenv"""
|
|
61
|
+
if Config._env_loaded:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
from dotenv import load_dotenv
|
|
66
|
+
|
|
67
|
+
# Find .env file (in project root, which is parent of scripts/)
|
|
68
|
+
project_root = Path(__file__).parent.parent
|
|
69
|
+
env_file = project_root / ".env"
|
|
70
|
+
|
|
71
|
+
if env_file.exists():
|
|
72
|
+
load_dotenv(env_file)
|
|
73
|
+
logger.debug(f"Loaded .env from {env_file}")
|
|
74
|
+
else:
|
|
75
|
+
logger.debug(f".env file not found at {env_file}")
|
|
76
|
+
|
|
77
|
+
Config._env_loaded = True
|
|
78
|
+
|
|
79
|
+
except ImportError:
|
|
80
|
+
logger.warning("python-dotenv not installed, skipping .env file loading")
|
|
81
|
+
logger.warning("Install with: pip install python-dotenv")
|
|
82
|
+
Config._env_loaded = True
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def get_api_url() -> str:
|
|
86
|
+
"""
|
|
87
|
+
Get API URL from environment.
|
|
88
|
+
|
|
89
|
+
Priority:
|
|
90
|
+
1. API_URL environment variable
|
|
91
|
+
2. AISTUDIO_HOST environment variable (legacy)
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Full API URL (https://...com/ocr)
|
|
95
|
+
"""
|
|
96
|
+
Config.load_env()
|
|
97
|
+
|
|
98
|
+
# Priority 1: Direct API_URL
|
|
99
|
+
api_url = os.getenv("API_URL", "").strip()
|
|
100
|
+
if api_url:
|
|
101
|
+
# Normalize: ensure it starts with https:// and ends with /ocr
|
|
102
|
+
api_url = re.sub(r'^https?://', '', api_url) # Remove protocol
|
|
103
|
+
api_url = re.sub(r'/ocr$', '', api_url) # Remove /ocr if exists
|
|
104
|
+
return f"https://{api_url}/ocr"
|
|
105
|
+
|
|
106
|
+
# Priority 2: Legacy AISTUDIO_HOST
|
|
107
|
+
host = os.getenv("AISTUDIO_HOST", "").strip()
|
|
108
|
+
if host:
|
|
109
|
+
host = Config.normalize_host(host)
|
|
110
|
+
return f"https://{host}/ocr"
|
|
111
|
+
|
|
112
|
+
# Not found
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"API not configured. Get your API at: https://aistudio.baidu.com/paddleocr/task"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def normalize_host(host: str) -> str:
|
|
119
|
+
"""
|
|
120
|
+
Normalize host to bare hostname without protocol or path.
|
|
121
|
+
Examples:
|
|
122
|
+
- your-subdomain.aistudio-app.com -> your-subdomain.aistudio-app.com
|
|
123
|
+
- https://your-subdomain.aistudio-app.com -> your-subdomain.aistudio-app.com
|
|
124
|
+
- https://your-subdomain.aistudio-app.com/ocr -> your-subdomain.aistudio-app.com
|
|
125
|
+
"""
|
|
126
|
+
# Remove http:// or https://
|
|
127
|
+
host = re.sub(r'^https?://', '', host)
|
|
128
|
+
# Remove trailing path (e.g., /ocr or /)
|
|
129
|
+
host = re.sub(r'/.*$', '', host)
|
|
130
|
+
return host.strip()
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def get_token() -> str:
|
|
134
|
+
"""
|
|
135
|
+
Get token from environment.
|
|
136
|
+
|
|
137
|
+
Priority:
|
|
138
|
+
1. PADDLE_OCR_TOKEN environment variable
|
|
139
|
+
2. PADDLE_OCR_TOKEN_FALLBACK key
|
|
140
|
+
3. COZE_PP_OCRV5_* prefix scan
|
|
141
|
+
"""
|
|
142
|
+
Config.load_env()
|
|
143
|
+
|
|
144
|
+
# Priority 1: Direct token
|
|
145
|
+
token = os.getenv("PADDLE_OCR_TOKEN", "").strip()
|
|
146
|
+
if token:
|
|
147
|
+
return token
|
|
148
|
+
|
|
149
|
+
# Priority 2: Fallback env key
|
|
150
|
+
fallback_key = os.getenv("PADDLE_OCR_TOKEN_FALLBACK", "").strip()
|
|
151
|
+
if fallback_key:
|
|
152
|
+
token = os.getenv(fallback_key, "").strip()
|
|
153
|
+
if token:
|
|
154
|
+
logger.info(f"Using token from fallback key: {fallback_key}")
|
|
155
|
+
return token
|
|
156
|
+
|
|
157
|
+
# Priority 3: Scan for COZE_PP_OCRV5_ prefix
|
|
158
|
+
for key, value in os.environ.items():
|
|
159
|
+
if key.startswith("COZE_PP_OCRV5_"):
|
|
160
|
+
logger.info(f"Using token from auto-detected key: {key}")
|
|
161
|
+
return value.strip()
|
|
162
|
+
|
|
163
|
+
# Not found
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"TOKEN not configured. Get your API at: https://aistudio.baidu.com/paddleocr/task"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def get_timeout_ms() -> int:
|
|
170
|
+
return int(os.getenv("PADDLE_OCR_TIMEOUT_MS", str(DEFAULT_TIMEOUT_MS)))
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def get_max_retry() -> int:
|
|
174
|
+
return int(os.getenv("PADDLE_OCR_MAX_RETRY", str(DEFAULT_MAX_RETRY)))
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def get_cache_ttl_sec() -> int:
|
|
178
|
+
return int(os.getenv("PADDLE_OCR_CACHE_TTL_SEC", str(DEFAULT_CACHE_TTL_SEC)))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# =============================================================================
|
|
182
|
+
# Mapper: snake_case <-> camelCase
|
|
183
|
+
# =============================================================================
|
|
184
|
+
|
|
185
|
+
class Mapper:
|
|
186
|
+
"""Convert between snake_case (Python) and camelCase (Provider API)"""
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def snake_to_camel(name: str) -> str:
|
|
190
|
+
"""Convert snake_case to camelCase"""
|
|
191
|
+
components = name.split('_')
|
|
192
|
+
return components[0] + ''.join(x.title() for x in components[1:])
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def camel_to_snake(name: str) -> str:
|
|
196
|
+
"""Convert camelCase to snake_case"""
|
|
197
|
+
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
198
|
+
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def dict_to_camel(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
202
|
+
"""Convert dict keys from snake_case to camelCase, drop None values"""
|
|
203
|
+
result = {}
|
|
204
|
+
for k, v in data.items():
|
|
205
|
+
if v is None:
|
|
206
|
+
continue # Drop None values
|
|
207
|
+
camel_key = Mapper.snake_to_camel(k)
|
|
208
|
+
result[camel_key] = v
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def dict_to_snake(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
213
|
+
"""Convert dict keys from camelCase to snake_case"""
|
|
214
|
+
result = {}
|
|
215
|
+
for k, v in data.items():
|
|
216
|
+
snake_key = Mapper.camel_to_snake(k)
|
|
217
|
+
result[snake_key] = v
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# =============================================================================
|
|
222
|
+
# Quality Evaluator
|
|
223
|
+
# =============================================================================
|
|
224
|
+
|
|
225
|
+
class QualityEvaluator:
|
|
226
|
+
"""Evaluate OCR quality based on text items and scores"""
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def norm(n: int, max_n: int = NORM_REFERENCE_COUNT) -> float:
|
|
230
|
+
"""Normalize text count: min(1, log(1+n)/log(1+max_n))"""
|
|
231
|
+
if n <= 0:
|
|
232
|
+
return 0.0
|
|
233
|
+
return min(1.0, math.log(1 + n) / math.log(1 + max_n))
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def evaluate(rec_texts: List[str], rec_scores: Optional[List[float]] = None) -> Dict[str, Any]:
|
|
237
|
+
"""
|
|
238
|
+
Evaluate quality from provider's prunedResult
|
|
239
|
+
Returns: {
|
|
240
|
+
"quality_score": float,
|
|
241
|
+
"avg_rec_score": float,
|
|
242
|
+
"text_items": int,
|
|
243
|
+
"warnings": List[str]
|
|
244
|
+
}
|
|
245
|
+
"""
|
|
246
|
+
text_items = len(rec_texts) if rec_texts else 0
|
|
247
|
+
warnings = []
|
|
248
|
+
|
|
249
|
+
# Average recognition score
|
|
250
|
+
if rec_scores and len(rec_scores) > 0:
|
|
251
|
+
avg_rec_score = sum(rec_scores) / len(rec_scores)
|
|
252
|
+
else:
|
|
253
|
+
avg_rec_score = 0.5 # Default if missing
|
|
254
|
+
if text_items > 0:
|
|
255
|
+
warnings.append("rec_scores missing, using default 0.5")
|
|
256
|
+
|
|
257
|
+
# Quality score: weighted combination of text count and confidence
|
|
258
|
+
if text_items == 0:
|
|
259
|
+
quality_score = 0.0
|
|
260
|
+
warnings.append("No text items detected")
|
|
261
|
+
else:
|
|
262
|
+
norm_count = QualityEvaluator.norm(text_items)
|
|
263
|
+
quality_score = QUALITY_TEXT_COUNT_WEIGHT * norm_count + QUALITY_CONFIDENCE_WEIGHT * avg_rec_score
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"quality_score": round(quality_score, 4),
|
|
267
|
+
"avg_rec_score": round(avg_rec_score, 4),
|
|
268
|
+
"text_items": text_items,
|
|
269
|
+
"warnings": warnings
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# =============================================================================
|
|
274
|
+
# Normalizer
|
|
275
|
+
# =============================================================================
|
|
276
|
+
|
|
277
|
+
class Normalizer:
|
|
278
|
+
"""Normalize provider response to unified output schema"""
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def normalize_response(
|
|
282
|
+
provider_response: Dict[str, Any],
|
|
283
|
+
request_id: str,
|
|
284
|
+
api_url: str,
|
|
285
|
+
status_code: int,
|
|
286
|
+
mode: str,
|
|
287
|
+
selected_attempt: int,
|
|
288
|
+
attempts_history: List[Dict[str, Any]],
|
|
289
|
+
return_raw: bool = False
|
|
290
|
+
) -> Dict[str, Any]:
|
|
291
|
+
"""
|
|
292
|
+
Convert provider response to normalized output.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
provider_response: Raw response from provider
|
|
296
|
+
request_id: Unique request ID
|
|
297
|
+
api_url: API endpoint used
|
|
298
|
+
status_code: HTTP status code
|
|
299
|
+
mode: fast/quality/auto
|
|
300
|
+
selected_attempt: Which attempt was selected (1-indexed)
|
|
301
|
+
attempts_history: List of attempt details
|
|
302
|
+
return_raw: Whether to include raw_provider in output
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Normalized output dict
|
|
306
|
+
"""
|
|
307
|
+
error_code = provider_response.get("errorCode", -1)
|
|
308
|
+
|
|
309
|
+
if error_code != 0:
|
|
310
|
+
# Error response
|
|
311
|
+
error_msg = provider_response.get("errorMsg", "Unknown error")
|
|
312
|
+
return {
|
|
313
|
+
"ok": False,
|
|
314
|
+
"request_id": request_id,
|
|
315
|
+
"provider": {
|
|
316
|
+
"api_url": api_url,
|
|
317
|
+
"status_code": status_code,
|
|
318
|
+
"log_id": provider_response.get("logId")
|
|
319
|
+
},
|
|
320
|
+
"result": None,
|
|
321
|
+
"quality": None,
|
|
322
|
+
"agent_trace": {
|
|
323
|
+
"mode": mode,
|
|
324
|
+
"selected_attempt": selected_attempt,
|
|
325
|
+
"attempts": attempts_history
|
|
326
|
+
},
|
|
327
|
+
"raw_provider": provider_response if return_raw else None,
|
|
328
|
+
"error": {
|
|
329
|
+
"code": Normalizer._map_error_code(error_code, status_code),
|
|
330
|
+
"message": error_msg,
|
|
331
|
+
"details": {
|
|
332
|
+
"error_code": error_code,
|
|
333
|
+
"status_code": status_code
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
# Success response
|
|
339
|
+
result = provider_response.get("result", {})
|
|
340
|
+
ocr_results = result.get("ocrResults", [])
|
|
341
|
+
|
|
342
|
+
pages = []
|
|
343
|
+
all_texts = []
|
|
344
|
+
total_items = 0
|
|
345
|
+
total_scores_sum = 0.0
|
|
346
|
+
total_scores_count = 0
|
|
347
|
+
|
|
348
|
+
for page_idx, ocr_res in enumerate(ocr_results):
|
|
349
|
+
pruned = ocr_res.get("prunedResult", {})
|
|
350
|
+
rec_texts = pruned.get("rec_texts", [])
|
|
351
|
+
rec_scores = pruned.get("rec_scores", [])
|
|
352
|
+
rec_boxes = pruned.get("rec_boxes", [])
|
|
353
|
+
rec_polys = pruned.get("rec_polys", [])
|
|
354
|
+
|
|
355
|
+
items = []
|
|
356
|
+
page_text_lines = []
|
|
357
|
+
|
|
358
|
+
for i, text in enumerate(rec_texts):
|
|
359
|
+
score = rec_scores[i] if i < len(rec_scores) else None
|
|
360
|
+
box = None
|
|
361
|
+
if i < len(rec_boxes):
|
|
362
|
+
box = rec_boxes[i]
|
|
363
|
+
elif i < len(rec_polys):
|
|
364
|
+
# Flatten polygon to box (simplified)
|
|
365
|
+
box = rec_polys[i]
|
|
366
|
+
|
|
367
|
+
item = {"text": text}
|
|
368
|
+
if score is not None:
|
|
369
|
+
item["score"] = round(score, 4)
|
|
370
|
+
total_scores_sum += score
|
|
371
|
+
total_scores_count += 1
|
|
372
|
+
if box is not None:
|
|
373
|
+
item["box"] = box
|
|
374
|
+
|
|
375
|
+
items.append(item)
|
|
376
|
+
page_text_lines.append(text)
|
|
377
|
+
total_items += 1
|
|
378
|
+
|
|
379
|
+
page_text = "\n".join(page_text_lines)
|
|
380
|
+
all_texts.append(page_text)
|
|
381
|
+
|
|
382
|
+
page_avg_conf = 0.0
|
|
383
|
+
if total_scores_count > 0:
|
|
384
|
+
page_avg_conf = total_scores_sum / total_scores_count
|
|
385
|
+
|
|
386
|
+
pages.append({
|
|
387
|
+
"page_index": page_idx,
|
|
388
|
+
"text": page_text,
|
|
389
|
+
"avg_confidence": round(page_avg_conf, 4) if items else 0.0,
|
|
390
|
+
"items": items
|
|
391
|
+
})
|
|
392
|
+
|
|
393
|
+
full_text = "\n\n".join(all_texts)
|
|
394
|
+
|
|
395
|
+
# Get quality from last attempt (selected one)
|
|
396
|
+
quality_info = None
|
|
397
|
+
if attempts_history and selected_attempt <= len(attempts_history):
|
|
398
|
+
last_attempt = attempts_history[selected_attempt - 1]
|
|
399
|
+
quality_info = {
|
|
400
|
+
"quality_score": last_attempt.get("quality_score", 0.0),
|
|
401
|
+
"avg_rec_score": last_attempt.get("avg_rec_score", 0.0),
|
|
402
|
+
"text_items": total_items,
|
|
403
|
+
"warnings": last_attempt.get("warnings", [])
|
|
404
|
+
}
|
|
405
|
+
else:
|
|
406
|
+
# Fallback: compute quality on the fly
|
|
407
|
+
quality_info = QualityEvaluator.evaluate(
|
|
408
|
+
rec_texts=[item["text"] for page in pages for item in page["items"]],
|
|
409
|
+
rec_scores=[item.get("score") for page in pages for item in page["items"] if "score" in item]
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return {
|
|
413
|
+
"ok": True,
|
|
414
|
+
"request_id": request_id,
|
|
415
|
+
"provider": {
|
|
416
|
+
"api_url": api_url,
|
|
417
|
+
"status_code": status_code,
|
|
418
|
+
"log_id": provider_response.get("logId")
|
|
419
|
+
},
|
|
420
|
+
"result": {
|
|
421
|
+
"pages": pages,
|
|
422
|
+
"full_text": full_text
|
|
423
|
+
},
|
|
424
|
+
"quality": quality_info,
|
|
425
|
+
"agent_trace": {
|
|
426
|
+
"mode": mode,
|
|
427
|
+
"selected_attempt": selected_attempt,
|
|
428
|
+
"attempts": attempts_history
|
|
429
|
+
},
|
|
430
|
+
"raw_provider": provider_response if return_raw else None,
|
|
431
|
+
"error": None
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
@staticmethod
|
|
435
|
+
def _map_error_code(error_code: int, status_code: int) -> str:
|
|
436
|
+
"""Map provider error code to unified error code"""
|
|
437
|
+
if status_code == 403:
|
|
438
|
+
return "PROVIDER_AUTH_ERROR"
|
|
439
|
+
elif status_code == 429:
|
|
440
|
+
return "PROVIDER_QUOTA_EXCEEDED"
|
|
441
|
+
elif status_code == 503:
|
|
442
|
+
return "PROVIDER_OVERLOADED"
|
|
443
|
+
elif status_code == 504:
|
|
444
|
+
return "PROVIDER_TIMEOUT"
|
|
445
|
+
elif error_code == 500:
|
|
446
|
+
return "PROVIDER_BAD_REQUEST"
|
|
447
|
+
else:
|
|
448
|
+
return "PROVIDER_ERROR"
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
# =============================================================================
|
|
452
|
+
# Provider Client
|
|
453
|
+
# =============================================================================
|
|
454
|
+
|
|
455
|
+
class ProviderClient:
|
|
456
|
+
"""HTTP client for Paddle AI Studio PP-OCRv5 API"""
|
|
457
|
+
|
|
458
|
+
def __init__(
|
|
459
|
+
self,
|
|
460
|
+
api_url: str,
|
|
461
|
+
token: str,
|
|
462
|
+
timeout_ms: int = 25000,
|
|
463
|
+
max_retry: int = 2
|
|
464
|
+
):
|
|
465
|
+
self.api_url = api_url
|
|
466
|
+
self.token = token
|
|
467
|
+
self.timeout_ms = timeout_ms
|
|
468
|
+
self.max_retry = max_retry
|
|
469
|
+
self.client = httpx.Client(timeout=timeout_ms / 1000.0)
|
|
470
|
+
|
|
471
|
+
def call(self, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int, float]:
|
|
472
|
+
"""
|
|
473
|
+
Call provider API with retry on 503/504.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
(response_json, status_code, elapsed_ms)
|
|
477
|
+
"""
|
|
478
|
+
headers = {
|
|
479
|
+
"Authorization": f"token {self.token}",
|
|
480
|
+
"Content-Type": "application/json"
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
attempt = 0
|
|
484
|
+
while attempt <= self.max_retry:
|
|
485
|
+
start_time = time.time()
|
|
486
|
+
try:
|
|
487
|
+
resp = self.client.post(self.api_url, json=payload, headers=headers)
|
|
488
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
489
|
+
|
|
490
|
+
# Parse response
|
|
491
|
+
try:
|
|
492
|
+
resp_json = resp.json()
|
|
493
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
494
|
+
logger.warning(f"Failed to parse JSON response: {e}")
|
|
495
|
+
resp_json = {"errorCode": -1, "errorMsg": "Invalid JSON response"}
|
|
496
|
+
|
|
497
|
+
# Retry on 503/504
|
|
498
|
+
if resp.status_code in [503, 504] and attempt < self.max_retry:
|
|
499
|
+
logger.warning(f"Attempt {attempt + 1} failed with {resp.status_code}, retrying...")
|
|
500
|
+
backoff_ms = 200 * (4 ** attempt) + (hash(str(time.time())) % 100)
|
|
501
|
+
time.sleep(backoff_ms / 1000.0)
|
|
502
|
+
attempt += 1
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
return resp_json, resp.status_code, elapsed_ms
|
|
506
|
+
|
|
507
|
+
except httpx.TimeoutException:
|
|
508
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
509
|
+
if attempt < self.max_retry:
|
|
510
|
+
logger.warning(f"Attempt {attempt + 1} timed out, retrying...")
|
|
511
|
+
backoff_ms = 200 * (4 ** attempt) + (hash(str(time.time())) % 100)
|
|
512
|
+
time.sleep(backoff_ms / 1000.0)
|
|
513
|
+
attempt += 1
|
|
514
|
+
continue
|
|
515
|
+
else:
|
|
516
|
+
return {
|
|
517
|
+
"errorCode": 504,
|
|
518
|
+
"errorMsg": "Request timed out"
|
|
519
|
+
}, 504, elapsed_ms
|
|
520
|
+
|
|
521
|
+
except Exception as e:
|
|
522
|
+
elapsed_ms = (time.time() - start_time) * 1000
|
|
523
|
+
logger.error(f"Request failed: {e}")
|
|
524
|
+
return {
|
|
525
|
+
"errorCode": -1,
|
|
526
|
+
"errorMsg": f"Request failed: {str(e)}"
|
|
527
|
+
}, 500, elapsed_ms
|
|
528
|
+
|
|
529
|
+
# Should not reach here
|
|
530
|
+
return {"errorCode": -1, "errorMsg": "Max retries exceeded"}, 500, 0.0
|
|
531
|
+
|
|
532
|
+
def close(self):
|
|
533
|
+
"""Close HTTP client"""
|
|
534
|
+
self.client.close()
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
# =============================================================================
|
|
538
|
+
# Cache
|
|
539
|
+
# =============================================================================
|
|
540
|
+
|
|
541
|
+
class SimpleCache:
|
|
542
|
+
"""In-memory TTL cache for normalized results"""
|
|
543
|
+
|
|
544
|
+
def __init__(self, ttl_sec: int = 600):
|
|
545
|
+
self.ttl_sec = ttl_sec
|
|
546
|
+
self._cache: Dict[str, Tuple[Any, float]] = {}
|
|
547
|
+
|
|
548
|
+
def get(self, key: str) -> Optional[Any]:
|
|
549
|
+
"""Get cached value if not expired"""
|
|
550
|
+
if key in self._cache:
|
|
551
|
+
value, expiry = self._cache[key]
|
|
552
|
+
if time.time() < expiry:
|
|
553
|
+
return value
|
|
554
|
+
else:
|
|
555
|
+
del self._cache[key]
|
|
556
|
+
return None
|
|
557
|
+
|
|
558
|
+
def set(self, key: str, value: Any):
|
|
559
|
+
"""Set cache value with TTL"""
|
|
560
|
+
expiry = time.time() + self.ttl_sec
|
|
561
|
+
self._cache[key] = (value, expiry)
|
|
562
|
+
|
|
563
|
+
@staticmethod
|
|
564
|
+
def make_key(file_input: str, options: Dict[str, Any]) -> str:
|
|
565
|
+
"""
|
|
566
|
+
Generate cache key from file and options.
|
|
567
|
+
For performance, only hash first 1KB of large inputs.
|
|
568
|
+
"""
|
|
569
|
+
# For large inputs (base64 encoded files), only hash first 1KB
|
|
570
|
+
input_sample = file_input[:1024] if len(file_input) > 1024 else file_input
|
|
571
|
+
file_hash = hashlib.sha256(input_sample.encode()).hexdigest()[:16]
|
|
572
|
+
|
|
573
|
+
options_str = json.dumps(options, sort_keys=True)
|
|
574
|
+
options_hash = hashlib.sha256(options_str.encode()).hexdigest()[:16]
|
|
575
|
+
return f"{file_hash}_{options_hash}"
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
# =============================================================================
|
|
579
|
+
# Agent Policy
|
|
580
|
+
# =============================================================================
|
|
581
|
+
|
|
582
|
+
class AgentPolicy:
|
|
583
|
+
"""Generate attempt strategies for auto mode"""
|
|
584
|
+
|
|
585
|
+
@staticmethod
|
|
586
|
+
def get_attempts_config(mode: str, max_attempts: int = 3) -> List[Dict[str, Any]]:
|
|
587
|
+
"""
|
|
588
|
+
Get list of attempt configurations based on mode.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
mode: 'fast', 'quality', or 'auto'
|
|
592
|
+
max_attempts: Max attempts for auto mode
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
List of option dicts
|
|
596
|
+
"""
|
|
597
|
+
if mode == "fast":
|
|
598
|
+
return [{
|
|
599
|
+
"use_doc_orientation_classify": False,
|
|
600
|
+
"use_doc_unwarping": False,
|
|
601
|
+
"use_textline_orientation": False
|
|
602
|
+
}]
|
|
603
|
+
|
|
604
|
+
elif mode == "quality":
|
|
605
|
+
return [{
|
|
606
|
+
"use_doc_orientation_classify": True,
|
|
607
|
+
"use_doc_unwarping": True,
|
|
608
|
+
"use_textline_orientation": False
|
|
609
|
+
}]
|
|
610
|
+
|
|
611
|
+
elif mode == "auto":
|
|
612
|
+
attempts = [
|
|
613
|
+
# Attempt 1: fast path
|
|
614
|
+
{
|
|
615
|
+
"use_doc_orientation_classify": False,
|
|
616
|
+
"use_doc_unwarping": False,
|
|
617
|
+
"use_textline_orientation": False
|
|
618
|
+
},
|
|
619
|
+
# Attempt 2: orientation fix
|
|
620
|
+
{
|
|
621
|
+
"use_doc_orientation_classify": True,
|
|
622
|
+
"use_doc_unwarping": False,
|
|
623
|
+
"use_textline_orientation": False
|
|
624
|
+
},
|
|
625
|
+
# Attempt 3: unwarping fix
|
|
626
|
+
{
|
|
627
|
+
"use_doc_orientation_classify": True,
|
|
628
|
+
"use_doc_unwarping": True,
|
|
629
|
+
"use_textline_orientation": False
|
|
630
|
+
}
|
|
631
|
+
]
|
|
632
|
+
return attempts[:max_attempts]
|
|
633
|
+
|
|
634
|
+
else:
|
|
635
|
+
raise ValueError(f"Unknown mode: {mode}")
|