mysphinx-forge 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/PKG-INFO +5 -3
  2. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/README.md +4 -2
  3. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/cli.py +44 -18
  4. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/file_io.py +27 -0
  5. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/templates/mysphinx-forge.yaml +4 -0
  6. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/PKG-INFO +5 -3
  7. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/pyproject.toml +1 -1
  8. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_cli.py +201 -14
  9. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/__init__.py +0 -0
  10. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/cleaning.py +0 -0
  11. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/cluster_labeling.py +0 -0
  12. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/cluster_reporting.py +0 -0
  13. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/clustering.py +0 -0
  14. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/config.py +0 -0
  15. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/deduplication.py +0 -0
  16. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/embedding.py +0 -0
  17. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/env_utils.py +0 -0
  18. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/http_client.py +0 -0
  19. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/logging_utils.py +0 -0
  20. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/model_eval.py +0 -0
  21. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/model_testing.py +0 -0
  22. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/openai_responses.py +0 -0
  23. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/progress.py +0 -0
  24. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/semantic_deduplication.py +0 -0
  25. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/sft_dataset.py +0 -0
  26. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/splitting.py +0 -0
  27. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge/templates/__init__.py +0 -0
  28. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/SOURCES.txt +0 -0
  29. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/dependency_links.txt +0 -0
  30. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/entry_points.txt +0 -0
  31. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/requires.txt +0 -0
  32. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/mysphinx_forge.egg-info/top_level.txt +0 -0
  33. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/setup.cfg +0 -0
  34. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_cleaning.py +0 -0
  35. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_cluster_labeling.py +0 -0
  36. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_cluster_reporting.py +0 -0
  37. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_clustering.py +0 -0
  38. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_deduplication.py +0 -0
  39. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_file_io.py +0 -0
  40. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_http_client.py +0 -0
  41. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_model_eval.py +0 -0
  42. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_model_testing.py +0 -0
  43. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_semantic_deduplication.py +0 -0
  44. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_sft_cli.py +0 -0
  45. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_sft_dataset.py +0 -0
  46. {mysphinx_forge-0.2.2 → mysphinx_forge-0.2.3}/tests/test_splitting.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mysphinx-forge
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
5
5
  Keywords: data-cleaning,deduplication,clustering,nlp,cli
6
6
  Classifier: Development Status :: 3 - Alpha
@@ -487,10 +487,10 @@ input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
487
487
  input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
488
488
  ```
489
489
 
490
- 通过 `--sft-pa-max-records-per-file` 可自定义阈值:
490
+ 通过 `--sft-max-records-per-file` 可自定义阈值:
491
491
 
492
492
  ```bash
493
- mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
493
+ mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-max-records-per-file 5000
494
494
  ```
495
495
 
496
496
  说明:
@@ -555,6 +555,8 @@ mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
555
555
 
556
556
  三者均为可选,可以同时存在,也可以只有其中一个或多个。
557
557
 
558
+ `clean`、`deduplicate`、`clean-deduplicate`、`cluster` 这几个 `split` 之前的步骤会原样保留这三个特殊 sheet(不参与清洗/去重/聚类处理),并在输出文件中继续以独立 sheet 的形式存在,确保依次执行整条流水线后,`split` 仍能正确识别并注入这些数据。
559
+
558
560
  显式分层切分:
559
561
 
560
562
  ```bash
@@ -450,10 +450,10 @@ input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
450
450
  input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
451
451
  ```
452
452
 
453
- 通过 `--sft-pa-max-records-per-file` 可自定义阈值:
453
+ 通过 `--sft-max-records-per-file` 可自定义阈值:
454
454
 
455
455
  ```bash
456
- mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
456
+ mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-max-records-per-file 5000
457
457
  ```
458
458
 
459
459
  说明:
@@ -518,6 +518,8 @@ mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
518
518
 
519
519
  三者均为可选,可以同时存在,也可以只有其中一个或多个。
520
520
 
521
+ `clean`、`deduplicate`、`clean-deduplicate`、`cluster` 这几个 `split` 之前的步骤会原样保留这三个特殊 sheet(不参与清洗/去重/聚类处理),并在输出文件中继续以独立 sheet 的形式存在,确保依次执行整条流水线后,`split` 仍能正确识别并注入这些数据。
522
+
521
523
  显式分层切分:
522
524
 
523
525
  ```bash
@@ -40,6 +40,7 @@ from mysphinx_forge.file_io import (
40
40
  load_dataframe,
41
41
  load_split_dataframes,
42
42
  write_dataframe,
43
+ write_dataframe_with_injection_sheets,
43
44
  write_match_rows,
44
45
  )
45
46
  from mysphinx_forge import __version__
@@ -312,7 +313,7 @@ def main() -> int:
312
313
  resolved_sft_system_prompt,
313
314
  args.sft_system_column,
314
315
  args.sft_user_query_as_instruction,
315
- args.sft_pa_max_records_per_file,
316
+ args.sft_max_records_per_file,
316
317
  )
317
318
 
318
319
  parser.print_help()
@@ -715,10 +716,10 @@ def _build_parser(
715
716
  help="为 true 时将用户输入作为 alpaca instruction 字段,input 字段留空;为 false 时保持原有行为(input 存用户输入,instruction 为固定文本)。默认 true。",
716
717
  )
717
718
  parser.add_argument(
718
- "--sft-pa-max-records-per-file",
719
+ "--sft-max-records-per-file",
719
720
  type=int,
720
- dest="sft_pa_max_records_per_file",
721
- default=config_defaults.get("sft_pa_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
721
+ dest="sft_max_records_per_file",
722
+ default=config_defaults.get("sft_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
722
723
  help=f"pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件,默认 {PA_MAX_RECORDS_PER_FILE}。",
723
724
  )
724
725
  return parser
@@ -758,7 +759,7 @@ def _run_clean(
758
759
 
759
760
  try:
760
761
  run_stage("读取文件", logger=logger)
761
- dataframe = load_dataframe(input_file)
762
+ dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(input_file)
762
763
  except ValueError as exc:
763
764
  _emit_error(str(exc), logger)
764
765
  close_logger()
@@ -784,7 +785,13 @@ def _run_clean(
784
785
  progress_bar.close()
785
786
 
786
787
  run_stage("写出结果", logger=logger)
787
- write_dataframe(cleaned, output_path)
788
+ write_dataframe_with_injection_sheets(
789
+ cleaned,
790
+ output_path,
791
+ train_inject=train_inject_df,
792
+ valid_inject=valid_inject_df,
793
+ test_inject=test_inject_df,
794
+ )
788
795
  _write_meta(
789
796
  output_path=output_path,
790
797
  action="clean",
@@ -872,7 +879,7 @@ def _run_deduplicate(
872
879
 
873
880
  try:
874
881
  run_stage("读取文件", logger=logger)
875
- dataframe = load_dataframe(input_file)
882
+ dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(input_file)
876
883
  progress_bar = ProgressBar(total=len(dataframe), description="执行去重", logger=logger)
877
884
  try:
878
885
  deduplicated, stats, match_rows = _deduplicate_dataframe(
@@ -904,7 +911,13 @@ def _run_deduplicate(
904
911
  return 1
905
912
 
906
913
  run_stage("写出结果", logger=logger)
907
- write_dataframe(deduplicated, output_path)
914
+ write_dataframe_with_injection_sheets(
915
+ deduplicated,
916
+ output_path,
917
+ train_inject=train_inject_df,
918
+ valid_inject=valid_inject_df,
919
+ test_inject=test_inject_df,
920
+ )
908
921
  write_match_rows(
909
922
  match_rows,
910
923
  _resolve_match_output_path(output_path),
@@ -979,7 +992,7 @@ def _run_clean_deduplicate(
979
992
 
980
993
  try:
981
994
  run_stage("读取文件", logger=logger)
982
- dataframe = load_dataframe(input_file)
995
+ dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(input_file)
983
996
  clean_bar = ProgressBar(total=len(dataframe), description="清洗数据", logger=logger)
984
997
  try:
985
998
  cleaned, clean_stats = clean_dataframe(
@@ -1030,7 +1043,13 @@ def _run_clean_deduplicate(
1030
1043
  return 1
1031
1044
 
1032
1045
  run_stage("写出结果", logger=logger)
1033
- write_dataframe(deduplicated, output_path)
1046
+ write_dataframe_with_injection_sheets(
1047
+ deduplicated,
1048
+ output_path,
1049
+ train_inject=train_inject_df,
1050
+ valid_inject=valid_inject_df,
1051
+ test_inject=test_inject_df,
1052
+ )
1034
1053
  write_match_rows(
1035
1054
  match_rows,
1036
1055
  _resolve_match_output_path(output_path),
@@ -1103,7 +1122,7 @@ def _run_cluster(
1103
1122
 
1104
1123
  try:
1105
1124
  run_stage("读取文件", logger=logger)
1106
- dataframe = load_dataframe(input_file)
1125
+ dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(input_file)
1107
1126
  progress_bar = ProgressBar(total=len(dataframe), description="执行聚类", logger=logger)
1108
1127
  try:
1109
1128
  clustered, cluster_summary, projection, stats = clusterer.cluster_dataframe(
@@ -1133,7 +1152,13 @@ def _run_cluster(
1133
1152
 
1134
1153
  run_stage("写出结果", logger=logger)
1135
1154
  analysis_report = build_cluster_analysis_report(cluster_summary, stats)
1136
- write_dataframe(clustered, output_path)
1155
+ write_dataframe_with_injection_sheets(
1156
+ clustered,
1157
+ output_path,
1158
+ train_inject=train_inject_df,
1159
+ valid_inject=valid_inject_df,
1160
+ test_inject=test_inject_df,
1161
+ )
1137
1162
  write_dataframe(cluster_summary, cluster_summary_path)
1138
1163
  write_dataframe(projection, projection_path)
1139
1164
  write_dataframe(analysis_report, analysis_path)
@@ -1471,18 +1496,19 @@ def _run_split(
1471
1496
  close_logger()
1472
1497
  return 1
1473
1498
 
1499
+ has_test_output = test_ratio > 0 or stats.inject_test_rows > 0
1474
1500
  run_stage("写出 train", logger=logger)
1475
1501
  write_dataframe(train_df, train_output_path)
1476
1502
  run_stage("写出 valid", logger=logger)
1477
1503
  write_dataframe(validation_df, validation_output_path)
1478
- if test_ratio > 0:
1504
+ if has_test_output:
1479
1505
  run_stage("写出 test", logger=logger)
1480
1506
  write_dataframe(test_df, test_output_path)
1481
1507
  extra_output_files: dict[str, Path] = {
1482
1508
  "train_file": train_output_path,
1483
1509
  "validation_file": validation_output_path,
1484
1510
  }
1485
- if test_ratio > 0:
1511
+ if has_test_output:
1486
1512
  extra_output_files["test_file"] = test_output_path
1487
1513
  _write_meta(
1488
1514
  output_path=base_output_path,
@@ -1509,7 +1535,7 @@ def _run_split(
1509
1535
  stats,
1510
1536
  train_output_path=train_output_path,
1511
1537
  validation_output_path=validation_output_path,
1512
- test_output_path=test_output_path if test_ratio > 0 else None,
1538
+ test_output_path=test_output_path if has_test_output else None,
1513
1539
  logger=logger,
1514
1540
  )
1515
1541
  close_logger()
@@ -1818,7 +1844,7 @@ def _run_convert_sft(
1818
1844
  sft_system_prompt: str,
1819
1845
  sft_system_column: str,
1820
1846
  sft_user_query_as_instruction: bool = True,
1821
- sft_pa_max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
1847
+ sft_max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
1822
1848
  ) -> int:
1823
1849
  input_path = Path(input_file)
1824
1850
  output_path = _resolve_sft_output_path(input_path, output_arg, sft_format)
@@ -1863,7 +1889,7 @@ def _run_convert_sft(
1863
1889
  run_stage("写出结果", logger=logger)
1864
1890
  if sft_format == PA_SFT_FORMAT:
1865
1891
  written_paths = write_pa_dataset(
1866
- records, output_path, max_records_per_file=sft_pa_max_records_per_file
1892
+ records, output_path, max_records_per_file=sft_max_records_per_file
1867
1893
  )
1868
1894
  else:
1869
1895
  write_alpaca_dataset(records, output_path)
@@ -1880,7 +1906,7 @@ def _run_convert_sft(
1880
1906
  "sft_system_prompt": sft_system_prompt,
1881
1907
  "sft_system_column": sft_system_column,
1882
1908
  "sft_user_query_as_instruction": sft_user_query_as_instruction,
1883
- "sft_pa_max_records_per_file": sft_pa_max_records_per_file,
1909
+ "sft_max_records_per_file": sft_max_records_per_file,
1884
1910
  },
1885
1911
  sft_conversion_stats=stats,
1886
1912
  extra_output_files={f"output_file_{i + 1}": p for i, p in enumerate(written_paths)}
@@ -129,6 +129,33 @@ def write_dataframe(dataframe: pd.DataFrame, output_path: str | Path) -> None:
129
129
  dataframe.to_excel(path, index=False)
130
130
 
131
131
 
132
+ def write_dataframe_with_injection_sheets(
133
+ dataframe: pd.DataFrame,
134
+ output_path: str | Path,
135
+ *,
136
+ train_inject: pd.DataFrame | None = None,
137
+ valid_inject: pd.DataFrame | None = None,
138
+ test_inject: pd.DataFrame | None = None,
139
+ ) -> None:
140
+ path = Path(output_path)
141
+ if path.suffix.lower() == ".csv":
142
+ dataframe.to_csv(path, index=False)
143
+ return
144
+
145
+ sheets = {"Sheet1": dataframe}
146
+ for sheet_name, frame in (
147
+ (TRAIN_SHEET_NAME, train_inject),
148
+ (VALID_SHEET_NAME, valid_inject),
149
+ (TEST_SHEET_NAME, test_inject),
150
+ ):
151
+ if frame is not None and not frame.empty:
152
+ sheets[sheet_name] = frame
153
+
154
+ with pd.ExcelWriter(path) as writer:
155
+ for sheet_name, frame in sheets.items():
156
+ frame.to_excel(writer, sheet_name=sheet_name, index=False)
157
+
158
+
132
159
  def write_progress_message(message: str, *, stream: TextIO | None = None) -> None:
133
160
  output_stream = stream or sys.stderr
134
161
  with tqdm.external_write_mode(file=output_stream):
@@ -409,3 +409,7 @@ convert-sft:
409
409
  # 按行读取的 system 列。
410
410
  # 若提供且该行非空,则优先使用该列值覆盖 sft_system_prompt。
411
411
  sft_system_column: ""
412
+
413
+ # pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件。
414
+ # 仅 sft_format=pa 时生效。
415
+ sft_max_records_per_file: 10000
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mysphinx-forge
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
5
5
  Keywords: data-cleaning,deduplication,clustering,nlp,cli
6
6
  Classifier: Development Status :: 3 - Alpha
@@ -487,10 +487,10 @@ input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
487
487
  input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
488
488
  ```
489
489
 
490
- 通过 `--sft-pa-max-records-per-file` 可自定义阈值:
490
+ 通过 `--sft-max-records-per-file` 可自定义阈值:
491
491
 
492
492
  ```bash
493
- mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
493
+ mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-max-records-per-file 5000
494
494
  ```
495
495
 
496
496
  说明:
@@ -555,6 +555,8 @@ mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
555
555
 
556
556
  三者均为可选,可以同时存在,也可以只有其中一个或多个。
557
557
 
558
+ `clean`、`deduplicate`、`clean-deduplicate`、`cluster` 这几个 `split` 之前的步骤会原样保留这三个特殊 sheet(不参与清洗/去重/聚类处理),并在输出文件中继续以独立 sheet 的形式存在,确保依次执行整条流水线后,`split` 仍能正确识别并注入这些数据。
559
+
558
560
  显式分层切分:
559
561
 
560
562
  ```bash
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mysphinx-forge"
7
- version = "0.2.2"
7
+ version = "0.2.3"
8
8
  description = "Data and model workflow toolkit for cleaning, clustering, generation, and evaluation"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"
@@ -387,7 +387,7 @@ def test_main_cli_overrides_config_values(tmp_path, monkeypatch, capsys) -> None
387
387
  assert meta["parameters"]["test_ratio"] == 0.1
388
388
 
389
389
 
390
- def test_main_split_injects_increment_sheet_into_train_and_valid(
390
+ def test_main_split_injects_train_valid_test_sheets_exclusively(
391
391
  tmp_path, monkeypatch, capsys
392
392
  ) -> None:
393
393
  input_file = tmp_path / "input.xlsx"
@@ -400,10 +400,22 @@ def test_main_split_injects_increment_sheet_into_train_and_valid(
400
400
  ).to_excel(writer, sheet_name="base_a", index=False)
401
401
  pd.DataFrame(
402
402
  {
403
- "text": ["增量问题1", "增量问题2"],
404
- "category": ["增量", "增量"],
403
+ "text": ["训练注入1"],
404
+ "category": ["增量"],
405
405
  }
406
- ).to_excel(writer, sheet_name="increment", index=False)
406
+ ).to_excel(writer, sheet_name="Train", index=False)
407
+ pd.DataFrame(
408
+ {
409
+ "text": ["验证注入1"],
410
+ "category": ["增量"],
411
+ }
412
+ ).to_excel(writer, sheet_name="valid", index=False)
413
+ pd.DataFrame(
414
+ {
415
+ "text": ["测试注入1"],
416
+ "category": ["增量"],
417
+ }
418
+ ).to_excel(writer, sheet_name="TEST", index=False)
407
419
 
408
420
  monkeypatch.setattr(
409
421
  sys,
@@ -425,23 +437,198 @@ def test_main_split_injects_increment_sheet_into_train_and_valid(
425
437
  captured = capsys.readouterr()
426
438
 
427
439
  assert exit_code == 0
428
- assert "增量工作表:increment" in captured.out
429
- assert "增量注入行数:2" in captured.out
440
+ assert "注入训练集行数('train' sheet):1" in captured.out
441
+ assert "注入验证集行数('valid' sheet):1" in captured.out
442
+ assert "注入测试集行数('test' sheet):1" in captured.out
430
443
 
431
444
  train = pd.read_excel(tmp_path / "input_split_train.xlsx")
432
445
  valid = pd.read_excel(tmp_path / "input_split_valid.xlsx")
433
446
  test = pd.read_excel(tmp_path / "input_split_test.xlsx")
434
447
 
435
- assert set(train["text"].tolist()) >= {"增量问题1", "增量问题2"}
436
- assert set(valid["text"].tolist()) >= {"增量问题1", "增量问题2"}
437
- assert "增量问题1" not in test["text"].tolist()
438
- assert "增量问题2" not in test["text"].tolist()
448
+ assert "训练注入1" in train["text"].tolist()
449
+ assert "训练注入1" not in valid["text"].tolist()
450
+ assert "训练注入1" not in test["text"].tolist()
451
+
452
+ assert "验证注入1" in valid["text"].tolist()
453
+ assert "验证注入1" not in train["text"].tolist()
454
+ assert "验证注入1" not in test["text"].tolist()
455
+
456
+ assert "测试注入1" in test["text"].tolist()
457
+ assert "测试注入1" not in train["text"].tolist()
458
+ assert "测试注入1" not in valid["text"].tolist()
459
+
460
+ meta = json.loads((tmp_path / "input_split.meta.json").read_text(encoding="utf-8"))
461
+ assert meta["parameters"]["inject_train_rows"] == 1
462
+ assert meta["parameters"]["inject_valid_rows"] == 1
463
+ assert meta["parameters"]["inject_test_rows"] == 1
464
+ assert meta["split_stats"]["inject_train_rows"] == 1
465
+ assert meta["split_stats"]["inject_valid_rows"] == 1
466
+ assert meta["split_stats"]["inject_test_rows"] == 1
467
+
468
+
469
+ def test_main_split_writes_test_file_when_test_ratio_zero_but_test_sheet_injected(
470
+ tmp_path, monkeypatch, capsys
471
+ ) -> None:
472
+ input_file = tmp_path / "input.xlsx"
473
+ with pd.ExcelWriter(input_file) as writer:
474
+ pd.DataFrame(
475
+ {
476
+ "text": [f"问题{i}" for i in range(6)],
477
+ "category": ["基金"] * 3 + ["股票"] * 3,
478
+ }
479
+ ).to_excel(writer, sheet_name="base_a", index=False)
480
+ pd.DataFrame(
481
+ {
482
+ "text": ["测试注入1"],
483
+ "category": ["增量"],
484
+ }
485
+ ).to_excel(writer, sheet_name="test", index=False)
486
+
487
+ monkeypatch.setattr(
488
+ sys,
489
+ "argv",
490
+ [
491
+ "main.py",
492
+ "--action",
493
+ "split",
494
+ "--input-file",
495
+ str(input_file),
496
+ "--validation-ratio",
497
+ "0.2",
498
+ "--test-ratio",
499
+ "0",
500
+ ],
501
+ )
502
+
503
+ exit_code = main()
504
+ captured = capsys.readouterr()
505
+
506
+ assert exit_code == 0
507
+ assert "注入测试集行数('test' sheet):1" in captured.out
508
+
509
+ test_output_path = tmp_path / "input_split_test.xlsx"
510
+ assert test_output_path.exists()
511
+ test = pd.read_excel(test_output_path)
512
+ assert "测试注入1" in test["text"].tolist()
439
513
 
440
514
  meta = json.loads((tmp_path / "input_split.meta.json").read_text(encoding="utf-8"))
441
- assert meta["parameters"]["increment_sheet_name"] == "increment"
442
- assert meta["parameters"]["increment_rows"] == 2
443
- assert meta["split_stats"]["increment_sheet_name"] == "increment"
444
- assert meta["split_stats"]["increment_rows"] == 2
515
+ assert meta["parameters"]["inject_test_rows"] == 1
516
+ assert "test_file" in meta["output_files"]
517
+
518
+
519
+ def test_main_clean_preserves_injection_sheets_for_downstream_split(
520
+ tmp_path, monkeypatch, capsys
521
+ ) -> None:
522
+ input_file = tmp_path / "input.xlsx"
523
+ with pd.ExcelWriter(input_file) as writer:
524
+ pd.DataFrame(
525
+ {
526
+ "text": [f"问题{i}" for i in range(6)] + ["!!!"],
527
+ "category": ["基金"] * 3 + ["股票"] * 3 + ["噪音"],
528
+ }
529
+ ).to_excel(writer, sheet_name="base_a", index=False)
530
+ pd.DataFrame({"text": ["训练注入1"], "category": ["增量"]}).to_excel(
531
+ writer, sheet_name="train", index=False
532
+ )
533
+ pd.DataFrame({"text": ["测试注入1"], "category": ["增量"]}).to_excel(
534
+ writer, sheet_name="test", index=False
535
+ )
536
+
537
+ monkeypatch.setattr(
538
+ sys,
539
+ "argv",
540
+ ["main.py", "--action", "clean", "--input-file", str(input_file)],
541
+ )
542
+
543
+ exit_code = main()
544
+ captured = capsys.readouterr()
545
+ assert exit_code == 0
546
+ assert "清洗后总行数:6" in captured.out
547
+
548
+ output_file = tmp_path / "input_cleaned.xlsx"
549
+ sheets = pd.read_excel(output_file, sheet_name=None)
550
+ assert "!!!" not in sheets["Sheet1"]["text"].tolist()
551
+ assert sheets["train"]["text"].tolist() == ["训练注入1"]
552
+ assert sheets["test"]["text"].tolist() == ["测试注入1"]
553
+
554
+
555
+ def test_main_deduplicate_preserves_injection_sheets_for_downstream_split(
556
+ tmp_path, monkeypatch, capsys
557
+ ) -> None:
558
+ input_file = tmp_path / "input.xlsx"
559
+ with pd.ExcelWriter(input_file) as writer:
560
+ pd.DataFrame(
561
+ {
562
+ "text": ["问题1", "问题1", "问题2"],
563
+ "category": ["基金", "基金", "股票"],
564
+ }
565
+ ).to_excel(writer, sheet_name="base_a", index=False)
566
+ pd.DataFrame({"text": ["测试注入1"], "category": ["增量"]}).to_excel(
567
+ writer, sheet_name="test", index=False
568
+ )
569
+
570
+ monkeypatch.setattr(
571
+ sys,
572
+ "argv",
573
+ ["main.py", "--action", "deduplicate", "--input-file", str(input_file)],
574
+ )
575
+
576
+ exit_code = main()
577
+ captured = capsys.readouterr()
578
+ assert exit_code == 0
579
+
580
+ output_file = tmp_path / "input_deduplicated.xlsx"
581
+ sheets = pd.read_excel(output_file, sheet_name=None)
582
+ assert sorted(sheets["Sheet1"]["text"].tolist()) == ["问题1", "问题2"]
583
+ assert sheets["test"]["text"].tolist() == ["测试注入1"]
584
+
585
+
586
+ def test_main_split_after_clean_includes_test_sheet_rows_in_test_output(
587
+ tmp_path, monkeypatch, capsys
588
+ ) -> None:
589
+ input_file = tmp_path / "input.xlsx"
590
+ with pd.ExcelWriter(input_file) as writer:
591
+ pd.DataFrame(
592
+ {
593
+ "text": [f"问题{i}" for i in range(8)],
594
+ "category": ["基金"] * 4 + ["股票"] * 4,
595
+ }
596
+ ).to_excel(writer, sheet_name="base_a", index=False)
597
+ pd.DataFrame({"text": ["测试注入1"], "category": ["增量"]}).to_excel(
598
+ writer, sheet_name="test", index=False
599
+ )
600
+
601
+ monkeypatch.setattr(
602
+ sys,
603
+ "argv",
604
+ ["main.py", "--action", "clean", "--input-file", str(input_file)],
605
+ )
606
+ assert main() == 0
607
+ capsys.readouterr()
608
+
609
+ cleaned_file = tmp_path / "input_cleaned.xlsx"
610
+ monkeypatch.setattr(
611
+ sys,
612
+ "argv",
613
+ [
614
+ "main.py",
615
+ "--action",
616
+ "split",
617
+ "--input-file",
618
+ str(cleaned_file),
619
+ "--validation-ratio",
620
+ "0.2",
621
+ "--test-ratio",
622
+ "0.2",
623
+ ],
624
+ )
625
+ exit_code = main()
626
+ captured = capsys.readouterr()
627
+
628
+ assert exit_code == 0
629
+ assert "注入测试集行数('test' sheet):1" in captured.out
630
+ test_output = pd.read_excel(tmp_path / "input_cleaned_split_test.xlsx")
631
+ assert "测试注入1" in test_output["text"].tolist()
445
632
 
446
633
 
447
634
  def test_main_split_rejects_missing_group_column(tmp_path, monkeypatch, capsys) -> None:
File without changes