dtflow 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/cli/sample.py ADDED
@@ -0,0 +1,294 @@
1
+ """
2
+ CLI 采样相关命令
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Literal, Optional
7
+
8
+ import orjson
9
+
10
+ from ..storage.io import load_data, sample_file, save_data
11
+ from ..utils.field_path import get_field_with_spec
12
+ from .common import (
13
+ _check_file_format,
14
+ _get_file_row_count,
15
+ _parse_field_list,
16
+ _print_samples,
17
+ )
18
+
19
+
20
+ def sample(
21
+ filename: str,
22
+ num: int = 10,
23
+ type: Literal["random", "head", "tail"] = "head",
24
+ output: Optional[str] = None,
25
+ seed: Optional[int] = None,
26
+ by: Optional[str] = None,
27
+ uniform: bool = False,
28
+ fields: Optional[str] = None,
29
+ raw: bool = False,
30
+ ) -> None:
31
+ """
32
+ 从数据文件中采样指定数量的数据。
33
+
34
+ Args:
35
+ filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
36
+ num: 采样数量,默认 10
37
+ - num > 0: 采样指定数量
38
+ - num = 0: 采样所有数据
39
+ - num < 0: Python 切片风格(如 -1 表示最后 1 条,-10 表示最后 10 条)
40
+ type: 采样方式,可选 random/head/tail,默认 head
41
+ output: 输出文件路径,不指定则打印到控制台
42
+ seed: 随机种子(仅在 type=random 时有效)
43
+ by: 分层采样字段名,按该字段的值分组采样
44
+ uniform: 均匀采样模式(需配合 --by 使用),各组采样相同数量
45
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
46
+ raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
47
+
48
+ Examples:
49
+ dt sample data.jsonl 5
50
+ dt sample data.csv 100 --type=head
51
+ dt sample data.xlsx 50 --output=sampled.jsonl
52
+ dt sample data.jsonl 0 # 采样所有数据
53
+ dt sample data.jsonl -10 # 最后 10 条数据
54
+ dt sample data.jsonl 1000 --by=category # 按比例分层采样
55
+ dt sample data.jsonl 1000 --by=category --uniform # 均匀分层采样
56
+ dt sample data.jsonl --fields=question,answer # 只显示指定字段
57
+ """
58
+ filepath = Path(filename)
59
+
60
+ if not filepath.exists():
61
+ print(f"错误: 文件不存在 - {filename}")
62
+ return
63
+
64
+ if not _check_file_format(filepath):
65
+ return
66
+
67
+ # uniform 必须配合 by 使用
68
+ if uniform and not by:
69
+ print("错误: --uniform 必须配合 --by 使用")
70
+ return
71
+
72
+ # 分层采样模式
73
+ if by:
74
+ try:
75
+ sampled = _stratified_sample(filepath, num, by, uniform, seed, type)
76
+ except Exception as e:
77
+ print(f"错误: {e}")
78
+ return
79
+ else:
80
+ # 普通采样
81
+ try:
82
+ sampled = sample_file(
83
+ str(filepath),
84
+ num=num,
85
+ sample_type=type,
86
+ seed=seed,
87
+ output=None, # 先不保存,统一在最后处理
88
+ )
89
+ except Exception as e:
90
+ print(f"错误: {e}")
91
+ return
92
+
93
+ # 输出结果
94
+ if output:
95
+ save_data(sampled, output)
96
+ print(f"已保存 {len(sampled)} 条数据到 {output}")
97
+ elif raw:
98
+ # 原始 JSON 输出(不截断)
99
+ for item in sampled:
100
+ print(orjson.dumps(item, option=orjson.OPT_INDENT_2).decode("utf-8"))
101
+ else:
102
+ # 获取文件总行数用于显示
103
+ total_count = _get_file_row_count(filepath)
104
+ # 解析 fields 参数
105
+ field_list = _parse_field_list(fields) if fields else None
106
+ _print_samples(sampled, filepath.name, total_count, field_list)
107
+
108
+
109
+ def _stratified_sample(
110
+ filepath: Path,
111
+ num: int,
112
+ stratify_field: str,
113
+ uniform: bool,
114
+ seed: Optional[int],
115
+ sample_type: str,
116
+ ) -> List[Dict]:
117
+ """
118
+ 分层采样实现。
119
+
120
+ Args:
121
+ filepath: 文件路径
122
+ num: 目标采样总数
123
+ stratify_field: 分层字段,支持嵌套路径语法:
124
+ - meta.source 嵌套字段
125
+ - messages[0].role 数组索引
126
+ - messages[-1].role 负索引
127
+ - messages.# 数组长度
128
+ - messages[*].role 展开所有元素(可加 :join/:unique 模式)
129
+ uniform: 是否均匀采样(各组相同数量)
130
+ seed: 随机种子
131
+ sample_type: 采样方式(用于组内采样)
132
+
133
+ Returns:
134
+ 采样后的数据列表
135
+ """
136
+ import random
137
+ from collections import defaultdict
138
+
139
+ if seed is not None:
140
+ random.seed(seed)
141
+
142
+ # 加载数据
143
+ data = load_data(str(filepath))
144
+ total = len(data)
145
+
146
+ if num <= 0 or num > total:
147
+ num = total
148
+
149
+ # 按字段分组(支持嵌套路径语法)
150
+ groups: Dict[Any, List[Dict]] = defaultdict(list)
151
+ for item in data:
152
+ key = get_field_with_spec(item, stratify_field, default="__null__")
153
+ # 确保 key 可哈希
154
+ if isinstance(key, list):
155
+ key = tuple(key)
156
+ groups[key].append(item)
157
+
158
+ group_keys = list(groups.keys())
159
+ num_groups = len(group_keys)
160
+
161
+ # 打印分组信息
162
+ print(f"📊 分层采样: 字段={stratify_field}, 共 {num_groups} 组")
163
+ for key in sorted(group_keys, key=lambda x: -len(groups[x])):
164
+ count = len(groups[key])
165
+ pct = count / total * 100
166
+ display_key = key if key != "__null__" else "[空值]"
167
+ print(f" {display_key}: {count} 条 ({pct:.1f}%)")
168
+
169
+ # 计算各组采样数量
170
+ if uniform:
171
+ # 均匀采样:各组数量相等
172
+ per_group = num // num_groups
173
+ remainder = num % num_groups
174
+ sample_counts = {key: per_group for key in group_keys}
175
+ # 余数分配给数据量最多的组
176
+ for key in sorted(group_keys, key=lambda x: -len(groups[x]))[:remainder]:
177
+ sample_counts[key] += 1
178
+ else:
179
+ # 按比例采样:保持原有比例
180
+ sample_counts = {}
181
+ allocated = 0
182
+ # 按组大小降序处理,确保小组也能分到
183
+ sorted_keys = sorted(group_keys, key=lambda x: -len(groups[x]))
184
+ for i, key in enumerate(sorted_keys):
185
+ if i == len(sorted_keys) - 1:
186
+ # 最后一组分配剩余
187
+ sample_counts[key] = num - allocated
188
+ else:
189
+ # 按比例计算
190
+ ratio = len(groups[key]) / total
191
+ count = int(num * ratio)
192
+ # 确保至少 1 条(如果组有数据)
193
+ count = max(1, count) if groups[key] else 0
194
+ sample_counts[key] = count
195
+ allocated += count
196
+
197
+ # 执行各组采样
198
+ result = []
199
+ print(f"🔄 执行采样...")
200
+ for key in group_keys:
201
+ group_data = groups[key]
202
+ target = min(sample_counts[key], len(group_data))
203
+
204
+ if target <= 0:
205
+ continue
206
+
207
+ # 组内采样
208
+ if sample_type == "random":
209
+ sampled = random.sample(group_data, target)
210
+ elif sample_type == "head":
211
+ sampled = group_data[:target]
212
+ else: # tail
213
+ sampled = group_data[-target:]
214
+
215
+ result.extend(sampled)
216
+
217
+ # 打印采样结果
218
+ print(f"\n📋 采样结果:")
219
+ result_groups: Dict[Any, int] = defaultdict(int)
220
+ for item in result:
221
+ key = item.get(stratify_field, "__null__")
222
+ result_groups[key] += 1
223
+
224
+ for key in sorted(group_keys, key=lambda x: -len(groups[x])):
225
+ orig = len(groups[key])
226
+ sampled_count = result_groups.get(key, 0)
227
+ display_key = key if key != "__null__" else "[空值]"
228
+ print(f" {display_key}: {orig} → {sampled_count}")
229
+
230
+ print(f"\n✅ 总计: {total} → {len(result)} 条")
231
+
232
+ return result
233
+
234
+
235
+ def head(
236
+ filename: str,
237
+ num: int = 10,
238
+ output: Optional[str] = None,
239
+ fields: Optional[str] = None,
240
+ raw: bool = False,
241
+ ) -> None:
242
+ """
243
+ 显示文件的前 N 条数据(dt sample --type=head 的快捷方式)。
244
+
245
+ Args:
246
+ filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
247
+ num: 显示数量,默认 10
248
+ - num > 0: 显示指定数量
249
+ - num = 0: 显示所有数据
250
+ - num < 0: Python 切片风格(如 -10 表示最后 10 条)
251
+ output: 输出文件路径,不指定则打印到控制台
252
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
253
+ raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
254
+
255
+ Examples:
256
+ dt head data.jsonl # 显示前 10 条
257
+ dt head data.jsonl 20 # 显示前 20 条
258
+ dt head data.csv 0 # 显示所有数据
259
+ dt head data.xlsx --output=head.jsonl
260
+ dt head data.jsonl --fields=question,answer
261
+ dt head data.jsonl 1 --raw # 完整 JSON 输出
262
+ """
263
+ sample(filename, num=num, type="head", output=output, fields=fields, raw=raw)
264
+
265
+
266
+ def tail(
267
+ filename: str,
268
+ num: int = 10,
269
+ output: Optional[str] = None,
270
+ fields: Optional[str] = None,
271
+ raw: bool = False,
272
+ ) -> None:
273
+ """
274
+ 显示文件的后 N 条数据(dt sample --type=tail 的快捷方式)。
275
+
276
+ Args:
277
+ filename: 输入文件路径,支持 csv/excel/jsonl/json/parquet/arrow/feather 格式
278
+ num: 显示数量,默认 10
279
+ - num > 0: 显示指定数量
280
+ - num = 0: 显示所有数据
281
+ - num < 0: Python 切片风格(如 -10 表示最后 10 条)
282
+ output: 输出文件路径,不指定则打印到控制台
283
+ fields: 只显示指定字段(逗号分隔),仅在预览模式下有效
284
+ raw: 输出原始 JSON 格式(不截断,完整显示所有内容)
285
+
286
+ Examples:
287
+ dt tail data.jsonl # 显示后 10 条
288
+ dt tail data.jsonl 20 # 显示后 20 条
289
+ dt tail data.csv 0 # 显示所有数据
290
+ dt tail data.xlsx --output=tail.jsonl
291
+ dt tail data.jsonl --fields=question,answer
292
+ dt tail data.jsonl 1 --raw # 完整 JSON 输出
293
+ """
294
+ sample(filename, num=num, type="tail", output=output, fields=fields, raw=raw)