dtflow 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtflow/__init__.py +70 -43
- dtflow/__main__.py +301 -239
- dtflow/cli/__init__.py +29 -2
- dtflow/cli/commands.py +1112 -113
- dtflow/converters.py +39 -23
- dtflow/core.py +140 -72
- dtflow/lineage.py +410 -0
- dtflow/mcp/__init__.py +1 -0
- dtflow/mcp/__main__.py +2 -0
- dtflow/mcp/cli.py +35 -17
- dtflow/mcp/docs.py +0 -5
- dtflow/pipeline.py +460 -0
- dtflow/presets.py +24 -22
- dtflow/storage/__init__.py +11 -10
- dtflow/storage/io.py +384 -369
- dtflow/streaming.py +656 -0
- dtflow/tokenizers.py +212 -57
- dtflow/utils/__init__.py +2 -1
- dtflow/utils/display.py +28 -27
- {dtflow-0.3.0.dist-info → dtflow-0.3.2.dist-info}/METADATA +153 -7
- dtflow-0.3.2.dist-info/RECORD +24 -0
- dtflow-0.3.0.dist-info/RECORD +0 -21
- {dtflow-0.3.0.dist-info → dtflow-0.3.2.dist-info}/WHEEL +0 -0
- {dtflow-0.3.0.dist-info → dtflow-0.3.2.dist-info}/entry_points.txt +0 -0
dtflow/converters.py
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
|
|
4
4
|
提供与 HuggingFace datasets 等常用格式的互转功能。
|
|
5
5
|
"""
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def to_hf_dataset(data: List[Dict[str, Any]]):
|
|
@@ -44,7 +45,7 @@ def from_hf_dataset(dataset, split: Optional[str] = None) -> List[Dict[str, Any]
|
|
|
44
45
|
>>> data = from_hf_dataset(my_dataset, split="train")
|
|
45
46
|
"""
|
|
46
47
|
try:
|
|
47
|
-
from datasets import
|
|
48
|
+
from datasets import Dataset, DatasetDict, load_dataset
|
|
48
49
|
except ImportError:
|
|
49
50
|
raise ImportError("需要安装 datasets: pip install datasets")
|
|
50
51
|
|
|
@@ -53,7 +54,7 @@ def from_hf_dataset(dataset, split: Optional[str] = None) -> List[Dict[str, Any]
|
|
|
53
54
|
dataset = load_dataset(dataset, split=split)
|
|
54
55
|
|
|
55
56
|
# 处理 DatasetDict
|
|
56
|
-
if hasattr(dataset,
|
|
57
|
+
if hasattr(dataset, "keys"): # DatasetDict
|
|
57
58
|
if split:
|
|
58
59
|
dataset = dataset[split]
|
|
59
60
|
else:
|
|
@@ -83,8 +84,9 @@ def to_hf_chat_format(
|
|
|
83
84
|
Examples:
|
|
84
85
|
>>> dt.transform(to_hf_chat_format())
|
|
85
86
|
"""
|
|
87
|
+
|
|
86
88
|
def transform(item) -> dict:
|
|
87
|
-
messages = item.get(messages_field, []) if hasattr(item,
|
|
89
|
+
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
88
90
|
result = {"messages": messages}
|
|
89
91
|
if add_generation_prompt:
|
|
90
92
|
result["add_generation_prompt"] = True
|
|
@@ -110,12 +112,14 @@ def from_openai_batch(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
110
112
|
for item in data:
|
|
111
113
|
if item.get("response", {}).get("status_code") == 200:
|
|
112
114
|
body = item["response"]["body"]
|
|
113
|
-
results.append(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
115
|
+
results.append(
|
|
116
|
+
{
|
|
117
|
+
"custom_id": item.get("custom_id"),
|
|
118
|
+
"content": body["choices"][0]["message"]["content"],
|
|
119
|
+
"model": body.get("model"),
|
|
120
|
+
"usage": body.get("usage"),
|
|
121
|
+
}
|
|
122
|
+
)
|
|
119
123
|
return results
|
|
120
124
|
|
|
121
125
|
|
|
@@ -138,11 +142,12 @@ def to_openai_batch(
|
|
|
138
142
|
Examples:
|
|
139
143
|
>>> batch_input = dt.to(to_openai_batch(model="gpt-4o"))
|
|
140
144
|
"""
|
|
145
|
+
|
|
141
146
|
def transform(item, idx=[0]) -> dict:
|
|
142
|
-
messages = item.get(messages_field, []) if hasattr(item,
|
|
147
|
+
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
143
148
|
|
|
144
149
|
if custom_id_field:
|
|
145
|
-
custom_id = item.get(custom_id_field) if hasattr(item,
|
|
150
|
+
custom_id = item.get(custom_id_field) if hasattr(item, "get") else item[custom_id_field]
|
|
146
151
|
else:
|
|
147
152
|
custom_id = f"request-{idx[0]}"
|
|
148
153
|
idx[0] += 1
|
|
@@ -154,7 +159,7 @@ def to_openai_batch(
|
|
|
154
159
|
"body": {
|
|
155
160
|
"model": model,
|
|
156
161
|
"messages": messages,
|
|
157
|
-
}
|
|
162
|
+
},
|
|
158
163
|
}
|
|
159
164
|
|
|
160
165
|
return transform
|
|
@@ -189,8 +194,9 @@ def to_llama_factory(
|
|
|
189
194
|
Returns:
|
|
190
195
|
转换函数
|
|
191
196
|
"""
|
|
197
|
+
|
|
192
198
|
def transform(item) -> dict:
|
|
193
|
-
get = lambda f: (item.get(f, "") if hasattr(item,
|
|
199
|
+
get = lambda f: (item.get(f, "") if hasattr(item, "get") else item.get(f, ""))
|
|
194
200
|
|
|
195
201
|
result = {
|
|
196
202
|
"instruction": get(instruction_field),
|
|
@@ -237,8 +243,13 @@ def to_axolotl(
|
|
|
237
243
|
Returns:
|
|
238
244
|
转换函数
|
|
239
245
|
"""
|
|
246
|
+
|
|
240
247
|
def transform(item) -> dict:
|
|
241
|
-
conversations =
|
|
248
|
+
conversations = (
|
|
249
|
+
item.get(conversations_field, [])
|
|
250
|
+
if hasattr(item, "get")
|
|
251
|
+
else item.get(conversations_field, [])
|
|
252
|
+
)
|
|
242
253
|
|
|
243
254
|
# 如果已经是正确格式,直接返回
|
|
244
255
|
if conversations and isinstance(conversations[0], dict):
|
|
@@ -246,11 +257,14 @@ def to_axolotl(
|
|
|
246
257
|
return {"conversations": conversations}
|
|
247
258
|
|
|
248
259
|
# 尝试从 messages 格式转换
|
|
249
|
-
messages = item.get("messages", []) if hasattr(item,
|
|
260
|
+
messages = item.get("messages", []) if hasattr(item, "get") else item.get("messages", [])
|
|
250
261
|
if messages:
|
|
251
262
|
role_map = {"user": "human", "assistant": "gpt", "system": "system"}
|
|
252
263
|
conversations = [
|
|
253
|
-
{
|
|
264
|
+
{
|
|
265
|
+
from_key: role_map.get(m.get("role", ""), m.get("role", "")),
|
|
266
|
+
value_key: m.get("content", ""),
|
|
267
|
+
}
|
|
254
268
|
for m in messages
|
|
255
269
|
]
|
|
256
270
|
|
|
@@ -541,10 +555,12 @@ def to_swift_messages(
|
|
|
541
555
|
|
|
542
556
|
for msg in messages:
|
|
543
557
|
# 标准化格式
|
|
544
|
-
result_messages.append(
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
558
|
+
result_messages.append(
|
|
559
|
+
{
|
|
560
|
+
"role": msg.get("role", "user"),
|
|
561
|
+
"content": msg.get("content", ""),
|
|
562
|
+
}
|
|
563
|
+
)
|
|
548
564
|
|
|
549
565
|
return {"messages": result_messages}
|
|
550
566
|
|
|
@@ -749,8 +765,8 @@ def messages_to_text(
|
|
|
749
765
|
fmt = templates[template]
|
|
750
766
|
|
|
751
767
|
def transform(item) -> dict:
|
|
752
|
-
result = item.to_dict() if hasattr(item,
|
|
753
|
-
messages = item.get(messages_field, []) if hasattr(item,
|
|
768
|
+
result = item.to_dict() if hasattr(item, "to_dict") else dict(item)
|
|
769
|
+
messages = item.get(messages_field, []) if hasattr(item, "get") else item[messages_field]
|
|
754
770
|
|
|
755
771
|
parts = []
|
|
756
772
|
for msg in messages:
|
dtflow/core.py
CHANGED
|
@@ -3,42 +3,32 @@ DataTransformer 核心模块
|
|
|
3
3
|
|
|
4
4
|
专注于数据格式转换,提供简洁的 API。
|
|
5
5
|
"""
|
|
6
|
-
|
|
6
|
+
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass
|
|
9
|
-
import
|
|
9
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
import orjson
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
import orjson
|
|
16
|
-
_HAS_ORJSON = True
|
|
17
|
-
except ImportError:
|
|
18
|
-
_HAS_ORJSON = False
|
|
13
|
+
from .lineage import LineageTracker
|
|
14
|
+
from .storage.io import load_data, save_data
|
|
19
15
|
|
|
20
16
|
|
|
21
17
|
def _fast_json_dumps(obj: Any) -> str:
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
orjson 比标准 json 快约 10 倍,特别适合大量数据的序列化场景。
|
|
26
|
-
"""
|
|
27
|
-
if _HAS_ORJSON:
|
|
28
|
-
# orjson.dumps 返回 bytes,需要 decode
|
|
29
|
-
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode('utf-8')
|
|
30
|
-
else:
|
|
31
|
-
return json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
|
18
|
+
"""快速 JSON 序列化(使用 orjson,比标准 json 快约 10 倍)"""
|
|
19
|
+
return orjson.dumps(obj, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
|
32
20
|
|
|
33
21
|
|
|
34
22
|
# ============ 错误处理 ============
|
|
35
23
|
|
|
24
|
+
|
|
36
25
|
@dataclass
|
|
37
26
|
class TransformError:
|
|
38
27
|
"""转换错误信息"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
28
|
+
|
|
29
|
+
index: int # 原始数据索引
|
|
30
|
+
item: Dict # 原始数据项
|
|
31
|
+
error: Exception # 异常对象
|
|
42
32
|
|
|
43
33
|
def __repr__(self) -> str:
|
|
44
34
|
return f"TransformError(index={self.index}, error={self.error!r})"
|
|
@@ -61,9 +51,11 @@ class TransformErrors(Exception):
|
|
|
61
51
|
def _build_message(self) -> str:
|
|
62
52
|
if len(self.errors) == 1:
|
|
63
53
|
return str(self.errors[0])
|
|
64
|
-
return
|
|
65
|
-
f"
|
|
66
|
-
|
|
54
|
+
return (
|
|
55
|
+
f"转换失败 {len(self.errors)} 条记录:\n"
|
|
56
|
+
+ "\n".join(f" [{e.index}] {e.error}" for e in self.errors[:5])
|
|
57
|
+
+ (f"\n ... 还有 {len(self.errors) - 5} 条错误" if len(self.errors) > 5 else "")
|
|
58
|
+
)
|
|
67
59
|
|
|
68
60
|
def __iter__(self):
|
|
69
61
|
return iter(self.errors)
|
|
@@ -102,8 +94,15 @@ class DataTransformer:
|
|
|
102
94
|
- fields/stats: 数据信息
|
|
103
95
|
"""
|
|
104
96
|
|
|
105
|
-
def __init__(
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
data: Optional[List[Dict[str, Any]]] = None,
|
|
100
|
+
_source_path: Optional[str] = None,
|
|
101
|
+
_lineage_tracker: Optional[LineageTracker] = None,
|
|
102
|
+
):
|
|
106
103
|
self._data = data if data is not None else []
|
|
104
|
+
self._source_path = _source_path
|
|
105
|
+
self._lineage_tracker = _lineage_tracker
|
|
107
106
|
|
|
108
107
|
@property
|
|
109
108
|
def data(self) -> List[Dict[str, Any]]:
|
|
@@ -122,23 +121,39 @@ class DataTransformer:
|
|
|
122
121
|
# ============ 加载/保存 ============
|
|
123
122
|
|
|
124
123
|
@classmethod
|
|
125
|
-
def load(cls, filepath: str) ->
|
|
124
|
+
def load(cls, filepath: str, track_lineage: bool = False) -> "DataTransformer":
|
|
126
125
|
"""
|
|
127
126
|
从文件加载数据。
|
|
128
127
|
|
|
129
128
|
支持格式: jsonl, json, csv, parquet(自动检测)
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
filepath: 文件路径
|
|
132
|
+
track_lineage: 是否追踪血缘(默认 False)
|
|
130
133
|
"""
|
|
131
134
|
data = load_data(filepath)
|
|
132
|
-
|
|
135
|
+
tracker = LineageTracker(filepath) if track_lineage else None
|
|
136
|
+
return cls(data, _source_path=filepath, _lineage_tracker=tracker)
|
|
133
137
|
|
|
134
|
-
def save(self, filepath: str) -> None:
|
|
138
|
+
def save(self, filepath: str, lineage: bool = False) -> None:
|
|
135
139
|
"""
|
|
136
140
|
保存数据到文件。
|
|
137
141
|
|
|
138
142
|
支持格式: jsonl, json, csv, parquet(根据扩展名)
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
filepath: 文件路径
|
|
146
|
+
lineage: 是否保存血缘元数据(默认 False)
|
|
139
147
|
"""
|
|
140
148
|
save_data(self._data, filepath)
|
|
141
149
|
|
|
150
|
+
# 保存血缘记录
|
|
151
|
+
if lineage and self._lineage_tracker:
|
|
152
|
+
lineage_path = self._lineage_tracker.save(filepath, len(self._data))
|
|
153
|
+
import sys
|
|
154
|
+
|
|
155
|
+
print(f"📜 血缘记录已保存: {lineage_path}", file=sys.stderr)
|
|
156
|
+
|
|
142
157
|
# ============ 核心转换 ============
|
|
143
158
|
|
|
144
159
|
def to(
|
|
@@ -215,7 +230,7 @@ class DataTransformer:
|
|
|
215
230
|
func: Callable[[Any], Any],
|
|
216
231
|
on_error: Literal["skip", "raise", "null"] = "skip",
|
|
217
232
|
raw: bool = False,
|
|
218
|
-
) ->
|
|
233
|
+
) -> "DataTransformer":
|
|
219
234
|
"""
|
|
220
235
|
转换数据并返回新的 DataTransformer(支持链式调用)。
|
|
221
236
|
|
|
@@ -230,7 +245,16 @@ class DataTransformer:
|
|
|
230
245
|
>>> # 原始模式(大数据集推荐)
|
|
231
246
|
>>> dt.transform(lambda x: {"q": x["q"]}, raw=True).save("output.jsonl")
|
|
232
247
|
"""
|
|
233
|
-
|
|
248
|
+
input_count = len(self._data)
|
|
249
|
+
result = self.to(func, on_error=on_error, raw=raw)
|
|
250
|
+
output_count = len(result)
|
|
251
|
+
|
|
252
|
+
# 传递血缘追踪器并记录操作
|
|
253
|
+
tracker = self._lineage_tracker
|
|
254
|
+
if tracker:
|
|
255
|
+
tracker.record("transform", {"func": func}, input_count, output_count)
|
|
256
|
+
|
|
257
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
234
258
|
|
|
235
259
|
# ============ 数据筛选 ============
|
|
236
260
|
|
|
@@ -239,7 +263,7 @@ class DataTransformer:
|
|
|
239
263
|
func: Callable[[Any], bool],
|
|
240
264
|
on_error: Literal["skip", "raise", "keep"] = "skip",
|
|
241
265
|
raw: bool = False,
|
|
242
|
-
) ->
|
|
266
|
+
) -> "DataTransformer":
|
|
243
267
|
"""
|
|
244
268
|
筛选数据。
|
|
245
269
|
|
|
@@ -281,9 +305,14 @@ class DataTransformer:
|
|
|
281
305
|
if errors:
|
|
282
306
|
_print_error_summary(errors, len(self._data))
|
|
283
307
|
|
|
284
|
-
|
|
308
|
+
# 传递血缘追踪器并记录操作
|
|
309
|
+
tracker = self._lineage_tracker
|
|
310
|
+
if tracker:
|
|
311
|
+
tracker.record("filter", {"func": func}, len(self._data), len(filtered))
|
|
285
312
|
|
|
286
|
-
|
|
313
|
+
return DataTransformer(filtered, _lineage_tracker=tracker)
|
|
314
|
+
|
|
315
|
+
def sample(self, n: int, seed: Optional[int] = None) -> "DataTransformer":
|
|
287
316
|
"""
|
|
288
317
|
随机采样 n 条数据。
|
|
289
318
|
|
|
@@ -292,24 +321,39 @@ class DataTransformer:
|
|
|
292
321
|
seed: 随机种子
|
|
293
322
|
"""
|
|
294
323
|
import random
|
|
324
|
+
|
|
295
325
|
if seed is not None:
|
|
296
326
|
random.seed(seed)
|
|
297
327
|
|
|
328
|
+
input_count = len(self._data)
|
|
298
329
|
data = self._data[:] if n >= len(self._data) else random.sample(self._data, n)
|
|
299
|
-
return DataTransformer(data)
|
|
300
330
|
|
|
301
|
-
|
|
331
|
+
tracker = self._lineage_tracker
|
|
332
|
+
if tracker:
|
|
333
|
+
tracker.record("sample", {"n": n, "seed": seed}, input_count, len(data))
|
|
334
|
+
|
|
335
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
336
|
+
|
|
337
|
+
def head(self, n: int = 10) -> "DataTransformer":
|
|
302
338
|
"""取前 n 条"""
|
|
303
|
-
|
|
339
|
+
data = self._data[:n]
|
|
340
|
+
tracker = self._lineage_tracker
|
|
341
|
+
if tracker:
|
|
342
|
+
tracker.record("head", {"n": n}, len(self._data), len(data))
|
|
343
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
304
344
|
|
|
305
|
-
def tail(self, n: int = 10) ->
|
|
345
|
+
def tail(self, n: int = 10) -> "DataTransformer":
|
|
306
346
|
"""取后 n 条"""
|
|
307
|
-
|
|
347
|
+
data = self._data[-n:]
|
|
348
|
+
tracker = self._lineage_tracker
|
|
349
|
+
if tracker:
|
|
350
|
+
tracker.record("tail", {"n": n}, len(self._data), len(data))
|
|
351
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
308
352
|
|
|
309
353
|
def dedupe(
|
|
310
354
|
self,
|
|
311
355
|
key: Union[None, str, List[str], Callable[[Any], Any]] = None,
|
|
312
|
-
) ->
|
|
356
|
+
) -> "DataTransformer":
|
|
313
357
|
"""
|
|
314
358
|
数据去重。
|
|
315
359
|
|
|
@@ -338,7 +382,11 @@ class DataTransformer:
|
|
|
338
382
|
seen.add(k)
|
|
339
383
|
result.append(item)
|
|
340
384
|
|
|
341
|
-
|
|
385
|
+
tracker = self._lineage_tracker
|
|
386
|
+
if tracker:
|
|
387
|
+
tracker.record("dedupe", {"key": key}, len(self._data), len(result))
|
|
388
|
+
|
|
389
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
342
390
|
|
|
343
391
|
def _get_dedupe_key(
|
|
344
392
|
self,
|
|
@@ -367,7 +415,7 @@ class DataTransformer:
|
|
|
367
415
|
threshold: float = 0.8,
|
|
368
416
|
num_perm: int = 128,
|
|
369
417
|
ngram: int = 3,
|
|
370
|
-
) ->
|
|
418
|
+
) -> "DataTransformer":
|
|
371
419
|
"""
|
|
372
420
|
基于 MinHash + LSH 的相似度去重。
|
|
373
421
|
|
|
@@ -388,9 +436,7 @@ class DataTransformer:
|
|
|
388
436
|
try:
|
|
389
437
|
from datasketch import MinHash, MinHashLSH
|
|
390
438
|
except ImportError:
|
|
391
|
-
raise ImportError(
|
|
392
|
-
"相似度去重需要 datasketch 库,请安装: pip install datasketch"
|
|
393
|
-
)
|
|
439
|
+
raise ImportError("相似度去重需要 datasketch 库,请安装: pip install datasketch")
|
|
394
440
|
|
|
395
441
|
if not self._data:
|
|
396
442
|
return DataTransformer([])
|
|
@@ -400,10 +446,11 @@ class DataTransformer:
|
|
|
400
446
|
# threshold=0.99 需要 num_perm>=512,threshold>=0.999 会需要极大的值(4096+)
|
|
401
447
|
if threshold >= 0.999:
|
|
402
448
|
import warnings
|
|
449
|
+
|
|
403
450
|
warnings.warn(
|
|
404
451
|
f"阈值 {threshold} 过高,已自动调整为 0.99。"
|
|
405
452
|
f"如需更高精度,建议使用 dedupe() 精确去重。",
|
|
406
|
-
UserWarning
|
|
453
|
+
UserWarning,
|
|
407
454
|
)
|
|
408
455
|
threshold = 0.99
|
|
409
456
|
|
|
@@ -442,7 +489,17 @@ class DataTransformer:
|
|
|
442
489
|
|
|
443
490
|
# 按原顺序保留数据
|
|
444
491
|
result = [self._data[i] for i in sorted(keep_indices)]
|
|
445
|
-
|
|
492
|
+
|
|
493
|
+
tracker = self._lineage_tracker
|
|
494
|
+
if tracker:
|
|
495
|
+
tracker.record(
|
|
496
|
+
"dedupe_similar",
|
|
497
|
+
{"key": key, "threshold": threshold, "num_perm": num_perm, "ngram": ngram},
|
|
498
|
+
len(self._data),
|
|
499
|
+
len(result),
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return DataTransformer(result, _lineage_tracker=tracker)
|
|
446
503
|
|
|
447
504
|
def _get_text_for_similarity(
|
|
448
505
|
self,
|
|
@@ -457,14 +514,14 @@ class DataTransformer:
|
|
|
457
514
|
else:
|
|
458
515
|
raise ValueError(f"不支持的 key 类型: {type(key)}")
|
|
459
516
|
|
|
460
|
-
def _create_minhash(self, text: str, num_perm: int, ngram: int) ->
|
|
517
|
+
def _create_minhash(self, text: str, num_perm: int, ngram: int) -> "MinHash":
|
|
461
518
|
"""创建文本的 MinHash 签名"""
|
|
462
519
|
from datasketch import MinHash
|
|
463
520
|
|
|
464
521
|
m = MinHash(num_perm=num_perm)
|
|
465
522
|
# 使用字符级 n-gram(对中英文都适用)
|
|
466
523
|
for i in range(len(text) - ngram + 1):
|
|
467
|
-
m.update(text[i:i + ngram].encode(
|
|
524
|
+
m.update(text[i : i + ngram].encode("utf-8"))
|
|
468
525
|
return m
|
|
469
526
|
|
|
470
527
|
# ============ 数据信息 ============
|
|
@@ -485,7 +542,7 @@ class DataTransformer:
|
|
|
485
542
|
|
|
486
543
|
return sorted(all_fields)
|
|
487
544
|
|
|
488
|
-
def _extract_fields(self, obj: Any, prefix: str =
|
|
545
|
+
def _extract_fields(self, obj: Any, prefix: str = "") -> List[str]:
|
|
489
546
|
"""递归提取字段名"""
|
|
490
547
|
fields = []
|
|
491
548
|
if isinstance(obj, dict):
|
|
@@ -516,25 +573,21 @@ class DataTransformer:
|
|
|
516
573
|
field_stats[key] = {
|
|
517
574
|
"count": len(values),
|
|
518
575
|
"missing": len(self._data) - len(values),
|
|
519
|
-
"type": type(values[0]).__name__ if values else "unknown"
|
|
576
|
+
"type": type(values[0]).__name__ if values else "unknown",
|
|
520
577
|
}
|
|
521
578
|
|
|
522
|
-
return {
|
|
523
|
-
"total": len(self._data),
|
|
524
|
-
"fields": sorted(all_keys),
|
|
525
|
-
"field_stats": field_stats
|
|
526
|
-
}
|
|
579
|
+
return {"total": len(self._data), "fields": sorted(all_keys), "field_stats": field_stats}
|
|
527
580
|
|
|
528
581
|
# ============ 工具方法 ============
|
|
529
582
|
|
|
530
|
-
def copy(self) ->
|
|
583
|
+
def copy(self) -> "DataTransformer":
|
|
531
584
|
"""深拷贝"""
|
|
532
585
|
return DataTransformer(deepcopy(self._data))
|
|
533
586
|
|
|
534
587
|
# ============ 数据合并 ============
|
|
535
588
|
|
|
536
589
|
@classmethod
|
|
537
|
-
def concat(cls, *sources: Union[str,
|
|
590
|
+
def concat(cls, *sources: Union[str, "DataTransformer"]) -> "DataTransformer":
|
|
538
591
|
"""
|
|
539
592
|
拼接多个数据源。
|
|
540
593
|
|
|
@@ -564,7 +617,7 @@ class DataTransformer:
|
|
|
564
617
|
|
|
565
618
|
return cls(all_data)
|
|
566
619
|
|
|
567
|
-
def __add__(self, other: Union[str,
|
|
620
|
+
def __add__(self, other: Union[str, "DataTransformer"]) -> "DataTransformer":
|
|
568
621
|
"""
|
|
569
622
|
使用 + 运算符拼接数据。
|
|
570
623
|
|
|
@@ -574,14 +627,20 @@ class DataTransformer:
|
|
|
574
627
|
"""
|
|
575
628
|
return DataTransformer.concat(self, other)
|
|
576
629
|
|
|
577
|
-
def shuffle(self, seed: Optional[int] = None) ->
|
|
630
|
+
def shuffle(self, seed: Optional[int] = None) -> "DataTransformer":
|
|
578
631
|
"""打乱顺序(返回新实例)"""
|
|
579
632
|
import random
|
|
633
|
+
|
|
580
634
|
data = self._data[:]
|
|
581
635
|
if seed is not None:
|
|
582
636
|
random.seed(seed)
|
|
583
637
|
random.shuffle(data)
|
|
584
|
-
|
|
638
|
+
|
|
639
|
+
tracker = self._lineage_tracker
|
|
640
|
+
if tracker:
|
|
641
|
+
tracker.record("shuffle", {"seed": seed}, len(self._data), len(data))
|
|
642
|
+
|
|
643
|
+
return DataTransformer(data, _lineage_tracker=tracker)
|
|
585
644
|
|
|
586
645
|
def split(self, ratio: float = 0.8, seed: Optional[int] = None) -> tuple:
|
|
587
646
|
"""
|
|
@@ -596,7 +655,16 @@ class DataTransformer:
|
|
|
596
655
|
"""
|
|
597
656
|
data = self.shuffle(seed).data
|
|
598
657
|
split_idx = int(len(data) * ratio)
|
|
599
|
-
|
|
658
|
+
|
|
659
|
+
# 分割后血缘追踪器各自独立
|
|
660
|
+
tracker = self._lineage_tracker
|
|
661
|
+
if tracker:
|
|
662
|
+
tracker.record("split", {"ratio": ratio, "seed": seed}, len(self._data), len(data))
|
|
663
|
+
|
|
664
|
+
return (
|
|
665
|
+
DataTransformer(data[:split_idx], _lineage_tracker=tracker),
|
|
666
|
+
DataTransformer(data[split_idx:], _lineage_tracker=tracker),
|
|
667
|
+
)
|
|
600
668
|
|
|
601
669
|
# ============ 并行处理 ============
|
|
602
670
|
|
|
@@ -641,7 +709,7 @@ class DataTransformer:
|
|
|
641
709
|
func: Callable[[Dict], bool],
|
|
642
710
|
workers: Optional[int] = None,
|
|
643
711
|
chunksize: int = 1000,
|
|
644
|
-
) ->
|
|
712
|
+
) -> "DataTransformer":
|
|
645
713
|
"""
|
|
646
714
|
并行执行过滤函数(使用多进程)。
|
|
647
715
|
|
|
@@ -700,18 +768,18 @@ class DictWrapper:
|
|
|
700
768
|
"""
|
|
701
769
|
|
|
702
770
|
def __init__(self, data: Dict[str, Any]):
|
|
703
|
-
object.__setattr__(self,
|
|
771
|
+
object.__setattr__(self, "_data", data)
|
|
704
772
|
# 构建规范化名称到原始名称的映射
|
|
705
773
|
alias_map = {}
|
|
706
774
|
for key in data.keys():
|
|
707
775
|
sanitized = _sanitize_key(key)
|
|
708
776
|
if sanitized != key:
|
|
709
777
|
alias_map[sanitized] = key
|
|
710
|
-
object.__setattr__(self,
|
|
778
|
+
object.__setattr__(self, "_alias_map", alias_map)
|
|
711
779
|
|
|
712
780
|
def __getattr__(self, name: str) -> Any:
|
|
713
|
-
data = object.__getattribute__(self,
|
|
714
|
-
alias_map = object.__getattribute__(self,
|
|
781
|
+
data = object.__getattribute__(self, "_data")
|
|
782
|
+
alias_map = object.__getattribute__(self, "_alias_map")
|
|
715
783
|
|
|
716
784
|
# 先尝试直接匹配
|
|
717
785
|
if name in data:
|
|
@@ -730,23 +798,23 @@ class DictWrapper:
|
|
|
730
798
|
raise AttributeError(f"字段不存在: {name}")
|
|
731
799
|
|
|
732
800
|
def __getitem__(self, key: str) -> Any:
|
|
733
|
-
data = object.__getattribute__(self,
|
|
801
|
+
data = object.__getattribute__(self, "_data")
|
|
734
802
|
value = data[key]
|
|
735
803
|
if isinstance(value, dict):
|
|
736
804
|
return DictWrapper(value)
|
|
737
805
|
return value
|
|
738
806
|
|
|
739
807
|
def __contains__(self, key: str) -> bool:
|
|
740
|
-
data = object.__getattribute__(self,
|
|
808
|
+
data = object.__getattribute__(self, "_data")
|
|
741
809
|
return key in data
|
|
742
810
|
|
|
743
811
|
def __repr__(self) -> str:
|
|
744
|
-
data = object.__getattribute__(self,
|
|
812
|
+
data = object.__getattribute__(self, "_data")
|
|
745
813
|
return repr(data)
|
|
746
814
|
|
|
747
815
|
def get(self, key: str, default: Any = None) -> Any:
|
|
748
816
|
"""安全获取字段值"""
|
|
749
|
-
data = object.__getattribute__(self,
|
|
817
|
+
data = object.__getattribute__(self, "_data")
|
|
750
818
|
value = data.get(key, default)
|
|
751
819
|
if isinstance(value, dict):
|
|
752
820
|
return DictWrapper(value)
|
|
@@ -754,4 +822,4 @@ class DictWrapper:
|
|
|
754
822
|
|
|
755
823
|
def to_dict(self) -> Dict[str, Any]:
|
|
756
824
|
"""返回原始字典"""
|
|
757
|
-
return object.__getattribute__(self,
|
|
825
|
+
return object.__getattribute__(self, "_data")
|