mysphinx-forge 0.2.3__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {mysphinx_forge-0.2.3/mysphinx_forge.egg-info → mysphinx_forge-0.3.0}/PKG-INFO +41 -1
  2. mysphinx_forge-0.2.3/PKG-INFO → mysphinx_forge-0.3.0/README.md +36 -37
  3. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cli.py +313 -1
  4. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/config.py +2 -0
  5. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/model_testing.py +51 -1
  6. mysphinx_forge-0.3.0/mysphinx_forge/model_training.py +415 -0
  7. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/templates/mysphinx-forge.yaml +57 -1
  8. mysphinx_forge-0.2.3/README.md → mysphinx_forge-0.3.0/mysphinx_forge.egg-info/PKG-INFO +77 -0
  9. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/SOURCES.txt +4 -1
  10. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/requires.txt +5 -0
  11. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/pyproject.toml +6 -1
  12. mysphinx_forge-0.3.0/tests/test_model_training.py +69 -0
  13. mysphinx_forge-0.3.0/tests/test_train_cli.py +151 -0
  14. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/__init__.py +0 -0
  15. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cleaning.py +0 -0
  16. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cluster_labeling.py +0 -0
  17. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/cluster_reporting.py +0 -0
  18. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/clustering.py +0 -0
  19. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/deduplication.py +0 -0
  20. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/embedding.py +0 -0
  21. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/env_utils.py +0 -0
  22. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/file_io.py +0 -0
  23. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/http_client.py +0 -0
  24. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/logging_utils.py +0 -0
  25. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/model_eval.py +0 -0
  26. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/openai_responses.py +0 -0
  27. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/progress.py +0 -0
  28. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/semantic_deduplication.py +0 -0
  29. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/sft_dataset.py +0 -0
  30. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/splitting.py +0 -0
  31. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge/templates/__init__.py +0 -0
  32. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/dependency_links.txt +0 -0
  33. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/entry_points.txt +0 -0
  34. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/mysphinx_forge.egg-info/top_level.txt +0 -0
  35. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/setup.cfg +0 -0
  36. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cleaning.py +0 -0
  37. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cli.py +0 -0
  38. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cluster_labeling.py +0 -0
  39. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_cluster_reporting.py +0 -0
  40. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_clustering.py +0 -0
  41. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_deduplication.py +0 -0
  42. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_file_io.py +0 -0
  43. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_http_client.py +0 -0
  44. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_model_eval.py +0 -0
  45. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_model_testing.py +0 -0
  46. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_semantic_deduplication.py +0 -0
  47. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_sft_cli.py +0 -0
  48. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_sft_dataset.py +0 -0
  49. {mysphinx_forge-0.2.3 → mysphinx_forge-0.3.0}/tests/test_splitting.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mysphinx-forge
3
- Version: 0.2.3
3
+ Version: 0.3.0
4
4
  Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
5
5
  Keywords: data-cleaning,deduplication,clustering,nlp,cli
6
6
  Classifier: Development Status :: 3 - Alpha
@@ -27,6 +27,10 @@ Requires-Dist: sentence-transformers>=5.1.0; extra == "embeddings"
27
27
  Provides-Extra: llm-local
28
28
  Requires-Dist: torch>=2.8.0; extra == "llm-local"
29
29
  Requires-Dist: transformers>=4.55.0; extra == "llm-local"
30
+ Provides-Extra: train
31
+ Requires-Dist: torch>=2.8.0; extra == "train"
32
+ Requires-Dist: transformers>=4.55.0; extra == "train"
33
+ Requires-Dist: scikit-learn>=1.7.0; extra == "train"
30
34
  Provides-Extra: all
31
35
  Requires-Dist: numpy>=2.2.0; extra == "all"
32
36
  Requires-Dist: scikit-learn>=1.7.0; extra == "all"
@@ -66,6 +70,7 @@ Requires-Dist: transformers>=4.55.0; extra == "all"
66
70
  | `split` | 切分 train / valid / test | `*_split_train.*` 等 |
67
71
  | `model-test` | 批量执行模型推理或单条烟雾测试,含预期结果列时自动输出评估报告 | `*_model_tested.*` 或终端输出 |
68
72
  | `convert-sft` | 转换表格数据为 SFT 数据(`alpaca` / `pa` 格式) | `*_alpaca.json` / `*_pa.jsonl`(超 10000 条自动切分) |
73
+ | `train` | 微调本地预训练分类基座(BERT 等)为意图分类模型,含验证集时自动输出评估报告 | `output_models/<基模名>-<日期>/` + `*_train_eval.csv` |
69
74
 
70
75
  ## 项目结构
71
76
 
@@ -128,6 +133,7 @@ uv sync --extra all --group dev
128
133
  - `embeddings`:语义去重
129
134
  - `ml + embeddings`:聚类
130
135
  - `llm-local`:本地模型 `model-test`
136
+ - `train`:意图分类模型 `train`(BERT 等基座),以及分类模型的 `model-test`(`local` 模式自动识别)
131
137
  - `all`:安装全部能力
132
138
 
133
139
  ### 从源码安装当前项目
@@ -137,6 +143,7 @@ uv pip install -e .
137
143
  uv pip install -e '.[embeddings]'
138
144
  uv pip install -e '.[ml,embeddings]'
139
145
  uv pip install -e '.[llm-local]'
146
+ uv pip install -e '.[train]'
140
147
  uv pip install -e '.[all]'
141
148
  ```
142
149
 
@@ -654,6 +661,38 @@ F1,0.75000,0.75000,0.75000,0.75000,8
654
661
 
655
662
  本地批量测试会按可见 GPU 数自动分配 worker;没有 GPU 时自动退化为单 worker CPU 模式。
656
663
 
664
+ ### 训练意图分类模型
665
+
666
+ `train` 基于一个本地预训练分类基座(默认 BERT,也兼容其它带 `*ForSequenceClassification` 实现的 encoder 架构),微调出意图分类模型。输入表需要同时包含文本列(默认自动探测 `text/用户问题/客户问题/用户输入`)和标签列(默认自动探测 `category/label/intent/.../预期结果`)。需要先安装训练依赖:
667
+
668
+ ```bash
669
+ uv sync --extra train --group dev
670
+ ```
671
+
672
+ ```bash
673
+ # 直接在切分后的训练集上微调(默认基础模型 models/bert-base-chinese)
674
+ mysphinx-forge --action train --input-file data/input_split_train.xlsx --base-model-path models/bert-base-chinese
675
+ ```
676
+
677
+ 验证集的选取优先级:`--valid-file` 指定的独立文件 > Excel 中的 `valid` 注入表 > 按 `--validation-ratio` 从训练集随机切分。存在验证集时会输出评估报告:
678
+
679
+ ```bash
680
+ mysphinx-forge --action train --input-file data/input_split_train.xlsx \
681
+ --valid-file data/input_split_valid.xlsx \
682
+ --label-column category --num-train-epochs 3 --learning-rate 2e-5 --train-batch-size 16
683
+ ```
684
+
685
+ 常用训练超参:`--num-train-epochs`、`--learning-rate`、`--train-batch-size`、`--max-length`、`--weight-decay`、`--warmup-ratio`、`--train-seed`。
686
+
687
+ 默认情况下,训练产物会写到当前工作目录下的 `output_models/`,目录按「基模名-日期」命名(例如 `output_models/bert-base-chinese-20260624`);同一天重复训练会自动追加序号(`-2`、`-3`……,首次不带序号)。可用 `--output-dir` 显式指定模型目录路径覆盖该默认行为。
688
+
689
+ 模型目录是一个标准的 transformers 模型目录(含权重、tokenizer 和 `label_map.json`),可直接用 `model-test`(`local` 模式)在测试集上复跑评估——检测到目录内的 `label_map.json` 时会自动按文本分类模型推理,无需额外指定:
690
+
691
+ ```bash
692
+ mysphinx-forge --action model-test --input-file data/input_split_test.xlsx \
693
+ --test-model-path output_models/bert-base-chinese-20260624
694
+ ```
695
+
657
696
  ## 输出文件规则
658
697
 
659
698
  | Action | 主要输出 | 附加输出 |
@@ -665,6 +704,7 @@ F1,0.75000,0.75000,0.75000,0.75000,8
665
704
  | `split` | `*_split_train.*`、`*_split_valid.*`、`*_split_test.*` | `*_split.meta.json`、`mysphinx-forge.log` |
666
705
  | `model-test` 文件模式 | `*_model_tested.*` | 含 `预期结果` 列时额外生成 `*_model_tested_eval.csv`;同时写 `mysphinx-forge.log` |
667
706
  | `model-test` 单条模式 | 终端输出 | 当前工作目录下的 `mysphinx-forge.log` |
707
+ | `train` | `output_models/<基模名>-<日期>/`(模型目录,可用 `--output-dir` 覆盖) | 含验证集时额外生成同级 `<模型目录名>_train_eval.csv`;同时写同级 `<模型目录名>.meta.json` 与 `output_models/mysphinx-forge.log` |
668
708
 
669
709
  补充说明:
670
710
 
@@ -1,40 +1,3 @@
1
- Metadata-Version: 2.4
2
- Name: mysphinx-forge
3
- Version: 0.2.3
4
- Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
5
- Keywords: data-cleaning,deduplication,clustering,nlp,cli
6
- Classifier: Development Status :: 3 - Alpha
7
- Classifier: Intended Audience :: Developers
8
- Classifier: Intended Audience :: Science/Research
9
- Classifier: Programming Language :: Python :: 3
10
- Classifier: Programming Language :: Python :: 3.12
11
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
13
- Classifier: Environment :: Console
14
- Requires-Python: >=3.12
15
- Description-Content-Type: text/markdown
16
- Requires-Dist: pandas>=2.2.3
17
- Requires-Dist: tqdm>=4.67.1
18
- Requires-Dist: openpyxl>=3.1.5
19
- Requires-Dist: xlrd>=2.0.1
20
- Requires-Dist: pyyaml>=6.0.0
21
- Provides-Extra: ml
22
- Requires-Dist: numpy>=2.2.0; extra == "ml"
23
- Requires-Dist: scikit-learn>=1.7.0; extra == "ml"
24
- Provides-Extra: embeddings
25
- Requires-Dist: faiss-cpu>=1.11.0; extra == "embeddings"
26
- Requires-Dist: sentence-transformers>=5.1.0; extra == "embeddings"
27
- Provides-Extra: llm-local
28
- Requires-Dist: torch>=2.8.0; extra == "llm-local"
29
- Requires-Dist: transformers>=4.55.0; extra == "llm-local"
30
- Provides-Extra: all
31
- Requires-Dist: numpy>=2.2.0; extra == "all"
32
- Requires-Dist: scikit-learn>=1.7.0; extra == "all"
33
- Requires-Dist: faiss-cpu>=1.11.0; extra == "all"
34
- Requires-Dist: sentence-transformers>=5.1.0; extra == "all"
35
- Requires-Dist: torch>=2.8.0; extra == "all"
36
- Requires-Dist: transformers>=4.55.0; extra == "all"
37
-
38
1
  # MySphinx Forge
39
2
 
40
3
  `MySphinx Forge` 是一个面向表格文本数据的 Python CLI 工具,重点解决语料预处理和模型验证这两类工作:
@@ -66,6 +29,7 @@ Requires-Dist: transformers>=4.55.0; extra == "all"
66
29
  | `split` | 切分 train / valid / test | `*_split_train.*` 等 |
67
30
  | `model-test` | 批量执行模型推理或单条烟雾测试,含预期结果列时自动输出评估报告 | `*_model_tested.*` 或终端输出 |
68
31
  | `convert-sft` | 转换表格数据为 SFT 数据(`alpaca` / `pa` 格式) | `*_alpaca.json` / `*_pa.jsonl`(超 10000 条自动切分) |
32
+ | `train` | 微调本地预训练分类基座(BERT 等)为意图分类模型,含验证集时自动输出评估报告 | `output_models/<基模名>-<日期>/` + `*_train_eval.csv` |
69
33
 
70
34
  ## 项目结构
71
35
 
@@ -128,6 +92,7 @@ uv sync --extra all --group dev
128
92
  - `embeddings`:语义去重
129
93
  - `ml + embeddings`:聚类
130
94
  - `llm-local`:本地模型 `model-test`
95
+ - `train`:意图分类模型 `train`(BERT 等基座),以及分类模型的 `model-test`(`local` 模式自动识别)
131
96
  - `all`:安装全部能力
132
97
 
133
98
  ### 从源码安装当前项目
@@ -137,6 +102,7 @@ uv pip install -e .
137
102
  uv pip install -e '.[embeddings]'
138
103
  uv pip install -e '.[ml,embeddings]'
139
104
  uv pip install -e '.[llm-local]'
105
+ uv pip install -e '.[train]'
140
106
  uv pip install -e '.[all]'
141
107
  ```
142
108
 
@@ -654,6 +620,38 @@ F1,0.75000,0.75000,0.75000,0.75000,8
654
620
 
655
621
  本地批量测试会按可见 GPU 数自动分配 worker;没有 GPU 时自动退化为单 worker CPU 模式。
656
622
 
623
+ ### 训练意图分类模型
624
+
625
+ `train` 基于一个本地预训练分类基座(默认 BERT,也兼容其它带 `*ForSequenceClassification` 实现的 encoder 架构),微调出意图分类模型。输入表需要同时包含文本列(默认自动探测 `text/用户问题/客户问题/用户输入`)和标签列(默认自动探测 `category/label/intent/.../预期结果`)。需要先安装训练依赖:
626
+
627
+ ```bash
628
+ uv sync --extra train --group dev
629
+ ```
630
+
631
+ ```bash
632
+ # 直接在切分后的训练集上微调(默认基础模型 models/bert-base-chinese)
633
+ mysphinx-forge --action train --input-file data/input_split_train.xlsx --base-model-path models/bert-base-chinese
634
+ ```
635
+
636
+ 验证集的选取优先级:`--valid-file` 指定的独立文件 > Excel 中的 `valid` 注入表 > 按 `--validation-ratio` 从训练集随机切分。存在验证集时会输出评估报告:
637
+
638
+ ```bash
639
+ mysphinx-forge --action train --input-file data/input_split_train.xlsx \
640
+ --valid-file data/input_split_valid.xlsx \
641
+ --label-column category --num-train-epochs 3 --learning-rate 2e-5 --train-batch-size 16
642
+ ```
643
+
644
+ 常用训练超参:`--num-train-epochs`、`--learning-rate`、`--train-batch-size`、`--max-length`、`--weight-decay`、`--warmup-ratio`、`--train-seed`。
645
+
646
+ 默认情况下,训练产物会写到当前工作目录下的 `output_models/`,目录按「基模名-日期」命名(例如 `output_models/bert-base-chinese-20260624`);同一天重复训练会自动追加序号(`-2`、`-3`……,首次不带序号)。可用 `--output-dir` 显式指定模型目录路径覆盖该默认行为。
647
+
648
+ 模型目录是一个标准的 transformers 模型目录(含权重、tokenizer 和 `label_map.json`),可直接用 `model-test`(`local` 模式)在测试集上复跑评估——检测到目录内的 `label_map.json` 时会自动按文本分类模型推理,无需额外指定:
649
+
650
+ ```bash
651
+ mysphinx-forge --action model-test --input-file data/input_split_test.xlsx \
652
+ --test-model-path output_models/bert-base-chinese-20260624
653
+ ```
654
+
657
655
  ## 输出文件规则
658
656
 
659
657
  | Action | 主要输出 | 附加输出 |
@@ -665,6 +663,7 @@ F1,0.75000,0.75000,0.75000,0.75000,8
665
663
  | `split` | `*_split_train.*`、`*_split_valid.*`、`*_split_test.*` | `*_split.meta.json`、`mysphinx-forge.log` |
666
664
  | `model-test` 文件模式 | `*_model_tested.*` | 含 `预期结果` 列时额外生成 `*_model_tested_eval.csv`;同时写 `mysphinx-forge.log` |
667
665
  | `model-test` 单条模式 | 终端输出 | 当前工作目录下的 `mysphinx-forge.log` |
666
+ | `train` | `output_models/<基模名>-<日期>/`(模型目录,可用 `--output-dir` 覆盖) | 含验证集时额外生成同级 `<模型目录名>_train_eval.csv`;同时写同级 `<模型目录名>.meta.json` 与 `output_models/mysphinx-forge.log` |
668
667
 
669
668
  补充说明:
670
669
 
@@ -56,6 +56,12 @@ from mysphinx_forge.model_testing import (
56
56
  model_test_dataframe,
57
57
  run_model_test,
58
58
  )
59
+ from mysphinx_forge.model_training import (
60
+ ClassifierTrainingConfig,
61
+ DEFAULT_BASE_MODEL_PATH,
62
+ TextClassificationStats,
63
+ train_intent_classifier,
64
+ )
59
65
  from mysphinx_forge.progress import ProgressBar, run_stage
60
66
  from mysphinx_forge.semantic_deduplication import (
61
67
  DEFAULT_EMBEDDING_MODEL_PATH,
@@ -87,6 +93,7 @@ _ACTION_CHOICES = [
87
93
  "model-test",
88
94
  "split",
89
95
  "convert-sft",
96
+ "train",
90
97
  ]
91
98
 
92
99
 
@@ -171,6 +178,24 @@ def main() -> int:
171
178
  if args.repetition_penalty <= 0:
172
179
  print("--repetition-penalty 必须是大于 0 的数值。")
173
180
  return 1
181
+ if args.num_train_epochs <= 0:
182
+ print("--num-train-epochs 必须是大于 0 的整数。")
183
+ return 1
184
+ if args.learning_rate <= 0:
185
+ print("--learning-rate 必须是大于 0 的数值。")
186
+ return 1
187
+ if args.train_batch_size <= 0:
188
+ print("--train-batch-size 必须是大于 0 的整数。")
189
+ return 1
190
+ if args.max_length <= 0:
191
+ print("--max-length 必须是大于 0 的整数。")
192
+ return 1
193
+ if args.weight_decay < 0:
194
+ print("--weight-decay 不能小于 0。")
195
+ return 1
196
+ if not 0 <= args.warmup_ratio < 1:
197
+ print("--warmup-ratio 必须在 0 到 1 之间,且不能等于 1。")
198
+ return 1
174
199
  if args.split_random_seed < 0:
175
200
  print("--split-random-seed 不能小于 0。")
176
201
  return 1
@@ -315,6 +340,23 @@ def main() -> int:
315
340
  args.sft_user_query_as_instruction,
316
341
  args.sft_max_records_per_file,
317
342
  )
343
+ if args.action == "train":
344
+ return _run_train(
345
+ input_file=args.input_file,
346
+ output_arg=args.output_dir,
347
+ target_column=args.target_column,
348
+ label_column=args.label_column,
349
+ valid_file=args.valid_file or None,
350
+ base_model_path=args.base_model_path,
351
+ num_train_epochs=args.num_train_epochs,
352
+ learning_rate=args.learning_rate,
353
+ train_batch_size=args.train_batch_size,
354
+ max_length=args.max_length,
355
+ weight_decay=args.weight_decay,
356
+ warmup_ratio=args.warmup_ratio,
357
+ train_seed=args.train_seed,
358
+ validation_ratio=args.validation_ratio,
359
+ )
318
360
 
319
361
  parser.print_help()
320
362
  return 1
@@ -416,6 +458,12 @@ def _build_parser(
416
458
  default=config_defaults.get("output"),
417
459
  help="输出文件路径。未指定时,默认在原文件旁生成 *_cleaned 文件。",
418
460
  )
461
+ parser.add_argument(
462
+ "--output-dir",
463
+ dest="output_dir",
464
+ default=config_defaults.get("output_dir"),
465
+ help="train 模型输出目录。未指定时默认写到 cwd 下的 output_models/,按「基模名-日期」命名。",
466
+ )
419
467
  parser.add_argument(
420
468
  "--chunk-size",
421
469
  type=int,
@@ -458,7 +506,7 @@ def _build_parser(
458
506
  "--model-test-mode",
459
507
  choices=["local", "openai", "http"],
460
508
  default=config_defaults.get("model_test_mode", DEFAULT_MODEL_TEST_MODE),
461
- help="模型测试模式。local 为本地模型推理,openai 为兼容 OpenAI Chat Completions 的接口调用,http 为通用 HTTP POST 接口调用。",
509
+ help="模型测试模式。local 为本地模型推理,openai 为兼容 OpenAI Chat Completions 的接口调用,http 为通用 HTTP POST 接口调用。local 模式下若模型目录含 label_map.json,会自动按文本分类模型推理。",
462
510
  )
463
511
  parser.add_argument(
464
512
  "--test-model-path",
@@ -722,6 +770,64 @@ def _build_parser(
722
770
  default=config_defaults.get("sft_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
723
771
  help=f"pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件,默认 {PA_MAX_RECORDS_PER_FILE}。",
724
772
  )
773
+ parser.add_argument(
774
+ "--base-model-path",
775
+ default=config_defaults.get("base_model_path", DEFAULT_BASE_MODEL_PATH),
776
+ help=f"train 训练使用的本地预训练分类基座模型路径(BERT 等 encoder 架构),默认 {DEFAULT_BASE_MODEL_PATH}。",
777
+ )
778
+ parser.add_argument(
779
+ "--label-column",
780
+ default=config_defaults.get("label_column", ""),
781
+ help="train 训练使用的标签列名。未指定时自动探测 category/label/intent/.../预期结果。",
782
+ )
783
+ parser.add_argument(
784
+ "--valid-file",
785
+ dest="valid_file",
786
+ default=config_defaults.get("valid_file", ""),
787
+ help="train 训练使用的独立验证集文件。未指定时优先使用 valid 注入表,否则按 --validation-ratio 从训练集切分。",
788
+ )
789
+ parser.add_argument(
790
+ "--num-train-epochs",
791
+ type=int,
792
+ default=config_defaults.get("num_train_epochs", 3),
793
+ help="train 训练轮数,默认 3。",
794
+ )
795
+ parser.add_argument(
796
+ "--learning-rate",
797
+ type=float,
798
+ default=config_defaults.get("learning_rate", 2e-5),
799
+ help="train 训练学习率,默认 2e-5。",
800
+ )
801
+ parser.add_argument(
802
+ "--train-batch-size",
803
+ type=int,
804
+ default=config_defaults.get("train_batch_size", 16),
805
+ help="train 单设备训练批大小,默认 16。",
806
+ )
807
+ parser.add_argument(
808
+ "--max-length",
809
+ type=int,
810
+ default=config_defaults.get("max_length", 128),
811
+ help="train 文本分词的最大长度,默认 128。",
812
+ )
813
+ parser.add_argument(
814
+ "--weight-decay",
815
+ type=float,
816
+ default=config_defaults.get("weight_decay", 0.01),
817
+ help="train 训练权重衰减系数,默认 0.01。",
818
+ )
819
+ parser.add_argument(
820
+ "--warmup-ratio",
821
+ type=float,
822
+ default=config_defaults.get("warmup_ratio", 0.1),
823
+ help="train 训练学习率预热比例,默认 0.1。",
824
+ )
825
+ parser.add_argument(
826
+ "--train-seed",
827
+ type=int,
828
+ default=config_defaults.get("train_seed", 42),
829
+ help="train 训练随机种子,默认 42。",
830
+ )
725
831
  return parser
726
832
 
727
833
 
@@ -1918,6 +2024,156 @@ def _run_convert_sft(
1918
2024
  return 0
1919
2025
 
1920
2026
 
2027
+ def _run_train(
2028
+ input_file: str,
2029
+ output_arg: str | None,
2030
+ target_column: str,
2031
+ label_column: str,
2032
+ valid_file: str | None,
2033
+ base_model_path: str,
2034
+ num_train_epochs: int,
2035
+ learning_rate: float,
2036
+ train_batch_size: int,
2037
+ max_length: int,
2038
+ weight_decay: float,
2039
+ warmup_ratio: float,
2040
+ train_seed: int,
2041
+ validation_ratio: float,
2042
+ ) -> int:
2043
+ input_path = Path(input_file)
2044
+ output_dir = _resolve_train_output_dir(output_arg, base_model_path)
2045
+ # 评估报告与模型目录同级,保证整套训练产物自成一体地落在 output_models/ 下。
2046
+ eval_csv_path = output_dir.with_name(f"{output_dir.name}_train_eval.csv")
2047
+ logger = configure_logger(_resolve_log_path(output_dir))
2048
+ logger.info(
2049
+ "开始执行 action=train input=%s output=%s base_model=%s",
2050
+ input_path,
2051
+ output_dir,
2052
+ base_model_path,
2053
+ )
2054
+
2055
+ try:
2056
+ run_stage("读取文件", logger=logger)
2057
+ main_df, _train_inject_df, valid_inject_df, _test_inject_df = load_split_dataframes(input_file)
2058
+ valid_explicit_df = load_dataframe(valid_file) if valid_file else None
2059
+ train_df, valid_df = _resolve_train_valid_frames(
2060
+ main_df=main_df,
2061
+ valid_inject_df=valid_inject_df,
2062
+ valid_explicit_df=valid_explicit_df,
2063
+ validation_ratio=validation_ratio,
2064
+ seed=train_seed,
2065
+ )
2066
+ except ValueError as exc:
2067
+ _emit_error(str(exc), logger)
2068
+ close_logger()
2069
+ return 1
2070
+
2071
+ config = ClassifierTrainingConfig(
2072
+ base_model_path=base_model_path,
2073
+ num_train_epochs=num_train_epochs,
2074
+ learning_rate=learning_rate,
2075
+ per_device_batch_size=train_batch_size,
2076
+ max_length=max_length,
2077
+ weight_decay=weight_decay,
2078
+ warmup_ratio=warmup_ratio,
2079
+ seed=train_seed,
2080
+ )
2081
+
2082
+ total_steps = num_train_epochs * math.ceil(max(len(train_df), 1) / train_batch_size)
2083
+ progress_bar = ProgressBar(total=total_steps, description="训练模型", logger=logger)
2084
+ try:
2085
+ stats, eval_report = train_intent_classifier(
2086
+ train_df,
2087
+ valid_df,
2088
+ config=config,
2089
+ target_column=target_column,
2090
+ label_column=label_column,
2091
+ model_output_dir=output_dir,
2092
+ progress_callback=progress_bar.advance,
2093
+ )
2094
+ except ValueError as exc:
2095
+ progress_bar.close()
2096
+ _emit_error(str(exc), logger)
2097
+ close_logger()
2098
+ return 1
2099
+ except Exception as exc:
2100
+ progress_bar.close()
2101
+ logger.exception("执行模型训练失败")
2102
+ _emit_error(f"执行模型训练失败:{type(exc).__name__}: {exc}", logger)
2103
+ close_logger()
2104
+ return 1
2105
+ finally:
2106
+ progress_bar.close()
2107
+
2108
+ run_stage("写出结果", logger=logger)
2109
+ if eval_report is not None:
2110
+ eval_csv_path.write_text(eval_report.format_csv(), encoding="utf-8")
2111
+ _write_meta(
2112
+ output_path=output_dir,
2113
+ action="train",
2114
+ input_path=input_path,
2115
+ parameters={
2116
+ "target_column": stats.target_column,
2117
+ "label_column": stats.label_column,
2118
+ "valid_file": valid_file,
2119
+ "base_model_path": base_model_path,
2120
+ "num_train_epochs": num_train_epochs,
2121
+ "learning_rate": learning_rate,
2122
+ "train_batch_size": train_batch_size,
2123
+ "max_length": max_length,
2124
+ "weight_decay": weight_decay,
2125
+ "warmup_ratio": warmup_ratio,
2126
+ "train_seed": train_seed,
2127
+ "validation_ratio": validation_ratio,
2128
+ },
2129
+ training_stats=stats,
2130
+ extra_output_files={"eval_file": eval_csv_path} if eval_report is not None else None,
2131
+ )
2132
+ _print_training_stats(stats, eval_report, output_dir, eval_csv_path if eval_report else None, logger)
2133
+ close_logger()
2134
+ return 0
2135
+
2136
+
2137
+ def _resolve_train_valid_frames(
2138
+ *,
2139
+ main_df: pd.DataFrame,
2140
+ valid_inject_df: pd.DataFrame | None,
2141
+ valid_explicit_df: pd.DataFrame | None,
2142
+ validation_ratio: float,
2143
+ seed: int,
2144
+ ) -> tuple[pd.DataFrame, pd.DataFrame | None]:
2145
+ if valid_explicit_df is not None and not valid_explicit_df.empty:
2146
+ return main_df, valid_explicit_df
2147
+ if valid_inject_df is not None and not valid_inject_df.empty:
2148
+ return main_df, valid_inject_df
2149
+ if validation_ratio > 0 and len(main_df) >= 2:
2150
+ valid_df = main_df.sample(frac=validation_ratio, random_state=seed)
2151
+ if not valid_df.empty:
2152
+ train_df = main_df.drop(valid_df.index)
2153
+ return train_df, valid_df
2154
+ return main_df, None
2155
+
2156
+
2157
+ DEFAULT_TRAIN_OUTPUT_ROOT = "output_models"
2158
+
2159
+
2160
+ def _resolve_train_output_dir(output_arg: str | None, base_model_path: str) -> Path:
2161
+ if output_arg:
2162
+ return Path(output_arg)
2163
+ # 默认产物落在 cwd 下的 output_models/,按「基模名-日期」命名;
2164
+ # 同一天重复训练时追加序号 -2、-3……(首次不带序号)。
2165
+ base_name = Path(base_model_path).name or "model"
2166
+ date_str = datetime.now().strftime("%Y%m%d")
2167
+ root = Path(DEFAULT_TRAIN_OUTPUT_ROOT)
2168
+ candidate = root / f"{base_name}-{date_str}"
2169
+ if not candidate.exists():
2170
+ return candidate
2171
+ index = 2
2172
+ while (root / f"{base_name}-{date_str}-{index}").exists():
2173
+ index += 1
2174
+ return root / f"{base_name}-{date_str}-{index}"
2175
+
2176
+
1921
2177
  def _resolve_output_path(input_path: Path, output_arg: str | None) -> Path:
1922
2178
  if output_arg:
1923
2179
  return Path(output_arg)
@@ -1987,6 +2243,10 @@ def _resolve_cluster_report_html_output_path(output_path: Path) -> Path:
1987
2243
 
1988
2244
 
1989
2245
  def _resolve_meta_output_path(output_path: Path) -> Path:
2246
+ # 目录型产物(如 train 的模型目录)用完整名派生,避免 .stem 把
2247
+ # 含点的名字(如 Qwen2.5-0.5B-20260624)从最后一个点处截断。
2248
+ if output_path.is_dir():
2249
+ return output_path.with_name(f"{output_path.name}.meta.json")
1990
2250
  return output_path.with_name(f"{output_path.stem}.meta.json")
1991
2251
 
1992
2252
 
@@ -2026,6 +2286,39 @@ def _print_deduplication_stats(
2026
2286
  _emit_message(f"去重后总行数:{stats.total_after}", logger)
2027
2287
 
2028
2288
 
2289
+ def _print_training_stats(
2290
+ stats: TextClassificationStats,
2291
+ eval_report: EvalReport | None,
2292
+ output_dir: Path,
2293
+ eval_csv_path: Path | None,
2294
+ logger: Logger,
2295
+ ) -> None:
2296
+ _emit_message(f"模型训练完成,输出目录:{output_dir}", logger)
2297
+ _emit_message(f"基础模型:{stats.base_model_path}", logger)
2298
+ _emit_message(f"使用目标列:{stats.target_column}", logger)
2299
+ _emit_message(f"使用标签列:{stats.label_column}", logger)
2300
+ _emit_message(f"标签数量:{stats.num_labels}", logger)
2301
+ _emit_message(f"标签列表:{', '.join(stats.label_names)}", logger)
2302
+ _emit_message(f"训练样本数:{stats.train_rows}", logger)
2303
+ _emit_message(f"验证样本数:{stats.valid_rows}", logger)
2304
+ _emit_message(f"训练设备:{stats.device}", logger)
2305
+ _emit_message(
2306
+ "训练参数:"
2307
+ f"epochs={stats.num_train_epochs}, "
2308
+ f"batch_size={stats.per_device_batch_size}, "
2309
+ f"learning_rate={stats.learning_rate}, "
2310
+ f"max_length={stats.max_length}",
2311
+ logger,
2312
+ )
2313
+ _emit_message(f"最终训练损失:{stats.final_train_loss}", logger)
2314
+ if stats.best_metric is not None:
2315
+ _emit_message(f"验证集 Macro F1:{stats.best_metric}", logger)
2316
+ if eval_report is not None:
2317
+ _emit_message(eval_report.format_summary(), logger)
2318
+ if eval_csv_path is not None:
2319
+ _emit_message(f"评估报告(CSV):{eval_csv_path}", logger)
2320
+
2321
+
2029
2322
  def _print_clustering_stats(
2030
2323
  stats: ClusteringStats,
2031
2324
  output_path: Path,
@@ -2179,6 +2472,7 @@ def _write_meta(
2179
2472
  clustering_stats: ClusteringStats | None = None,
2180
2473
  split_stats: SplitStats | None = None,
2181
2474
  sft_conversion_stats: SftConversionStats | None = None,
2475
+ training_stats: TextClassificationStats | None = None,
2182
2476
  match_output_path: Path | None = None,
2183
2477
  cluster_summary_path: Path | None = None,
2184
2478
  projection_path: Path | None = None,
@@ -2259,6 +2553,24 @@ def _write_meta(
2259
2553
  "skipped_blank_output_rows": sft_conversion_stats.skipped_blank_output_rows,
2260
2554
  "skipped_rows": sft_conversion_stats.skipped_rows,
2261
2555
  }
2556
+ if training_stats is not None:
2557
+ meta["training_stats"] = {
2558
+ "base_model_path": training_stats.base_model_path,
2559
+ "model_output_dir": training_stats.model_output_dir,
2560
+ "target_column": training_stats.target_column,
2561
+ "label_column": training_stats.label_column,
2562
+ "num_labels": training_stats.num_labels,
2563
+ "label_names": training_stats.label_names,
2564
+ "train_rows": training_stats.train_rows,
2565
+ "valid_rows": training_stats.valid_rows,
2566
+ "num_train_epochs": training_stats.num_train_epochs,
2567
+ "per_device_batch_size": training_stats.per_device_batch_size,
2568
+ "learning_rate": training_stats.learning_rate,
2569
+ "max_length": training_stats.max_length,
2570
+ "device": training_stats.device,
2571
+ "final_train_loss": training_stats.final_train_loss,
2572
+ "best_metric": training_stats.best_metric,
2573
+ }
2262
2574
  if match_output_path is not None and match_output_path.exists():
2263
2575
  meta["match_file"] = str(match_output_path)
2264
2576
  if cluster_summary_path is not None and cluster_summary_path.exists():
@@ -22,6 +22,8 @@ _PATH_LIKE_KEYS = {
22
22
  "embedding_model_path",
23
23
  "train_model_path",
24
24
  "test_model_path",
25
+ "base_model_path",
26
+ "valid_file",
25
27
  "system_prompt_file",
26
28
  "sft_system_prompt_file",
27
29
  }
@@ -37,6 +37,7 @@ EXPECTED_RESULT_COLUMN = "预期结果"
37
37
  MATCH_EXPECTED_COLUMN = "匹配预期"
38
38
  MODEL_CALL_TIME_COLUMN = "模型调用时间"
39
39
  DEFAULT_MODEL_TEST_MODE = "local"
40
+ DEFAULT_CLASSIFICATION_MAX_LENGTH = 128
40
41
  DEFAULT_MODEL_TEST_API_BASE_URL = "https://api.openai.com/v1"
41
42
  DEFAULT_HTTP_API_KEY_ENV_VAR = "HTTP_API_KEY"
42
43
  DEFAULT_HTTP_API_KEY_HEADER = "api_key"
@@ -496,7 +497,19 @@ def model_test_dataframe(
496
497
  expected_results = (
497
498
  dataframe[expected_result_column].tolist() if has_expected_result else [None] * len(dataframe)
498
499
  )
499
- if runtime_config.mode in {"openai", "http"}:
500
+ effective_mode = runtime_config.mode
501
+ if effective_mode == "local" and _is_classification_model_dir(model_path):
502
+ effective_mode = "classification"
503
+
504
+ if effective_mode == "classification":
505
+ model_results, model_call_times, device_used = _run_classification_batches(
506
+ prompts=prompts,
507
+ model_path=model_path,
508
+ runtime_config=runtime_config,
509
+ progress_callback=progress_callback,
510
+ )
511
+ worker_count = 1
512
+ elif effective_mode in {"openai", "http"}:
500
513
  model_results, model_call_times, device_used = _run_openai_batches(
501
514
  prompts=prompts,
502
515
  model_path=model_path,
@@ -597,6 +610,43 @@ def _run_openai_batches(
597
610
  return model_results, model_call_times, tester.device
598
611
 
599
612
 
613
+ def _is_classification_model_dir(model_path: str | Path) -> bool:
614
+ # 延迟导入以避免与 model_training -> model_eval -> model_testing 形成循环依赖。
615
+ from mysphinx_forge.model_training import is_classification_model_dir
616
+
617
+ candidate = Path(model_path)
618
+ return candidate.is_dir() and is_classification_model_dir(candidate)
619
+
620
+
621
+ def _run_classification_batches(
622
+ prompts: list[object],
623
+ model_path: str | Path,
624
+ runtime_config: ModelTestRuntimeConfig,
625
+ progress_callback: Callable[[int], None] | None = None,
626
+ ) -> tuple[list[str], list[float], str]:
627
+ from mysphinx_forge.model_training import _resolve_training_device, predict_intent
628
+
629
+ try:
630
+ import torch
631
+ except ImportError as exc:
632
+ raise ValueError("未安装分类推理所需依赖,请先执行 uv sync --extra train。") from exc
633
+
634
+ device = _resolve_training_device(torch)
635
+ started_at = time.perf_counter()
636
+ model_results = predict_intent(
637
+ model_path,
638
+ list(prompts),
639
+ batch_size=runtime_config.batch_size,
640
+ max_length=DEFAULT_CLASSIFICATION_MAX_LENGTH,
641
+ device=device,
642
+ progress_callback=progress_callback,
643
+ )
644
+ elapsed_seconds = time.perf_counter() - started_at
645
+ per_row_seconds = round(elapsed_seconds / len(model_results), 4) if model_results else 0.0
646
+ model_call_times = [per_row_seconds] * len(model_results)
647
+ return model_results, model_call_times, device
648
+
649
+
600
650
  def _build_model_tester(
601
651
  model_path: str | Path,
602
652
  runtime_config: ModelTestRuntimeConfig,