dtflow 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtflow/storage/io.py CHANGED
@@ -1,15 +1,19 @@
1
1
  """
2
2
  Input/Output utilities for saving and loading data.
3
+
4
+ 使用 Polars 作为主要 I/O 引擎,性能比 Pandas 快 3-5 倍。
5
+ 使用 orjson 作为 JSON 解析引擎,性能比标准 json 快 10 倍。
3
6
  """
4
- from typing import List, Dict, Any, Optional
5
- import json
6
- import os
7
+ import orjson
7
8
  from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import polars as pl
8
12
 
9
13
 
10
- def save_data(data: List[Dict[str, Any]],
11
- filepath: str,
12
- file_format: Optional[str] = None) -> None:
14
+ def save_data(
15
+ data: List[Dict[str, Any]], filepath: str, file_format: Optional[str] = None
16
+ ) -> None:
13
17
  """
14
18
  Save data to file.
15
19
 
@@ -21,23 +25,22 @@ def save_data(data: List[Dict[str, Any]],
21
25
  filepath = Path(filepath)
22
26
  filepath.parent.mkdir(parents=True, exist_ok=True)
23
27
 
24
- # Auto-detect format from extension
25
28
  if file_format is None:
26
29
  file_format = _detect_format(filepath)
27
30
 
28
- if file_format == 'jsonl':
31
+ if file_format == "jsonl":
29
32
  _save_jsonl(data, filepath)
30
- elif file_format == 'json':
33
+ elif file_format == "json":
31
34
  _save_json(data, filepath)
32
- elif file_format == 'csv':
35
+ elif file_format == "csv":
33
36
  _save_csv(data, filepath)
34
- elif file_format == 'parquet':
37
+ elif file_format == "parquet":
35
38
  _save_parquet(data, filepath)
36
- elif file_format == 'arrow':
39
+ elif file_format == "arrow":
37
40
  _save_arrow(data, filepath)
38
- elif file_format == 'excel':
41
+ elif file_format == "excel":
39
42
  _save_excel(data, filepath)
40
- elif file_format == 'flaxkv':
43
+ elif file_format == "flaxkv":
41
44
  _save_flaxkv(data, filepath)
42
45
  else:
43
46
  raise ValueError(f"Unknown file format: {file_format}")
@@ -59,23 +62,22 @@ def load_data(filepath: str, file_format: Optional[str] = None) -> List[Dict[str
59
62
  if not filepath.exists():
60
63
  raise FileNotFoundError(f"File not found: {filepath}")
61
64
 
62
- # Auto-detect format from extension
63
65
  if file_format is None:
64
66
  file_format = _detect_format(filepath)
65
67
 
66
- if file_format == 'jsonl':
68
+ if file_format == "jsonl":
67
69
  return _load_jsonl(filepath)
68
- elif file_format == 'json':
70
+ elif file_format == "json":
69
71
  return _load_json(filepath)
70
- elif file_format == 'csv':
72
+ elif file_format == "csv":
71
73
  return _load_csv(filepath)
72
- elif file_format == 'parquet':
74
+ elif file_format == "parquet":
73
75
  return _load_parquet(filepath)
74
- elif file_format == 'arrow':
76
+ elif file_format == "arrow":
75
77
  return _load_arrow(filepath)
76
- elif file_format == 'excel':
78
+ elif file_format == "excel":
77
79
  return _load_excel(filepath)
78
- elif file_format == 'flaxkv':
80
+ elif file_format == "flaxkv":
79
81
  return _load_flaxkv(filepath)
80
82
  else:
81
83
  raise ValueError(f"Unknown file format: {file_format}")
@@ -84,200 +86,204 @@ def load_data(filepath: str, file_format: Optional[str] = None) -> List[Dict[str
84
86
  def _detect_format(filepath: Path) -> str:
85
87
  """Detect file format from extension."""
86
88
  ext = filepath.suffix.lower()
87
- if ext == '.jsonl':
88
- return 'jsonl'
89
- elif ext == '.json':
90
- return 'json'
91
- elif ext == '.csv':
92
- return 'csv'
93
- elif ext == '.parquet':
94
- return 'parquet'
95
- elif ext in ('.arrow', '.feather'):
96
- return 'arrow'
97
- elif ext in ('.xlsx', '.xls'):
98
- return 'excel'
99
- elif ext == '.flaxkv' or ext == '':
100
- # For FlaxKV, filepath is typically a directory
101
- return 'flaxkv'
89
+ if ext == ".jsonl":
90
+ return "jsonl"
91
+ elif ext == ".json":
92
+ return "json"
93
+ elif ext == ".csv":
94
+ return "csv"
95
+ elif ext == ".parquet":
96
+ return "parquet"
97
+ elif ext in (".arrow", ".feather"):
98
+ return "arrow"
99
+ elif ext in (".xlsx", ".xls"):
100
+ return "excel"
101
+ elif ext == ".flaxkv" or ext == "":
102
+ return "flaxkv"
102
103
  else:
103
- # Default to JSONL
104
- return 'jsonl'
104
+ return "jsonl"
105
105
 
106
106
 
107
107
  # ============ JSONL Format ============
108
+ # JSONL 保持用原生 Python,因为需要处理复杂嵌套结构
109
+
108
110
 
109
111
  def _save_jsonl(data: List[Dict[str, Any]], filepath: Path) -> None:
110
112
  """Save data in JSONL format."""
111
- with open(filepath, 'w', encoding='utf-8') as f:
113
+ with open(filepath, "wb") as f:
112
114
  for item in data:
113
- json_line = json.dumps(item, ensure_ascii=False)
114
- f.write(json_line + '\n')
115
+ f.write(orjson.dumps(item) + b"\n")
115
116
 
116
117
 
117
118
  def _load_jsonl(filepath: Path) -> List[Dict[str, Any]]:
118
119
  """Load data from JSONL format."""
119
120
  data = []
120
- with open(filepath, 'r', encoding='utf-8') as f:
121
+ with open(filepath, "rb") as f:
121
122
  for line in f:
122
123
  line = line.strip()
123
124
  if line:
124
- data.append(json.loads(line))
125
+ data.append(orjson.loads(line))
125
126
  return data
126
127
 
127
128
 
128
129
  # ============ JSON Format ============
129
130
 
131
+
130
132
  def _save_json(data: List[Dict[str, Any]], filepath: Path) -> None:
131
133
  """Save data in JSON format."""
132
- with open(filepath, 'w', encoding='utf-8') as f:
133
- json.dump(data, f, ensure_ascii=False, indent=2)
134
+ with open(filepath, "wb") as f:
135
+ f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
134
136
 
135
137
 
136
138
  def _load_json(filepath: Path) -> List[Dict[str, Any]]:
137
139
  """Load data from JSON format."""
138
- with open(filepath, 'r', encoding='utf-8') as f:
139
- data = json.load(f)
140
+ with open(filepath, "rb") as f:
141
+ data = orjson.loads(f.read())
140
142
 
141
- # Ensure data is a list
142
143
  if not isinstance(data, list):
143
144
  data = [data]
144
145
 
145
146
  return data
146
147
 
147
148
 
148
- # ============ CSV Format ============
149
+ # ============ CSV Format (Polars) ============
150
+
149
151
 
150
152
  def _save_csv(data: List[Dict[str, Any]], filepath: Path) -> None:
151
- """Save data in CSV format."""
152
- try:
153
- import pandas as pd
154
- except ImportError:
155
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
153
+ """Save data in CSV format using Polars."""
154
+ if not data:
155
+ # 空数据,创建空文件
156
+ filepath.touch()
157
+ return
156
158
 
157
- df = pd.DataFrame(data)
158
- df.to_csv(filepath, index=False, encoding='utf-8')
159
+ # 序列化复杂字段为 JSON 字符串
160
+ serialized = _serialize_complex_fields(data)
161
+ df = pl.DataFrame(serialized)
162
+ df.write_csv(filepath)
159
163
 
160
164
 
161
165
  def _load_csv(filepath: Path) -> List[Dict[str, Any]]:
162
- """Load data from CSV format."""
163
- try:
164
- import pandas as pd
165
- except ImportError:
166
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
166
+ """Load data from CSV format using Polars."""
167
+ df = pl.read_csv(filepath)
168
+ data = df.to_dicts()
169
+ # 反序列化 JSON 字符串
170
+ return _deserialize_complex_fields(data)
167
171
 
168
- df = pd.read_csv(filepath, encoding='utf-8')
169
- return df.to_dict('records')
170
172
 
173
+ # ============ Parquet Format (Polars) ============
171
174
 
172
- # ============ Excel Format ============
173
175
 
174
- def _save_excel(data: List[Dict[str, Any]], filepath: Path) -> None:
175
- """Save data in Excel format."""
176
- try:
177
- import pandas as pd
178
- except ImportError:
179
- raise ImportError("pandas and openpyxl are required for Excel support. Install with: pip install pandas openpyxl")
176
+ def _save_parquet(data: List[Dict[str, Any]], filepath: Path) -> None:
177
+ """Save data in Parquet format using Polars."""
178
+ if not data:
179
+ # 空数据,创建空 parquet
180
+ pl.DataFrame().write_parquet(filepath)
181
+ return
180
182
 
181
- df = pd.DataFrame(data)
182
- df.to_excel(filepath, index=False)
183
+ serialized = _serialize_complex_fields(data)
184
+ df = pl.DataFrame(serialized)
185
+ df.write_parquet(filepath)
183
186
 
184
187
 
185
- def _load_excel(filepath: Path) -> List[Dict[str, Any]]:
186
- """Load data from Excel format."""
187
- try:
188
- import pandas as pd
189
- except ImportError:
190
- raise ImportError("pandas and openpyxl are required for Excel support. Install with: pip install pandas openpyxl")
188
+ def _load_parquet(filepath: Path) -> List[Dict[str, Any]]:
189
+ """Load data from Parquet format using Polars."""
190
+ df = pl.read_parquet(filepath)
191
+ data = df.to_dicts()
192
+ return _deserialize_complex_fields(data)
191
193
 
192
- df = pd.read_excel(filepath)
193
- return df.to_dict('records')
194
194
 
195
+ # ============ Arrow Format (Polars) ============
195
196
 
196
- # ============ Parquet Format ============
197
197
 
198
- def _save_parquet(data: List[Dict[str, Any]], filepath: Path) -> None:
199
- """Save data in Parquet format."""
200
- try:
201
- import pandas as pd
202
- except ImportError:
203
- raise ImportError("pandas is required for Parquet support. Install with: pip install pandas pyarrow")
198
+ def _save_arrow(data: List[Dict[str, Any]], filepath: Path) -> None:
199
+ """Save data in Arrow IPC format using Polars."""
200
+ if not data:
201
+ pl.DataFrame().write_ipc(filepath)
202
+ return
204
203
 
205
- df = pd.DataFrame(data)
206
- df.to_parquet(filepath, index=False, engine='pyarrow')
204
+ serialized = _serialize_complex_fields(data)
205
+ df = pl.DataFrame(serialized)
206
+ df.write_ipc(filepath)
207
207
 
208
208
 
209
- def _load_parquet(filepath: Path) -> List[Dict[str, Any]]:
210
- """Load data from Parquet format."""
211
- try:
212
- import pandas as pd
213
- except ImportError:
214
- raise ImportError("pandas is required for Parquet support. Install with: pip install pandas pyarrow")
209
+ def _load_arrow(filepath: Path) -> List[Dict[str, Any]]:
210
+ """Load data from Arrow IPC format using Polars."""
211
+ df = pl.read_ipc(filepath)
212
+ data = df.to_dicts()
213
+ return _deserialize_complex_fields(data)
214
+
215
215
 
216
- df = pd.read_parquet(filepath, engine='pyarrow')
217
- return df.to_dict('records')
216
+ # ============ Excel Format ============
217
+ # Excel 需要额外依赖,保持可选
218
218
 
219
219
 
220
- # ============ Arrow Format ============
220
+ def _save_excel(data: List[Dict[str, Any]], filepath: Path) -> None:
221
+ """Save data in Excel format."""
222
+ if not data:
223
+ # 空数据
224
+ try:
225
+ import xlsxwriter
226
+ workbook = xlsxwriter.Workbook(str(filepath))
227
+ workbook.close()
228
+ except ImportError:
229
+ raise ImportError(
230
+ "xlsxwriter is required for Excel write. Install with: pip install xlsxwriter"
231
+ )
232
+ return
233
+
234
+ serialized = _serialize_complex_fields(data)
235
+ df = pl.DataFrame(serialized)
236
+ df.write_excel(filepath)
221
237
 
222
- def _save_arrow(data: List[Dict[str, Any]], filepath: Path) -> None:
223
- """Save data in Arrow IPC format (also known as Feather v2).
224
238
 
225
- Note: Complex nested structures (like list of dicts) are serialized as JSON strings.
226
- """
227
- try:
228
- import pyarrow as pa
229
- import pyarrow.feather as feather
230
- except ImportError:
231
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
239
+ def _load_excel(filepath: Path) -> List[Dict[str, Any]]:
240
+ """Load data from Excel format."""
241
+ df = pl.read_excel(filepath)
242
+ data = df.to_dicts()
243
+ return _deserialize_complex_fields(data)
244
+
232
245
 
233
- # Serialize complex fields to JSON strings for Arrow compatibility
234
- serialized_data = []
246
+ # ============ 复杂字段序列化 ============
247
+
248
+
249
+ def _serialize_complex_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
250
+ """将复杂字段(list, dict)序列化为 JSON 字符串"""
251
+ result = []
235
252
  for item in data:
236
253
  new_item = {}
237
254
  for k, v in item.items():
238
255
  if isinstance(v, (list, dict)):
239
- new_item[k] = json.dumps(v, ensure_ascii=False)
256
+ new_item[k] = orjson.dumps(v).decode("utf-8")
240
257
  else:
241
258
  new_item[k] = v
242
- serialized_data.append(new_item)
243
-
244
- table = pa.Table.from_pylist(serialized_data)
245
-
246
- # Use Feather format (simpler and more portable)
247
- feather.write_feather(table, filepath)
248
-
249
-
250
- def _load_arrow(filepath: Path) -> List[Dict[str, Any]]:
251
- """Load data from Arrow IPC format (also known as Feather v2).
252
-
253
- Note: JSON-serialized fields are automatically deserialized.
254
- """
255
- try:
256
- import pyarrow.feather as feather
257
- except ImportError:
258
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
259
+ result.append(new_item)
260
+ return result
259
261
 
260
- table = feather.read_table(filepath)
261
- data = table.to_pylist()
262
262
 
263
- # Deserialize JSON strings back to complex objects
263
+ def _deserialize_complex_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
264
+ """将 JSON 字符串反序列化为复杂字段"""
264
265
  result = []
265
266
  for item in data:
266
267
  new_item = {}
267
268
  for k, v in item.items():
268
- if isinstance(v, str) and v.startswith(('[', '{')):
269
+ if isinstance(v, str) and v.startswith(("[", "{")):
269
270
  try:
270
- new_item[k] = json.loads(v)
271
- except json.JSONDecodeError:
271
+ new_item[k] = orjson.loads(v)
272
+ except orjson.JSONDecodeError:
272
273
  new_item[k] = v
273
274
  else:
274
275
  new_item[k] = v
275
276
  result.append(new_item)
276
-
277
277
  return result
278
278
 
279
279
 
280
- # ============ Additional Utilities ============
280
+ def _clean_null_fields(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
281
+ """清理 Polars 添加的 null 字段,保持原始数据结构"""
282
+ return [{k: v for k, v in item.items() if v is not None} for item in data]
283
+
284
+
285
+ # ============ Streaming Utilities ============
286
+
281
287
 
282
288
  def sample_data(
283
289
  data: List[Dict[str, Any]],
@@ -290,28 +296,12 @@ def sample_data(
290
296
 
291
297
  Args:
292
298
  data: List of data items
293
- num: Number of items to sample.
294
- - num > 0: sample specified number of items
295
- - num = 0: sample all data
296
- - num < 0: Python slice style (e.g., -1 means last 1, -10 means last 10)
299
+ num: Number of items to sample
297
300
  sample_type: Sampling method - "random", "head", or "tail"
298
- seed: Random seed for reproducibility (only for random sampling)
301
+ seed: Random seed for reproducibility
299
302
 
300
303
  Returns:
301
304
  Sampled data list
302
-
303
- Examples:
304
- >>> data = [{"id": i} for i in range(100)]
305
- >>> sample_data(data, num=5, sample_type="head")
306
- [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}]
307
- >>> sample_data(data, num=3, sample_type="tail")
308
- [{'id': 97}, {'id': 98}, {'id': 99}]
309
- >>> len(sample_data(data, num=0)) # 0 means all
310
- 100
311
- >>> sample_data(data, num=-1, sample_type="head") # last 1 item
312
- [{'id': 99}]
313
- >>> sample_data(data, num=-3, sample_type="tail") # last 3 items
314
- [{'id': 97}, {'id': 98}, {'id': 99}]
315
305
  """
316
306
  import random as rand_module
317
307
 
@@ -320,15 +310,11 @@ def sample_data(
320
310
 
321
311
  total = len(data)
322
312
 
323
- # Determine actual number to sample
324
313
  if num == 0:
325
- # 0 means sample all data
326
314
  actual_num = total
327
315
  elif num < 0:
328
- # Negative number: Python slice style (e.g., -1 means 1 item, -10 means 10 items)
329
316
  actual_num = min(abs(num), total)
330
317
  else:
331
- # Positive number: normal sampling
332
318
  actual_num = min(num, total)
333
319
 
334
320
  if sample_type == "head":
@@ -349,32 +335,23 @@ def sample_file(
349
335
  output: Optional[str] = None,
350
336
  ) -> List[Dict[str, Any]]:
351
337
  """
352
- Sample data from a file with streaming support for large files.
353
-
354
- 对于 head/tail 采样,支持流式读取,不需要加载整个文件到内存。
355
- 对于 random 采样,JSONL 使用蓄水池采样算法,其他格式需要加载全部数据。
338
+ Sample data from a file with streaming support.
356
339
 
357
340
  Args:
358
- filepath: Input file path (supports csv, xlsx, jsonl, json, parquet, arrow, feather)
341
+ filepath: Input file path
359
342
  num: Number of items to sample
360
- sample_type: Sampling method - "random", "head", or "tail"
361
- seed: Random seed for reproducibility (only for random sampling)
362
- output: Output file path (optional, if provided, saves sampled data)
343
+ sample_type: Sampling method
344
+ seed: Random seed
345
+ output: Output file path
363
346
 
364
347
  Returns:
365
348
  Sampled data list
366
-
367
- Examples:
368
- >>> sampled = sample_file("data.jsonl", num=100, sample_type="random")
369
- >>> sample_file("data.csv", num=50, output="sampled.jsonl")
370
349
  """
371
350
  filepath = Path(filepath)
372
351
  file_format = _detect_format(filepath)
373
352
 
374
- # 尝试使用流式采样
375
353
  sampled = _stream_sample(filepath, file_format, num, sample_type, seed)
376
354
 
377
- # Save if output specified
378
355
  if output:
379
356
  save_data(sampled, output)
380
357
 
@@ -388,18 +365,7 @@ def _stream_sample(
388
365
  sample_type: str,
389
366
  seed: Optional[int],
390
367
  ) -> List[Dict[str, Any]]:
391
- """
392
- 流式采样实现。
393
-
394
- 支持的流式优化:
395
- - head: jsonl, csv, parquet, arrow, excel
396
- - tail: jsonl(反向读取)
397
- - random: jsonl(蓄水池采样)
398
-
399
- num == 0 表示采样所有数据,回退到全量加载。
400
- num < 0 表示 Python 切片风格,回退到全量加载。
401
- """
402
- # num == 0 表示采样所有数据,num < 0 表示切片风格,都需要全量加载
368
+ """流式采样实现"""
403
369
  if num <= 0:
404
370
  data = load_data(str(filepath))
405
371
  return sample_data(data, num=num, sample_type=sample_type, seed=seed)
@@ -417,13 +383,27 @@ def _stream_sample(
417
383
  elif file_format == "excel":
418
384
  return _stream_head_excel(filepath, num)
419
385
 
420
- # tail 采样优化(仅 JSONL)
421
- if sample_type == "tail" and file_format == "jsonl":
422
- return _stream_tail_jsonl(filepath, num)
386
+ # tail 采样优化
387
+ if sample_type == "tail":
388
+ if file_format == "jsonl":
389
+ return _stream_tail_jsonl(filepath, num)
390
+ elif file_format == "csv":
391
+ return _stream_tail_csv(filepath, num)
392
+ elif file_format == "parquet":
393
+ return _stream_tail_parquet(filepath, num)
394
+ elif file_format == "arrow":
395
+ return _stream_tail_arrow(filepath, num)
423
396
 
424
- # random 采样优化(仅 JSONL,使用蓄水池采样)
425
- if sample_type == "random" and file_format == "jsonl":
426
- return _stream_random_jsonl(filepath, num, seed)
397
+ # random 采样优化
398
+ if sample_type == "random":
399
+ if file_format == "jsonl":
400
+ return _stream_random_jsonl(filepath, num, seed)
401
+ elif file_format == "csv":
402
+ return _stream_random_csv(filepath, num, seed)
403
+ elif file_format == "parquet":
404
+ return _stream_random_parquet(filepath, num, seed)
405
+ elif file_format == "arrow":
406
+ return _stream_random_arrow(filepath, num, seed)
427
407
 
428
408
  # 其他情况回退到全量加载
429
409
  data = load_data(str(filepath))
@@ -431,262 +411,288 @@ def _stream_sample(
431
411
 
432
412
 
433
413
  def _stream_head_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
434
- """JSONL 流式读取前 N """
435
- result = []
436
- with open(filepath, "r", encoding="utf-8") as f:
437
- for line in f:
438
- line = line.strip()
439
- if line:
440
- result.append(json.loads(line))
441
- if len(result) >= num:
442
- break
443
- return result
414
+ """JSONL 流式读取前 N 行(使用 Polars ndjson)"""
415
+ try:
416
+ df = pl.scan_ndjson(filepath).head(num).collect()
417
+ return _clean_null_fields(df.to_dicts())
418
+ except Exception as e:
419
+ # 回退到 Python 实现
420
+ import sys
421
+ print(f"[Warning] Polars ndjson 解析失败,回退到 Python 实现: {type(e).__name__}", file=sys.stderr)
422
+
423
+ result = []
424
+ with open(filepath, "rb") as f:
425
+ for line in f:
426
+ line = line.strip()
427
+ if line:
428
+ try:
429
+ result.append(orjson.loads(line))
430
+ except orjson.JSONDecodeError:
431
+ continue # 跳过无效行
432
+ if len(result) >= num:
433
+ break
434
+ return result
444
435
 
445
436
 
446
437
  def _stream_head_csv(filepath: Path, num: int) -> List[Dict[str, Any]]:
447
- """CSV 流式读取前 N """
448
- try:
449
- import pandas as pd
450
- except ImportError:
451
- raise ImportError("pandas is required for CSV support. Install with: pip install pandas")
452
-
453
- df = pd.read_csv(filepath, encoding="utf-8", nrows=num)
454
- return df.to_dict("records")
438
+ """CSV 流式读取前 N 行(使用 Polars LazyFrame)"""
439
+ df = pl.scan_csv(filepath).head(num).collect()
440
+ return _deserialize_complex_fields(df.to_dicts())
455
441
 
456
442
 
457
443
  def _stream_head_parquet(filepath: Path, num: int) -> List[Dict[str, Any]]:
458
- """Parquet 真流式读取前 N 行(使用 iter_batches 避免全量加载)"""
459
- try:
460
- import pyarrow.parquet as pq
461
- except ImportError:
462
- raise ImportError("pyarrow is required for Parquet support. Install with: pip install pyarrow")
444
+ """Parquet 流式读取前 N 行(使用 Polars LazyFrame)"""
445
+ df = pl.scan_parquet(filepath).head(num).collect()
446
+ return _deserialize_complex_fields(df.to_dicts())
463
447
 
464
- parquet_file = pq.ParquetFile(filepath)
465
- result = []
466
448
 
467
- # 使用 iter_batches 真正流式读取,只读取需要的数据
468
- for batch in parquet_file.iter_batches(batch_size=min(num, 10000)):
469
- batch_data = batch.to_pylist()
470
- result.extend(batch_data)
471
- if len(result) >= num:
472
- break
449
+ def _stream_head_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
450
+ """Arrow 流式读取前 N 行(使用 Polars LazyFrame)"""
451
+ df = pl.scan_ipc(filepath).head(num).collect()
452
+ return _deserialize_complex_fields(df.to_dicts())
473
453
 
474
- return result[:num]
475
454
 
455
+ def _stream_head_excel(filepath: Path, num: int) -> List[Dict[str, Any]]:
456
+ """Excel 读取前 N 行"""
457
+ # Excel 不支持 lazy scan,使用普通读取
458
+ df = pl.read_excel(filepath).head(num)
459
+ return _deserialize_complex_fields(df.to_dicts())
476
460
 
477
- def _stream_head_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
478
- """Arrow/Feather 流式读取前 N 行"""
461
+
462
+ def _stream_tail_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
463
+ """JSONL 流式读取后 N 行(使用 Polars ndjson)"""
479
464
  try:
480
- import pyarrow.feather as feather
481
- except ImportError:
482
- raise ImportError("pyarrow is required for Arrow support. Install with: pip install pyarrow")
465
+ df = pl.scan_ndjson(filepath).tail(num).collect()
466
+ return _clean_null_fields(df.to_dicts())
467
+ except Exception as e:
468
+ # 回退到 Python 两遍遍历实现
469
+ import sys
470
+ print(f"[Warning] Polars ndjson 解析失败,回退到 Python 实现: {type(e).__name__}", file=sys.stderr)
483
471
 
484
- table = feather.read_table(filepath)
485
- sliced = table.slice(0, min(num, table.num_rows))
486
- return _deserialize_arrow_data(sliced.to_pylist())
472
+ total_lines = 0
473
+ with open(filepath, "rb") as f:
474
+ for _ in f:
475
+ total_lines += 1
487
476
 
477
+ if total_lines <= num:
478
+ return _load_jsonl(filepath)
488
479
 
489
- def _deserialize_arrow_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
490
- """反序列化 Arrow 数据中的 JSON 字符串字段"""
491
- result = []
492
- for item in data:
493
- new_item = {}
494
- for k, v in item.items():
495
- if isinstance(v, str) and v.startswith(("[", "{")):
496
- try:
497
- new_item[k] = json.loads(v)
498
- except json.JSONDecodeError:
499
- new_item[k] = v
500
- else:
501
- new_item[k] = v
502
- result.append(new_item)
503
- return result
480
+ skip_count = total_lines - num
481
+ result = []
482
+ with open(filepath, "rb") as f:
483
+ for i, line in enumerate(f):
484
+ if i < skip_count:
485
+ continue
486
+ line = line.strip()
487
+ if line:
488
+ try:
489
+ result.append(orjson.loads(line))
490
+ except orjson.JSONDecodeError:
491
+ continue # 跳过无效行
492
+ return result
504
493
 
505
494
 
506
- def _stream_head_excel(filepath: Path, num: int) -> List[Dict[str, Any]]:
507
- """Excel 流式读取前 N """
508
- try:
509
- import pandas as pd
510
- except ImportError:
511
- raise ImportError("pandas and openpyxl are required for Excel support")
495
+ def _stream_tail_csv(filepath: Path, num: int) -> List[Dict[str, Any]]:
496
+ """CSV 流式读取后 N 行(使用 Polars LazyFrame)"""
497
+ df = pl.scan_csv(filepath).tail(num).collect()
498
+ return _deserialize_complex_fields(df.to_dicts())
512
499
 
513
- df = pd.read_excel(filepath, nrows=num)
514
- return df.to_dict("records")
515
500
 
501
+ def _stream_tail_parquet(filepath: Path, num: int) -> List[Dict[str, Any]]:
502
+ """Parquet 流式读取后 N 行(使用 Polars LazyFrame)"""
503
+ df = pl.scan_parquet(filepath).tail(num).collect()
504
+ return _deserialize_complex_fields(df.to_dicts())
516
505
 
517
- def append_to_file(data: List[Dict[str, Any]],
518
- filepath: str,
519
- file_format: str = 'jsonl') -> None:
520
- """
521
- Append data to an existing file.
522
506
 
523
- Args:
524
- data: List of data items to append
525
- filepath: Path to file
526
- file_format: File format (only 'jsonl' supported for append)
527
- """
528
- filepath = Path(filepath)
507
+ def _stream_tail_arrow(filepath: Path, num: int) -> List[Dict[str, Any]]:
508
+ """Arrow 流式读取后 N 行(使用 Polars LazyFrame)"""
509
+ df = pl.scan_ipc(filepath).tail(num).collect()
510
+ return _deserialize_complex_fields(df.to_dicts())
529
511
 
530
- if file_format != 'jsonl':
531
- raise ValueError("Only JSONL format supports appending")
532
512
 
533
- filepath.parent.mkdir(parents=True, exist_ok=True)
534
-
535
- with open(filepath, 'a', encoding='utf-8') as f:
536
- for item in data:
537
- json_line = json.dumps(item, ensure_ascii=False)
538
- f.write(json_line + '\n')
513
+ # 文件大小阈值:超过此值使用 Python 流式采样,否则使用 Polars
514
+ _STREAM_THRESHOLD_BYTES = 100 * 1024 * 1024 # 100MB
539
515
 
540
516
 
541
- def count_lines(filepath: str) -> int:
542
- """
543
- Count number of lines in a JSONL file without loading all data.
544
-
545
- Args:
546
- filepath: Path to JSONL file
517
+ def _count_sample_jsonl(
518
+ filepath: Path, num: int, seed: Optional[int] = None
519
+ ) -> List[Dict[str, Any]]:
520
+ """JSONL 流式采样(Polars 计数 + Python 选择性读取)
547
521
 
548
- Returns:
549
- Number of lines
522
+ 策略:
523
+ 1. 使用 Polars 快速获取行数(比 Python 快 4 倍)
524
+ 2. 生成随机索引
525
+ 3. Python 遍历文件,只解析选中的行
550
526
  """
551
- count = 0
552
- with open(filepath, 'r', encoding='utf-8') as f:
553
- for _ in f:
554
- count += 1
555
- return count
527
+ import random
556
528
 
529
+ # Step 1: Polars 快速获取行数
530
+ try:
531
+ total_lines = pl.scan_ndjson(filepath).select(pl.len()).collect().item()
532
+ except Exception:
533
+ # 回退到 Python 计数
534
+ with open(filepath, "rb") as f:
535
+ total_lines = sum(1 for _ in f)
557
536
 
558
- def stream_jsonl(filepath: str, chunk_size: int = 1000):
559
- """
560
- Stream JSONL file in chunks.
537
+ if total_lines == 0:
538
+ return []
561
539
 
562
- Args:
563
- filepath: Path to JSONL file
564
- chunk_size: Number of items per chunk
540
+ # 采样数超过总行数,读取全部
541
+ if num >= total_lines:
542
+ return _load_jsonl(filepath)
565
543
 
566
- Yields:
567
- Chunks of data items
568
- """
569
- chunk = []
570
- with open(filepath, 'r', encoding='utf-8') as f:
571
- for line in f:
572
- line = line.strip()
573
- if line:
574
- chunk.append(json.loads(line))
575
- if len(chunk) >= chunk_size:
576
- yield chunk
577
- chunk = []
544
+ # Step 2: 生成随机索引
545
+ if seed is not None:
546
+ random.seed(seed)
547
+ selected_indices = set(random.sample(range(total_lines), num))
578
548
 
579
- if chunk:
580
- yield chunk
549
+ # Step 3: 只解析选中的行
550
+ result = []
551
+ with open(filepath, "rb") as f:
552
+ for i, line in enumerate(f):
553
+ if i in selected_indices:
554
+ line = line.strip()
555
+ if line:
556
+ try:
557
+ result.append(orjson.loads(line))
558
+ except orjson.JSONDecodeError:
559
+ continue
560
+ if len(result) >= num:
561
+ break
581
562
 
563
+ return result
582
564
 
583
- # ============ JSONL 流式采样优化 ============
584
565
 
566
+ def _stream_random_jsonl(
567
+ filepath: Path, num: int, seed: Optional[int] = None
568
+ ) -> List[Dict[str, Any]]:
569
+ """JSONL 随机采样
585
570
 
586
- def _stream_tail_jsonl(filepath: Path, num: int) -> List[Dict[str, Any]]:
571
+ 策略:
572
+ - 小文件 (<100MB): 使用 Polars collect+sample
573
+ - 大文件 (>=100MB): 使用 count+sample 流式采样(更快且内存友好)
587
574
  """
588
- JSONL 反向读取后 N 行(避免全量加载)。
575
+ file_size = filepath.stat().st_size
589
576
 
590
- 使用双端队列保持最后 N 行,内存占用 O(num) 而非 O(total)。
591
- """
592
- from collections import deque
577
+ # 大文件使用流式采样(更快)
578
+ if file_size >= _STREAM_THRESHOLD_BYTES:
579
+ return _count_sample_jsonl(filepath, num, seed)
593
580
 
594
- # 使用 deque 的 maxlen 自动保持最后 N 个元素
595
- buffer = deque(maxlen=num)
581
+ # 小文件尝试 Polars
582
+ try:
583
+ df = pl.scan_ndjson(filepath).collect()
584
+ if len(df) <= num:
585
+ return _clean_null_fields(df.to_dicts())
586
+ sampled = df.sample(n=num, seed=seed)
587
+ return _clean_null_fields(sampled.to_dicts())
588
+ except Exception as e:
589
+ import sys
590
+ print(f"[Warning] Polars ndjson 解析失败,回退到流式采样: {type(e).__name__}", file=sys.stderr)
591
+ return _count_sample_jsonl(filepath, num, seed)
592
+
593
+
594
+ def _stream_random_csv(
595
+ filepath: Path, num: int, seed: Optional[int] = None
596
+ ) -> List[Dict[str, Any]]:
597
+ """CSV 随机采样(使用 Polars)"""
598
+ df = pl.scan_csv(filepath).collect()
599
+ if len(df) <= num:
600
+ return _deserialize_complex_fields(df.to_dicts())
601
+ sampled = df.sample(n=num, seed=seed)
602
+ return _deserialize_complex_fields(sampled.to_dicts())
596
603
 
597
- with open(filepath, "r", encoding="utf-8") as f:
598
- for line in f:
599
- line = line.strip()
600
- if line:
601
- buffer.append(json.loads(line))
602
604
 
603
- return list(buffer)
605
+ def _stream_random_parquet(
606
+ filepath: Path, num: int, seed: Optional[int] = None
607
+ ) -> List[Dict[str, Any]]:
608
+ """Parquet 随机采样(使用 Polars)"""
609
+ df = pl.scan_parquet(filepath).collect()
610
+ if len(df) <= num:
611
+ return _deserialize_complex_fields(df.to_dicts())
612
+ sampled = df.sample(n=num, seed=seed)
613
+ return _deserialize_complex_fields(sampled.to_dicts())
604
614
 
605
615
 
606
- def _stream_random_jsonl(
616
+ def _stream_random_arrow(
607
617
  filepath: Path, num: int, seed: Optional[int] = None
608
618
  ) -> List[Dict[str, Any]]:
609
- """
610
- JSONL 蓄水池采样(Reservoir Sampling)。
619
+ """Arrow 随机采样(使用 Polars)"""
620
+ df = pl.scan_ipc(filepath).collect()
621
+ if len(df) <= num:
622
+ return _deserialize_complex_fields(df.to_dicts())
623
+ sampled = df.sample(n=num, seed=seed)
624
+ return _deserialize_complex_fields(sampled.to_dicts())
611
625
 
612
- 单次遍历文件,内存占用 O(num),适合超大文件随机采样。
613
- 算法保证每条数据被选中的概率相等。
614
- """
615
- import random
616
626
 
617
- if seed is not None:
618
- random.seed(seed)
627
+ # ============ Additional Utilities ============
628
+
629
+
630
+ def append_to_file(
631
+ data: List[Dict[str, Any]], filepath: str, file_format: str = "jsonl"
632
+ ) -> None:
633
+ """Append data to an existing file (only JSONL supported)."""
634
+ filepath = Path(filepath)
635
+
636
+ if file_format != "jsonl":
637
+ raise ValueError("Only JSONL format supports appending")
638
+
639
+ filepath.parent.mkdir(parents=True, exist_ok=True)
640
+
641
+ with open(filepath, "ab") as f:
642
+ for item in data:
643
+ f.write(orjson.dumps(item) + b"\n")
619
644
 
620
- reservoir = [] # 蓄水池
621
645
 
646
+ def count_lines(filepath: str) -> int:
647
+ """Count number of lines in a JSONL file."""
648
+ count = 0
622
649
  with open(filepath, "r", encoding="utf-8") as f:
623
- for i, line in enumerate(f):
624
- line = line.strip()
625
- if not line:
626
- continue
650
+ for _ in f:
651
+ count += 1
652
+ return count
627
653
 
628
- item = json.loads(line)
629
654
 
630
- if len(reservoir) < num:
631
- # 蓄水池未满,直接加入
632
- reservoir.append(item)
633
- else:
634
- # 蓄水池已满,以 num/(i+1) 的概率替换
635
- j = random.randint(0, i)
636
- if j < num:
637
- reservoir[j] = item
655
+ def stream_jsonl(filepath: str, chunk_size: int = 1000):
656
+ """Stream JSONL file in chunks."""
657
+ chunk = []
658
+ with open(filepath, "rb") as f:
659
+ for line in f:
660
+ line = line.strip()
661
+ if line:
662
+ chunk.append(orjson.loads(line))
663
+ if len(chunk) >= chunk_size:
664
+ yield chunk
665
+ chunk = []
638
666
 
639
- return reservoir
667
+ if chunk:
668
+ yield chunk
640
669
 
641
670
 
642
671
  # ============ FlaxKV Format ============
643
672
 
644
- def _save_flaxkv(data: List[Dict[str, Any]], filepath: Path) -> None:
645
- """
646
- Save data in FlaxKV format.
647
673
 
648
- Args:
649
- data: List of data items to save
650
- filepath: Path to FlaxKV database (directory)
651
- """
674
+ def _save_flaxkv(data: List[Dict[str, Any]], filepath: Path) -> None:
675
+ """Save data in FlaxKV format."""
652
676
  from flaxkv2 import FlaxKV
653
677
 
654
- # Use the directory name as the database name
655
678
  db_name = filepath.stem if filepath.stem else "data"
656
679
  db_path = filepath.parent
657
680
 
658
- # Create FlaxKV database
659
681
  with FlaxKV(db_name, str(db_path)) as db:
660
- # Store metadata
661
- db["_metadata"] = {
662
- "total": len(data),
663
- "format": "flaxkv"
664
- }
682
+ db["_metadata"] = {"total": len(data), "format": "flaxkv"}
665
683
 
666
- # Store each item with index as key
667
684
  for i, item in enumerate(data):
668
685
  db[f"item:{i}"] = item
669
686
 
670
687
 
671
688
  def _load_flaxkv(filepath: Path) -> List[Dict[str, Any]]:
672
- """
673
- Load data from FlaxKV format.
674
-
675
- Args:
676
- filepath: Path to FlaxKV database (directory)
677
-
678
- Returns:
679
- List of data items
680
- """
689
+ """Load data from FlaxKV format."""
681
690
  from flaxkv2 import FlaxKV
682
691
 
683
- # Use the directory name as the database name
684
692
  db_name = filepath.stem if filepath.stem else "data"
685
693
  db_path = filepath.parent
686
694
 
687
- # Open FlaxKV database
688
695
  with FlaxKV(db_name, str(db_path)) as db:
689
- # Collect all items
690
696
  items = []
691
697
  for key in sorted(db.keys()):
692
698
  if key.startswith("item:"):