mysphinx-forge 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/PKG-INFO +49 -8
  2. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/README.md +48 -7
  3. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cli.py +113 -28
  4. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/file_io.py +25 -10
  5. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/sft_dataset.py +17 -7
  6. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/splitting.py +3 -2
  7. mysphinx_forge-0.2.2/mysphinx_forge/templates/__init__.py +0 -0
  8. mysphinx_forge-0.2.2/mysphinx_forge/templates/mysphinx-forge.yaml +411 -0
  9. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/PKG-INFO +49 -8
  10. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/SOURCES.txt +2 -0
  11. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/pyproject.toml +4 -1
  12. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cli.py +54 -0
  13. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_file_io.py +38 -19
  14. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_sft_dataset.py +18 -5
  15. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_splitting.py +17 -0
  16. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/__init__.py +0 -0
  17. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cleaning.py +0 -0
  18. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cluster_labeling.py +0 -0
  19. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cluster_reporting.py +0 -0
  20. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/clustering.py +0 -0
  21. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/config.py +0 -0
  22. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/deduplication.py +0 -0
  23. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/embedding.py +0 -0
  24. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/env_utils.py +0 -0
  25. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/http_client.py +0 -0
  26. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/logging_utils.py +0 -0
  27. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/model_eval.py +0 -0
  28. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/model_testing.py +0 -0
  29. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/openai_responses.py +0 -0
  30. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/progress.py +0 -0
  31. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/semantic_deduplication.py +0 -0
  32. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/dependency_links.txt +0 -0
  33. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/entry_points.txt +0 -0
  34. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/requires.txt +0 -0
  35. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/top_level.txt +0 -0
  36. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/setup.cfg +0 -0
  37. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cleaning.py +0 -0
  38. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cluster_labeling.py +0 -0
  39. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cluster_reporting.py +0 -0
  40. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_clustering.py +0 -0
  41. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_deduplication.py +0 -0
  42. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_http_client.py +0 -0
  43. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_model_eval.py +0 -0
  44. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_model_testing.py +0 -0
  45. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_semantic_deduplication.py +0 -0
  46. {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_sft_cli.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mysphinx-forge
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
5
5
  Keywords: data-cleaning,deduplication,clustering,nlp,cli
6
6
  Classifier: Development Status :: 3 - Alpha
@@ -166,7 +166,25 @@ uv run python main.py ...
166
166
 
167
167
  除了命令行参数,也可以在当前目录放一个 `mysphinx-forge.yaml`,或通过 `--config <path>` 显式指定配置文件。
168
168
 
169
- 仓库根目录已经提供了一份带完整注释的 [mysphinx-forge.yaml](./mysphinx-forge.yaml),可以直接按需修改。
169
+ 在你的项目目录下运行以下命令,可以生成一份带完整注释的配置模版:
170
+
171
+ ```bash
172
+ mysphinx-forge init
173
+ ```
174
+
175
+ 如需生成到指定路径:
176
+
177
+ ```bash
178
+ mysphinx-forge init -o configs/mysphinx-forge.yaml
179
+ ```
180
+
181
+ 如需覆盖已有文件:
182
+
183
+ ```bash
184
+ mysphinx-forge init --force
185
+ ```
186
+
187
+ 生成的模版包含所有参数的默认值和说明,修改 `action` 和 `input_file` 后即可直接运行。
170
188
 
171
189
  优先级规则:
172
190
 
@@ -254,6 +272,19 @@ cp .env.example .env
254
272
 
255
273
  ## 常用命令
256
274
 
275
+ ### 生成配置模版
276
+
277
+ ```bash
278
+ mysphinx-forge init
279
+ ```
280
+
281
+ 在当前目录生成 `mysphinx-forge.yaml` 配置模版,包含所有参数的默认值和注释说明。修改 `action` 和 `input_file` 后即可直接运行 `mysphinx-forge`,无需每次手动输入命令行参数。
282
+
283
+ ```bash
284
+ mysphinx-forge init -o configs/mysphinx-forge.yaml # 指定输出路径
285
+ mysphinx-forge init --force # 覆盖已有文件
286
+ ```
287
+
257
288
  ### 数据清洗
258
289
 
259
290
  `clean` 会删除目标列中的以下行:
@@ -449,13 +480,19 @@ data/input_deduplicated_split_train_pa.jsonl
449
480
  mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-system-prompt "你是证券领域用户意图识别专家。"
450
481
  ```
451
482
 
452
- **自动切分**:当转换结果超过 10000 条时,自动按 10000 条一份切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
483
+ **自动切分**:当转换结果超过阈值(默认 10000)时,自动按阈值切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
453
484
 
454
485
  ```
455
486
  input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
456
487
  input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
457
488
  ```
458
489
 
490
+ 通过 `--sft-pa-max-records-per-file` 可自定义阈值:
491
+
492
+ ```bash
493
+ mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
494
+ ```
495
+
459
496
  说明:
460
497
 
461
498
  - 支持 `alpaca`(默认)和 `pa` 两种格式,通过 `--sft-format` 切换
@@ -508,11 +545,15 @@ OPENAI_API_KEY=... mysphinx-forge --action cluster --input-file data/input_dedup
508
545
  mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
509
546
  ```
510
547
 
511
- 如果输入是 Excel,多 sheet 会默认合并;其中名字精确等于 `increment` 的 sheet 会被视为增量数据:
548
+ 如果输入是 Excel,多 sheet 会默认合并。以下三个特殊 sheet 名(大小写不敏感)用于注入数据,不参与比例切分:
549
+
550
+ | sheet 名 | 注入目标 |
551
+ |----------|----------|
552
+ | `train` | 全量追加到训练集 |
553
+ | `valid` | 全量追加到验证集 |
554
+ | `test` | 全量追加到测试集 |
512
555
 
513
- - `increment` 不参与 train / valid / test 比例切分
514
- - `increment` 会全量追加到 train 和 valid
515
- - `increment` 不会进入 test
556
+ 三者均为可选,可以同时存在,也可以只有其中一个或多个。
516
557
 
517
558
  显式分层切分:
518
559
 
@@ -800,7 +841,7 @@ mysphinx-forge --action convert-sft --input-file data/raw_deduplicated_split_val
800
841
 
801
842
  - `convert-sft` 默认会自动探测输出列:`category`、`label`、`intent`、`output`、`response`、`answer`、`target`
802
843
  - 如果输入是 Excel,多 sheet 会默认合并
803
- - 如果存在名字精确等于 `increment` sheet,它在 `split` 时不会参与比例切分,而是会全量注入 `train` `valid`
844
+ - 如果存在名为 `train` / `valid` / `test` sheet(大小写不敏感),它们在 `split` 时不参与比例切分,而是分别全量注入对应的集合
804
845
 
805
846
  ### 场景 2:收到一份原始 Excel,但只有问题文本,没有标签
806
847
 
@@ -129,7 +129,25 @@ uv run python main.py ...
129
129
 
130
130
  除了命令行参数,也可以在当前目录放一个 `mysphinx-forge.yaml`,或通过 `--config <path>` 显式指定配置文件。
131
131
 
132
- 仓库根目录已经提供了一份带完整注释的 [mysphinx-forge.yaml](./mysphinx-forge.yaml),可以直接按需修改。
132
+ 在你的项目目录下运行以下命令,可以生成一份带完整注释的配置模版:
133
+
134
+ ```bash
135
+ mysphinx-forge init
136
+ ```
137
+
138
+ 如需生成到指定路径:
139
+
140
+ ```bash
141
+ mysphinx-forge init -o configs/mysphinx-forge.yaml
142
+ ```
143
+
144
+ 如需覆盖已有文件:
145
+
146
+ ```bash
147
+ mysphinx-forge init --force
148
+ ```
149
+
150
+ 生成的模版包含所有参数的默认值和说明,修改 `action` 和 `input_file` 后即可直接运行。
133
151
 
134
152
  优先级规则:
135
153
 
@@ -217,6 +235,19 @@ cp .env.example .env
217
235
 
218
236
  ## 常用命令
219
237
 
238
+ ### 生成配置模版
239
+
240
+ ```bash
241
+ mysphinx-forge init
242
+ ```
243
+
244
+ 在当前目录生成 `mysphinx-forge.yaml` 配置模版,包含所有参数的默认值和注释说明。修改 `action` 和 `input_file` 后即可直接运行 `mysphinx-forge`,无需每次手动输入命令行参数。
245
+
246
+ ```bash
247
+ mysphinx-forge init -o configs/mysphinx-forge.yaml # 指定输出路径
248
+ mysphinx-forge init --force # 覆盖已有文件
249
+ ```
250
+
220
251
  ### 数据清洗
221
252
 
222
253
  `clean` 会删除目标列中的以下行:
@@ -412,13 +443,19 @@ data/input_deduplicated_split_train_pa.jsonl
412
443
  mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-system-prompt "你是证券领域用户意图识别专家。"
413
444
  ```
414
445
 
415
- **自动切分**:当转换结果超过 10000 条时,自动按 10000 条一份切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
446
+ **自动切分**:当转换结果超过阈值(默认 10000)时,自动按阈值切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
416
447
 
417
448
  ```
418
449
  input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
419
450
  input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
420
451
  ```
421
452
 
453
+ 通过 `--sft-pa-max-records-per-file` 可自定义阈值:
454
+
455
+ ```bash
456
+ mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
457
+ ```
458
+
422
459
  说明:
423
460
 
424
461
  - 支持 `alpaca`(默认)和 `pa` 两种格式,通过 `--sft-format` 切换
@@ -471,11 +508,15 @@ OPENAI_API_KEY=... mysphinx-forge --action cluster --input-file data/input_dedup
471
508
  mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
472
509
  ```
473
510
 
474
- 如果输入是 Excel,多 sheet 会默认合并;其中名字精确等于 `increment` 的 sheet 会被视为增量数据:
511
+ 如果输入是 Excel,多 sheet 会默认合并。以下三个特殊 sheet 名(大小写不敏感)用于注入数据,不参与比例切分:
512
+
513
+ | sheet 名 | 注入目标 |
514
+ |----------|----------|
515
+ | `train` | 全量追加到训练集 |
516
+ | `valid` | 全量追加到验证集 |
517
+ | `test` | 全量追加到测试集 |
475
518
 
476
- - `increment` 不参与 train / valid / test 比例切分
477
- - `increment` 会全量追加到 train 和 valid
478
- - `increment` 不会进入 test
519
+ 三者均为可选,可以同时存在,也可以只有其中一个或多个。
479
520
 
480
521
  显式分层切分:
481
522
 
@@ -763,7 +804,7 @@ mysphinx-forge --action convert-sft --input-file data/raw_deduplicated_split_val
763
804
 
764
805
  - `convert-sft` 默认会自动探测输出列:`category`、`label`、`intent`、`output`、`response`、`answer`、`target`
765
806
  - 如果输入是 Excel,多 sheet 会默认合并
766
- - 如果存在名字精确等于 `increment` sheet,它在 `split` 时不会参与比例切分,而是会全量注入 `train` `valid`
807
+ - 如果存在名为 `train` / `valid` / `test` sheet(大小写不敏感),它们在 `split` 时不参与比例切分,而是分别全量注入对应的集合
767
808
 
768
809
  ### 场景 2:收到一份原始 Excel,但只有问题文本,没有标签
769
810
 
@@ -31,7 +31,9 @@ from mysphinx_forge.config import (
31
31
  from mysphinx_forge.deduplication import DeduplicationStats, deduplicate_dataframe
32
32
  from mysphinx_forge.env_utils import load_project_env_files
33
33
  from mysphinx_forge.file_io import (
34
- INCREMENT_SHEET_NAME,
34
+ TEST_SHEET_NAME,
35
+ TRAIN_SHEET_NAME,
36
+ VALID_SHEET_NAME,
35
37
  append_dataframe_chunk,
36
38
  count_csv_rows,
37
39
  iter_dataframes,
@@ -61,6 +63,7 @@ from mysphinx_forge.semantic_deduplication import (
61
63
  )
62
64
  from mysphinx_forge.sft_dataset import (
63
65
  DEFAULT_SFT_FORMAT,
66
+ PA_MAX_RECORDS_PER_FILE,
64
67
  PA_SFT_FORMAT,
65
68
  SftConversionStats,
66
69
  convert_dataframe_to_alpaca,
@@ -89,6 +92,10 @@ _ACTION_CHOICES = [
89
92
  def main() -> int:
90
93
  load_project_env_files()
91
94
  raw_argv = sys.argv[1:]
95
+
96
+ if raw_argv and raw_argv[0] == "init":
97
+ return _run_init(raw_argv[1:])
98
+
92
99
  bootstrap_args, _ = _build_bootstrap_parser().parse_known_args(raw_argv)
93
100
 
94
101
  try:
@@ -305,12 +312,63 @@ def main() -> int:
305
312
  resolved_sft_system_prompt,
306
313
  args.sft_system_column,
307
314
  args.sft_user_query_as_instruction,
315
+ args.sft_pa_max_records_per_file,
308
316
  )
309
317
 
310
318
  parser.print_help()
311
319
  return 1
312
320
 
313
321
 
322
+ def _run_init(argv: list[str]) -> int:
323
+ init_parser = argparse.ArgumentParser(
324
+ prog="mysphinx-forge init",
325
+ description="在当前目录生成 mysphinx-forge.yaml 配置模版。",
326
+ )
327
+ init_parser.add_argument(
328
+ "--output",
329
+ "-o",
330
+ default=DEFAULT_CONFIG_FILE_NAME,
331
+ help=f"输出文件路径,默认为当前目录下的 {DEFAULT_CONFIG_FILE_NAME}。",
332
+ )
333
+ init_parser.add_argument(
334
+ "--force",
335
+ action="store_true",
336
+ help="若目标文件已存在,强制覆盖。",
337
+ )
338
+ args = init_parser.parse_args(argv)
339
+
340
+ output_path = Path(args.output)
341
+ if not output_path.is_absolute():
342
+ output_path = Path.cwd() / output_path
343
+
344
+ if output_path.exists() and not args.force:
345
+ print(f"文件已存在:{output_path}")
346
+ print("如需覆盖,请添加 --force 参数。")
347
+ return 1
348
+
349
+ try:
350
+ from importlib.resources import files as _pkg_files
351
+ template_text = (
352
+ _pkg_files("mysphinx_forge.templates")
353
+ .joinpath("mysphinx-forge.yaml")
354
+ .read_text(encoding="utf-8")
355
+ )
356
+ except Exception as exc:
357
+ print(f"读取配置模版失败:{exc}")
358
+ return 1
359
+
360
+ try:
361
+ output_path.parent.mkdir(parents=True, exist_ok=True)
362
+ output_path.write_text(template_text, encoding="utf-8")
363
+ except OSError as exc:
364
+ print(f"写出配置模版失败:{output_path},{type(exc).__name__}: {exc}")
365
+ return 1
366
+
367
+ print(f"已生成配置模版:{output_path}")
368
+ print("请编辑文件,将 action 和 input_file 替换为实际值后直接运行 mysphinx-forge。")
369
+ return 0
370
+
371
+
314
372
  def _build_bootstrap_parser() -> argparse.ArgumentParser:
315
373
  parser = argparse.ArgumentParser(add_help=False)
316
374
  parser.add_argument("--config", default="")
@@ -656,6 +714,13 @@ def _build_parser(
656
714
  default=config_defaults.get("sft_user_query_as_instruction", True),
657
715
  help="为 true 时将用户输入作为 alpaca instruction 字段,input 字段留空;为 false 时保持原有行为(input 存用户输入,instruction 为固定文本)。默认 true。",
658
716
  )
717
+ parser.add_argument(
718
+ "--sft-pa-max-records-per-file",
719
+ type=int,
720
+ dest="sft_pa_max_records_per_file",
721
+ default=config_defaults.get("sft_pa_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
722
+ help=f"pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件,默认 {PA_MAX_RECORDS_PER_FILE}。",
723
+ )
659
724
  return parser
660
725
 
661
726
 
@@ -1366,9 +1431,8 @@ def _run_split(
1366
1431
 
1367
1432
  try:
1368
1433
  run_stage("读取文件", logger=logger)
1369
- dataframe, increment_dataframe = load_split_dataframes(
1434
+ dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(
1370
1435
  input_file,
1371
- increment_sheet_name=INCREMENT_SHEET_NAME,
1372
1436
  )
1373
1437
  resolved_split_mode, resolved_stratify_column = resolve_auto_split_mode(
1374
1438
  dataframe,
@@ -1387,15 +1451,21 @@ def _run_split(
1387
1451
  time_column=time_column,
1388
1452
  time_ascending=time_ascending,
1389
1453
  )
1390
- if not increment_dataframe.empty:
1391
- train_df = pd.concat([train_df, increment_dataframe], ignore_index=True)
1392
- validation_df = pd.concat([validation_df, increment_dataframe], ignore_index=True)
1393
- stats.total_rows += len(increment_dataframe)
1454
+ if not train_inject_df.empty:
1455
+ train_df = pd.concat([train_df, train_inject_df], ignore_index=True)
1456
+ stats.inject_train_rows = len(train_inject_df)
1457
+ if not valid_inject_df.empty:
1458
+ validation_df = pd.concat([validation_df, valid_inject_df], ignore_index=True)
1459
+ stats.inject_valid_rows = len(valid_inject_df)
1460
+ if not test_inject_df.empty:
1461
+ test_df = pd.concat([test_df, test_inject_df], ignore_index=True)
1462
+ stats.inject_test_rows = len(test_inject_df)
1463
+ total_inject = stats.inject_train_rows + stats.inject_valid_rows + stats.inject_test_rows
1464
+ if total_inject > 0:
1465
+ stats.total_rows += total_inject
1394
1466
  stats.train_rows = len(train_df)
1395
1467
  stats.validation_rows = len(validation_df)
1396
1468
  stats.test_rows = len(test_df)
1397
- stats.increment_rows = len(increment_dataframe)
1398
- stats.increment_sheet_name = INCREMENT_SHEET_NAME
1399
1469
  except ValueError as exc:
1400
1470
  _emit_error(str(exc), logger)
1401
1471
  close_logger()
@@ -1405,8 +1475,15 @@ def _run_split(
1405
1475
  write_dataframe(train_df, train_output_path)
1406
1476
  run_stage("写出 valid", logger=logger)
1407
1477
  write_dataframe(validation_df, validation_output_path)
1408
- run_stage("写出 test", logger=logger)
1409
- write_dataframe(test_df, test_output_path)
1478
+ if test_ratio > 0:
1479
+ run_stage("写出 test", logger=logger)
1480
+ write_dataframe(test_df, test_output_path)
1481
+ extra_output_files: dict[str, Path] = {
1482
+ "train_file": train_output_path,
1483
+ "validation_file": validation_output_path,
1484
+ }
1485
+ if test_ratio > 0:
1486
+ extra_output_files["test_file"] = test_output_path
1410
1487
  _write_meta(
1411
1488
  output_path=base_output_path,
1412
1489
  action="split",
@@ -1421,21 +1498,18 @@ def _run_split(
1421
1498
  "group_column": group_column,
1422
1499
  "time_column": time_column,
1423
1500
  "time_order": "asc" if time_ascending else "desc",
1424
- "increment_sheet_name": stats.increment_sheet_name,
1425
- "increment_rows": stats.increment_rows,
1501
+ "inject_train_rows": stats.inject_train_rows,
1502
+ "inject_valid_rows": stats.inject_valid_rows,
1503
+ "inject_test_rows": stats.inject_test_rows,
1426
1504
  },
1427
1505
  split_stats=stats,
1428
- extra_output_files={
1429
- "train_file": train_output_path,
1430
- "validation_file": validation_output_path,
1431
- "test_file": test_output_path,
1432
- },
1506
+ extra_output_files=extra_output_files,
1433
1507
  )
1434
1508
  _print_split_stats(
1435
1509
  stats,
1436
1510
  train_output_path=train_output_path,
1437
1511
  validation_output_path=validation_output_path,
1438
- test_output_path=test_output_path,
1512
+ test_output_path=test_output_path if test_ratio > 0 else None,
1439
1513
  logger=logger,
1440
1514
  )
1441
1515
  close_logger()
@@ -1744,6 +1818,7 @@ def _run_convert_sft(
1744
1818
  sft_system_prompt: str,
1745
1819
  sft_system_column: str,
1746
1820
  sft_user_query_as_instruction: bool = True,
1821
+ sft_pa_max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
1747
1822
  ) -> int:
1748
1823
  input_path = Path(input_file)
1749
1824
  output_path = _resolve_sft_output_path(input_path, output_arg, sft_format)
@@ -1764,6 +1839,7 @@ def _run_convert_sft(
1764
1839
  dataframe,
1765
1840
  target_column=target_column,
1766
1841
  output_column=sft_output_column,
1842
+ instruction=sft_instruction,
1767
1843
  system_prompt=sft_system_prompt,
1768
1844
  system_column=sft_system_column,
1769
1845
  )
@@ -1786,7 +1862,9 @@ def _run_convert_sft(
1786
1862
 
1787
1863
  run_stage("写出结果", logger=logger)
1788
1864
  if sft_format == PA_SFT_FORMAT:
1789
- written_paths = write_pa_dataset(records, output_path)
1865
+ written_paths = write_pa_dataset(
1866
+ records, output_path, max_records_per_file=sft_pa_max_records_per_file
1867
+ )
1790
1868
  else:
1791
1869
  write_alpaca_dataset(records, output_path)
1792
1870
  written_paths = [output_path]
@@ -1802,6 +1880,7 @@ def _run_convert_sft(
1802
1880
  "sft_system_prompt": sft_system_prompt,
1803
1881
  "sft_system_column": sft_system_column,
1804
1882
  "sft_user_query_as_instruction": sft_user_query_as_instruction,
1883
+ "sft_pa_max_records_per_file": sft_pa_max_records_per_file,
1805
1884
  },
1806
1885
  sft_conversion_stats=stats,
1807
1886
  extra_output_files={f"output_file_{i + 1}": p for i, p in enumerate(written_paths)}
@@ -2004,7 +2083,7 @@ def _print_split_stats(
2004
2083
  *,
2005
2084
  train_output_path: Path,
2006
2085
  validation_output_path: Path,
2007
- test_output_path: Path,
2086
+ test_output_path: Path | None,
2008
2087
  logger: Logger,
2009
2088
  ) -> None:
2010
2089
  _emit_message("数据切分完成", logger)
@@ -2016,17 +2095,22 @@ def _print_split_stats(
2016
2095
  if stats.time_column:
2017
2096
  _emit_message(f"时间列:{stats.time_column}", logger)
2018
2097
  _emit_message(f"时间顺序:{'asc' if stats.time_ascending else 'desc'}", logger)
2019
- if stats.increment_rows > 0:
2020
- _emit_message(f"增量工作表:{stats.increment_sheet_name}", logger)
2021
- _emit_message(f"增量注入行数:{stats.increment_rows}", logger)
2098
+ if stats.inject_train_rows > 0:
2099
+ _emit_message(f"注入训练集行数({TRAIN_SHEET_NAME!r} sheet):{stats.inject_train_rows}", logger)
2100
+ if stats.inject_valid_rows > 0:
2101
+ _emit_message(f"注入验证集行数({VALID_SHEET_NAME!r} sheet):{stats.inject_valid_rows}", logger)
2102
+ if stats.inject_test_rows > 0:
2103
+ _emit_message(f"注入测试集行数({TEST_SHEET_NAME!r} sheet):{stats.inject_test_rows}", logger)
2022
2104
  _emit_message(f"随机种子:{stats.random_seed}", logger)
2023
2105
  _emit_message(f"总行数:{stats.total_rows}", logger)
2024
2106
  _emit_message(f"训练集行数:{stats.train_rows}", logger)
2025
2107
  _emit_message(f"验证集行数:{stats.validation_rows}", logger)
2026
- _emit_message(f"测试集行数:{stats.test_rows}", logger)
2108
+ if test_output_path is not None:
2109
+ _emit_message(f"测试集行数:{stats.test_rows}", logger)
2027
2110
  _emit_message(f"训练集文件:{train_output_path}", logger)
2028
2111
  _emit_message(f"验证集文件:{validation_output_path}", logger)
2029
- _emit_message(f"测试集文件:{test_output_path}", logger)
2112
+ if test_output_path is not None:
2113
+ _emit_message(f"测试集文件:{test_output_path}", logger)
2030
2114
 
2031
2115
 
2032
2116
  def _print_sft_conversion_stats(
@@ -2134,8 +2218,9 @@ def _write_meta(
2134
2218
  "group_column": split_stats.group_column,
2135
2219
  "time_column": split_stats.time_column,
2136
2220
  "time_ascending": split_stats.time_ascending,
2137
- "increment_rows": split_stats.increment_rows,
2138
- "increment_sheet_name": split_stats.increment_sheet_name,
2221
+ "inject_train_rows": split_stats.inject_train_rows,
2222
+ "inject_valid_rows": split_stats.inject_valid_rows,
2223
+ "inject_test_rows": split_stats.inject_test_rows,
2139
2224
  }
2140
2225
  if sft_conversion_stats is not None:
2141
2226
  meta["sft_conversion_stats"] = {
@@ -14,7 +14,10 @@ if TYPE_CHECKING:
14
14
 
15
15
 
16
16
  SUPPORTED_EXTENSIONS = {".csv", ".xls", ".xlsx", ".xlsm"}
17
- INCREMENT_SHEET_NAME = "increment"
17
+ TRAIN_SHEET_NAME = "train"
18
+ VALID_SHEET_NAME = "valid"
19
+ TEST_SHEET_NAME = "test"
20
+ _INJECT_SHEET_NAMES = {TRAIN_SHEET_NAME, VALID_SHEET_NAME, TEST_SHEET_NAME}
18
21
  DEFAULT_PROGRESS_COLOURS = [
19
22
  "red",
20
23
  "green",
@@ -44,28 +47,40 @@ def load_dataframe(file_path: str | Path) -> pd.DataFrame:
44
47
 
45
48
  def load_split_dataframes(
46
49
  file_path: str | Path,
47
- *,
48
- increment_sheet_name: str = INCREMENT_SHEET_NAME,
49
- ) -> tuple[pd.DataFrame, pd.DataFrame]:
50
+ ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
50
51
  path = validate_tabular_file(file_path)
52
+ empty = pd.DataFrame()
51
53
  if path.suffix.lower() == ".csv":
52
54
  dataframe = pd.read_csv(path, skip_blank_lines=False)
53
- return dataframe, dataframe.head(0).copy()
55
+ return dataframe, empty, empty, empty
54
56
 
55
57
  regular_frames: list[pd.DataFrame] = []
56
- increment_frames: list[pd.DataFrame] = []
58
+ train_frames: list[pd.DataFrame] = []
59
+ valid_frames: list[pd.DataFrame] = []
60
+ test_frames: list[pd.DataFrame] = []
57
61
  for sheet_name, dataframe in _load_excel_sheets(path).items():
58
- if sheet_name == increment_sheet_name:
59
- increment_frames.append(dataframe)
62
+ lower = sheet_name.lower()
63
+ if lower == TRAIN_SHEET_NAME:
64
+ train_frames.append(dataframe)
65
+ elif lower == VALID_SHEET_NAME:
66
+ valid_frames.append(dataframe)
67
+ elif lower == TEST_SHEET_NAME:
68
+ test_frames.append(dataframe)
60
69
  else:
61
70
  regular_frames.append(dataframe)
62
71
 
63
72
  if not regular_frames:
73
+ inject_names = ", ".join(f"{n!r}" for n in sorted(_INJECT_SHEET_NAMES))
64
74
  raise ValueError(
65
- f"除工作表 {increment_sheet_name!r} 外没有可切分的 Excel sheet。"
75
+ f"除注入工作表({inject_names})外没有可切分的 Excel sheet。"
66
76
  )
67
77
 
68
- return _concat_excel_frames(regular_frames), _concat_excel_frames(increment_frames)
78
+ return (
79
+ _concat_excel_frames(regular_frames),
80
+ _concat_excel_frames(train_frames),
81
+ _concat_excel_frames(valid_frames),
82
+ _concat_excel_frames(test_frames),
83
+ )
69
84
 
70
85
 
71
86
  def iter_dataframes(file_path: str | Path, chunksize: int = 50_000) -> Iterable[pd.DataFrame]:
@@ -114,6 +114,7 @@ def convert_dataframe_to_pa(
114
114
  *,
115
115
  target_column: str = "text",
116
116
  output_column: str = "",
117
+ instruction: str = "",
117
118
  system_prompt: str = "",
118
119
  system_column: str = "",
119
120
  ) -> tuple[list[dict], SftConversionStats]:
@@ -121,6 +122,7 @@ def convert_dataframe_to_pa(
121
122
  resolved_output_column = resolve_sft_output_column(dataframe, output_column)
122
123
  resolved_system_column = _resolve_optional_column(dataframe, system_column)
123
124
 
125
+ fixed_instruction = instruction.strip()
124
126
  final_system_prompt = system_prompt.strip()
125
127
  records: list[dict] = []
126
128
  skipped_blank_input_rows = 0
@@ -146,7 +148,8 @@ def convert_dataframe_to_pa(
146
148
  if system_text:
147
149
  conversations.append({"context": system_text, "role": "system"})
148
150
 
149
- conversations.append({"context": input_text, "role": "human"})
151
+ human_text = f"{fixed_instruction}\n{input_text}" if fixed_instruction else input_text
152
+ conversations.append({"context": human_text, "role": "human"})
150
153
  conversations.append({"context": output_text, "role": "assistant"})
151
154
 
152
155
  records.append({"conversations": conversations, "id": str(len(records) + 1)})
@@ -163,14 +166,21 @@ def convert_dataframe_to_pa(
163
166
  return records, stats
164
167
 
165
168
 
166
- def write_pa_dataset(records: list[dict], output_path: str | Path) -> list[Path]:
167
- """Write PA-format records as one or more JSONL files split at PA_MAX_RECORDS_PER_FILE.
169
+ def write_pa_dataset(
170
+ records: list[dict],
171
+ output_path: str | Path,
172
+ *,
173
+ max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
174
+ ) -> list[Path]:
175
+ """Write PA-format records as one or more JSONL files.
168
176
 
169
- Returns the list of paths written.
177
+ When len(records) <= max_records_per_file a single file is written.
178
+ Otherwise the records are split into numbered chunks, e.g. *_1.jsonl,
179
+ *_2.jsonl, ... Returns the list of paths written.
170
180
  """
171
181
  output_path = Path(output_path)
172
182
  total = len(records)
173
- if total <= PA_MAX_RECORDS_PER_FILE:
183
+ if total <= max_records_per_file:
174
184
  _write_pa_jsonl(records, output_path)
175
185
  return [output_path]
176
186
 
@@ -179,8 +189,8 @@ def write_pa_dataset(records: list[dict], output_path: str | Path) -> list[Path]
179
189
  parent = output_path.parent
180
190
  written: list[Path] = []
181
191
  chunk_index = 1
182
- for start in range(0, total, PA_MAX_RECORDS_PER_FILE):
183
- chunk = records[start : start + PA_MAX_RECORDS_PER_FILE]
192
+ for start in range(0, total, max_records_per_file):
193
+ chunk = records[start : start + max_records_per_file]
184
194
  chunk_path = parent / f"{stem}_{chunk_index}{suffix}"
185
195
  _write_pa_jsonl(chunk, chunk_path)
186
196
  written.append(chunk_path)
@@ -24,8 +24,9 @@ class SplitStats:
24
24
  group_column: str | None = None
25
25
  time_column: str | None = None
26
26
  time_ascending: bool = True
27
- increment_rows: int = 0
28
- increment_sheet_name: str | None = None
27
+ inject_train_rows: int = 0
28
+ inject_valid_rows: int = 0
29
+ inject_test_rows: int = 0
29
30
 
30
31
 
31
32
  def split_dataframe(