dtflow 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/storage/io.py CHANGED
@@ -1,15 +1,18 @@
1
1
  """
2
2
  Input/Output utilities for saving and loading data.
3
+
4
+ 使用 Polars 作为主要 I/O 引擎,性能比 Pandas 快 3-5 倍。
5
+ 使用 orjson 作为 JSON 解析引擎,性能比标准 json 快 10 倍。
3
6
  """
4
- from typing import List, Dict, Any, Optional
5
- import json
6
- import os
7
+
7
8
  from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
8
10
 
11
+ import orjson
12
+ import polars as pl
9
13
 
10
- def save_data(data: List[Dict[str, Any]],
11
- filepath: str,
12
- file_format: Optional[str] = None) -> None:
14
+
15
+ def save_data(data: List[Dict[str, Any]], filepath: str, file_format: Optional[str] = None) -> None:
13
16
  """
14
17
  Save data to file.
15
18
 
@@ -21,23 +24,22 @@ def save_data(data: List[Dict[str, Any]],
21
24
  filepath = Path(filepath)
22
25
  filepath.parent.mkdir(parents=True, exist_ok=True)
23
26
 
24
- # Auto-detect format from extension
25
27
  if file_format is None:
26
28
  file_format = _detect_format(filepath)
27
29
 
28
- if file_format == 'jsonl':
30
+ if file_format == "jsonl":
29
31
  _save_jsonl(data, filepath)
30
- elif file_format == 'json':
32
+ elif file_format == "json":
31
33
  _save_json(data, filepath)
32
- elif file_format == 'csv':
34
+ elif file_format == "csv":
33
35
  _save_csv(data, filepath)
34
- elif file_format == 'parquet':
36
+ elif file_format == "parquet":
35
37
  _save_parquet(data, filepath)
36
- elif file_format == 'arrow':
38
+ elif file_format == "arrow":
37
39
  _save_arrow(data, filepath)
38
- elif file_format == 'excel':
40
+ elif file_format == "excel":
39
41
  _save_excel(data, filepath)
40
- elif file_format == 'flaxkv':
42
+ elif file_format == "flaxkv":
41
43
  _save_flaxkv(data, filepath)
42
44
  else:
43
45
  raise ValueError(f"Unknown file format: {file_format}")
@@ -59,23 +61,22 @@ def load_data(filepath: str, file_format: Optional[str] = None) -> List[Dict[str
59
61
  if not filepath.exists():
60
62
  raise FileNotFoundError(f"File not found: {filepath}")
61
63
 
62
- # Auto-detect format from extension
63
64
  if file_format is None:
64
65
  file_format = _detect_format(filepath)
65
66
 
66
- if file_format == 'jsonl':
67
+ if file_format == "jsonl":
67
68
  return _load_jsonl(filepath)
68
- elif file_format == 'json':
69
+ elif file_format == "json":
69
70
  return _load_json(filepath)
70
- elif file_format == 'csv':
71
+ elif file_format == "csv":
71
72
  return _load_csv(filepath)
72
- elif file_format == 'parquet':
73
+ elif file_format == "parquet":
73
74
  return _load_parquet(filepath)
74
- elif file_format == 'arrow':
75
+ elif file_format == "arrow":
75
76
  return _load_arrow(filepath)
76
- elif file_format == 'excel':
77
+ elif file_format == "excel":
77
78
  return _load_excel(filepath)
78
- elif file_format == 'flaxkv':
79
+ elif file_format == "flaxkv":
79
80
  return _load_flaxkv(filepath)
80
81
  else:
81
82
  raise ValueError(f"Unknown file format: {file_format}")
@@ -84,200 +85,205 @@ def load_data(filepath: str, file_format: Optional[str] = None) -> List[Dict[str
84
85
  def _detect_format(filepath: Path) -> str:
85
86
  """Detect file format from extension."""
86
87
  ext = filepath.suffix.lower()
87
- if ext == '.jsonl':
88
- return 'jsonl'
89
- elif ext == '.json':
90
- return 'json'
91
- elif ext == '.csv':
92
- return 'csv'
93
- elif ext == '.parquet':
94
- return 'parquet'
95
- elif ext in ('.arrow', '.feather'):
96
- return 'arrow'
97
- elif ext in ('.xlsx', '.xls'):
98
- return 'excel'
99
- elif ext == '.flaxkv' or ext == '':
100
- # For FlaxKV, filepath is typically a directory
101
- return 'flaxkv'
88
+ if ext == ".jsonl":
89
+ return "jsonl"
90
+ elif ext == ".json":
91
+ return "json"
92
+ elif ext == ".csv":
93
+ return "csv"
94
+ elif ext == ".parquet":
95
+ return "parquet"
96
+ elif ext in (".arrow", ".feather"):
97
+ return "arrow"
98
+ elif ext in (".xlsx", ".xls"):
99
+ return "excel"
100
+ elif ext == ".flaxkv" or ext == "":
101
+ return "flaxkv"
102
102
  else:
103
- # Default to JSONL
104
- return 'jsonl'
103
+ return "jsonl"
105
104
 
106
105
 
107
106
  # ============ JSONL Format ============
107
+ # JSONL 保持用原生 Python,因为需要处理复杂嵌套结构
108
+
108
109
 
109
110
  def _save_jsonl(data: List[Dict[str, Any]], filepath: Path) -> None:
110
111
  """Save data in JSONL format."""
111
- with open(filepath, 'w', encoding='utf-8') as f:
112
+ with open(filepath, "wb") as f:
112
113
  for item in data:
113
- json_line = json.dumps(item, ensure_ascii=False)
114
- f.write(json_line + '\n')
114
+ f.write(orjson.dumps(item) + b"\n")
115
115
 
116
116
 
117
117
  def _load_jsonl(filepath: Path) -> List[Dict[str, Any]]:
118
118
  """Load data from JSONL format."""
119
119
  data = []
120
- with open(filepath, 'r', encoding='utf-8') as f:
120
+ with open(filepath, "rb") as f:
121
121
  for line in f:
122
122
  line = line.strip()
123
123
  if line:
124
- data.append(json.loads(line))
124
+ data.append(orjson.loads(line))
125
125
  return data
126
126
 
127
127
 
128
128
  # ============ JSON Format ============
129
129
 
130
+
130
131
  def _save_json(data: List[Dict[str, Any]], filepath: Path) -> None:
131
132
  """Save data in JSON format."""
132
- with open(filepath, 'w', encoding='utf-8') as f:
133
- json.dump(data, f, ensure_ascii=False, indent=2)
133
+ with open(filepath, "wb") as f:
134
+ f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
134
135
 
135
136
 
136
137
  def _load_json(filepath: Path) -> List[Dict[str, Any]]:
137
138
  """Load data from JSON format."""
138
- with open(filepath, 'r', encoding='utf-8') as f:
139
- data = json.load(f)
139
+ with open(filepath, "rb") as f:
140
+ data = orjson.loads(f.read())
140
141
 
141
- # Ensure data is a list
142
142
  if not isinstance(data, list):
143
143
  data = [data]
144
144
 
145
145
  return data
146
146
 
147
147
 
148
- # ============ CSV Format ============
148
+ # ============ CSV Format (Polars) ============
149
+
149
150
 
150
151
  def _save_csv(data: List[Dict[str, Any]], filepath: Path) -> None:
151
- """Save data in CSV format."""
152
- try:
153
- import pandas as pd
154
- except ImportError:
155
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
152
+ """Save data in CSV format using Polars."""
153
+ if not data:
154
+ # 空数据,创建空文件
155
+ filepath.touch()
156
+ return
156
157
 
157
- df = pd.DataFrame(data)
158
- df.to_csv(filepath, index=False, encoding='utf-8')
158
+ # 序列化复杂字段为 JSON 字符串
159
+ serialized = _serialize_complex_fields(data)
160
+ df = pl.DataFrame(serialized)
161
+ df.write_csv(filepath)
159
162
 
160
163
 
161
164
  def _load_csv(filepath: Path) -> List[Dict[str, Any]]:
162
- """Load data from CSV format."""
163
- try:
164
- import pandas as pd
165
- except ImportError:
166
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
165
+ """Load data from CSV format using Polars."""
166
+ df = pl.read_csv(filepath)
167
+ data = df.to_dicts()
168
+ # 反序列化 JSON 字符串
169
+ return _deserialize_complex_fields(data)
167
170
 
168
- df = pd.read_csv(filepath, encoding='utf-8')
169
- return df.to_dict('records')
170
171
 
172
+ # ============ Parquet Format (Polars) ============
171
173
 
172
- # ============ Excel Format ============
173
174
 
174
- def _save_excel(data: List[Dict[str, Any]], filepath: Path) -> None:
175
- """Save data in Excel format."""
176
- try:
177
- import pandas as pd
178
- except ImportError:
179
- raise ImportError("pandas and openpyxl are required for Excel support. Install with: pip install pandas openpyxl")
175
+ def _save_parquet(data: List[Dict[str, Any]], filepath: Path) -> None:
176
+ """Save data in Parquet format using Polars."""
177
+ if not data:
178
+ # 空数据,创建空 parquet
179
+ pl.DataFrame().write_parquet(filepath)
180
+ return
180
181
 
181
- df = pd.DataFrame(data)
182
- df.to_excel(filepath, index=False)
182
+ serialized = _serialize_complex_fields(data)
183
+ df = pl.DataFrame(serialized)
184
+ df.write_parquet(filepath)
183
185
 
184
186
 
185
- def _load_excel(filepath: Path) -> List[Dict[str, Any]]:
186
- """Load data from Excel format."""
187
- try:
188
- import pandas as pd
189
- except ImportError:
190
- raise ImportError("pandas and openpyxl are required for Excel support. Install with: pip install pandas openpyxl")
187
+ def _load_parquet(filepath: Path) -> List[Dict[str, Any]]:
188
+ """Load data from Parquet format using Polars."""
189
+ df = pl.read_parquet(filepath)
190
+ data = df.to_dicts()
191
+ return _deserialize_complex_fields(data)
191
192
 
192
- df = pd.read_excel(filepath)
193
- return df.to_dict('records')
194
193
 
194
+ # ============ Arrow Format (Polars) ============
195
195
 
196
- # ============ Parquet Format ============
197
196
 
198
- def _save_parquet(data: List[Dict[str, Any]], filepath: Path) -> None:
199
- """Save data in Parquet format."""
200
- try:
201
- import pandas as pd
202
- except ImportError:
203
- raise ImportError("pandas is required for Parquet support. Install with: pip install pandas pyarrow")
197
+ def _save_arrow(data: List[Dict[str, Any]], filepath: Path) -> None:
198
+ """Save data in Arrow IPC format using Polars."""
199
+ if not data:
200
+ pl.DataFrame().write_ipc(filepath)
201
+ return
204
202
 
205
- df = pd.DataFrame(data)
206
- df.to_parquet(filepath, index=False, engine='pyarrow')
203
+ serialized = _serialize_complex_fields(data)
204
+ df = pl.DataFrame(serialized)
205
+ df.write_ipc(filepath)
207
206
 
208
207
 
209
- def _load_parquet(filepath: Path) -> List[Dict[str, Any]]:
210
- """Load data from Parquet format."""
211
- try:
212
- import pandas as pd
213
- except ImportError:
214
- raise ImportError("pandas is required for Parquet support. Install with: pip install pandas pyarrow")
208
+ def _load_arrow(filepath: Path) -> List[Dict[str, Any]]:
209
+ """Load data from Arrow IPC format using Polars."""
210
+ df = pl.read_ipc(filepath)
211
+ data = df.to_dicts()
212
+ return _deserialize_complex_fields(data)
215
213
 
216
- df = pd.read_parquet(filepath, engine='pyarrow')
217
- return df.to_dict('records')
218
214
 
215
+ # ============ Excel Format ============
216
+ # Excel 需要额外依赖,保持可选
219
217
 
220
- # ============ Arrow Format ============
221
218
 
222
- def _save_arrow(data: List[Dict[str, Any]], filepath: Path) -> None:
223
- """Save data in Arrow IPC format (also known as Feather v2).
219
+ def _save_excel(data: List[Dict[str, Any]], filepath: Path) -> None:
220
+ """Save data in Excel format."""
221
+ if not data:
222
+ # 空数据
223
+ try:
224
+ import xlsxwriter
224
225
 
225
- Note: Complex nested structures (like list of dicts) are serialized as JSON strings.
226
- """
227
- try:
228
- import pyarrow as pa
229
- import pyarrow.feather as feather
230
- except ImportError:
231
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
226
+ workbook = xlsxwriter.Workbook(str(filepath))
227
+ workbook.close()
228
+ except ImportError:
229
+ raise ImportError(
230
+ "xlsxwriter is required for Excel write. Install with: pip install xlsxwriter"
231
+ )
232
+ return
233
+
234
+ serialized = _serialize_complex_fields(data)
235
+ df = pl.DataFrame(serialized)
236
+ df.write_excel(filepath)
237
+
238
+
239
+ def _load_excel(filepath: Path) -> List[Dict[str, Any]]:
240
+ """Load data from Excel format."""
241
+ df = pl.read_excel(filepath)
242
+ data = df.to_dicts()
243
+ return _deserialize_complex_fields(data)
244
+
245
+
246
+ # ============ 复杂字段序列化 ============
232
247
 
233
- # Serialize complex fields to JSON strings for Arrow compatibility
234
- serialized_data = []
248
+
249
+ def _serialize_complex_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
250
+ """将复杂字段(list, dict)序列化为 JSON 字符串"""
251
+ result = []
235
252
  for item in data:
236
253
  new_item = {}
237
254
  for k, v in item.items():
238
255
  if isinstance(v, (list, dict)):
239
- new_item[k] = json.dumps(v, ensure_ascii=False)
256
+ new_item[k] = orjson.dumps(v).decode("utf-8")
240
257
  else:
241
258
  new_item[k] = v
242
- serialized_data.append(new_item)
243
-
244
- table = pa.Table.from_pylist(serialized_data)
245
-
246
- # Use Feather format (simpler and more portable)
247
- feather.write_feather(table, filepath)
248
-
249
-
250
- def _load_arrow(filepath: Path) -> List[Dict[str, Any]]:
251
- """Load data from Arrow IPC format (also known as Feather v2).
252
-
253
- Note: JSON-serialized fields are automatically deserialized.
254
- """
255
- try:
256
- import pyarrow.feather as feather
257
- except ImportError:
258
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
259
+ result.append(new_item)
260
+ return result
259
261
 
260
- table = feather.read_table(filepath)
261
- data = table.to_pylist()
262
262
 
263
- # Deserialize JSON strings back to complex objects
263
+ def _deserialize_complex_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
264
+ """将 JSON 字符串反序列化为复杂字段"""
264
265
  result = []
265
266
  for item in data:
266
267
  new_item = {}
267
268
  for k, v in item.items():
268
- if isinstance(v, str) and v.startswith(('[', '{')):
269
+ if isinstance(v, str) and v.startswith(("[", "{")):
269
270
  try:
270
- new_item[k] = json.loads(v)
271
- except json.JSONDecodeError:
271
+ new_item[k] = orjson.loads(v)
272
+ except orjson.JSONDecodeError:
272
273
  new_item[k] = v
273
274
  else:
274
275
  new_item[k] = v
275
276
  result.append(new_item)
276
-
277
277
  return result
278
278
 
279
279
 
280
- # ============ Additional Utilities ============
280
+ def _clean_null_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
281
+ """清理 Polars 添加的 null 字段,保持原始数据结构"""
282
+ return [{k: v for k, v in item.items() if v is not None} for item in data]
283
+
284
+
285
+ # ============ Streaming Utilities ============
286
+
281
287
 
282
288
  def sample_data(
283
289
  data: List[Dict[str, Any]],
@@ -290,28 +296,12 @@ def sample_data(
290
296
 
291
297
  Args:
292
298
  data: List of data items
293
- num: Number of items to sample.
294
- - num > 0: sample specified number of items
295
- - num = 0: sample all data
296
- - num < 0: Python slice style (e.g., -1 means last 1, -10 means last 10)
299
+ num: Number of items to sample
297
300
  sample_type: Sampling method - "random", "head", or "tail"
298
- seed: Random seed for reproducibility (only for random sampling)
301
+ seed: Random seed for reproducibility
299
302
 
300
303
  Returns:
301
304
  Sampled data list
302
-
303
- Examples:
304
- >>> data = [{"id": i} for i in range(100)]
305
- >>> sample_data(data, num=5, sample_type="head")
306
- [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}]
307
- >>> sample_data(data, num=3, sample_type="tail")
308
- [{'id': 97}, {'id': 98}, {'id': 99}]
309
- >>> len(sample_data(data, num=0)) # 0 means all
310
- 100
311
- >>> sample_data(data, num=-1, sample_type="head") # last 1 item
312
- [{'id': 99}]
313
- >>> sample_data(data, num=-3, sample_type="tail") # last 3 items
314
- [{'id': 97}, {'id': 98}, {'id': 99}]
315
305
  """
316
306
  import random as rand_module
317
307
 
@@ -320,15 +310,11 @@ def sample_data(
320
310
 
321
311
  total = len(data)
322
312
 
323
- # Determine actual number to sample
324
313
  if num == 0:
325
- # 0 means sample all data
326
314
  actual_num = total
327
315
  elif num < 0:
328
- # Negative number: Python slice style (e.g., -1 means 1 item, -10 means 10 items)
329
316
  actual_num = min(abs(num), total)
330
317
  else:
331
- # Positive number: normal sampling
332
318
  actual_num = min(num, total)
333
319
 
334
320
  if sample_type == "head":
@@ -349,32 +335,23 @@ def sample_file(
349
335
  output: Optional[str] = None,
350
336
  ) -> List[Dict[str, Any]]:
351
337
  """
352
- Sample data from a file with streaming support for large files.
353
-
354
- 对于 head/tail 采样,支持流式读取,不需要加载整个文件到内存。
355
- 对于 random 采样,JSONL 使用蓄水池采样算法,其他格式需要加载全部数据。
338
+ Sample data from a file with streaming support.
356
339
 
357
340
  Args:
358
- filepath: Input file path (supports csv, xlsx, jsonl, json, parquet, arrow, feather)
341
+ filepath: Input file path
359
342
  num: Number of items to sample
360
- sample_type: Sampling method - "random", "head", or "tail"
361
- seed: Random seed for reproducibility (only for random sampling)
362
- output: Output file path (optional, if provided, saves sampled data)
343
+ sample_type: Sampling method
344
+ seed: Random seed
345
+ output: Output file path
363
346
 
364
347
  Returns:
365
348
  Sampled data list
366
-
367
- Examples:
368
- >>> sampled = sample_file("data.jsonl", num=100, sample_type="random")
369
- >>> sample_file("data.csv", num=50, output="sampled.jsonl")
370
349
  """
371
350
  filepath = Path(filepath)
372
351
  file_format = _detect_format(filepath)
373
352
 
374
- # 尝试使用流式采样
375
353
  sampled = _stream_sample(filepath, file_format, num, sample_type, seed)
376
354
 
377
- # Save if output specified
378
355
  if output:
379
356
  save_data(sampled, output)
380
357
 
@@ -388,18 +365,7 @@ def _stream_sample(
388
365
  sample_type: str,
389
366
  seed: Optional[int],
390
367
  ) -> List[Dict[str, Any]]:
391
- """
392
- 流式采样实现。
393
-
394
- 支持的流式优化:
395
- - head: jsonl, csv, parquet, arrow, excel
396
- - tail: jsonl(反向读取)
397
- - random: jsonl(蓄水池采样)
398
-
399
- num == 0 表示采样所有数据,回退到全量加载。
400
- num < 0 表示 Python 切片风格,回退到全量加载。
401
- """
402
- # num == 0 表示采样所有数据,num < 0 表示切片风格,都需要全量加载
368
+ """流式采样实现"""
403
369
  if num <= 0:
404
370
  data = load_data(str(filepath))
405
371
  return sample_data(data, num=num, sample_type=sample_type, seed=seed)
@@ -417,13 +383,27 @@ def _stream_sample(
417
383
  elif file_format == "excel":
418
384
  return _stream_head_excel(filepath, num)
419
385
 
420
- # tail 采样优化(仅 JSONL)
421
- if sample_type == "tail" and file_format == "jsonl":
422
- return _stream_tail_jsonl(filepath, num)
386
+ # tail 采样优化
387
+ if sample_type == "tail":
388
+ if file_format == "jsonl":
389
+ return _stream_tail_jsonl(filepath, num)
390
+ elif file_format == "csv":
391
+ return _stream_tail_csv(filepath, num)
392
+ elif file_format == "parquet":
393
+ return _stream_tail_parquet(filepath, num)
394
+ elif file_format == "arrow":
395
+ return _stream_tail_arrow(filepath, num)
423
396
 
424
- # random 采样优化(仅 JSONL,使用蓄水池采样)
425
- if sample_type == "random" and file_format == "jsonl":
426
- return _stream_random_jsonl(filepath, num, seed)
397
+ # random 采样优化
398
+ if sample_type == "random":
399
+ if file_format == "jsonl":
400
+ return _stream_random_jsonl(filepath, num, seed)
401
+ elif file_format == "csv":
402
+ return _stream_random_csv(filepath, num, seed)
403
+ elif file_format == "parquet":
404
+ return _stream_random_parquet(filepath, num, seed)
405
+ elif file_format == "arrow":
406
+ return _stream_random_arrow(filepath, num, seed)
427
407
 
428
408
  # 其他情况回退到全量加载
429
409
  data = load_data(str(filepath))
@@ -431,262 +411,297 @@ def _stream_sample(
431
411
 
432
412
 
433
413
  def _stream_head_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
434
- """JSONL 流式读取前 N """
435
- result = []
436
- with open(filepath, "r", encoding="utf-8") as f:
437
- for line in f:
438
- line = line.strip()
439
- if line:
440
- result.append(json.loads(line))
441
- if len(result) >= num:
442
- break
443
- return result
414
+ """JSONL 流式读取前 N 行(使用 Polars ndjson)"""
415
+ try:
416
+ df = pl.scan_ndjson(filepath).head(num).collect()
417
+ return _clean_null_fields(df.to_dicts())
418
+ except Exception as e:
419
+ # 回退到 Python 实现
420
+ import sys
421
+
422
+ print(
423
+ f"[Warning] Polars ndjson 解析失败,回退到 Python 实现: {type(e).__name__}",
424
+ file=sys.stderr,
425
+ )
426
+
427
+ result = []
428
+ with open(filepath, "rb") as f:
429
+ for line in f:
430
+ line = line.strip()
431
+ if line:
432
+ try:
433
+ result.append(orjson.loads(line))
434
+ except orjson.JSONDecodeError:
435
+ continue # 跳过无效行
436
+ if len(result) >= num:
437
+ break
438
+ return result
444
439
 
445
440
 
446
441
  def _stream_head_csv(filepath: Path, num: int) -> List[Dict[str, Any]]:
447
- """CSV 流式读取前 N """
448
- try:
449
- import pandas as pd
450
- except ImportError:
451
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
452
-
453
- df = pd.read_csv(filepath, encoding="utf-8", nrows=num)
454
- return df.to_dict("records")
442
+ """CSV 流式读取前 N 行(使用 Polars LazyFrame)"""
443
+ df = pl.scan_csv(filepath).head(num).collect()
444
+ return _deserialize_complex_fields(df.to_dicts())
455
445
 
456
446
 
457
447
  def _stream_head_parquet(filepath: Path, num: int) -> List[Dict[str, Any]]:
458
- """Parquet 真流式读取前 N 行(使用 iter_batches 避免全量加载)"""
459
- try:
460
- import pyarrow.parquet as pq
461
- except ImportError:
462
- raise ImportError("pyarrow is required for Parquet support. Install with: pip install pyarrow")
448
+ """Parquet 流式读取前 N 行(使用 Polars LazyFrame)"""
449
+ df = pl.scan_parquet(filepath).head(num).collect()
450
+ return _deserialize_complex_fields(df.to_dicts())
463
451
 
464
- parquet_file = pq.ParquetFile(filepath)
465
- result = []
466
452
 
467
- # 使用 iter_batches 真正流式读取,只读取需要的数据
468
- for batch in parquet_file.iter_batches(batch_size=min(num, 10000)):
469
- batch_data = batch.to_pylist()
470
- result.extend(batch_data)
471
- if len(result) >= num:
472
- break
453
+ def _stream_head_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
454
+ """Arrow 流式读取前 N 行(使用 Polars LazyFrame)"""
455
+ df = pl.scan_ipc(filepath).head(num).collect()
456
+ return _deserialize_complex_fields(df.to_dicts())
473
457
 
474
- return result[:num]
475
458
 
459
+ def _stream_head_excel(filepath: Path, num: int) -> List[Dict[str, Any]]:
460
+ """Excel 读取前 N 行"""
461
+ # Excel 不支持 lazy scan,使用普通读取
462
+ df = pl.read_excel(filepath).head(num)
463
+ return _deserialize_complex_fields(df.to_dicts())
476
464
 
477
- def _stream_head_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
478
- """Arrow/Feather 流式读取前 N 行"""
479
- try:
480
- import pyarrow.feather as feather
481
- except ImportError:
482
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
483
465
 
484
- table = feather.read_table(filepath)
485
- sliced = table.slice(0, min(num, table.num_rows))
486
- return _deserialize_arrow_data(sliced.to_pylist())
466
+ def _stream_tail_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
467
+ """JSONL 流式读取后 N 行(使用 Polars ndjson)"""
468
+ try:
469
+ df = pl.scan_ndjson(filepath).tail(num).collect()
470
+ return _clean_null_fields(df.to_dicts())
471
+ except Exception as e:
472
+ # 回退到 Python 两遍遍历实现
473
+ import sys
487
474
 
475
+ print(
476
+ f"[Warning] Polars ndjson 解析失败,回退到 Python 实现: {type(e).__name__}",
477
+ file=sys.stderr,
478
+ )
488
479
 
489
- def _deserialize_arrow_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
490
- """反序列化 Arrow 数据中的 JSON 字符串字段"""
491
- result = []
492
- for item in data:
493
- new_item = {}
494
- for k, v in item.items():
495
- if isinstance(v, str) and v.startswith(("[", "{")):
496
- try:
497
- new_item[k] = json.loads(v)
498
- except json.JSONDecodeError:
499
- new_item[k] = v
500
- else:
501
- new_item[k] = v
502
- result.append(new_item)
503
- return result
480
+ total_lines = 0
481
+ with open(filepath, "rb") as f:
482
+ for _ in f:
483
+ total_lines += 1
504
484
 
485
+ if total_lines <= num:
486
+ return _load_jsonl(filepath)
505
487
 
506
- def _stream_head_excel(filepath: Path, num: int) -> List[Dict[str, Any]]:
507
- """Excel 流式读取前 N 行"""
508
- try:
509
- import pandas as pd
510
- except ImportError:
511
- raise ImportError("pandas and openpyxl are required for Excel support")
488
+ skip_count = total_lines - num
489
+ result = []
490
+ with open(filepath, "rb") as f:
491
+ for i, line in enumerate(f):
492
+ if i < skip_count:
493
+ continue
494
+ line = line.strip()
495
+ if line:
496
+ try:
497
+ result.append(orjson.loads(line))
498
+ except orjson.JSONDecodeError:
499
+ continue # 跳过无效行
500
+ return result
512
501
 
513
- df = pd.read_excel(filepath, nrows=num)
514
- return df.to_dict("records")
515
502
 
503
+ def _stream_tail_csv(filepath: Path, num: int) -> List[Dict[str, Any]]:
504
+ """CSV 流式读取后 N 行(使用 Polars LazyFrame)"""
505
+ df = pl.scan_csv(filepath).tail(num).collect()
506
+ return _deserialize_complex_fields(df.to_dicts())
516
507
 
517
- def append_to_file(data: List[Dict[str, Any]],
518
- filepath: str,
519
- file_format: str = 'jsonl') -> None:
520
- """
521
- Append data to an existing file.
522
508
 
523
- Args:
524
- data: List of data items to append
525
- filepath: Path to file
526
- file_format: File format (only 'jsonl' supported for append)
527
- """
528
- filepath = Path(filepath)
509
+ def _stream_tail_parquet(filepath: Path, num: int) -> List[Dict[str, Any]]:
510
+ """Parquet 流式读取后 N 行(使用 Polars LazyFrame)"""
511
+ df = pl.scan_parquet(filepath).tail(num).collect()
512
+ return _deserialize_complex_fields(df.to_dicts())
529
513
 
530
- if file_format != 'jsonl':
531
- raise ValueError("Only JSONL format supports appending")
532
514
 
533
- filepath.parent.mkdir(parents=True, exist_ok=True)
515
+ def _stream_tail_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
516
+ """Arrow 流式读取后 N 行(使用 Polars LazyFrame)"""
517
+ df = pl.scan_ipc(filepath).tail(num).collect()
518
+ return _deserialize_complex_fields(df.to_dicts())
534
519
 
535
- with open(filepath, 'a', encoding='utf-8') as f:
536
- for item in data:
537
- json_line = json.dumps(item, ensure_ascii=False)
538
- f.write(json_line + '\n')
539
520
 
521
+ # 文件大小阈值:超过此值使用 Python 流式采样,否则使用 Polars
522
+ _STREAM_THRESHOLD_BYTES = 100 * 1024 * 1024 # 100MB
540
523
 
541
- def count_lines(filepath: str) -> int:
542
- """
543
- Count number of lines in a JSONL file without loading all data.
544
524
 
545
- Args:
546
- filepath: Path to JSONL file
525
+ def _count_sample_jsonl(
526
+ filepath: Path, num: int, seed: Optional[int] = None
527
+ ) -> List[Dict[str, Any]]:
528
+ """JSONL 流式采样(Polars 计数 + Python 选择性读取)
547
529
 
548
- Returns:
549
- Number of lines
530
+ 策略:
531
+ 1. 使用 Polars 快速获取行数(比 Python 快 4 倍)
532
+ 2. 生成随机索引
533
+ 3. Python 遍历文件,只解析选中的行
550
534
  """
551
- count = 0
552
- with open(filepath, 'r', encoding='utf-8') as f:
553
- for _ in f:
554
- count += 1
555
- return count
535
+ import random
556
536
 
537
+ # Step 1: Polars 快速获取行数
538
+ try:
539
+ total_lines = pl.scan_ndjson(filepath).select(pl.len()).collect().item()
540
+ except Exception:
541
+ # 回退到 Python 计数
542
+ with open(filepath, "rb") as f:
543
+ total_lines = sum(1 for _ in f)
557
544
 
558
- def stream_jsonl(filepath: str, chunk_size: int = 1000):
559
- """
560
- Stream JSONL file in chunks.
545
+ if total_lines == 0:
546
+ return []
561
547
 
562
- Args:
563
- filepath: Path to JSONL file
564
- chunk_size: Number of items per chunk
548
+ # 采样数超过总行数,读取全部
549
+ if num >= total_lines:
550
+ return _load_jsonl(filepath)
565
551
 
566
- Yields:
567
- Chunks of data items
568
- """
569
- chunk = []
570
- with open(filepath, 'r', encoding='utf-8') as f:
571
- for line in f:
572
- line = line.strip()
573
- if line:
574
- chunk.append(json.loads(line))
575
- if len(chunk) >= chunk_size:
576
- yield chunk
577
- chunk = []
552
+ # Step 2: 生成随机索引
553
+ if seed is not None:
554
+ random.seed(seed)
555
+ selected_indices = set(random.sample(range(total_lines), num))
578
556
 
579
- if chunk:
580
- yield chunk
557
+ # Step 3: 只解析选中的行
558
+ result = []
559
+ with open(filepath, "rb") as f:
560
+ for i, line in enumerate(f):
561
+ if i in selected_indices:
562
+ line = line.strip()
563
+ if line:
564
+ try:
565
+ result.append(orjson.loads(line))
566
+ except orjson.JSONDecodeError:
567
+ continue
568
+ if len(result) >= num:
569
+ break
581
570
 
571
+ return result
582
572
 
583
- # ============ JSONL 流式采样优化 ============
584
573
 
574
+ def _stream_random_jsonl(
575
+ filepath: Path, num: int, seed: Optional[int] = None
576
+ ) -> List[Dict[str, Any]]:
577
+ """JSONL 随机采样
585
578
 
586
- def _stream_tail_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
579
+ 策略:
580
+ - 小文件 (<100MB): 使用 Polars collect+sample
581
+ - 大文件 (>=100MB): 使用 count+sample 流式采样(更快且内存友好)
587
582
  """
588
- JSONL 反向读取后 N 行(避免全量加载)。
583
+ file_size = filepath.stat().st_size
589
584
 
590
- 使用双端队列保持最后 N 行,内存占用 O(num) 而非 O(total)。
591
- """
592
- from collections import deque
585
+ # 大文件使用流式采样(更快)
586
+ if file_size >= _STREAM_THRESHOLD_BYTES:
587
+ return _count_sample_jsonl(filepath, num, seed)
593
588
 
594
- # 使用 deque 的 maxlen 自动保持最后 N 个元素
595
- buffer = deque(maxlen=num)
589
+ # 小文件尝试 Polars
590
+ try:
591
+ df = pl.scan_ndjson(filepath).collect()
592
+ if len(df) <= num:
593
+ return _clean_null_fields(df.to_dicts())
594
+ sampled = df.sample(n=num, seed=seed)
595
+ return _clean_null_fields(sampled.to_dicts())
596
+ except Exception as e:
597
+ import sys
598
+
599
+ print(
600
+ f"[Warning] Polars ndjson 解析失败,回退到流式采样: {type(e).__name__}", file=sys.stderr
601
+ )
602
+ return _count_sample_jsonl(filepath, num, seed)
596
603
 
597
- with open(filepath, "r", encoding="utf-8") as f:
598
- for line in f:
599
- line = line.strip()
600
- if line:
601
- buffer.append(json.loads(line))
602
604
 
603
- return list(buffer)
605
+ def _stream_random_csv(
606
+ filepath: Path, num: int, seed: Optional[int] = None
607
+ ) -> List[Dict[str, Any]]:
608
+ """CSV 随机采样(使用 Polars)"""
609
+ df = pl.scan_csv(filepath).collect()
610
+ if len(df) <= num:
611
+ return _deserialize_complex_fields(df.to_dicts())
612
+ sampled = df.sample(n=num, seed=seed)
613
+ return _deserialize_complex_fields(sampled.to_dicts())
604
614
 
605
615
 
606
- def _stream_random_jsonl(
616
+ def _stream_random_parquet(
607
617
  filepath: Path, num: int, seed: Optional[int] = None
608
618
  ) -> List[Dict[str, Any]]:
609
- """
610
- JSONL 蓄水池采样(Reservoir Sampling)。
619
+ """Parquet 随机采样(使用 Polars)"""
620
+ df = pl.scan_parquet(filepath).collect()
621
+ if len(df) <= num:
622
+ return _deserialize_complex_fields(df.to_dicts())
623
+ sampled = df.sample(n=num, seed=seed)
624
+ return _deserialize_complex_fields(sampled.to_dicts())
611
625
 
612
- 单次遍历文件,内存占用 O(num),适合超大文件随机采样。
613
- 算法保证每条数据被选中的概率相等。
614
- """
615
- import random
616
626
 
617
- if seed is not None:
618
- random.seed(seed)
627
+ def _stream_random_arrow(
628
+ filepath: Path, num: int, seed: Optional[int] = None
629
+ ) -> List[Dict[str, Any]]:
630
+ """Arrow 随机采样(使用 Polars)"""
631
+ df = pl.scan_ipc(filepath).collect()
632
+ if len(df) <= num:
633
+ return _deserialize_complex_fields(df.to_dicts())
634
+ sampled = df.sample(n=num, seed=seed)
635
+ return _deserialize_complex_fields(sampled.to_dicts())
619
636
 
620
- reservoir = [] # 蓄水池
621
637
 
638
+ # ============ Additional Utilities ============
639
+
640
+
641
+ def append_to_file(data: List[Dict[str, Any]], filepath: str, file_format: str = "jsonl") -> None:
642
+ """Append data to an existing file (only JSONL supported)."""
643
+ filepath = Path(filepath)
644
+
645
+ if file_format != "jsonl":
646
+ raise ValueError("Only JSONL format supports appending")
647
+
648
+ filepath.parent.mkdir(parents=True, exist_ok=True)
649
+
650
+ with open(filepath, "ab") as f:
651
+ for item in data:
652
+ f.write(orjson.dumps(item) + b"\n")
653
+
654
+
655
+ def count_lines(filepath: str) -> int:
656
+ """Count number of lines in a JSONL file."""
657
+ count = 0
622
658
  with open(filepath, "r", encoding="utf-8") as f:
623
- for i, line in enumerate(f):
624
- line = line.strip()
625
- if not line:
626
- continue
659
+ for _ in f:
660
+ count += 1
661
+ return count
627
662
 
628
- item = json.loads(line)
629
663
 
630
- if len(reservoir) < num:
631
- # 蓄水池未满,直接加入
632
- reservoir.append(item)
633
- else:
634
- # 蓄水池已满,以 num/(i+1) 的概率替换
635
- j = random.randint(0, i)
636
- if j < num:
637
- reservoir[j] = item
664
+ def stream_jsonl(filepath: str, chunk_size: int = 1000):
665
+ """Stream JSONL file in chunks."""
666
+ chunk = []
667
+ with open(filepath, "rb") as f:
668
+ for line in f:
669
+ line = line.strip()
670
+ if line:
671
+ chunk.append(orjson.loads(line))
672
+ if len(chunk) >= chunk_size:
673
+ yield chunk
674
+ chunk = []
638
675
 
639
- return reservoir
676
+ if chunk:
677
+ yield chunk
640
678
 
641
679
 
642
680
  # ============ FlaxKV Format ============
643
681
 
644
- def _save_flaxkv(data: List[Dict[str, Any]], filepath: Path) -> None:
645
- """
646
- Save data in FlaxKV format.
647
682
 
648
- Args:
649
- data: List of data items to save
650
- filepath: Path to FlaxKV database (directory)
651
- """
683
+ def _save_flaxkv(data: List[Dict[str, Any]], filepath: Path) -> None:
684
+ """Save data in FlaxKV format."""
652
685
  from flaxkv2 import FlaxKV
653
686
 
654
- # Use the directory name as the database name
655
687
  db_name = filepath.stem if filepath.stem else "data"
656
688
  db_path = filepath.parent
657
689
 
658
- # Create FlaxKV database
659
690
  with FlaxKV(db_name, str(db_path)) as db:
660
- # Store metadata
661
- db["_metadata"] = {
662
- "total": len(data),
663
- "format": "flaxkv"
664
- }
691
+ db["_metadata"] = {"total": len(data), "format": "flaxkv"}
665
692
 
666
- # Store each item with index as key
667
693
  for i, item in enumerate(data):
668
694
  db[f"item:{i}"] = item
669
695
 
670
696
 
671
697
  def _load_flaxkv(filepath: Path) -> List[Dict[str, Any]]:
672
- """
673
- Load data from FlaxKV format.
674
-
675
- Args:
676
- filepath: Path to FlaxKV database (directory)
677
-
678
- Returns:
679
- List of data items
680
- """
698
+ """Load data from FlaxKV format."""
681
699
  from flaxkv2 import FlaxKV
682
700
 
683
- # Use the directory name as the database name
684
701
  db_name = filepath.stem if filepath.stem else "data"
685
702
  db_path = filepath.parent
686
703
 
687
- # Open FlaxKV database
688
704
  with FlaxKV(db_name, str(db_path)) as db:
689
- # Collect all items
690
705
  items = []
691
706
  for key in sorted(db.keys()):
692
707
  if key.startswith("item:"):