mysphinx-forge 0.2.3__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mysphinx_forge-0.2.3/mysphinx_forge.egg-info → mysphinx_forge-0.3.0}/PKG-INFO +41 -1
- mysphinx_forge-0.2.3/PKG-INFO → mysphinx_forge-0.3.0/README.md +36 -37
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cli.py +313 -1
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/config.py +2 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/model_testing.py +51 -1
- mysphinx_forge-0.3.0/mysphinx_forge/model_training.py +415 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/templates/mysphinx-forge.yaml +57 -1
- mysphinx_forge-0.2.3/README.md → mysphinx_forge-0.3.0/mysphinx_forge.egg-info/PKG-INFO +77 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/SOURCES.txt +4 -1
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/requires.txt +5 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/pyproject.toml +6 -1
- mysphinx_forge-0.3.0/tests/test_model_training.py +69 -0
- mysphinx_forge-0.3.0/tests/test_train_cli.py +151 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/__init__.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cleaning.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cluster_labeling.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cluster_reporting.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/clustering.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/deduplication.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/embedding.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/env_utils.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/file_io.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/http_client.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/logging_utils.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/model_eval.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/openai_responses.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/progress.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/semantic_deduplication.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/sft_dataset.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/splitting.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/templates/__init__.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/dependency_links.txt +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/entry_points.txt +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/top_level.txt +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/setup.cfg +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cleaning.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cli.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cluster_labeling.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cluster_reporting.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_clustering.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_deduplication.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_file_io.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_http_client.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_model_eval.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_model_testing.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_semantic_deduplication.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_sft_cli.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_sft_dataset.py +0 -0
- {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_splitting.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mysphinx-forge
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
|
|
5
5
|
Keywords: data-cleaning,deduplication,clustering,nlp,cli
|
|
6
6
|
Classifier: Development Status :: 3 - Alpha
|
|
@@ -27,6 +27,10 @@ Requires-Dist: sentence-transformers>=5.1.0; extra == "embeddings"
|
|
|
27
27
|
Provides-Extra: llm-local
|
|
28
28
|
Requires-Dist: torch>=2.8.0; extra == "llm-local"
|
|
29
29
|
Requires-Dist: transformers>=4.55.0; extra == "llm-local"
|
|
30
|
+
Provides-Extra: train
|
|
31
|
+
Requires-Dist: torch>=2.8.0; extra == "train"
|
|
32
|
+
Requires-Dist: transformers>=4.55.0; extra == "train"
|
|
33
|
+
Requires-Dist: scikit-learn>=1.7.0; extra == "train"
|
|
30
34
|
Provides-Extra: all
|
|
31
35
|
Requires-Dist: numpy>=2.2.0; extra == "all"
|
|
32
36
|
Requires-Dist: scikit-learn>=1.7.0; extra == "all"
|
|
@@ -66,6 +70,7 @@ Requires-Dist: transformers>=4.55.0; extra == "all"
|
|
|
66
70
|
| `split` | 切分 train / valid / test | `*_split_train.*` 等 |
|
|
67
71
|
| `model-test` | 批量执行模型推理或单条烟雾测试,含预期结果列时自动输出评估报告 | `*_model_tested.*` 或终端输出 |
|
|
68
72
|
| `convert-sft` | 转换表格数据为 SFT 数据(`alpaca` / `pa` 格式) | `*_alpaca.json` / `*_pa.jsonl`(超 10000 条自动切分) |
|
|
73
|
+
| `train` | 微调本地预训练分类基座(BERT 等)为意图分类模型,含验证集时自动输出评估报告 | `output_models/<基模名>-<日期>/` + `*_train_eval.csv` |
|
|
69
74
|
|
|
70
75
|
## 项目结构
|
|
71
76
|
|
|
@@ -128,6 +133,7 @@ uv sync --extra all --group dev
|
|
|
128
133
|
- `embeddings`:语义去重
|
|
129
134
|
- `ml + embeddings`:聚类
|
|
130
135
|
- `llm-local`:本地模型 `model-test`
|
|
136
|
+
- `train`:意图分类模型 `train`(BERT 等基座),以及分类模型的 `model-test`(`local` 模式自动识别)
|
|
131
137
|
- `all`:安装全部能力
|
|
132
138
|
|
|
133
139
|
### 从源码安装当前项目
|
|
@@ -137,6 +143,7 @@ uv pip install -e .
|
|
|
137
143
|
uv pip install -e '.[embeddings]'
|
|
138
144
|
uv pip install -e '.[ml,embeddings]'
|
|
139
145
|
uv pip install -e '.[llm-local]'
|
|
146
|
+
uv pip install -e '.[train]'
|
|
140
147
|
uv pip install -e '.[all]'
|
|
141
148
|
```
|
|
142
149
|
|
|
@@ -654,6 +661,38 @@ F1,0.75000,0.75000,0.75000,0.75000,8
|
|
|
654
661
|
|
|
655
662
|
本地批量测试会按可见 GPU 数自动分配 worker;没有 GPU 时自动退化为单 worker CPU 模式。
|
|
656
663
|
|
|
664
|
+
### 训练意图分类模型
|
|
665
|
+
|
|
666
|
+
`train` 基于一个本地预训练分类基座(默认 BERT,也兼容其它带 `*ForSequenceClassification` 实现的 encoder 架构),微调出意图分类模型。输入表需要同时包含文本列(默认自动探测 `text/用户问题/客户问题/用户输入`)和标签列(默认自动探测 `category/label/intent/.../预期结果`)。需要先安装训练依赖:
|
|
667
|
+
|
|
668
|
+
```bash
|
|
669
|
+
uv sync --extra train --group dev
|
|
670
|
+
```
|
|
671
|
+
|
|
672
|
+
```bash
|
|
673
|
+
# 直接在切分后的训练集上微调(默认基础模型 models/bert-base-chinese)
|
|
674
|
+
mysphinx-forge --action train --input-file data/input_split_train.xlsx --base-model-path models/bert-base-chinese
|
|
675
|
+
```
|
|
676
|
+
|
|
677
|
+
验证集的选取优先级:`--valid-file` 指定的独立文件 > Excel 中的 `valid` 注入表 > 按 `--validation-ratio` 从训练集随机切分。存在验证集时会输出评估报告:
|
|
678
|
+
|
|
679
|
+
```bash
|
|
680
|
+
mysphinx-forge --action train --input-file data/input_split_train.xlsx \
|
|
681
|
+
--valid-file data/input_split_valid.xlsx \
|
|
682
|
+
--label-column category --num-train-epochs 3 --learning-rate 2e-5 --train-batch-size 16
|
|
683
|
+
```
|
|
684
|
+
|
|
685
|
+
常用训练超参:`--num-train-epochs`、`--learning-rate`、`--train-batch-size`、`--max-length`、`--weight-decay`、`--warmup-ratio`、`--train-seed`。
|
|
686
|
+
|
|
687
|
+
默认情况下,训练产物会写到当前工作目录下的 `output_models/`,目录按「基模名-日期」命名(例如 `output_models/bert-base-chinese-20260624`);同一天重复训练会自动追加序号(`-2`、`-3`……,首次不带序号)。可用 `--output-dir` 显式指定模型目录路径覆盖该默认行为。
|
|
688
|
+
|
|
689
|
+
模型目录是一个标准的 transformers 模型目录(含权重、tokenizer 和 `label_map.json`),可直接用 `model-test`(`local` 模式)在测试集上复跑评估——检测到目录内的 `label_map.json` 时会自动按文本分类模型推理,无需额外指定:
|
|
690
|
+
|
|
691
|
+
```bash
|
|
692
|
+
mysphinx-forge --action model-test --input-file data/input_split_test.xlsx \
|
|
693
|
+
--test-model-path output_models/bert-base-chinese-20260624
|
|
694
|
+
```
|
|
695
|
+
|
|
657
696
|
## 输出文件规则
|
|
658
697
|
|
|
659
698
|
| Action | 主要输出 | 附加输出 |
|
|
@@ -665,6 +704,7 @@ F1,0.75000,0.75000,0.75000,0.75000,8
|
|
|
665
704
|
| `split` | `*_split_train.*`、`*_split_valid.*`、`*_split_test.*` | `*_split.meta.json`、`mysphinx-forge.log` |
|
|
666
705
|
| `model-test` 文件模式 | `*_model_tested.*` | 含 `预期结果` 列时额外生成 `*_model_tested_eval.csv`;同时写 `mysphinx-forge.log` |
|
|
667
706
|
| `model-test` 单条模式 | 终端输出 | 当前工作目录下的 `mysphinx-forge.log` |
|
|
707
|
+
| `train` | `output_models/<基模名>-<日期>/`(模型目录,可用 `--output-dir` 覆盖) | 含验证集时额外生成同级 `<模型目录名>_train_eval.csv`;同时写同级 `<模型目录名>.meta.json` 与 `output_models/mysphinx-forge.log` |
|
|
668
708
|
|
|
669
709
|
补充说明:
|
|
670
710
|
|
|
@@ -1,40 +1,3 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: mysphinx-forge
|
|
3
|
-
Version: 0.2.3
|
|
4
|
-
Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
|
|
5
|
-
Keywords: data-cleaning,deduplication,clustering,nlp,cli
|
|
6
|
-
Classifier: Development Status :: 3 - Alpha
|
|
7
|
-
Classifier: Intended Audience :: Developers
|
|
8
|
-
Classifier: Intended Audience :: Science/Research
|
|
9
|
-
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
13
|
-
Classifier: Environment :: Console
|
|
14
|
-
Requires-Python: >=3.12
|
|
15
|
-
Description-Content-Type: text/markdown
|
|
16
|
-
Requires-Dist: pandas>=2.2.3
|
|
17
|
-
Requires-Dist: tqdm>=4.67.1
|
|
18
|
-
Requires-Dist: openpyxl>=3.1.5
|
|
19
|
-
Requires-Dist: xlrd>=2.0.1
|
|
20
|
-
Requires-Dist: pyyaml>=6.0.0
|
|
21
|
-
Provides-Extra: ml
|
|
22
|
-
Requires-Dist: numpy>=2.2.0; extra == "ml"
|
|
23
|
-
Requires-Dist: scikit-learn>=1.7.0; extra == "ml"
|
|
24
|
-
Provides-Extra: embeddings
|
|
25
|
-
Requires-Dist: faiss-cpu>=1.11.0; extra == "embeddings"
|
|
26
|
-
Requires-Dist: sentence-transformers>=5.1.0; extra == "embeddings"
|
|
27
|
-
Provides-Extra: llm-local
|
|
28
|
-
Requires-Dist: torch>=2.8.0; extra == "llm-local"
|
|
29
|
-
Requires-Dist: transformers>=4.55.0; extra == "llm-local"
|
|
30
|
-
Provides-Extra: all
|
|
31
|
-
Requires-Dist: numpy>=2.2.0; extra == "all"
|
|
32
|
-
Requires-Dist: scikit-learn>=1.7.0; extra == "all"
|
|
33
|
-
Requires-Dist: faiss-cpu>=1.11.0; extra == "all"
|
|
34
|
-
Requires-Dist: sentence-transformers>=5.1.0; extra == "all"
|
|
35
|
-
Requires-Dist: torch>=2.8.0; extra == "all"
|
|
36
|
-
Requires-Dist: transformers>=4.55.0; extra == "all"
|
|
37
|
-
|
|
38
1
|
# MySphinx Forge
|
|
39
2
|
|
|
40
3
|
`MySphinx Forge` 是一个面向表格文本数据的 Python CLI 工具,重点解决语料预处理和模型验证这两类工作:
|
|
@@ -66,6 +29,7 @@ Requires-Dist: transformers>=4.55.0; extra == "all"
|
|
|
66
29
|
| `split` | 切分 train / valid / test | `*_split_train.*` 等 |
|
|
67
30
|
| `model-test` | 批量执行模型推理或单条烟雾测试,含预期结果列时自动输出评估报告 | `*_model_tested.*` 或终端输出 |
|
|
68
31
|
| `convert-sft` | 转换表格数据为 SFT 数据(`alpaca` / `pa` 格式) | `*_alpaca.json` / `*_pa.jsonl`(超 10000 条自动切分) |
|
|
32
|
+
| `train` | 微调本地预训练分类基座(BERT 等)为意图分类模型,含验证集时自动输出评估报告 | `output_models/<基模名>-<日期>/` + `*_train_eval.csv` |
|
|
69
33
|
|
|
70
34
|
## 项目结构
|
|
71
35
|
|
|
@@ -128,6 +92,7 @@ uv sync --extra all --group dev
|
|
|
128
92
|
- `embeddings`:语义去重
|
|
129
93
|
- `ml + embeddings`:聚类
|
|
130
94
|
- `llm-local`:本地模型 `model-test`
|
|
95
|
+
- `train`:意图分类模型 `train`(BERT 等基座),以及分类模型的 `model-test`(`local` 模式自动识别)
|
|
131
96
|
- `all`:安装全部能力
|
|
132
97
|
|
|
133
98
|
### 从源码安装当前项目
|
|
@@ -137,6 +102,7 @@ uv pip install -e .
|
|
|
137
102
|
uv pip install -e '.[embeddings]'
|
|
138
103
|
uv pip install -e '.[ml,embeddings]'
|
|
139
104
|
uv pip install -e '.[llm-local]'
|
|
105
|
+
uv pip install -e '.[train]'
|
|
140
106
|
uv pip install -e '.[all]'
|
|
141
107
|
```
|
|
142
108
|
|
|
@@ -654,6 +620,38 @@ F1,0.75000,0.75000,0.75000,0.75000,8
|
|
|
654
620
|
|
|
655
621
|
本地批量测试会按可见 GPU 数自动分配 worker;没有 GPU 时自动退化为单 worker CPU 模式。
|
|
656
622
|
|
|
623
|
+
### 训练意图分类模型
|
|
624
|
+
|
|
625
|
+
`train` 基于一个本地预训练分类基座(默认 BERT,也兼容其它带 `*ForSequenceClassification` 实现的 encoder 架构),微调出意图分类模型。输入表需要同时包含文本列(默认自动探测 `text/用户问题/客户问题/用户输入`)和标签列(默认自动探测 `category/label/intent/.../预期结果`)。需要先安装训练依赖:
|
|
626
|
+
|
|
627
|
+
```bash
|
|
628
|
+
uv sync --extra train --group dev
|
|
629
|
+
```
|
|
630
|
+
|
|
631
|
+
```bash
|
|
632
|
+
# 直接在切分后的训练集上微调(默认基础模型 models/bert-base-chinese)
|
|
633
|
+
mysphinx-forge --action train --input-file data/input_split_train.xlsx --base-model-path models/bert-base-chinese
|
|
634
|
+
```
|
|
635
|
+
|
|
636
|
+
验证集的选取优先级:`--valid-file` 指定的独立文件 > Excel 中的 `valid` 注入表 > 按 `--validation-ratio` 从训练集随机切分。存在验证集时会输出评估报告:
|
|
637
|
+
|
|
638
|
+
```bash
|
|
639
|
+
mysphinx-forge --action train --input-file data/input_split_train.xlsx \
|
|
640
|
+
--valid-file data/input_split_valid.xlsx \
|
|
641
|
+
--label-column category --num-train-epochs 3 --learning-rate 2e-5 --train-batch-size 16
|
|
642
|
+
```
|
|
643
|
+
|
|
644
|
+
常用训练超参:`--num-train-epochs`、`--learning-rate`、`--train-batch-size`、`--max-length`、`--weight-decay`、`--warmup-ratio`、`--train-seed`。
|
|
645
|
+
|
|
646
|
+
默认情况下,训练产物会写到当前工作目录下的 `output_models/`,目录按「基模名-日期」命名(例如 `output_models/bert-base-chinese-20260624`);同一天重复训练会自动追加序号(`-2`、`-3`……,首次不带序号)。可用 `--output-dir` 显式指定模型目录路径覆盖该默认行为。
|
|
647
|
+
|
|
648
|
+
模型目录是一个标准的 transformers 模型目录(含权重、tokenizer 和 `label_map.json`),可直接用 `model-test`(`local` 模式)在测试集上复跑评估——检测到目录内的 `label_map.json` 时会自动按文本分类模型推理,无需额外指定:
|
|
649
|
+
|
|
650
|
+
```bash
|
|
651
|
+
mysphinx-forge --action model-test --input-file data/input_split_test.xlsx \
|
|
652
|
+
--test-model-path output_models/bert-base-chinese-20260624
|
|
653
|
+
```
|
|
654
|
+
|
|
657
655
|
## 输出文件规则
|
|
658
656
|
|
|
659
657
|
| Action | 主要输出 | 附加输出 |
|
|
@@ -665,6 +663,7 @@ F1,0.75000,0.75000,0.75000,0.75000,8
|
|
|
665
663
|
| `split` | `*_split_train.*`、`*_split_valid.*`、`*_split_test.*` | `*_split.meta.json`、`mysphinx-forge.log` |
|
|
666
664
|
| `model-test` 文件模式 | `*_model_tested.*` | 含 `预期结果` 列时额外生成 `*_model_tested_eval.csv`;同时写 `mysphinx-forge.log` |
|
|
667
665
|
| `model-test` 单条模式 | 终端输出 | 当前工作目录下的 `mysphinx-forge.log` |
|
|
666
|
+
| `train` | `output_models/<基模名>-<日期>/`(模型目录,可用 `--output-dir` 覆盖) | 含验证集时额外生成同级 `<模型目录名>_train_eval.csv`;同时写同级 `<模型目录名>.meta.json` 与 `output_models/mysphinx-forge.log` |
|
|
668
667
|
|
|
669
668
|
补充说明:
|
|
670
669
|
|
|
@@ -56,6 +56,12 @@ from mysphinx_forge.model_testing import (
|
|
|
56
56
|
model_test_dataframe,
|
|
57
57
|
run_model_test,
|
|
58
58
|
)
|
|
59
|
+
from mysphinx_forge.model_training import (
|
|
60
|
+
ClassifierTrainingConfig,
|
|
61
|
+
DEFAULT_BASE_MODEL_PATH,
|
|
62
|
+
TextClassificationStats,
|
|
63
|
+
train_intent_classifier,
|
|
64
|
+
)
|
|
59
65
|
from mysphinx_forge.progress import ProgressBar, run_stage
|
|
60
66
|
from mysphinx_forge.semantic_deduplication import (
|
|
61
67
|
DEFAULT_EMBEDDING_MODEL_PATH,
|
|
@@ -87,6 +93,7 @@ _ACTION_CHOICES = [
|
|
|
87
93
|
"model-test",
|
|
88
94
|
"split",
|
|
89
95
|
"convert-sft",
|
|
96
|
+
"train",
|
|
90
97
|
]
|
|
91
98
|
|
|
92
99
|
|
|
@@ -171,6 +178,24 @@ def main() -> int:
|
|
|
171
178
|
if args.repetition_penalty <= 0:
|
|
172
179
|
print("--repetition-penalty 必须是大于 0 的数值。")
|
|
173
180
|
return 1
|
|
181
|
+
if args.num_train_epochs <= 0:
|
|
182
|
+
print("--num-train-epochs 必须是大于 0 的整数。")
|
|
183
|
+
return 1
|
|
184
|
+
if args.learning_rate <= 0:
|
|
185
|
+
print("--learning-rate 必须是大于 0 的数值。")
|
|
186
|
+
return 1
|
|
187
|
+
if args.train_batch_size <= 0:
|
|
188
|
+
print("--train-batch-size 必须是大于 0 的整数。")
|
|
189
|
+
return 1
|
|
190
|
+
if args.max_length <= 0:
|
|
191
|
+
print("--max-length 必须是大于 0 的整数。")
|
|
192
|
+
return 1
|
|
193
|
+
if args.weight_decay < 0:
|
|
194
|
+
print("--weight-decay 不能小于 0。")
|
|
195
|
+
return 1
|
|
196
|
+
if not 0 <= args.warmup_ratio < 1:
|
|
197
|
+
print("--warmup-ratio 必须在 0 到 1 之间,且不能等于 1。")
|
|
198
|
+
return 1
|
|
174
199
|
if args.split_random_seed < 0:
|
|
175
200
|
print("--split-random-seed 不能小于 0。")
|
|
176
201
|
return 1
|
|
@@ -315,6 +340,23 @@ def main() -> int:
|
|
|
315
340
|
args.sft_user_query_as_instruction,
|
|
316
341
|
args.sft_max_records_per_file,
|
|
317
342
|
)
|
|
343
|
+
if args.action == "train":
|
|
344
|
+
return _run_train(
|
|
345
|
+
input_file=args.input_file,
|
|
346
|
+
output_arg=args.output_dir,
|
|
347
|
+
target_column=args.target_column,
|
|
348
|
+
label_column=args.label_column,
|
|
349
|
+
valid_file=args.valid_file or None,
|
|
350
|
+
base_model_path=args.base_model_path,
|
|
351
|
+
num_train_epochs=args.num_train_epochs,
|
|
352
|
+
learning_rate=args.learning_rate,
|
|
353
|
+
train_batch_size=args.train_batch_size,
|
|
354
|
+
max_length=args.max_length,
|
|
355
|
+
weight_decay=args.weight_decay,
|
|
356
|
+
warmup_ratio=args.warmup_ratio,
|
|
357
|
+
train_seed=args.train_seed,
|
|
358
|
+
validation_ratio=args.validation_ratio,
|
|
359
|
+
)
|
|
318
360
|
|
|
319
361
|
parser.print_help()
|
|
320
362
|
return 1
|
|
@@ -416,6 +458,12 @@ def _build_parser(
|
|
|
416
458
|
default=config_defaults.get("output"),
|
|
417
459
|
help="输出文件路径。未指定时,默认在原文件旁生成 *_cleaned 文件。",
|
|
418
460
|
)
|
|
461
|
+
parser.add_argument(
|
|
462
|
+
"--output-dir",
|
|
463
|
+
dest="output_dir",
|
|
464
|
+
default=config_defaults.get("output_dir"),
|
|
465
|
+
help="train 模型输出目录。未指定时默认写到 cwd 下的 output_models/,按「基模名-日期」命名。",
|
|
466
|
+
)
|
|
419
467
|
parser.add_argument(
|
|
420
468
|
"--chunk-size",
|
|
421
469
|
type=int,
|
|
@@ -458,7 +506,7 @@ def _build_parser(
|
|
|
458
506
|
"--model-test-mode",
|
|
459
507
|
choices=["local", "openai", "http"],
|
|
460
508
|
default=config_defaults.get("model_test_mode", DEFAULT_MODEL_TEST_MODE),
|
|
461
|
-
help="模型测试模式。local 为本地模型推理,openai 为兼容 OpenAI Chat Completions 的接口调用,http 为通用 HTTP POST 接口调用。",
|
|
509
|
+
help="模型测试模式。local 为本地模型推理,openai 为兼容 OpenAI Chat Completions 的接口调用,http 为通用 HTTP POST 接口调用。local 模式下若模型目录含 label_map.json,会自动按文本分类模型推理。",
|
|
462
510
|
)
|
|
463
511
|
parser.add_argument(
|
|
464
512
|
"--test-model-path",
|
|
@@ -722,6 +770,64 @@ def _build_parser(
|
|
|
722
770
|
default=config_defaults.get("sft_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
|
|
723
771
|
help=f"pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件,默认 {PA_MAX_RECORDS_PER_FILE}。",
|
|
724
772
|
)
|
|
773
|
+
parser.add_argument(
|
|
774
|
+
"--base-model-path",
|
|
775
|
+
default=config_defaults.get("base_model_path", DEFAULT_BASE_MODEL_PATH),
|
|
776
|
+
help=f"train 训练使用的本地预训练分类基座模型路径(BERT 等 encoder 架构),默认 {DEFAULT_BASE_MODEL_PATH}。",
|
|
777
|
+
)
|
|
778
|
+
parser.add_argument(
|
|
779
|
+
"--label-column",
|
|
780
|
+
default=config_defaults.get("label_column", ""),
|
|
781
|
+
help="train 训练使用的标签列名。未指定时自动探测 category/label/intent/.../预期结果。",
|
|
782
|
+
)
|
|
783
|
+
parser.add_argument(
|
|
784
|
+
"--valid-file",
|
|
785
|
+
dest="valid_file",
|
|
786
|
+
default=config_defaults.get("valid_file", ""),
|
|
787
|
+
help="train 训练使用的独立验证集文件。未指定时优先使用 valid 注入表,否则按 --validation-ratio 从训练集切分。",
|
|
788
|
+
)
|
|
789
|
+
parser.add_argument(
|
|
790
|
+
"--num-train-epochs",
|
|
791
|
+
type=int,
|
|
792
|
+
default=config_defaults.get("num_train_epochs", 3),
|
|
793
|
+
help="train 训练轮数,默认 3。",
|
|
794
|
+
)
|
|
795
|
+
parser.add_argument(
|
|
796
|
+
"--learning-rate",
|
|
797
|
+
type=float,
|
|
798
|
+
default=config_defaults.get("learning_rate", 2e-5),
|
|
799
|
+
help="train 训练学习率,默认 2e-5。",
|
|
800
|
+
)
|
|
801
|
+
parser.add_argument(
|
|
802
|
+
"--train-batch-size",
|
|
803
|
+
type=int,
|
|
804
|
+
default=config_defaults.get("train_batch_size", 16),
|
|
805
|
+
help="train 单设备训练批大小,默认 16。",
|
|
806
|
+
)
|
|
807
|
+
parser.add_argument(
|
|
808
|
+
"--max-length",
|
|
809
|
+
type=int,
|
|
810
|
+
default=config_defaults.get("max_length", 128),
|
|
811
|
+
help="train 文本分词的最大长度,默认 128。",
|
|
812
|
+
)
|
|
813
|
+
parser.add_argument(
|
|
814
|
+
"--weight-decay",
|
|
815
|
+
type=float,
|
|
816
|
+
default=config_defaults.get("weight_decay", 0.01),
|
|
817
|
+
help="train 训练权重衰减系数,默认 0.01。",
|
|
818
|
+
)
|
|
819
|
+
parser.add_argument(
|
|
820
|
+
"--warmup-ratio",
|
|
821
|
+
type=float,
|
|
822
|
+
default=config_defaults.get("warmup_ratio", 0.1),
|
|
823
|
+
help="train 训练学习率预热比例,默认 0.1。",
|
|
824
|
+
)
|
|
825
|
+
parser.add_argument(
|
|
826
|
+
"--train-seed",
|
|
827
|
+
type=int,
|
|
828
|
+
default=config_defaults.get("train_seed", 42),
|
|
829
|
+
help="train 训练随机种子,默认 42。",
|
|
830
|
+
)
|
|
725
831
|
return parser
|
|
726
832
|
|
|
727
833
|
|
|
@@ -1918,6 +2024,156 @@ def _run_convert_sft(
|
|
|
1918
2024
|
return 0
|
|
1919
2025
|
|
|
1920
2026
|
|
|
2027
|
+
def _run_train(
|
|
2028
|
+
input_file: str,
|
|
2029
|
+
output_arg: str | None,
|
|
2030
|
+
target_column: str,
|
|
2031
|
+
label_column: str,
|
|
2032
|
+
valid_file: str | None,
|
|
2033
|
+
base_model_path: str,
|
|
2034
|
+
num_train_epochs: int,
|
|
2035
|
+
learning_rate: float,
|
|
2036
|
+
train_batch_size: int,
|
|
2037
|
+
max_length: int,
|
|
2038
|
+
weight_decay: float,
|
|
2039
|
+
warmup_ratio: float,
|
|
2040
|
+
train_seed: int,
|
|
2041
|
+
validation_ratio: float,
|
|
2042
|
+
) -> int:
|
|
2043
|
+
input_path = Path(input_file)
|
|
2044
|
+
output_dir = _resolve_train_output_dir(output_arg, base_model_path)
|
|
2045
|
+
# 评估报告与模型目录同级,保证整套训练产物自成一体地落在 output_models/ 下。
|
|
2046
|
+
eval_csv_path = output_dir.with_name(f"{output_dir.name}_train_eval.csv")
|
|
2047
|
+
logger = configure_logger(_resolve_log_path(output_dir))
|
|
2048
|
+
logger.info(
|
|
2049
|
+
"开始执行 action=train input=%s output=%s base_model=%s",
|
|
2050
|
+
input_path,
|
|
2051
|
+
output_dir,
|
|
2052
|
+
base_model_path,
|
|
2053
|
+
)
|
|
2054
|
+
|
|
2055
|
+
try:
|
|
2056
|
+
run_stage("读取文件", logger=logger)
|
|
2057
|
+
main_df, _train_inject_df, valid_inject_df, _test_inject_df = load_split_dataframes(input_file)
|
|
2058
|
+
valid_explicit_df = load_dataframe(valid_file) if valid_file else None
|
|
2059
|
+
train_df, valid_df = _resolve_train_valid_frames(
|
|
2060
|
+
main_df=main_df,
|
|
2061
|
+
valid_inject_df=valid_inject_df,
|
|
2062
|
+
valid_explicit_df=valid_explicit_df,
|
|
2063
|
+
validation_ratio=validation_ratio,
|
|
2064
|
+
seed=train_seed,
|
|
2065
|
+
)
|
|
2066
|
+
except ValueError as exc:
|
|
2067
|
+
_emit_error(str(exc), logger)
|
|
2068
|
+
close_logger()
|
|
2069
|
+
return 1
|
|
2070
|
+
|
|
2071
|
+
config = ClassifierTrainingConfig(
|
|
2072
|
+
base_model_path=base_model_path,
|
|
2073
|
+
num_train_epochs=num_train_epochs,
|
|
2074
|
+
learning_rate=learning_rate,
|
|
2075
|
+
per_device_batch_size=train_batch_size,
|
|
2076
|
+
max_length=max_length,
|
|
2077
|
+
weight_decay=weight_decay,
|
|
2078
|
+
warmup_ratio=warmup_ratio,
|
|
2079
|
+
seed=train_seed,
|
|
2080
|
+
)
|
|
2081
|
+
|
|
2082
|
+
total_steps = num_train_epochs * math.ceil(max(len(train_df), 1) / train_batch_size)
|
|
2083
|
+
progress_bar = ProgressBar(total=total_steps, description="训练模型", logger=logger)
|
|
2084
|
+
try:
|
|
2085
|
+
stats, eval_report = train_intent_classifier(
|
|
2086
|
+
train_df,
|
|
2087
|
+
valid_df,
|
|
2088
|
+
config=config,
|
|
2089
|
+
target_column=target_column,
|
|
2090
|
+
label_column=label_column,
|
|
2091
|
+
model_output_dir=output_dir,
|
|
2092
|
+
progress_callback=progress_bar.advance,
|
|
2093
|
+
)
|
|
2094
|
+
except ValueError as exc:
|
|
2095
|
+
progress_bar.close()
|
|
2096
|
+
_emit_error(str(exc), logger)
|
|
2097
|
+
close_logger()
|
|
2098
|
+
return 1
|
|
2099
|
+
except Exception as exc:
|
|
2100
|
+
progress_bar.close()
|
|
2101
|
+
logger.exception("执行模型训练失败")
|
|
2102
|
+
_emit_error(f"执行模型训练失败:{type(exc).__name__}: {exc}", logger)
|
|
2103
|
+
close_logger()
|
|
2104
|
+
return 1
|
|
2105
|
+
finally:
|
|
2106
|
+
progress_bar.close()
|
|
2107
|
+
|
|
2108
|
+
run_stage("写出结果", logger=logger)
|
|
2109
|
+
if eval_report is not None:
|
|
2110
|
+
eval_csv_path.write_text(eval_report.format_csv(), encoding="utf-8")
|
|
2111
|
+
_write_meta(
|
|
2112
|
+
output_path=output_dir,
|
|
2113
|
+
action="train",
|
|
2114
|
+
input_path=input_path,
|
|
2115
|
+
parameters={
|
|
2116
|
+
"target_column": stats.target_column,
|
|
2117
|
+
"label_column": stats.label_column,
|
|
2118
|
+
"valid_file": valid_file,
|
|
2119
|
+
"base_model_path": base_model_path,
|
|
2120
|
+
"num_train_epochs": num_train_epochs,
|
|
2121
|
+
"learning_rate": learning_rate,
|
|
2122
|
+
"train_batch_size": train_batch_size,
|
|
2123
|
+
"max_length": max_length,
|
|
2124
|
+
"weight_decay": weight_decay,
|
|
2125
|
+
"warmup_ratio": warmup_ratio,
|
|
2126
|
+
"train_seed": train_seed,
|
|
2127
|
+
"validation_ratio": validation_ratio,
|
|
2128
|
+
},
|
|
2129
|
+
training_stats=stats,
|
|
2130
|
+
extra_output_files={"eval_file": eval_csv_path} if eval_report is not None else None,
|
|
2131
|
+
)
|
|
2132
|
+
_print_training_stats(stats, eval_report, output_dir, eval_csv_path if eval_report else None, logger)
|
|
2133
|
+
close_logger()
|
|
2134
|
+
return 0
|
|
2135
|
+
|
|
2136
|
+
|
|
2137
|
+
def _resolve_train_valid_frames(
|
|
2138
|
+
*,
|
|
2139
|
+
main_df: pd.DataFrame,
|
|
2140
|
+
valid_inject_df: pd.DataFrame | None,
|
|
2141
|
+
valid_explicit_df: pd.DataFrame | None,
|
|
2142
|
+
validation_ratio: float,
|
|
2143
|
+
seed: int,
|
|
2144
|
+
) -> tuple[pd.DataFrame, pd.DataFrame | None]:
|
|
2145
|
+
if valid_explicit_df is not None and not valid_explicit_df.empty:
|
|
2146
|
+
return main_df, valid_explicit_df
|
|
2147
|
+
if valid_inject_df is not None and not valid_inject_df.empty:
|
|
2148
|
+
return main_df, valid_inject_df
|
|
2149
|
+
if validation_ratio > 0 and len(main_df) >= 2:
|
|
2150
|
+
valid_df = main_df.sample(frac=validation_ratio, random_state=seed)
|
|
2151
|
+
if not valid_df.empty:
|
|
2152
|
+
train_df = main_df.drop(valid_df.index)
|
|
2153
|
+
return train_df, valid_df
|
|
2154
|
+
return main_df, None
|
|
2155
|
+
|
|
2156
|
+
|
|
2157
|
+
DEFAULT_TRAIN_OUTPUT_ROOT = "output_models"
|
|
2158
|
+
|
|
2159
|
+
|
|
2160
|
+
def _resolve_train_output_dir(output_arg: str | None, base_model_path: str) -> Path:
|
|
2161
|
+
if output_arg:
|
|
2162
|
+
return Path(output_arg)
|
|
2163
|
+
# 默认产物落在 cwd 下的 output_models/,按「基模名-日期」命名;
|
|
2164
|
+
# 同一天重复训练时追加序号 -2、-3……(首次不带序号)。
|
|
2165
|
+
base_name = Path(base_model_path).name or "model"
|
|
2166
|
+
date_str = datetime.now().strftime("%Y%m%d")
|
|
2167
|
+
root = Path(DEFAULT_TRAIN_OUTPUT_ROOT)
|
|
2168
|
+
candidate = root / f"{base_name}-{date_str}"
|
|
2169
|
+
if not candidate.exists():
|
|
2170
|
+
return candidate
|
|
2171
|
+
index = 2
|
|
2172
|
+
while (root / f"{base_name}-{date_str}-{index}").exists():
|
|
2173
|
+
index += 1
|
|
2174
|
+
return root / f"{base_name}-{date_str}-{index}"
|
|
2175
|
+
|
|
2176
|
+
|
|
1921
2177
|
def _resolve_output_path(input_path: Path, output_arg: str | None) -> Path:
|
|
1922
2178
|
if output_arg:
|
|
1923
2179
|
return Path(output_arg)
|
|
@@ -1987,6 +2243,10 @@ def _resolve_cluster_report_html_output_path(output_path: Path) -> Path:
|
|
|
1987
2243
|
|
|
1988
2244
|
|
|
1989
2245
|
def _resolve_meta_output_path(output_path: Path) -> Path:
|
|
2246
|
+
# 目录型产物(如 train 的模型目录)用完整名派生,避免 .stem 把
|
|
2247
|
+
# 含点的名字(如 Qwen2.5-0.5B-20260624)从最后一个点处截断。
|
|
2248
|
+
if output_path.is_dir():
|
|
2249
|
+
return output_path.with_name(f"{output_path.name}.meta.json")
|
|
1990
2250
|
return output_path.with_name(f"{output_path.stem}.meta.json")
|
|
1991
2251
|
|
|
1992
2252
|
|
|
@@ -2026,6 +2286,39 @@ def _print_deduplication_stats(
|
|
|
2026
2286
|
_emit_message(f"去重后总行数:{stats.total_after}", logger)
|
|
2027
2287
|
|
|
2028
2288
|
|
|
2289
|
+
def _print_training_stats(
|
|
2290
|
+
stats: TextClassificationStats,
|
|
2291
|
+
eval_report: EvalReport | None,
|
|
2292
|
+
output_dir: Path,
|
|
2293
|
+
eval_csv_path: Path | None,
|
|
2294
|
+
logger: Logger,
|
|
2295
|
+
) -> None:
|
|
2296
|
+
_emit_message(f"模型训练完成,输出目录:{output_dir}", logger)
|
|
2297
|
+
_emit_message(f"基础模型:{stats.base_model_path}", logger)
|
|
2298
|
+
_emit_message(f"使用目标列:{stats.target_column}", logger)
|
|
2299
|
+
_emit_message(f"使用标签列:{stats.label_column}", logger)
|
|
2300
|
+
_emit_message(f"标签数量:{stats.num_labels}", logger)
|
|
2301
|
+
_emit_message(f"标签列表:{', '.join(stats.label_names)}", logger)
|
|
2302
|
+
_emit_message(f"训练样本数:{stats.train_rows}", logger)
|
|
2303
|
+
_emit_message(f"验证样本数:{stats.valid_rows}", logger)
|
|
2304
|
+
_emit_message(f"训练设备:{stats.device}", logger)
|
|
2305
|
+
_emit_message(
|
|
2306
|
+
"训练参数:"
|
|
2307
|
+
f"epochs={stats.num_train_epochs}, "
|
|
2308
|
+
f"batch_size={stats.per_device_batch_size}, "
|
|
2309
|
+
f"learning_rate={stats.learning_rate}, "
|
|
2310
|
+
f"max_length={stats.max_length}",
|
|
2311
|
+
logger,
|
|
2312
|
+
)
|
|
2313
|
+
_emit_message(f"最终训练损失:{stats.final_train_loss}", logger)
|
|
2314
|
+
if stats.best_metric is not None:
|
|
2315
|
+
_emit_message(f"验证集 Macro F1:{stats.best_metric}", logger)
|
|
2316
|
+
if eval_report is not None:
|
|
2317
|
+
_emit_message(eval_report.format_summary(), logger)
|
|
2318
|
+
if eval_csv_path is not None:
|
|
2319
|
+
_emit_message(f"评估报告(CSV):{eval_csv_path}", logger)
|
|
2320
|
+
|
|
2321
|
+
|
|
2029
2322
|
def _print_clustering_stats(
|
|
2030
2323
|
stats: ClusteringStats,
|
|
2031
2324
|
output_path: Path,
|
|
@@ -2179,6 +2472,7 @@ def _write_meta(
|
|
|
2179
2472
|
clustering_stats: ClusteringStats | None = None,
|
|
2180
2473
|
split_stats: SplitStats | None = None,
|
|
2181
2474
|
sft_conversion_stats: SftConversionStats | None = None,
|
|
2475
|
+
training_stats: TextClassificationStats | None = None,
|
|
2182
2476
|
match_output_path: Path | None = None,
|
|
2183
2477
|
cluster_summary_path: Path | None = None,
|
|
2184
2478
|
projection_path: Path | None = None,
|
|
@@ -2259,6 +2553,24 @@ def _write_meta(
|
|
|
2259
2553
|
"skipped_blank_output_rows": sft_conversion_stats.skipped_blank_output_rows,
|
|
2260
2554
|
"skipped_rows": sft_conversion_stats.skipped_rows,
|
|
2261
2555
|
}
|
|
2556
|
+
if training_stats is not None:
|
|
2557
|
+
meta["training_stats"] = {
|
|
2558
|
+
"base_model_path": training_stats.base_model_path,
|
|
2559
|
+
"model_output_dir": training_stats.model_output_dir,
|
|
2560
|
+
"target_column": training_stats.target_column,
|
|
2561
|
+
"label_column": training_stats.label_column,
|
|
2562
|
+
"num_labels": training_stats.num_labels,
|
|
2563
|
+
"label_names": training_stats.label_names,
|
|
2564
|
+
"train_rows": training_stats.train_rows,
|
|
2565
|
+
"valid_rows": training_stats.valid_rows,
|
|
2566
|
+
"num_train_epochs": training_stats.num_train_epochs,
|
|
2567
|
+
"per_device_batch_size": training_stats.per_device_batch_size,
|
|
2568
|
+
"learning_rate": training_stats.learning_rate,
|
|
2569
|
+
"max_length": training_stats.max_length,
|
|
2570
|
+
"device": training_stats.device,
|
|
2571
|
+
"final_train_loss": training_stats.final_train_loss,
|
|
2572
|
+
"best_metric": training_stats.best_metric,
|
|
2573
|
+
}
|
|
2262
2574
|
if match_output_path is not None and match_output_path.exists():
|
|
2263
2575
|
meta["match_file"] = str(match_output_path)
|
|
2264
2576
|
if cluster_summary_path is not None and cluster_summary_path.exists():
|
|
@@ -37,6 +37,7 @@ EXPECTED_RESULT_COLUMN = "预期结果"
|
|
|
37
37
|
MATCH_EXPECTED_COLUMN = "匹配预期"
|
|
38
38
|
MODEL_CALL_TIME_COLUMN = "模型调用时间"
|
|
39
39
|
DEFAULT_MODEL_TEST_MODE = "local"
|
|
40
|
+
DEFAULT_CLASSIFICATION_MAX_LENGTH = 128
|
|
40
41
|
DEFAULT_MODEL_TEST_API_BASE_URL = "https://api.openai.com/v1"
|
|
41
42
|
DEFAULT_HTTP_API_KEY_ENV_VAR = "HTTP_API_KEY"
|
|
42
43
|
DEFAULT_HTTP_API_KEY_HEADER = "api_key"
|
|
@@ -496,7 +497,19 @@ def model_test_dataframe(
|
|
|
496
497
|
expected_results = (
|
|
497
498
|
dataframe[expected_result_column].tolist() if has_expected_result else [None] * len(dataframe)
|
|
498
499
|
)
|
|
499
|
-
|
|
500
|
+
effective_mode = runtime_config.mode
|
|
501
|
+
if effective_mode == "local" and _is_classification_model_dir(model_path):
|
|
502
|
+
effective_mode = "classification"
|
|
503
|
+
|
|
504
|
+
if effective_mode == "classification":
|
|
505
|
+
model_results, model_call_times, device_used = _run_classification_batches(
|
|
506
|
+
prompts=prompts,
|
|
507
|
+
model_path=model_path,
|
|
508
|
+
runtime_config=runtime_config,
|
|
509
|
+
progress_callback=progress_callback,
|
|
510
|
+
)
|
|
511
|
+
worker_count = 1
|
|
512
|
+
elif effective_mode in {"openai", "http"}:
|
|
500
513
|
model_results, model_call_times, device_used = _run_openai_batches(
|
|
501
514
|
prompts=prompts,
|
|
502
515
|
model_path=model_path,
|
|
@@ -597,6 +610,43 @@ def _run_openai_batches(
|
|
|
597
610
|
return model_results, model_call_times, tester.device
|
|
598
611
|
|
|
599
612
|
|
|
613
|
+
def _is_classification_model_dir(model_path: str | Path) -> bool:
|
|
614
|
+
# 延迟导入以避免与 model_training -> model_eval -> model_testing 形成循环依赖。
|
|
615
|
+
from mysphinx_forge.model_training import is_classification_model_dir
|
|
616
|
+
|
|
617
|
+
candidate = Path(model_path)
|
|
618
|
+
return candidate.is_dir() and is_classification_model_dir(candidate)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def _run_classification_batches(
|
|
622
|
+
prompts: list[object],
|
|
623
|
+
model_path: str | Path,
|
|
624
|
+
runtime_config: ModelTestRuntimeConfig,
|
|
625
|
+
progress_callback: Callable[[int], None] | None = None,
|
|
626
|
+
) -> tuple[list[str], list[float], str]:
|
|
627
|
+
from mysphinx_forge.model_training import _resolve_training_device, predict_intent
|
|
628
|
+
|
|
629
|
+
try:
|
|
630
|
+
import torch
|
|
631
|
+
except ImportError as exc:
|
|
632
|
+
raise ValueError("未安装分类推理所需依赖,请先执行 uv sync --extra train。") from exc
|
|
633
|
+
|
|
634
|
+
device = _resolve_training_device(torch)
|
|
635
|
+
started_at = time.perf_counter()
|
|
636
|
+
model_results = predict_intent(
|
|
637
|
+
model_path,
|
|
638
|
+
list(prompts),
|
|
639
|
+
batch_size=runtime_config.batch_size,
|
|
640
|
+
max_length=DEFAULT_CLASSIFICATION_MAX_LENGTH,
|
|
641
|
+
device=device,
|
|
642
|
+
progress_callback=progress_callback,
|
|
643
|
+
)
|
|
644
|
+
elapsed_seconds = time.perf_counter() - started_at
|
|
645
|
+
per_row_seconds = round(elapsed_seconds / len(model_results), 4) if model_results else 0.0
|
|
646
|
+
model_call_times = [per_row_seconds] * len(model_results)
|
|
647
|
+
return model_results, model_call_times, device
|
|
648
|
+
|
|
649
|
+
|
|
600
650
|
def _build_model_tester(
|
|
601
651
|
model_path: str | Path,
|
|
602
652
|
runtime_config: ModelTestRuntimeConfig,
|