dtflow 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/SKILL.md ADDED
@@ -0,0 +1,225 @@
1
+ ---
2
+ name: dtflow
3
+ description: 数据文件处理(JSONL/CSV/Parquet)- 去重/采样/统计/过滤/转换/Schema验证/训练框架导出
4
+ ---
5
+
6
+ # dtflow - 机器学习训练数据格式转换工具
7
+
8
+ ## 设计理念
9
+
10
+ - **函数式优于类继承**:直接用 lambda/函数做转换,不需要 OOP 抽象
11
+ - **KISS 原则**:一个 `DataTransformer` 类搞定所有操作
12
+ - **链式 API**:`dt.filter(...).to(...).save(...)`
13
+
14
+ ## Python API
15
+
16
+ ```python
17
+ from dtflow import DataTransformer
18
+
19
+ # 加载数据(支持 JSONL/JSON/CSV/Parquet/Arrow,使用 Polars 引擎)
20
+ dt = DataTransformer.load("data.jsonl")
21
+
22
+ # 链式操作
23
+ (dt.filter(lambda x: x.score > 0.8)
24
+ .to(lambda x: {"q": x.question, "a": x.answer})
25
+ .dedupe("text")
26
+ .save("output.jsonl"))
27
+ ```
28
+
29
+ ### 数据过滤
30
+
31
+ ```python
32
+ dt.filter(lambda x: x.score > 0.8)
33
+ dt.filter(lambda x: x.language == "zh")
34
+ ```
35
+
36
+ ### 数据验证
37
+
38
+ ```python
39
+ # 简单验证
40
+ errors = dt.validate(lambda x: len(x.messages) >= 2)
41
+
42
+ # Schema 验证
43
+ from dtflow import Schema, Field, openai_chat_schema
44
+
45
+ result = dt.validate_schema(openai_chat_schema) # 预设 Schema
46
+ valid_dt = dt.validate_schema(schema, filter_invalid=True) # 过滤无效数据
47
+ ```
48
+
49
+ **预设 Schema**:`openai_chat_schema`、`alpaca_schema`、`sharegpt_schema`、`dpo_schema`
50
+
51
+ ### 数据转换
52
+
53
+ ```python
54
+ # 自定义转换
55
+ dt.to(lambda x: {"question": x.q, "answer": x.a})
56
+
57
+ # 使用预设模板
58
+ dt.to(preset="openai_chat", user_field="q", assistant_field="a")
59
+ ```
60
+
61
+ **预设模板**:`openai_chat`、`alpaca`、`sharegpt`、`dpo_pair`、`simple_qa`
62
+
63
+ ### Token 统计
64
+
65
+ ```python
66
+ from dtflow import count_tokens, token_counter, token_filter, token_stats
67
+
68
+ count = count_tokens("Hello world", model="gpt-4")
69
+ dt.transform(token_counter("text")).save("with_tokens.jsonl")
70
+ dt.filter(token_filter("text", max_tokens=2048))
71
+
72
+ # Messages Token 统计(多轮对话)
73
+ from dtflow import messages_token_counter, messages_token_filter
74
+ dt.transform(messages_token_counter(model="gpt-4", detailed=True))
75
+ dt.filter(messages_token_filter(min_turns=2, max_turns=10))
76
+ ```
77
+
78
+ ### 格式转换器
79
+
80
+ ```python
81
+ from dtflow import (
82
+ to_hf_dataset, from_hf_dataset, # HuggingFace Dataset
83
+ to_openai_batch, from_openai_batch, # OpenAI Batch API
84
+ to_llama_factory, to_llama_factory_sharegpt, # LLaMA-Factory
85
+ to_swift_messages, to_swift_query_response, # ms-swift
86
+ messages_to_text, # messages 转纯文本
87
+ )
88
+ ```
89
+
90
+ ### 训练框架导出
91
+
92
+ ```python
93
+ # 检查兼容性
94
+ result = dt.check_compatibility("llama-factory")
95
+
96
+ # 一键导出
97
+ files = dt.export_for("llama-factory", "./output/") # 生成 data.json + dataset_info.json + train_args.yaml
98
+ files = dt.export_for("swift", "./output/") # 生成 data.jsonl + train_swift.sh
99
+ files = dt.export_for("axolotl", "./output/") # 生成 data.jsonl + config.yaml
100
+ ```
101
+
102
+ ### 大文件流式处理
103
+
104
+ ```python
105
+ from dtflow import load_stream, load_sharded
106
+
107
+ # O(1) 内存,100GB 文件也能处理
108
+ (load_stream("huge.jsonl")
109
+ .filter(lambda x: x["score"] > 0.5)
110
+ .save("output.jsonl"))
111
+
112
+ # 分片文件加载
113
+ (load_sharded("data/train_*.parquet")
114
+ .filter(lambda x: len(x["text"]) > 10)
115
+ .save("merged.jsonl"))
116
+
117
+ # 分片保存
118
+ load_stream("huge.jsonl").save_sharded("output/", shard_size=100000)
119
+ ```
120
+
121
+ ### 其他操作
122
+
123
+ ```python
124
+ dt.sample(100) # 随机采样
125
+ dt.head(10) / dt.tail(10) # 取前/后 N 条
126
+ train, test = dt.split(ratio=0.8) # 分割
127
+ dt.shuffle(seed=42) # 打乱
128
+ dt.stats() # 统计
129
+ ```
130
+
131
+ ## CLI 命令
132
+
133
+ ```bash
134
+ # 统计(推荐首先使用)
135
+ dt stats data.jsonl # 基本统计(文件大小、条数、字段)
136
+ dt stats data.jsonl --full # 完整模式:值分布、唯一值、非空率
137
+ dt stats data.jsonl --full -n 20 # 显示 Top 20 值分布
138
+
139
+ # Token 统计
140
+ dt token-stats data.jsonl # 默认统计 messages 字段
141
+ dt token-stats data.jsonl -f text # 指定统计字段
142
+ dt token-stats data.jsonl -m qwen2.5 # 指定分词器 (cl100k_base/qwen2.5/llama3)
143
+ dt token-stats data.jsonl --detailed # 显示详细统计
144
+
145
+ # 采样(支持字段路径语法)
146
+ dt sample data.jsonl 100 # 随机采样 100 条
147
+ dt sample data.jsonl 100 -t head # 取前 100 条 (head/tail/random)
148
+ dt sample data.jsonl 1000 --by=category # 分层采样
149
+ dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
150
+ dt sample data.jsonl --where="messages.#>=2" # 条件筛选
151
+ dt sample data.jsonl 10 -f input,output # 只显示指定字段
152
+ dt sample data.jsonl 10 --raw # 输出原始 JSON(不截断)
153
+ dt sample data.jsonl 100 --seed=42 -o out.jsonl # 固定随机种子并保存
154
+
155
+ # 去重
156
+ dt dedupe data.jsonl --key=text # 精确去重
157
+ dt dedupe data.jsonl --key=meta.id # 按嵌套字段去重
158
+ dt dedupe data.jsonl --key=text --similar=0.8 # 相似度去重
159
+ dt dedupe data.jsonl --key=text -o deduped.jsonl # 指定输出文件
160
+
161
+ # 清洗
162
+ dt clean data.jsonl --drop-empty=text,answer # 删除空值记录
163
+ dt clean data.jsonl --min-len=text:10 # 最小长度过滤
164
+ dt clean data.jsonl --max-len=text:2000 # 最大长度过滤
165
+ dt clean data.jsonl --min-len=messages.#:2 # 最少 2 条消息
166
+ dt clean data.jsonl --keep=question,answer # 只保留指定字段
167
+ dt clean data.jsonl --drop=metadata # 删除指定字段
168
+ dt clean data.jsonl --strip # 去除字符串首尾空白
169
+ dt clean data.jsonl --strip --drop-empty=input -o cleaned.jsonl # 组合使用
170
+
171
+ # 验证
172
+ dt validate data.jsonl --preset=openai_chat # 预设: openai_chat/alpaca/dpo/sharegpt
173
+ dt validate data.jsonl -p alpaca -f -o valid.jsonl # 过滤无效数据并保存
174
+ dt validate data.jsonl -p openai_chat -v # 显示详细信息
175
+ dt validate data.jsonl -p openai_chat --max-errors=50 # 最多显示 50 条错误
176
+
177
+ # 转换
178
+ dt transform data.jsonl --preset=openai_chat
179
+ dt transform data.jsonl # 交互式生成配置文件
180
+
181
+ # 合并与对比
182
+ dt concat a.jsonl b.jsonl -o merged.jsonl # 合并文件
183
+ dt concat a.jsonl b.jsonl -o merged.jsonl --strict # 严格模式(字段必须一致)
184
+ dt diff a.jsonl b.jsonl --key=id # 对比差异
185
+ dt diff a.jsonl b.jsonl --key=id -o report.md # 输出对比报告
186
+
187
+ # 查看数据
188
+ dt head data.jsonl 10 # 前 10 条
189
+ dt head data.jsonl 10 -f input,output # 只显示指定字段
190
+ dt head data.jsonl 10 --raw # 输出完整 JSON(不截断)
191
+ dt tail data.jsonl 10 # 后 10 条
192
+
193
+ # 其他
194
+ dt run pipeline.yaml # Pipeline 执行
195
+ dt history processed.jsonl # 数据血缘
196
+ dt install-skill # 安装 Claude Code skill
197
+ ```
198
+
199
+ ## 字段路径语法
200
+
201
+ | 语法 | 含义 | 示例 |
202
+ |------|------|------|
203
+ | `a.b.c` | 嵌套字段 | `meta.source` |
204
+ | `a[0].b` | 数组索引 | `messages[0].role` |
205
+ | `a[-1].b` | 负索引 | `messages[-1].content` |
206
+ | `a.#` | 数组长度 | `messages.#` |
207
+ | `a[*].b` | 展开所有元素 | `messages[*].role` |
208
+
209
+ ## Pipeline 配置
210
+
211
+ ```yaml
212
+ # pipeline.yaml
213
+ version: "1.0"
214
+ seed: 42
215
+ input: raw_data.jsonl
216
+ output: processed.jsonl
217
+
218
+ steps:
219
+ - type: filter
220
+ condition: "score > 0.5"
221
+ - type: transform
222
+ preset: openai_chat
223
+ - type: dedupe
224
+ key: text
225
+ ```
dtflow/__init__.py CHANGED
@@ -60,7 +60,7 @@ from .tokenizers import (
60
60
  token_stats,
61
61
  )
62
62
 
63
- __version__ = "0.5.5"
63
+ __version__ = "0.5.7"
64
64
 
65
65
  __all__ = [
66
66
  # core
dtflow/__main__.py CHANGED
@@ -6,21 +6,21 @@ Usage:
6
6
  dt --install-completion # 安装 shell 自动补全
7
7
 
8
8
  Commands:
9
- sample 从数据文件中采样
10
- head 显示文件的前 N 条数据
11
- tail 显示文件的后 N 条数据
12
- transform 转换数据格式(核心命令)
13
- stats 显示数据文件的统计信息
14
- token-stats Token 统计
15
- diff 数据集对比
16
- dedupe 数据去重
17
- concat 拼接多个数据文件
18
- clean 数据清洗
19
- run 执行 Pipeline 配置文件
20
- history 显示数据血缘历史
21
- validate 使用 Schema 验证数据格式
22
- mcp MCP 服务管理(install/uninstall/status)
23
- logs 日志查看工具使用说明
9
+ sample 从数据文件中采样
10
+ head 显示文件的前 N 条数据
11
+ tail 显示文件的后 N 条数据
12
+ transform 转换数据格式(核心命令)
13
+ stats 显示数据文件的统计信息
14
+ token-stats Token 统计
15
+ diff 数据集对比
16
+ dedupe 数据去重
17
+ concat 拼接多个数据文件
18
+ clean 数据清洗
19
+ run 执行 Pipeline 配置文件
20
+ history 显示数据血缘历史
21
+ validate 使用 Schema 验证数据格式
22
+ logs 日志查看工具使用说明
23
+ install-skill 安装 dtflow skill 到 Claude Code
24
24
  """
25
25
 
26
26
  import os
@@ -35,12 +35,15 @@ from .cli.commands import dedupe as _dedupe
35
35
  from .cli.commands import diff as _diff
36
36
  from .cli.commands import head as _head
37
37
  from .cli.commands import history as _history
38
+ from .cli.commands import install_skill as _install_skill
38
39
  from .cli.commands import run as _run
39
40
  from .cli.commands import sample as _sample
41
+ from .cli.commands import skill_status as _skill_status
40
42
  from .cli.commands import stats as _stats
41
43
  from .cli.commands import tail as _tail
42
44
  from .cli.commands import token_stats as _token_stats
43
45
  from .cli.commands import transform as _transform
46
+ from .cli.commands import uninstall_skill as _uninstall_skill
44
47
  from .cli.commands import validate as _validate
45
48
 
46
49
  # 创建主应用
@@ -67,10 +70,11 @@ def sample(
67
70
  uniform: bool = typer.Option(False, "--uniform", help="均匀采样模式"),
68
71
  fields: Optional[str] = typer.Option(None, "--fields", "-f", help="只显示指定字段(逗号分隔)"),
69
72
  raw: bool = typer.Option(False, "--raw", "-r", help="输出原始 JSON(不截断)"),
73
+ where: Optional[List[str]] = typer.Option(None, "--where", "-w", help="筛选条件 (可多次使用)"),
70
74
  ):
71
75
  """从数据文件中采样指定数量的数据"""
72
76
  actual_num = num_arg if num_arg is not None else num
73
- _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw)
77
+ _sample(filename, actual_num, type, output, seed, by, uniform, fields, raw, where)
74
78
 
75
79
 
76
80
  @app.command()
@@ -262,48 +266,25 @@ dtflow 内置了 toolong 日志查看器,安装后可直接使用 tl 命令:
262
266
  print(help_text)
263
267
 
264
268
 
265
- # ============ MCP 子命令 ============
269
+ # ============ Skill 命令 ============
266
270
 
267
- mcp_app = typer.Typer(help="MCP 服务管理")
268
- app.add_typer(mcp_app, name="mcp")
269
271
 
272
+ @app.command("install-skill")
273
+ def install_skill():
274
+ """安装 dtflow skill 到 Claude Code"""
275
+ _install_skill()
270
276
 
271
- @mcp_app.command()
272
- def install(
273
- name: str = typer.Option("datatron", "--name", "-n", help="MCP 服务名称"),
274
- target: str = typer.Option("code", "--target", "-t", help="安装目标: desktop/code/all"),
275
- ):
276
- """安装 Datatron MCP 服务"""
277
- from .mcp.cli import MCPCommands
278
-
279
- MCPCommands().install(name, target)
280
-
281
-
282
- @mcp_app.command()
283
- def uninstall(
284
- name: str = typer.Option("datatron", "--name", "-n", help="MCP 服务名称"),
285
- target: str = typer.Option("all", "--target", "-t", help="移除目标: desktop/code/all"),
286
- ):
287
- """移除 Datatron MCP 服务"""
288
- from .mcp.cli import MCPCommands
289
-
290
- MCPCommands().uninstall(name, target)
291
-
292
-
293
- @mcp_app.command()
294
- def status():
295
- """查看 MCP 服务安装状态"""
296
- from .mcp.cli import MCPCommands
297
-
298
- MCPCommands().status()
299
277
 
278
+ @app.command("uninstall-skill")
279
+ def uninstall_skill():
280
+ """卸载 dtflow skill"""
281
+ _uninstall_skill()
300
282
 
301
- @mcp_app.command()
302
- def test():
303
- """测试 MCP 服务是否正常"""
304
- from .mcp.cli import MCPCommands
305
283
 
306
- MCPCommands().test()
284
+ @app.command("skill-status")
285
+ def skill_status():
286
+ """查看 skill 安装状态"""
287
+ _skill_status()
307
288
 
308
289
 
309
290
  def _show_completion_hint():
dtflow/cli/commands.py CHANGED
@@ -13,25 +13,27 @@ CLI 命令统一导出入口
13
13
  """
14
14
 
15
15
  # 采样命令
16
- from .sample import head, sample, tail
17
-
18
- # 转换命令
19
- from .transform import transform
20
-
21
- # 统计命令
22
- from .stats import stats, token_stats
23
-
24
16
  # 清洗命令
25
17
  from .clean import clean, dedupe
26
18
 
27
19
  # IO 操作命令
28
20
  from .io_ops import concat, diff
29
21
 
22
+ # 血缘追踪命令
23
+ from .lineage import history
24
+
30
25
  # Pipeline 命令
31
26
  from .pipeline import run
27
+ from .sample import head, sample, tail
32
28
 
33
- # 血缘追踪命令
34
- from .lineage import history
29
+ # Skill 命令
30
+ from .skill import install_skill, skill_status, uninstall_skill
31
+
32
+ # 统计命令
33
+ from .stats import stats, token_stats
34
+
35
+ # 转换命令
36
+ from .transform import transform
35
37
 
36
38
  # 验证命令
37
39
  from .validate import validate
@@ -58,4 +60,8 @@ __all__ = [
58
60
  "history",
59
61
  # 验证
60
62
  "validate",
63
+ # Skill
64
+ "install_skill",
65
+ "uninstall_skill",
66
+ "skill_status",
61
67
  ]
dtflow/cli/sample.py CHANGED
@@ -2,8 +2,9 @@
2
2
  CLI 采样相关命令
3
3
  """
4
4
 
5
+ import re
5
6
  from pathlib import Path
6
- from typing import Any, Dict, List, Literal, Optional
7
+ from typing import Any, Callable, Dict, List, Literal, Optional
7
8
 
8
9
  import orjson
9
10
 
@@ -16,6 +17,122 @@ from .common import (
16
17
  _print_samples,
17
18
  )
18
19
 
20
+ # where 条件解析正则:field op value
21
+ _WHERE_PATTERN = re.compile(r"^(.+?)(!=|~=|>=|<=|>|<|=)(.*)$")
22
+
23
+
24
+ def _parse_where(condition: str) -> Callable[[dict], bool]:
25
+ """
26
+ 解析 where 条件字符串,返回筛选函数。
27
+
28
+ 支持的操作符:
29
+ = 等于
30
+ != 不等于
31
+ ~= 包含(字符串)
32
+ > 大于
33
+ >= 大于等于
34
+ < 小于
35
+ <= 小于等于
36
+
37
+ Examples:
38
+ _parse_where("category=tech")
39
+ _parse_where("meta.source!=wiki")
40
+ _parse_where("content~=机器学习")
41
+ _parse_where("messages.#>=2")
42
+ """
43
+ match = _WHERE_PATTERN.match(condition)
44
+ if not match:
45
+ raise ValueError(f"无效的 where 条件: {condition}")
46
+
47
+ field, op, value = match.groups()
48
+
49
+ # 尝试转换 value 为数值
50
+ def parse_value(v: str) -> Any:
51
+ if v.lower() == "true":
52
+ return True
53
+ if v.lower() == "false":
54
+ return False
55
+ try:
56
+ return int(v)
57
+ except ValueError:
58
+ try:
59
+ return float(v)
60
+ except ValueError:
61
+ return v
62
+
63
+ parsed_value = parse_value(value)
64
+
65
+ def filter_fn(item: dict) -> bool:
66
+ field_value = get_field_with_spec(item, field)
67
+
68
+ if op == "=":
69
+ # 字符串比较或数值比较
70
+ if field_value is None:
71
+ return value == "" or value.lower() == "none"
72
+ return str(field_value) == value or field_value == parsed_value
73
+ elif op == "!=":
74
+ if field_value is None:
75
+ return value != "" and value.lower() != "none"
76
+ return str(field_value) != value and field_value != parsed_value
77
+ elif op == "~=":
78
+ # 包含
79
+ if field_value is None:
80
+ return False
81
+ return value in str(field_value)
82
+ elif op in (">", ">=", "<", "<="):
83
+ # 数值比较
84
+ if field_value is None:
85
+ return False
86
+ try:
87
+ num_field = float(field_value)
88
+ num_value = float(value)
89
+ if op == ">":
90
+ return num_field > num_value
91
+ elif op == ">=":
92
+ return num_field >= num_value
93
+ elif op == "<":
94
+ return num_field < num_value
95
+ else: # <=
96
+ return num_field <= num_value
97
+ except (ValueError, TypeError):
98
+ return False
99
+ return False
100
+
101
+ return filter_fn
102
+
103
+
104
+ def _apply_where_filters(data: List[Dict], where_conditions: List[str]) -> List[Dict]:
105
+ """应用多个 where 条件(AND 关系)"""
106
+ if not where_conditions:
107
+ return data
108
+
109
+ filters = [_parse_where(cond) for cond in where_conditions]
110
+ return [item for item in data if all(f(item) for f in filters)]
111
+
112
+
113
+ def _sample_from_list(
114
+ data: List[Dict],
115
+ num: int,
116
+ sample_type: str,
117
+ seed: Optional[int] = None,
118
+ ) -> List[Dict]:
119
+ """从列表中采样"""
120
+ import random
121
+
122
+ if seed is not None:
123
+ random.seed(seed)
124
+
125
+ total = len(data)
126
+ if num <= 0 or num > total:
127
+ num = total
128
+
129
+ if sample_type == "random":
130
+ return random.sample(data, num)
131
+ elif sample_type == "head":
132
+ return data[:num]
133
+ else: # tail
134
+ return data[-num:]
135
+
19
136
 
20
137
  def sample(
21
138
  filename: str,
@@ -27,6 +144,7 @@ def sample(
27
144
  uniform: bool = False,
28
145
  fields: Optional[str] = None,
29
146
  raw: bool = False,
147
+ where: Optional[List[str]] = None,
30
148
  ) -> None:
31
149
  """
32
150
  从数据文件中采样指定数量的数据。
@@ -44,6 +162,7 @@ def sample(
44
162
  uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
45
163
  fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
46
164
  raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
165
+ where: 筛选条件列表,支持 =, !=, ~=, >, >=, <, <= 操作符
47
166
 
48
167
  Examples:
49
168
  dt sample data.jsonl 5
@@ -54,6 +173,9 @@ def sample(
54
173
  dt sample data.jsonl 1000 --by=category # 按比例分层采样
55
174
  dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
56
175
  dt sample data.jsonl --fields=question,answer # 只显示指定字段
176
+ dt sample data.jsonl --where="category=tech" # 筛选 category 为 tech 的数据
177
+ dt sample data.jsonl --where="meta.source~=wiki" # 筛选 meta.source 包含 wiki
178
+ dt sample data.jsonl --where="messages.#>=2" # 筛选消息数量 >= 2
57
179
  """
58
180
  filepath = Path(filename)
59
181
 
@@ -69,23 +191,46 @@ def sample(
69
191
  print("错误: --uniform 必须配合 --by 使用")
70
192
  return
71
193
 
194
+ # 处理 where 筛选
195
+ where_conditions = where or []
196
+ filtered_data = None
197
+ original_count = None
198
+
199
+ if where_conditions:
200
+ # 有 where 条件时,先加载全部数据再筛选
201
+ try:
202
+ all_data = load_data(str(filepath))
203
+ original_count = len(all_data)
204
+ filtered_data = _apply_where_filters(all_data, where_conditions)
205
+ print(f"🔍 筛选: {original_count} → {len(filtered_data)} 条")
206
+ if not filtered_data:
207
+ print("⚠️ 筛选后无数据")
208
+ return
209
+ except ValueError as e:
210
+ print(f"错误: {e}")
211
+ return
212
+
72
213
  # 分层采样模式
73
214
  if by:
74
215
  try:
75
- sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
216
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type, data=filtered_data)
76
217
  except Exception as e:
77
218
  print(f"错误: {e}")
78
219
  return
79
220
  else:
80
221
  # 普通采样
81
222
  try:
82
- sampled = sample_file(
83
- str(filepath),
84
- num=num,
85
- sample_type=type,
86
- seed=seed,
87
- output=None, # 先不保存,统一在最后处理
88
- )
223
+ if filtered_data is not None:
224
+ # 已筛选的数据,直接采样
225
+ sampled = _sample_from_list(filtered_data, num, type, seed)
226
+ else:
227
+ sampled = sample_file(
228
+ str(filepath),
229
+ num=num,
230
+ sample_type=type,
231
+ seed=seed,
232
+ output=None, # 先不保存,统一在最后处理
233
+ )
89
234
  except Exception as e:
90
235
  print(f"错误: {e}")
91
236
  return
@@ -117,6 +262,7 @@ def _stratified_sample(
117
262
  uniform: bool,
118
263
  seed: Optional[int],
119
264
  sample_type: str,
265
+ data: Optional[List[Dict]] = None,
120
266
  ) -> List[Dict]:
121
267
  """
122
268
  分层采样实现。
@@ -133,6 +279,7 @@ def _stratified_sample(
133
279
  uniform: 是否均匀采样(各组相同数量)
134
280
  seed: 随机种子
135
281
  sample_type: 采样方式(用于组内采样)
282
+ data: 预筛选的数据(可选,如果提供则不从文件加载)
136
283
 
137
284
  Returns:
138
285
  采样后的数据列表
@@ -143,8 +290,9 @@ def _stratified_sample(
143
290
  if seed is not None:
144
291
  random.seed(seed)
145
292
 
146
- # 加载数据
147
- data = load_data(str(filepath))
293
+ # 加载数据(如果没有预筛选数据)
294
+ if data is None:
295
+ data = load_data(str(filepath))
148
296
  total = len(data)
149
297
 
150
298
  if num <= 0 or num > total: