ponychart-classifier 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ponychart_classifier-0.1.0/PKG-INFO +194 -0
  2. ponychart_classifier-0.1.0/README.md +163 -0
  3. ponychart_classifier-0.1.0/pyproject.toml +62 -0
  4. ponychart_classifier-0.1.0/setup.cfg +4 -0
  5. ponychart_classifier-0.1.0/src/ponychart_classifier/__init__.py +27 -0
  6. ponychart_classifier-0.1.0/src/ponychart_classifier/inference.py +114 -0
  7. ponychart_classifier-0.1.0/src/ponychart_classifier/model.onnx +60784 -119
  8. ponychart_classifier-0.1.0/src/ponychart_classifier/model_spec.py +38 -0
  9. ponychart_classifier-0.1.0/src/ponychart_classifier/py.typed +0 -0
  10. ponychart_classifier-0.1.0/src/ponychart_classifier/thresholds.json +8 -0
  11. ponychart_classifier-0.1.0/src/ponychart_classifier/training/__init__.py +141 -0
  12. ponychart_classifier-0.1.0/src/ponychart_classifier/training/constants.py +70 -0
  13. ponychart_classifier-0.1.0/src/ponychart_classifier/training/dataset.py +265 -0
  14. ponychart_classifier-0.1.0/src/ponychart_classifier/training/device.py +37 -0
  15. ponychart_classifier-0.1.0/src/ponychart_classifier/training/export.py +46 -0
  16. ponychart_classifier-0.1.0/src/ponychart_classifier/training/log_helpers.py +25 -0
  17. ponychart_classifier-0.1.0/src/ponychart_classifier/training/model.py +156 -0
  18. ponychart_classifier-0.1.0/src/ponychart_classifier/training/sampling.py +154 -0
  19. ponychart_classifier-0.1.0/src/ponychart_classifier/training/splitting.py +142 -0
  20. ponychart_classifier-0.1.0/src/ponychart_classifier/training/training.py +380 -0
  21. ponychart_classifier-0.1.0/src/ponychart_classifier.egg-info/PKG-INFO +194 -0
  22. ponychart_classifier-0.1.0/src/ponychart_classifier.egg-info/SOURCES.txt +23 -0
  23. ponychart_classifier-0.1.0/src/ponychart_classifier.egg-info/dependency_links.txt +1 -0
  24. ponychart_classifier-0.1.0/src/ponychart_classifier.egg-info/requires.txt +17 -0
  25. ponychart_classifier-0.1.0/src/ponychart_classifier.egg-info/top_level.txt +1 -0
@@ -0,0 +1,194 @@
1
+ Metadata-Version: 2.4
2
+ Name: ponychart-classifier
3
+ Version: 0.1.0
4
+ Summary: Multi-label image classifier for PonyChart character identification.
5
+ Author: Kuan-Lun Wang
6
+ License: GNU Affero General Public License v3
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Requires-Python: <3.14,>=3.11
15
+ Description-Content-Type: text/markdown
16
+ Requires-Dist: numpy>=2.2.6
17
+ Requires-Dist: opencv-python>=4.12.0.88
18
+ Requires-Dist: onnxruntime>=1.15.0
19
+ Provides-Extra: train
20
+ Requires-Dist: torch>=2.0; extra == "train"
21
+ Requires-Dist: torchvision>=0.15; extra == "train"
22
+ Requires-Dist: scikit-learn>=1.3; extra == "train"
23
+ Requires-Dist: onnx>=1.20.1; extra == "train"
24
+ Requires-Dist: onnxscript>=0.6.2; extra == "train"
25
+ Requires-Dist: flask>=3.0; extra == "train"
26
+ Requires-Dist: psutil>=5.9; extra == "train"
27
+ Requires-Dist: Pillow>=10.0; extra == "train"
28
+ Provides-Extra: publish
29
+ Requires-Dist: build; extra == "publish"
30
+ Requires-Dist: twine; extra == "publish"
31
+
32
+ # PonyChart Classifier
33
+
34
+ PonyChart 角色辨識模型,用於自動辨識 HentaiVerse 戰鬥中出現的 PonyChart 圖片中的角色。
35
+
36
+ ## 目錄結構
37
+
38
+ ```
39
+ ponychart-classifier/
40
+ ├── src/
41
+ │ └── ponychart_classifier/ # PyPI 套件
42
+ │ ├── __init__.py # 公開 API (re-export model_spec)
43
+ │ ├── model_spec.py # 推論常數 + select_predictions()
44
+ │ ├── model.onnx # ONNX 模型 (推論用)
45
+ │ ├── thresholds.json # 各角色的分類閾值 (推論用)
46
+ │ └── training/ # 訓練函式庫
47
+ │ ├── __init__.py # Re-export 所有 symbol
48
+ │ ├── constants.py # 常數與訓練超參數 (single source of truth)
49
+ │ ├── device.py # 裝置偵測
50
+ │ ├── dataset.py # 資料載入、Dataset、transforms
51
+ │ ├── model.py # Backbone registry + build_model()
52
+ │ ├── training.py # 訓練迴圈、evaluate、threshold 優化
53
+ │ ├── sampling.py # 樣本載入與平衡
54
+ │ ├── splitting.py # Hash-based group splitting
55
+ │ ├── log_helpers.py # 日誌輔助
56
+ │ └── export.py # ONNX 匯出
57
+ ├── scripts/ # 開發用腳本 (不隨套件發佈)
58
+ │ ├── train.py # 模型訓練腳本
59
+ │ ├── label_images.py # 圖片標註工具 (GUI)
60
+ │ ├── inspect_checkpoint.py # Checkpoint 資訊檢視
61
+ │ ├── compare_backbones.py # Backbone 架構比較
62
+ │ ├── compare_crops.py # 裁切圖片效果分析
63
+ │ ├── compare_pos_weight.py # pos_weight 效果比較
64
+ │ ├── compare_resolution.py # 輸入解析度比較
65
+ │ ├── compare_resume_scratch.py # Resume vs from-scratch 分析
66
+ │ ├── evaluate_holdout.py # Holdout 評估
67
+ │ ├── analyze_augmentations.py # 資料增強 ablation study
68
+ │ ├── analyze_distribution.py # 標籤分布互動式視覺化 (Flask)
69
+ │ ├── learning_curve.py # Learning curve 分析 + power-law 外推
70
+ │ ├── search_batch_lr.py # LR 超參數搜尋
71
+ │ └── profile_dataloader.py # DataLoader 效能分析
72
+ ├── rawimage/ # 訓練用原始圖片 (PNG)
73
+ ├── labels.json # 標註資料 {"rawimage/filename.png": [1,3]}
74
+ ├── checkpoint.pt # PyTorch checkpoint (resume 訓練用)
75
+ ├── pyproject.toml
76
+ └── README.md
77
+ ```
78
+
79
+ ## 標籤對照
80
+
81
+ | 編號 | 角色 |
82
+ |------|------|
83
+ | 1 | Twilight Sparkle |
84
+ | 2 | Rarity |
85
+ | 3 | Fluttershy |
86
+ | 4 | Rainbow Dash |
87
+ | 5 | Pinkie Pie |
88
+ | 6 | Applejack |
89
+
90
+ ## 安裝
91
+
92
+ ```bash
93
+ # 推論用 (hbrowser 會自動安裝)
94
+ pip install ponychart-classifier
95
+
96
+ # 開發用 (包含訓練依賴)
97
+ pip install -e ".[train]"
98
+ ```
99
+
100
+ ## 工作流程
101
+
102
+ ### 1. 收集圖片
103
+
104
+ 將新的 PonyChart 截圖 (PNG) 放入 `rawimage/` 資料夾。
105
+
106
+ ### 2. 標註圖片
107
+
108
+ ```bash
109
+ uv run python scripts/label_images.py
110
+ ```
111
+
112
+ 操作方式:
113
+ - `1`~`6`: 加/取消對應角色標籤
114
+ - `A` / `D`: 上一張 / 下一張
115
+ - `S`: 儲存目前標籤
116
+
117
+ 標註結果會即時更新到 `labels.json`。
118
+
119
+ ### 3. 訓練模型
120
+
121
+ ```bash
122
+ # 安裝訓練依賴 (只需一次)
123
+ uv pip install -e ".[train]"
124
+
125
+ # 執行訓練 (若存在 checkpoint.pt 則自動從上次結果繼續訓練)
126
+ uv run python scripts/train.py
127
+
128
+ # 強制從頭訓練 (忽略 checkpoint,從 ImageNet 預訓練權重開始)
129
+ uv run python scripts/train.py --from-scratch
130
+ ```
131
+
132
+ 訓練完成後會覆寫 `model.onnx`、`thresholds.json` 和 `checkpoint.pt`,下次推論自動使用新模型。
133
+
134
+ ### Resume 訓練
135
+
136
+ 新增圖片並標註後,直接執行 `train.py` 即可。腳本會自動偵測 `checkpoint.pt`:
137
+ - **有 checkpoint**: 載入之前的模型權重,跳過 Phase 1 (head-only),直接進入 Phase 2 fine-tuning,收斂更快
138
+ - **無 checkpoint**: 從 ImageNet 預訓練權重開始完整兩階段訓練
139
+
140
+ ### 訓練超參數
141
+
142
+ 所有超參數集中於 `src/ponychart_classifier/training/constants.py`,修改後對所有腳本生效:
143
+
144
+ | 參數 | 預設值 | 說明 |
145
+ |------|--------|------|
146
+ | `BACKBONE` | `efficientnet_b0` | 見下方支援的 backbone |
147
+ | `BATCH_SIZE` | 32 | 批次大小 |
148
+ | `SEED` | 42 | 隨機種子 |
149
+ | `PHASE1_EPOCHS` | 10 | Phase 1 (head-only) 訓練輪數 |
150
+ | `PHASE2_EPOCHS` | 100 | Phase 2 (full fine-tuning) 最大訓練輪數 |
151
+ | `PHASE2_PATIENCE` | 12 | Phase 2 early stopping patience |
152
+
153
+ ## 支援的 Backbone
154
+
155
+ | Backbone | 參數量 | ONNX 大小 | 說明 |
156
+ |----------|--------|-----------|------|
157
+ | `mobilenet_v3_small` | 2.5M | ~4MB | 輕量快速 |
158
+ | `mobilenet_v3_large` | 5.4M | ~9MB | 精度最高 |
159
+ | `efficientnet_b0` | 5.3M | ~11MB | 預設,精度接近 Large,但訓練較慢 |
160
+
161
+ 所有 backbone 都使用 ImageNet 預訓練權重 + transfer learning。
162
+ 推論端使用 ONNX Runtime,backbone 更換後只需重新匯出 `model.onnx`,推論程式碼不需改動。
163
+
164
+ ## 分析腳本
165
+
166
+ 分析腳本使用 `training/constants.py` 中的超參數設定:
167
+
168
+ ```bash
169
+ # 比較三種 backbone 的效果
170
+ uv run python scripts/compare_backbones.py
171
+
172
+ # 分析裁切圖片的影響
173
+ uv run python scripts/compare_crops.py
174
+
175
+ # 資料增強 ablation study
176
+ uv run python scripts/analyze_augmentations.py
177
+
178
+ # 標籤分布互動式視覺化 (Flask web UI)
179
+ uv run python scripts/analyze_distribution.py
180
+
181
+ # Learning curve 分析 (估算增加資料的邊際效益)
182
+ uv run python scripts/learning_curve.py
183
+
184
+ # LR 超參數搜尋
185
+ uv run python scripts/search_batch_lr.py
186
+ ```
187
+
188
+ ## 模型架構
189
+
190
+ - **Backbone**: 可選 MobileNetV3-Small/Large 或 EfficientNet-B0 (預設 EfficientNet-B0,ImageNet 預訓練)
191
+ - **訓練策略**: Phase 1 head-only + Phase 2 full fine-tuning,支援從 checkpoint 繼續訓練
192
+ - **輸出**: 6 個 sigmoid 節點 (多標籤分類)
193
+ - **推論引擎**: ONNX Runtime (CPU)
194
+ - **推論速度**: 3-21ms / 張
@@ -0,0 +1,163 @@
1
+ # PonyChart Classifier
2
+
3
+ PonyChart 角色辨識模型,用於自動辨識 HentaiVerse 戰鬥中出現的 PonyChart 圖片中的角色。
4
+
5
+ ## 目錄結構
6
+
7
+ ```
8
+ ponychart-classifier/
9
+ ├── src/
10
+ │ └── ponychart_classifier/ # PyPI 套件
11
+ │ ├── __init__.py # 公開 API (re-export model_spec)
12
+ │ ├── model_spec.py # 推論常數 + select_predictions()
13
+ │ ├── model.onnx # ONNX 模型 (推論用)
14
+ │ ├── thresholds.json # 各角色的分類閾值 (推論用)
15
+ │ └── training/ # 訓練函式庫
16
+ │ ├── __init__.py # Re-export 所有 symbol
17
+ │ ├── constants.py # 常數與訓練超參數 (single source of truth)
18
+ │ ├── device.py # 裝置偵測
19
+ │ ├── dataset.py # 資料載入、Dataset、transforms
20
+ │ ├── model.py # Backbone registry + build_model()
21
+ │ ├── training.py # 訓練迴圈、evaluate、threshold 優化
22
+ │ ├── sampling.py # 樣本載入與平衡
23
+ │ ├── splitting.py # Hash-based group splitting
24
+ │ ├── log_helpers.py # 日誌輔助
25
+ │ └── export.py # ONNX 匯出
26
+ ├── scripts/ # 開發用腳本 (不隨套件發佈)
27
+ │ ├── train.py # 模型訓練腳本
28
+ │ ├── label_images.py # 圖片標註工具 (GUI)
29
+ │ ├── inspect_checkpoint.py # Checkpoint 資訊檢視
30
+ │ ├── compare_backbones.py # Backbone 架構比較
31
+ │ ├── compare_crops.py # 裁切圖片效果分析
32
+ │ ├── compare_pos_weight.py # pos_weight 效果比較
33
+ │ ├── compare_resolution.py # 輸入解析度比較
34
+ │ ├── compare_resume_scratch.py # Resume vs from-scratch 分析
35
+ │ ├── evaluate_holdout.py # Holdout 評估
36
+ │ ├── analyze_augmentations.py # 資料增強 ablation study
37
+ │ ├── analyze_distribution.py # 標籤分布互動式視覺化 (Flask)
38
+ │ ├── learning_curve.py # Learning curve 分析 + power-law 外推
39
+ │ ├── search_batch_lr.py # LR 超參數搜尋
40
+ │ └── profile_dataloader.py # DataLoader 效能分析
41
+ ├── rawimage/ # 訓練用原始圖片 (PNG)
42
+ ├── labels.json # 標註資料 {"rawimage/filename.png": [1,3]}
43
+ ├── checkpoint.pt # PyTorch checkpoint (resume 訓練用)
44
+ ├── pyproject.toml
45
+ └── README.md
46
+ ```
47
+
48
+ ## 標籤對照
49
+
50
+ | 編號 | 角色 |
51
+ |------|------|
52
+ | 1 | Twilight Sparkle |
53
+ | 2 | Rarity |
54
+ | 3 | Fluttershy |
55
+ | 4 | Rainbow Dash |
56
+ | 5 | Pinkie Pie |
57
+ | 6 | Applejack |
58
+
59
+ ## 安裝
60
+
61
+ ```bash
62
+ # 推論用 (hbrowser 會自動安裝)
63
+ pip install ponychart-classifier
64
+
65
+ # 開發用 (包含訓練依賴)
66
+ pip install -e ".[train]"
67
+ ```
68
+
69
+ ## 工作流程
70
+
71
+ ### 1. 收集圖片
72
+
73
+ 將新的 PonyChart 截圖 (PNG) 放入 `rawimage/` 資料夾。
74
+
75
+ ### 2. 標註圖片
76
+
77
+ ```bash
78
+ uv run python scripts/label_images.py
79
+ ```
80
+
81
+ 操作方式:
82
+ - `1`~`6`: 加/取消對應角色標籤
83
+ - `A` / `D`: 上一張 / 下一張
84
+ - `S`: 儲存目前標籤
85
+
86
+ 標註結果會即時更新到 `labels.json`。
87
+
88
+ ### 3. 訓練模型
89
+
90
+ ```bash
91
+ # 安裝訓練依賴 (只需一次)
92
+ uv pip install -e ".[train]"
93
+
94
+ # 執行訓練 (若存在 checkpoint.pt 則自動從上次結果繼續訓練)
95
+ uv run python scripts/train.py
96
+
97
+ # 強制從頭訓練 (忽略 checkpoint,從 ImageNet 預訓練權重開始)
98
+ uv run python scripts/train.py --from-scratch
99
+ ```
100
+
101
+ 訓練完成後會覆寫 `model.onnx`、`thresholds.json` 和 `checkpoint.pt`,下次推論自動使用新模型。
102
+
103
+ ### Resume 訓練
104
+
105
+ 新增圖片並標註後,直接執行 `train.py` 即可。腳本會自動偵測 `checkpoint.pt`:
106
+ - **有 checkpoint**: 載入之前的模型權重,跳過 Phase 1 (head-only),直接進入 Phase 2 fine-tuning,收斂更快
107
+ - **無 checkpoint**: 從 ImageNet 預訓練權重開始完整兩階段訓練
108
+
109
+ ### 訓練超參數
110
+
111
+ 所有超參數集中於 `src/ponychart_classifier/training/constants.py`,修改後對所有腳本生效:
112
+
113
+ | 參數 | 預設值 | 說明 |
114
+ |------|--------|------|
115
+ | `BACKBONE` | `efficientnet_b0` | 見下方支援的 backbone |
116
+ | `BATCH_SIZE` | 32 | 批次大小 |
117
+ | `SEED` | 42 | 隨機種子 |
118
+ | `PHASE1_EPOCHS` | 10 | Phase 1 (head-only) 訓練輪數 |
119
+ | `PHASE2_EPOCHS` | 100 | Phase 2 (full fine-tuning) 最大訓練輪數 |
120
+ | `PHASE2_PATIENCE` | 12 | Phase 2 early stopping patience |
121
+
122
+ ## 支援的 Backbone
123
+
124
+ | Backbone | 參數量 | ONNX 大小 | 說明 |
125
+ |----------|--------|-----------|------|
126
+ | `mobilenet_v3_small` | 2.5M | ~4MB | 輕量快速 |
127
+ | `mobilenet_v3_large` | 5.4M | ~9MB | 精度最高 |
128
+ | `efficientnet_b0` | 5.3M | ~11MB | 預設,精度接近 Large,但訓練較慢 |
129
+
130
+ 所有 backbone 都使用 ImageNet 預訓練權重 + transfer learning。
131
+ 推論端使用 ONNX Runtime,backbone 更換後只需重新匯出 `model.onnx`,推論程式碼不需改動。
132
+
133
+ ## 分析腳本
134
+
135
+ 分析腳本使用 `training/constants.py` 中的超參數設定:
136
+
137
+ ```bash
138
+ # 比較三種 backbone 的效果
139
+ uv run python scripts/compare_backbones.py
140
+
141
+ # 分析裁切圖片的影響
142
+ uv run python scripts/compare_crops.py
143
+
144
+ # 資料增強 ablation study
145
+ uv run python scripts/analyze_augmentations.py
146
+
147
+ # 標籤分布互動式視覺化 (Flask web UI)
148
+ uv run python scripts/analyze_distribution.py
149
+
150
+ # Learning curve 分析 (估算增加資料的邊際效益)
151
+ uv run python scripts/learning_curve.py
152
+
153
+ # LR 超參數搜尋
154
+ uv run python scripts/search_batch_lr.py
155
+ ```
156
+
157
+ ## 模型架構
158
+
159
+ - **Backbone**: 可選 MobileNetV3-Small/Large 或 EfficientNet-B0 (預設 EfficientNet-B0,ImageNet 預訓練)
160
+ - **訓練策略**: Phase 1 head-only + Phase 2 full fine-tuning,支援從 checkpoint 繼續訓練
161
+ - **輸出**: 6 個 sigmoid 節點 (多標籤分類)
162
+ - **推論引擎**: ONNX Runtime (CPU)
163
+ - **推論速度**: 3-21ms / 張
@@ -0,0 +1,62 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "ponychart-classifier"
7
+ version = "0.1.0"
8
+ description = "Multi-label image classifier for PonyChart character identification."
9
+ readme = "README.md"
10
+ authors = [{ name = "Kuan-Lun Wang" }]
11
+ license = { text = "GNU Affero General Public License v3" }
12
+ requires-python = ">=3.11,<3.14"
13
+ dependencies = [
14
+ "numpy>=2.2.6",
15
+ "opencv-python>=4.12.0.88",
16
+ "onnxruntime>=1.15.0",
17
+ ]
18
+
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Developers",
22
+ "Operating System :: OS Independent",
23
+ "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Programming Language :: Python :: 3.12",
26
+ "Programming Language :: Python :: 3.13",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ train = [
31
+ "torch>=2.0",
32
+ "torchvision>=0.15",
33
+ "scikit-learn>=1.3",
34
+ "onnx>=1.20.1",
35
+ "onnxscript>=0.6.2",
36
+ "flask>=3.0",
37
+ "psutil>=5.9",
38
+ "Pillow>=10.0",
39
+ ]
40
+ publish = [
41
+ "build",
42
+ "twine",
43
+ ]
44
+
45
+ [tool.setuptools]
46
+ include-package-data = true
47
+
48
+ [tool.setuptools.packages.find]
49
+ where = ["src"]
50
+
51
+ [tool.setuptools.package-data]
52
+ "ponychart_classifier" = [
53
+ "py.typed",
54
+ "model.onnx",
55
+ "thresholds.json",
56
+ ]
57
+
58
+ [tool.ruff]
59
+ line-length = 88
60
+
61
+ [tool.ruff.lint]
62
+ select = ["E", "F", "I", "UP"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,27 @@
1
+ """PonyChart classifier -- inference constants and prediction utilities."""
2
+
3
+ from .inference import PonyChartClassifier, predict, preload
4
+ from .model_spec import (
5
+ CLASS_NAMES,
6
+ IMAGENET_MEAN,
7
+ IMAGENET_STD,
8
+ INPUT_SIZE,
9
+ MAX_K,
10
+ NUM_CLASSES,
11
+ PRE_RESIZE,
12
+ select_predictions,
13
+ )
14
+
15
+ __all__ = [
16
+ "CLASS_NAMES",
17
+ "IMAGENET_MEAN",
18
+ "IMAGENET_STD",
19
+ "INPUT_SIZE",
20
+ "MAX_K",
21
+ "NUM_CLASSES",
22
+ "PRE_RESIZE",
23
+ "PonyChartClassifier",
24
+ "predict",
25
+ "preload",
26
+ "select_predictions",
27
+ ]
@@ -0,0 +1,114 @@
1
+ """High-level inference API for PonyChart character classification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import sys
8
+ from typing import Any
9
+
10
+ import cv2 as cv
11
+ import numpy as np
12
+ import onnxruntime as ort
13
+
14
+ from .model_spec import (
15
+ CLASS_NAMES,
16
+ IMAGENET_MEAN,
17
+ IMAGENET_STD,
18
+ INPUT_SIZE,
19
+ PRE_RESIZE,
20
+ select_predictions,
21
+ )
22
+
23
+ _IMAGENET_MEAN = np.array(IMAGENET_MEAN, dtype=np.float32)
24
+ _IMAGENET_STD = np.array(IMAGENET_STD, dtype=np.float32)
25
+
26
+
27
+ def _package_dir() -> str:
28
+ return os.path.dirname(__file__)
29
+
30
+
31
+ class PonyChartClassifier:
32
+ """Lazy-loading ONNX classifier for PonyChart images."""
33
+
34
+ def __init__(self) -> None:
35
+ self._loaded = False
36
+ self._session: Any = None
37
+ self._classes: list[str] = list(CLASS_NAMES)
38
+ self._thresholds: dict[str, float] = {}
39
+
40
+ def load(self) -> None:
41
+ """Load the ONNX model and thresholds. Safe to call multiple times."""
42
+ if self._loaded:
43
+ return
44
+
45
+ d = _package_dir()
46
+ model_path = os.path.join(d, "model.onnx")
47
+ th_path = os.path.join(d, "thresholds.json")
48
+ self._session = ort.InferenceSession(
49
+ model_path, providers=["CPUExecutionProvider"]
50
+ )
51
+ with open(th_path, encoding="utf-8") as f:
52
+ self._thresholds = json.load(f)
53
+ self._loaded = True
54
+
55
+ def _preprocess(self, bgr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
56
+ """BGR image -> NCHW float32 tensor (matching training transforms)."""
57
+ resized = cv.resize(bgr, (PRE_RESIZE, PRE_RESIZE), interpolation=cv.INTER_AREA)
58
+ offset = (PRE_RESIZE - INPUT_SIZE) // 2
59
+ cropped = resized[offset : offset + INPUT_SIZE, offset : offset + INPUT_SIZE]
60
+ rgb = cv.cvtColor(cropped, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
61
+ normalized = (rgb - _IMAGENET_MEAN) / _IMAGENET_STD
62
+ # HWC -> CHW -> NCHW
63
+ return normalized.transpose(2, 0, 1)[np.newaxis, ...].astype(np.float32)
64
+
65
+ def predict(
66
+ self, img_path: str, min_k: int = 1, max_k: int = 3
67
+ ) -> tuple[list[str], dict[str, float]]:
68
+ """Predict characters in a PonyChart image.
69
+
70
+ Returns ``(picked_names, scores)`` where *picked_names* is a list of
71
+ selected character names and *scores* maps every class name to its
72
+ sigmoid probability.
73
+ """
74
+ self.load()
75
+ img = cv.imread(img_path, cv.IMREAD_COLOR)
76
+ if img is None:
77
+ raise RuntimeError(f"Cannot read image: {img_path}")
78
+
79
+ input_tensor = self._preprocess(img)
80
+ input_name: str = self._session.get_inputs()[0].name
81
+ logits = self._session.run(None, {input_name: input_tensor})[0]
82
+ probs = 1.0 / (1.0 + np.exp(-logits[0]))
83
+
84
+ scores = {self._classes[i]: float(probs[i]) for i in range(len(self._classes))}
85
+ thresholds = [self._thresholds.get(c, 0.5) for c in self._classes]
86
+ indices = select_predictions(list(probs), thresholds, min_k=min_k, max_k=max_k)
87
+ picked = [self._classes[i] for i in indices]
88
+ return picked, scores
89
+
90
+
91
+ _default_classifier = PonyChartClassifier()
92
+
93
+
94
+ def predict(
95
+ img_path: str, min_k: int = 1, max_k: int = 3
96
+ ) -> tuple[list[str], dict[str, float]]:
97
+ """Predict characters using the default classifier instance."""
98
+ return _default_classifier.predict(img_path, min_k=min_k, max_k=max_k)
99
+
100
+
101
+ def preload() -> None:
102
+ """Pre-load the ONNX model to catch dependency issues early."""
103
+ try:
104
+ _default_classifier.load()
105
+ except ImportError as e:
106
+ msg = "onnxruntime failed to load."
107
+ if sys.platform == "win32" and "DLL load failed" in str(e):
108
+ msg += (
109
+ "\nPossible cause: missing Microsoft Visual C++ Redistributable."
110
+ "\nDownload from https://aka.ms/vs/17/release/vc_redist.x64.exe"
111
+ )
112
+ else:
113
+ msg += "\nPlease install: pip install onnxruntime"
114
+ raise RuntimeError(msg) from e