mysphinx-forge 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/PKG-INFO +49 -8
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/README.md +48 -7
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cli.py +113 -28
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/file_io.py +25 -10
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/sft_dataset.py +17 -7
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/splitting.py +3 -2
- mysphinx_forge-0.2.2/mysphinx_forge/templates/__init__.py +0 -0
- mysphinx_forge-0.2.2/mysphinx_forge/templates/mysphinx-forge.yaml +411 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/PKG-INFO +49 -8
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/SOURCES.txt +2 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/pyproject.toml +4 -1
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cli.py +54 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_file_io.py +38 -19
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_sft_dataset.py +18 -5
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_splitting.py +17 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/__init__.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cleaning.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cluster_labeling.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/cluster_reporting.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/clustering.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/config.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/deduplication.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/embedding.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/env_utils.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/http_client.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/logging_utils.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/model_eval.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/model_testing.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/openai_responses.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/progress.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge/semantic_deduplication.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/dependency_links.txt +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/entry_points.txt +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/requires.txt +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/mysphinx_forge.egg-info/top_level.txt +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/setup.cfg +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cleaning.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cluster_labeling.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_cluster_reporting.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_clustering.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_deduplication.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_http_client.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_model_eval.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_model_testing.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_semantic_deduplication.py +0 -0
- {mysphinx_forge-0.2.1 → mysphinx_forge-0.2.2}/tests/test_sft_cli.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mysphinx-forge
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: Data and model workflow toolkit for cleaning, clustering, generation, and evaluation
|
|
5
5
|
Keywords: data-cleaning,deduplication,clustering,nlp,cli
|
|
6
6
|
Classifier: Development Status :: 3 - Alpha
|
|
@@ -166,7 +166,25 @@ uv run python main.py ...
|
|
|
166
166
|
|
|
167
167
|
除了命令行参数,也可以在当前目录放一个 `mysphinx-forge.yaml`,或通过 `--config <path>` 显式指定配置文件。
|
|
168
168
|
|
|
169
|
-
|
|
169
|
+
在你的项目目录下运行以下命令,可以生成一份带完整注释的配置模版:
|
|
170
|
+
|
|
171
|
+
```bash
|
|
172
|
+
mysphinx-forge init
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
如需生成到指定路径:
|
|
176
|
+
|
|
177
|
+
```bash
|
|
178
|
+
mysphinx-forge init -o configs/mysphinx-forge.yaml
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
如需覆盖已有文件:
|
|
182
|
+
|
|
183
|
+
```bash
|
|
184
|
+
mysphinx-forge init --force
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
生成的模版包含所有参数的默认值和说明,修改 `action` 和 `input_file` 后即可直接运行。
|
|
170
188
|
|
|
171
189
|
优先级规则:
|
|
172
190
|
|
|
@@ -254,6 +272,19 @@ cp .env.example .env
|
|
|
254
272
|
|
|
255
273
|
## 常用命令
|
|
256
274
|
|
|
275
|
+
### 生成配置模版
|
|
276
|
+
|
|
277
|
+
```bash
|
|
278
|
+
mysphinx-forge init
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
在当前目录生成 `mysphinx-forge.yaml` 配置模版,包含所有参数的默认值和注释说明。修改 `action` 和 `input_file` 后即可直接运行 `mysphinx-forge`,无需每次手动输入命令行参数。
|
|
282
|
+
|
|
283
|
+
```bash
|
|
284
|
+
mysphinx-forge init -o configs/mysphinx-forge.yaml # 指定输出路径
|
|
285
|
+
mysphinx-forge init --force # 覆盖已有文件
|
|
286
|
+
```
|
|
287
|
+
|
|
257
288
|
### 数据清洗
|
|
258
289
|
|
|
259
290
|
`clean` 会删除目标列中的以下行:
|
|
@@ -449,13 +480,19 @@ data/input_deduplicated_split_train_pa.jsonl
|
|
|
449
480
|
mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-system-prompt "你是证券领域用户意图识别专家。"
|
|
450
481
|
```
|
|
451
482
|
|
|
452
|
-
|
|
483
|
+
**自动切分**:当转换结果超过阈值(默认 10000)时,自动按阈值切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
|
|
453
484
|
|
|
454
485
|
```
|
|
455
486
|
input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
|
|
456
487
|
input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
|
|
457
488
|
```
|
|
458
489
|
|
|
490
|
+
通过 `--sft-pa-max-records-per-file` 可自定义阈值:
|
|
491
|
+
|
|
492
|
+
```bash
|
|
493
|
+
mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
|
|
494
|
+
```
|
|
495
|
+
|
|
459
496
|
说明:
|
|
460
497
|
|
|
461
498
|
- 支持 `alpaca`(默认)和 `pa` 两种格式,通过 `--sft-format` 切换
|
|
@@ -508,11 +545,15 @@ OPENAI_API_KEY=... mysphinx-forge --action cluster --input-file data/input_dedup
|
|
|
508
545
|
mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
|
|
509
546
|
```
|
|
510
547
|
|
|
511
|
-
如果输入是 Excel,多 sheet
|
|
548
|
+
如果输入是 Excel,多 sheet 会默认合并。以下三个特殊 sheet 名(大小写不敏感)用于注入数据,不参与比例切分:
|
|
549
|
+
|
|
550
|
+
| sheet 名 | 注入目标 |
|
|
551
|
+
|----------|----------|
|
|
552
|
+
| `train` | 全量追加到训练集 |
|
|
553
|
+
| `valid` | 全量追加到验证集 |
|
|
554
|
+
| `test` | 全量追加到测试集 |
|
|
512
555
|
|
|
513
|
-
|
|
514
|
-
- `increment` 会全量追加到 train 和 valid
|
|
515
|
-
- `increment` 不会进入 test
|
|
556
|
+
三者均为可选,可以同时存在,也可以只有其中一个或多个。
|
|
516
557
|
|
|
517
558
|
显式分层切分:
|
|
518
559
|
|
|
@@ -800,7 +841,7 @@ mysphinx-forge --action convert-sft --input-file data/raw_deduplicated_split_val
|
|
|
800
841
|
|
|
801
842
|
- `convert-sft` 默认会自动探测输出列:`category`、`label`、`intent`、`output`、`response`、`answer`、`target`
|
|
802
843
|
- 如果输入是 Excel,多 sheet 会默认合并
|
|
803
|
-
-
|
|
844
|
+
- 如果存在名为 `train` / `valid` / `test` 的 sheet(大小写不敏感),它们在 `split` 时不参与比例切分,而是分别全量注入对应的集合
|
|
804
845
|
|
|
805
846
|
### 场景 2:收到一份原始 Excel,但只有问题文本,没有标签
|
|
806
847
|
|
|
@@ -129,7 +129,25 @@ uv run python main.py ...
|
|
|
129
129
|
|
|
130
130
|
除了命令行参数,也可以在当前目录放一个 `mysphinx-forge.yaml`,或通过 `--config <path>` 显式指定配置文件。
|
|
131
131
|
|
|
132
|
-
|
|
132
|
+
在你的项目目录下运行以下命令,可以生成一份带完整注释的配置模版:
|
|
133
|
+
|
|
134
|
+
```bash
|
|
135
|
+
mysphinx-forge init
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
如需生成到指定路径:
|
|
139
|
+
|
|
140
|
+
```bash
|
|
141
|
+
mysphinx-forge init -o configs/mysphinx-forge.yaml
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
如需覆盖已有文件:
|
|
145
|
+
|
|
146
|
+
```bash
|
|
147
|
+
mysphinx-forge init --force
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
生成的模版包含所有参数的默认值和说明,修改 `action` 和 `input_file` 后即可直接运行。
|
|
133
151
|
|
|
134
152
|
优先级规则:
|
|
135
153
|
|
|
@@ -217,6 +235,19 @@ cp .env.example .env
|
|
|
217
235
|
|
|
218
236
|
## 常用命令
|
|
219
237
|
|
|
238
|
+
### 生成配置模版
|
|
239
|
+
|
|
240
|
+
```bash
|
|
241
|
+
mysphinx-forge init
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
在当前目录生成 `mysphinx-forge.yaml` 配置模版,包含所有参数的默认值和注释说明。修改 `action` 和 `input_file` 后即可直接运行 `mysphinx-forge`,无需每次手动输入命令行参数。
|
|
245
|
+
|
|
246
|
+
```bash
|
|
247
|
+
mysphinx-forge init -o configs/mysphinx-forge.yaml # 指定输出路径
|
|
248
|
+
mysphinx-forge init --force # 覆盖已有文件
|
|
249
|
+
```
|
|
250
|
+
|
|
220
251
|
### 数据清洗
|
|
221
252
|
|
|
222
253
|
`clean` 会删除目标列中的以下行:
|
|
@@ -412,13 +443,19 @@ data/input_deduplicated_split_train_pa.jsonl
|
|
|
412
443
|
mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-system-prompt "你是证券领域用户意图识别专家。"
|
|
413
444
|
```
|
|
414
445
|
|
|
415
|
-
|
|
446
|
+
**自动切分**:当转换结果超过阈值(默认 10000)时,自动按阈值切分为多个 JSONL 文件,文件名末尾追加序号。例如输入文件 `input_deduplicated_split_train.xlsx` 包含 12000 条数据,输出为:
|
|
416
447
|
|
|
417
448
|
```
|
|
418
449
|
input_deduplicated_split_train_pa_1.jsonl # 前 10000 条
|
|
419
450
|
input_deduplicated_split_train_pa_2.jsonl # 后 2000 条
|
|
420
451
|
```
|
|
421
452
|
|
|
453
|
+
通过 `--sft-pa-max-records-per-file` 可自定义阈值:
|
|
454
|
+
|
|
455
|
+
```bash
|
|
456
|
+
mysphinx-forge --action convert-sft --sft-format pa --input-file data/input.xlsx --sft-pa-max-records-per-file 5000
|
|
457
|
+
```
|
|
458
|
+
|
|
422
459
|
说明:
|
|
423
460
|
|
|
424
461
|
- 支持 `alpaca`(默认)和 `pa` 两种格式,通过 `--sft-format` 切换
|
|
@@ -471,11 +508,15 @@ OPENAI_API_KEY=... mysphinx-forge --action cluster --input-file data/input_dedup
|
|
|
471
508
|
mysphinx-forge --action split --input-file data/input_deduplicated.xlsx
|
|
472
509
|
```
|
|
473
510
|
|
|
474
|
-
如果输入是 Excel,多 sheet
|
|
511
|
+
如果输入是 Excel,多 sheet 会默认合并。以下三个特殊 sheet 名(大小写不敏感)用于注入数据,不参与比例切分:
|
|
512
|
+
|
|
513
|
+
| sheet 名 | 注入目标 |
|
|
514
|
+
|----------|----------|
|
|
515
|
+
| `train` | 全量追加到训练集 |
|
|
516
|
+
| `valid` | 全量追加到验证集 |
|
|
517
|
+
| `test` | 全量追加到测试集 |
|
|
475
518
|
|
|
476
|
-
|
|
477
|
-
- `increment` 会全量追加到 train 和 valid
|
|
478
|
-
- `increment` 不会进入 test
|
|
519
|
+
三者均为可选,可以同时存在,也可以只有其中一个或多个。
|
|
479
520
|
|
|
480
521
|
显式分层切分:
|
|
481
522
|
|
|
@@ -763,7 +804,7 @@ mysphinx-forge --action convert-sft --input-file data/raw_deduplicated_split_val
|
|
|
763
804
|
|
|
764
805
|
- `convert-sft` 默认会自动探测输出列:`category`、`label`、`intent`、`output`、`response`、`answer`、`target`
|
|
765
806
|
- 如果输入是 Excel,多 sheet 会默认合并
|
|
766
|
-
-
|
|
807
|
+
- 如果存在名为 `train` / `valid` / `test` 的 sheet(大小写不敏感),它们在 `split` 时不参与比例切分,而是分别全量注入对应的集合
|
|
767
808
|
|
|
768
809
|
### 场景 2:收到一份原始 Excel,但只有问题文本,没有标签
|
|
769
810
|
|
|
@@ -31,7 +31,9 @@ from mysphinx_forge.config import (
|
|
|
31
31
|
from mysphinx_forge.deduplication import DeduplicationStats, deduplicate_dataframe
|
|
32
32
|
from mysphinx_forge.env_utils import load_project_env_files
|
|
33
33
|
from mysphinx_forge.file_io import (
|
|
34
|
-
|
|
34
|
+
TEST_SHEET_NAME,
|
|
35
|
+
TRAIN_SHEET_NAME,
|
|
36
|
+
VALID_SHEET_NAME,
|
|
35
37
|
append_dataframe_chunk,
|
|
36
38
|
count_csv_rows,
|
|
37
39
|
iter_dataframes,
|
|
@@ -61,6 +63,7 @@ from mysphinx_forge.semantic_deduplication import (
|
|
|
61
63
|
)
|
|
62
64
|
from mysphinx_forge.sft_dataset import (
|
|
63
65
|
DEFAULT_SFT_FORMAT,
|
|
66
|
+
PA_MAX_RECORDS_PER_FILE,
|
|
64
67
|
PA_SFT_FORMAT,
|
|
65
68
|
SftConversionStats,
|
|
66
69
|
convert_dataframe_to_alpaca,
|
|
@@ -89,6 +92,10 @@ _ACTION_CHOICES = [
|
|
|
89
92
|
def main() -> int:
|
|
90
93
|
load_project_env_files()
|
|
91
94
|
raw_argv = sys.argv[1:]
|
|
95
|
+
|
|
96
|
+
if raw_argv and raw_argv[0] == "init":
|
|
97
|
+
return _run_init(raw_argv[1:])
|
|
98
|
+
|
|
92
99
|
bootstrap_args, _ = _build_bootstrap_parser().parse_known_args(raw_argv)
|
|
93
100
|
|
|
94
101
|
try:
|
|
@@ -305,12 +312,63 @@ def main() -> int:
|
|
|
305
312
|
resolved_sft_system_prompt,
|
|
306
313
|
args.sft_system_column,
|
|
307
314
|
args.sft_user_query_as_instruction,
|
|
315
|
+
args.sft_pa_max_records_per_file,
|
|
308
316
|
)
|
|
309
317
|
|
|
310
318
|
parser.print_help()
|
|
311
319
|
return 1
|
|
312
320
|
|
|
313
321
|
|
|
322
|
+
def _run_init(argv: list[str]) -> int:
|
|
323
|
+
init_parser = argparse.ArgumentParser(
|
|
324
|
+
prog="mysphinx-forge init",
|
|
325
|
+
description="在当前目录生成 mysphinx-forge.yaml 配置模版。",
|
|
326
|
+
)
|
|
327
|
+
init_parser.add_argument(
|
|
328
|
+
"--output",
|
|
329
|
+
"-o",
|
|
330
|
+
default=DEFAULT_CONFIG_FILE_NAME,
|
|
331
|
+
help=f"输出文件路径,默认为当前目录下的 {DEFAULT_CONFIG_FILE_NAME}。",
|
|
332
|
+
)
|
|
333
|
+
init_parser.add_argument(
|
|
334
|
+
"--force",
|
|
335
|
+
action="store_true",
|
|
336
|
+
help="若目标文件已存在,强制覆盖。",
|
|
337
|
+
)
|
|
338
|
+
args = init_parser.parse_args(argv)
|
|
339
|
+
|
|
340
|
+
output_path = Path(args.output)
|
|
341
|
+
if not output_path.is_absolute():
|
|
342
|
+
output_path = Path.cwd() / output_path
|
|
343
|
+
|
|
344
|
+
if output_path.exists() and not args.force:
|
|
345
|
+
print(f"文件已存在:{output_path}")
|
|
346
|
+
print("如需覆盖,请添加 --force 参数。")
|
|
347
|
+
return 1
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
from importlib.resources import files as _pkg_files
|
|
351
|
+
template_text = (
|
|
352
|
+
_pkg_files("mysphinx_forge.templates")
|
|
353
|
+
.joinpath("mysphinx-forge.yaml")
|
|
354
|
+
.read_text(encoding="utf-8")
|
|
355
|
+
)
|
|
356
|
+
except Exception as exc:
|
|
357
|
+
print(f"读取配置模版失败:{exc}")
|
|
358
|
+
return 1
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
362
|
+
output_path.write_text(template_text, encoding="utf-8")
|
|
363
|
+
except OSError as exc:
|
|
364
|
+
print(f"写出配置模版失败:{output_path},{type(exc).__name__}: {exc}")
|
|
365
|
+
return 1
|
|
366
|
+
|
|
367
|
+
print(f"已生成配置模版:{output_path}")
|
|
368
|
+
print("请编辑文件,将 action 和 input_file 替换为实际值后直接运行 mysphinx-forge。")
|
|
369
|
+
return 0
|
|
370
|
+
|
|
371
|
+
|
|
314
372
|
def _build_bootstrap_parser() -> argparse.ArgumentParser:
|
|
315
373
|
parser = argparse.ArgumentParser(add_help=False)
|
|
316
374
|
parser.add_argument("--config", default="")
|
|
@@ -656,6 +714,13 @@ def _build_parser(
|
|
|
656
714
|
default=config_defaults.get("sft_user_query_as_instruction", True),
|
|
657
715
|
help="为 true 时将用户输入作为 alpaca instruction 字段,input 字段留空;为 false 时保持原有行为(input 存用户输入,instruction 为固定文本)。默认 true。",
|
|
658
716
|
)
|
|
717
|
+
parser.add_argument(
|
|
718
|
+
"--sft-pa-max-records-per-file",
|
|
719
|
+
type=int,
|
|
720
|
+
dest="sft_pa_max_records_per_file",
|
|
721
|
+
default=config_defaults.get("sft_pa_max_records_per_file", PA_MAX_RECORDS_PER_FILE),
|
|
722
|
+
help=f"pa 格式每个 JSONL 文件最大记录数,超出时自动切分为多个文件,默认 {PA_MAX_RECORDS_PER_FILE}。",
|
|
723
|
+
)
|
|
659
724
|
return parser
|
|
660
725
|
|
|
661
726
|
|
|
@@ -1366,9 +1431,8 @@ def _run_split(
|
|
|
1366
1431
|
|
|
1367
1432
|
try:
|
|
1368
1433
|
run_stage("读取文件", logger=logger)
|
|
1369
|
-
dataframe,
|
|
1434
|
+
dataframe, train_inject_df, valid_inject_df, test_inject_df = load_split_dataframes(
|
|
1370
1435
|
input_file,
|
|
1371
|
-
increment_sheet_name=INCREMENT_SHEET_NAME,
|
|
1372
1436
|
)
|
|
1373
1437
|
resolved_split_mode, resolved_stratify_column = resolve_auto_split_mode(
|
|
1374
1438
|
dataframe,
|
|
@@ -1387,15 +1451,21 @@ def _run_split(
|
|
|
1387
1451
|
time_column=time_column,
|
|
1388
1452
|
time_ascending=time_ascending,
|
|
1389
1453
|
)
|
|
1390
|
-
if not
|
|
1391
|
-
train_df = pd.concat([train_df,
|
|
1392
|
-
|
|
1393
|
-
|
|
1454
|
+
if not train_inject_df.empty:
|
|
1455
|
+
train_df = pd.concat([train_df, train_inject_df], ignore_index=True)
|
|
1456
|
+
stats.inject_train_rows = len(train_inject_df)
|
|
1457
|
+
if not valid_inject_df.empty:
|
|
1458
|
+
validation_df = pd.concat([validation_df, valid_inject_df], ignore_index=True)
|
|
1459
|
+
stats.inject_valid_rows = len(valid_inject_df)
|
|
1460
|
+
if not test_inject_df.empty:
|
|
1461
|
+
test_df = pd.concat([test_df, test_inject_df], ignore_index=True)
|
|
1462
|
+
stats.inject_test_rows = len(test_inject_df)
|
|
1463
|
+
total_inject = stats.inject_train_rows + stats.inject_valid_rows + stats.inject_test_rows
|
|
1464
|
+
if total_inject > 0:
|
|
1465
|
+
stats.total_rows += total_inject
|
|
1394
1466
|
stats.train_rows = len(train_df)
|
|
1395
1467
|
stats.validation_rows = len(validation_df)
|
|
1396
1468
|
stats.test_rows = len(test_df)
|
|
1397
|
-
stats.increment_rows = len(increment_dataframe)
|
|
1398
|
-
stats.increment_sheet_name = INCREMENT_SHEET_NAME
|
|
1399
1469
|
except ValueError as exc:
|
|
1400
1470
|
_emit_error(str(exc), logger)
|
|
1401
1471
|
close_logger()
|
|
@@ -1405,8 +1475,15 @@ def _run_split(
|
|
|
1405
1475
|
write_dataframe(train_df, train_output_path)
|
|
1406
1476
|
run_stage("写出 valid", logger=logger)
|
|
1407
1477
|
write_dataframe(validation_df, validation_output_path)
|
|
1408
|
-
|
|
1409
|
-
|
|
1478
|
+
if test_ratio > 0:
|
|
1479
|
+
run_stage("写出 test", logger=logger)
|
|
1480
|
+
write_dataframe(test_df, test_output_path)
|
|
1481
|
+
extra_output_files: dict[str, Path] = {
|
|
1482
|
+
"train_file": train_output_path,
|
|
1483
|
+
"validation_file": validation_output_path,
|
|
1484
|
+
}
|
|
1485
|
+
if test_ratio > 0:
|
|
1486
|
+
extra_output_files["test_file"] = test_output_path
|
|
1410
1487
|
_write_meta(
|
|
1411
1488
|
output_path=base_output_path,
|
|
1412
1489
|
action="split",
|
|
@@ -1421,21 +1498,18 @@ def _run_split(
|
|
|
1421
1498
|
"group_column": group_column,
|
|
1422
1499
|
"time_column": time_column,
|
|
1423
1500
|
"time_order": "asc" if time_ascending else "desc",
|
|
1424
|
-
"
|
|
1425
|
-
"
|
|
1501
|
+
"inject_train_rows": stats.inject_train_rows,
|
|
1502
|
+
"inject_valid_rows": stats.inject_valid_rows,
|
|
1503
|
+
"inject_test_rows": stats.inject_test_rows,
|
|
1426
1504
|
},
|
|
1427
1505
|
split_stats=stats,
|
|
1428
|
-
extra_output_files=
|
|
1429
|
-
"train_file": train_output_path,
|
|
1430
|
-
"validation_file": validation_output_path,
|
|
1431
|
-
"test_file": test_output_path,
|
|
1432
|
-
},
|
|
1506
|
+
extra_output_files=extra_output_files,
|
|
1433
1507
|
)
|
|
1434
1508
|
_print_split_stats(
|
|
1435
1509
|
stats,
|
|
1436
1510
|
train_output_path=train_output_path,
|
|
1437
1511
|
validation_output_path=validation_output_path,
|
|
1438
|
-
test_output_path=test_output_path,
|
|
1512
|
+
test_output_path=test_output_path if test_ratio > 0 else None,
|
|
1439
1513
|
logger=logger,
|
|
1440
1514
|
)
|
|
1441
1515
|
close_logger()
|
|
@@ -1744,6 +1818,7 @@ def _run_convert_sft(
|
|
|
1744
1818
|
sft_system_prompt: str,
|
|
1745
1819
|
sft_system_column: str,
|
|
1746
1820
|
sft_user_query_as_instruction: bool = True,
|
|
1821
|
+
sft_pa_max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
|
|
1747
1822
|
) -> int:
|
|
1748
1823
|
input_path = Path(input_file)
|
|
1749
1824
|
output_path = _resolve_sft_output_path(input_path, output_arg, sft_format)
|
|
@@ -1764,6 +1839,7 @@ def _run_convert_sft(
|
|
|
1764
1839
|
dataframe,
|
|
1765
1840
|
target_column=target_column,
|
|
1766
1841
|
output_column=sft_output_column,
|
|
1842
|
+
instruction=sft_instruction,
|
|
1767
1843
|
system_prompt=sft_system_prompt,
|
|
1768
1844
|
system_column=sft_system_column,
|
|
1769
1845
|
)
|
|
@@ -1786,7 +1862,9 @@ def _run_convert_sft(
|
|
|
1786
1862
|
|
|
1787
1863
|
run_stage("写出结果", logger=logger)
|
|
1788
1864
|
if sft_format == PA_SFT_FORMAT:
|
|
1789
|
-
written_paths = write_pa_dataset(
|
|
1865
|
+
written_paths = write_pa_dataset(
|
|
1866
|
+
records, output_path, max_records_per_file=sft_pa_max_records_per_file
|
|
1867
|
+
)
|
|
1790
1868
|
else:
|
|
1791
1869
|
write_alpaca_dataset(records, output_path)
|
|
1792
1870
|
written_paths = [output_path]
|
|
@@ -1802,6 +1880,7 @@ def _run_convert_sft(
|
|
|
1802
1880
|
"sft_system_prompt": sft_system_prompt,
|
|
1803
1881
|
"sft_system_column": sft_system_column,
|
|
1804
1882
|
"sft_user_query_as_instruction": sft_user_query_as_instruction,
|
|
1883
|
+
"sft_pa_max_records_per_file": sft_pa_max_records_per_file,
|
|
1805
1884
|
},
|
|
1806
1885
|
sft_conversion_stats=stats,
|
|
1807
1886
|
extra_output_files={f"output_file_{i + 1}": p for i, p in enumerate(written_paths)}
|
|
@@ -2004,7 +2083,7 @@ def _print_split_stats(
|
|
|
2004
2083
|
*,
|
|
2005
2084
|
train_output_path: Path,
|
|
2006
2085
|
validation_output_path: Path,
|
|
2007
|
-
test_output_path: Path,
|
|
2086
|
+
test_output_path: Path | None,
|
|
2008
2087
|
logger: Logger,
|
|
2009
2088
|
) -> None:
|
|
2010
2089
|
_emit_message("数据切分完成", logger)
|
|
@@ -2016,17 +2095,22 @@ def _print_split_stats(
|
|
|
2016
2095
|
if stats.time_column:
|
|
2017
2096
|
_emit_message(f"时间列:{stats.time_column}", logger)
|
|
2018
2097
|
_emit_message(f"时间顺序:{'asc' if stats.time_ascending else 'desc'}", logger)
|
|
2019
|
-
if stats.
|
|
2020
|
-
_emit_message(f"
|
|
2021
|
-
|
|
2098
|
+
if stats.inject_train_rows > 0:
|
|
2099
|
+
_emit_message(f"注入训练集行数({TRAIN_SHEET_NAME!r} sheet):{stats.inject_train_rows}", logger)
|
|
2100
|
+
if stats.inject_valid_rows > 0:
|
|
2101
|
+
_emit_message(f"注入验证集行数({VALID_SHEET_NAME!r} sheet):{stats.inject_valid_rows}", logger)
|
|
2102
|
+
if stats.inject_test_rows > 0:
|
|
2103
|
+
_emit_message(f"注入测试集行数({TEST_SHEET_NAME!r} sheet):{stats.inject_test_rows}", logger)
|
|
2022
2104
|
_emit_message(f"随机种子:{stats.random_seed}", logger)
|
|
2023
2105
|
_emit_message(f"总行数:{stats.total_rows}", logger)
|
|
2024
2106
|
_emit_message(f"训练集行数:{stats.train_rows}", logger)
|
|
2025
2107
|
_emit_message(f"验证集行数:{stats.validation_rows}", logger)
|
|
2026
|
-
|
|
2108
|
+
if test_output_path is not None:
|
|
2109
|
+
_emit_message(f"测试集行数:{stats.test_rows}", logger)
|
|
2027
2110
|
_emit_message(f"训练集文件:{train_output_path}", logger)
|
|
2028
2111
|
_emit_message(f"验证集文件:{validation_output_path}", logger)
|
|
2029
|
-
|
|
2112
|
+
if test_output_path is not None:
|
|
2113
|
+
_emit_message(f"测试集文件:{test_output_path}", logger)
|
|
2030
2114
|
|
|
2031
2115
|
|
|
2032
2116
|
def _print_sft_conversion_stats(
|
|
@@ -2134,8 +2218,9 @@ def _write_meta(
|
|
|
2134
2218
|
"group_column": split_stats.group_column,
|
|
2135
2219
|
"time_column": split_stats.time_column,
|
|
2136
2220
|
"time_ascending": split_stats.time_ascending,
|
|
2137
|
-
"
|
|
2138
|
-
"
|
|
2221
|
+
"inject_train_rows": split_stats.inject_train_rows,
|
|
2222
|
+
"inject_valid_rows": split_stats.inject_valid_rows,
|
|
2223
|
+
"inject_test_rows": split_stats.inject_test_rows,
|
|
2139
2224
|
}
|
|
2140
2225
|
if sft_conversion_stats is not None:
|
|
2141
2226
|
meta["sft_conversion_stats"] = {
|
|
@@ -14,7 +14,10 @@ if TYPE_CHECKING:
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
SUPPORTED_EXTENSIONS = {".csv", ".xls", ".xlsx", ".xlsm"}
|
|
17
|
-
|
|
17
|
+
TRAIN_SHEET_NAME = "train"
|
|
18
|
+
VALID_SHEET_NAME = "valid"
|
|
19
|
+
TEST_SHEET_NAME = "test"
|
|
20
|
+
_INJECT_SHEET_NAMES = {TRAIN_SHEET_NAME, VALID_SHEET_NAME, TEST_SHEET_NAME}
|
|
18
21
|
DEFAULT_PROGRESS_COLOURS = [
|
|
19
22
|
"red",
|
|
20
23
|
"green",
|
|
@@ -44,28 +47,40 @@ def load_dataframe(file_path: str | Path) -> pd.DataFrame:
|
|
|
44
47
|
|
|
45
48
|
def load_split_dataframes(
|
|
46
49
|
file_path: str | Path,
|
|
47
|
-
|
|
48
|
-
increment_sheet_name: str = INCREMENT_SHEET_NAME,
|
|
49
|
-
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
50
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
50
51
|
path = validate_tabular_file(file_path)
|
|
52
|
+
empty = pd.DataFrame()
|
|
51
53
|
if path.suffix.lower() == ".csv":
|
|
52
54
|
dataframe = pd.read_csv(path, skip_blank_lines=False)
|
|
53
|
-
return dataframe,
|
|
55
|
+
return dataframe, empty, empty, empty
|
|
54
56
|
|
|
55
57
|
regular_frames: list[pd.DataFrame] = []
|
|
56
|
-
|
|
58
|
+
train_frames: list[pd.DataFrame] = []
|
|
59
|
+
valid_frames: list[pd.DataFrame] = []
|
|
60
|
+
test_frames: list[pd.DataFrame] = []
|
|
57
61
|
for sheet_name, dataframe in _load_excel_sheets(path).items():
|
|
58
|
-
|
|
59
|
-
|
|
62
|
+
lower = sheet_name.lower()
|
|
63
|
+
if lower == TRAIN_SHEET_NAME:
|
|
64
|
+
train_frames.append(dataframe)
|
|
65
|
+
elif lower == VALID_SHEET_NAME:
|
|
66
|
+
valid_frames.append(dataframe)
|
|
67
|
+
elif lower == TEST_SHEET_NAME:
|
|
68
|
+
test_frames.append(dataframe)
|
|
60
69
|
else:
|
|
61
70
|
regular_frames.append(dataframe)
|
|
62
71
|
|
|
63
72
|
if not regular_frames:
|
|
73
|
+
inject_names = ", ".join(f"{n!r}" for n in sorted(_INJECT_SHEET_NAMES))
|
|
64
74
|
raise ValueError(
|
|
65
|
-
f"
|
|
75
|
+
f"除注入工作表({inject_names})外没有可切分的 Excel sheet。"
|
|
66
76
|
)
|
|
67
77
|
|
|
68
|
-
return
|
|
78
|
+
return (
|
|
79
|
+
_concat_excel_frames(regular_frames),
|
|
80
|
+
_concat_excel_frames(train_frames),
|
|
81
|
+
_concat_excel_frames(valid_frames),
|
|
82
|
+
_concat_excel_frames(test_frames),
|
|
83
|
+
)
|
|
69
84
|
|
|
70
85
|
|
|
71
86
|
def iter_dataframes(file_path: str | Path, chunksize: int = 50_000) -> Iterable[pd.DataFrame]:
|
|
@@ -114,6 +114,7 @@ def convert_dataframe_to_pa(
|
|
|
114
114
|
*,
|
|
115
115
|
target_column: str = "text",
|
|
116
116
|
output_column: str = "",
|
|
117
|
+
instruction: str = "",
|
|
117
118
|
system_prompt: str = "",
|
|
118
119
|
system_column: str = "",
|
|
119
120
|
) -> tuple[list[dict], SftConversionStats]:
|
|
@@ -121,6 +122,7 @@ def convert_dataframe_to_pa(
|
|
|
121
122
|
resolved_output_column = resolve_sft_output_column(dataframe, output_column)
|
|
122
123
|
resolved_system_column = _resolve_optional_column(dataframe, system_column)
|
|
123
124
|
|
|
125
|
+
fixed_instruction = instruction.strip()
|
|
124
126
|
final_system_prompt = system_prompt.strip()
|
|
125
127
|
records: list[dict] = []
|
|
126
128
|
skipped_blank_input_rows = 0
|
|
@@ -146,7 +148,8 @@ def convert_dataframe_to_pa(
|
|
|
146
148
|
if system_text:
|
|
147
149
|
conversations.append({"context": system_text, "role": "system"})
|
|
148
150
|
|
|
149
|
-
|
|
151
|
+
human_text = f"{fixed_instruction}\n{input_text}" if fixed_instruction else input_text
|
|
152
|
+
conversations.append({"context": human_text, "role": "human"})
|
|
150
153
|
conversations.append({"context": output_text, "role": "assistant"})
|
|
151
154
|
|
|
152
155
|
records.append({"conversations": conversations, "id": str(len(records) + 1)})
|
|
@@ -163,14 +166,21 @@ def convert_dataframe_to_pa(
|
|
|
163
166
|
return records, stats
|
|
164
167
|
|
|
165
168
|
|
|
166
|
-
def write_pa_dataset(
|
|
167
|
-
|
|
169
|
+
def write_pa_dataset(
|
|
170
|
+
records: list[dict],
|
|
171
|
+
output_path: str | Path,
|
|
172
|
+
*,
|
|
173
|
+
max_records_per_file: int = PA_MAX_RECORDS_PER_FILE,
|
|
174
|
+
) -> list[Path]:
|
|
175
|
+
"""Write PA-format records as one or more JSONL files.
|
|
168
176
|
|
|
169
|
-
|
|
177
|
+
When len(records) <= max_records_per_file a single file is written.
|
|
178
|
+
Otherwise the records are split into numbered chunks, e.g. *_1.jsonl,
|
|
179
|
+
*_2.jsonl, ... Returns the list of paths written.
|
|
170
180
|
"""
|
|
171
181
|
output_path = Path(output_path)
|
|
172
182
|
total = len(records)
|
|
173
|
-
if total <=
|
|
183
|
+
if total <= max_records_per_file:
|
|
174
184
|
_write_pa_jsonl(records, output_path)
|
|
175
185
|
return [output_path]
|
|
176
186
|
|
|
@@ -179,8 +189,8 @@ def write_pa_dataset(records: list[dict], output_path: str | Path) -> list[Path]
|
|
|
179
189
|
parent = output_path.parent
|
|
180
190
|
written: list[Path] = []
|
|
181
191
|
chunk_index = 1
|
|
182
|
-
for start in range(0, total,
|
|
183
|
-
chunk = records[start : start +
|
|
192
|
+
for start in range(0, total, max_records_per_file):
|
|
193
|
+
chunk = records[start : start + max_records_per_file]
|
|
184
194
|
chunk_path = parent / f"{stem}_{chunk_index}{suffix}"
|
|
185
195
|
_write_pa_jsonl(chunk, chunk_path)
|
|
186
196
|
written.append(chunk_path)
|
|
@@ -24,8 +24,9 @@ class SplitStats:
|
|
|
24
24
|
group_column: str | None = None
|
|
25
25
|
time_column: str | None = None
|
|
26
26
|
time_ascending: bool = True
|
|
27
|
-
|
|
28
|
-
|
|
27
|
+
inject_train_rows: int = 0
|
|
28
|
+
inject_valid_rows: int = 0
|
|
29
|
+
inject_test_rows: int = 0
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
def split_dataframe(
|
|
File without changes
|