mplang-nightly 0.1.dev263__py3-none-any.whl → 0.1.dev264__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,12 +20,15 @@ Implements execution logic for Table primitives using DuckDB and PyArrow.
20
20
  from __future__ import annotations
21
21
 
22
22
  import base64
23
- from typing import Any, ClassVar
23
+ from abc import ABC, abstractmethod
24
+ from dataclasses import dataclass
25
+ from typing import Any, ClassVar, Protocol, Self, runtime_checkable
24
26
 
25
27
  import duckdb
26
28
  import pandas as pd
27
29
  import pyarrow as pa
28
30
 
31
+ import mplang.v2.edsl.typing as elt
29
32
  from mplang.v2.backends.tensor_impl import TensorValue
30
33
  from mplang.v2.dialects import table
31
34
  from mplang.v2.edsl import serde
@@ -33,13 +36,464 @@ from mplang.v2.edsl.graph import Operation
33
36
  from mplang.v2.runtime.interpreter import Interpreter
34
37
  from mplang.v2.runtime.value import WrapValue
35
38
 
39
+
40
+ class BatchReader(ABC):
41
+ @property
42
+ @abstractmethod
43
+ def schema(self) -> pa.Schema: ...
44
+
45
+ @abstractmethod
46
+ def read_next_batch(self) -> pa.RecordBatch: ...
47
+ @abstractmethod
48
+ def close(self) -> None: ...
49
+
50
+ def __enter__(self) -> Self:
51
+ return self
52
+
53
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
54
+ self.close()
55
+
56
+ def __iter__(self) -> Self:
57
+ return self
58
+
59
+ def __next__(self) -> pa.RecordBatch:
60
+ return self.read_next_batch()
61
+
62
+
63
+ class TableReader(BatchReader):
64
+ """A reader for streaming table data from PyArrow RecordBatchReader or Table.
65
+
66
+ This class provides an efficient way to read large tables in batches,
67
+ with support for custom batch sizes and proper handling of data boundaries.
68
+ It implements the iterator protocol for easy consumption of data.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ data: pa.RecordBatchReader | pa.Table,
74
+ num_rows: int = -1,
75
+ batch_size: int = -1,
76
+ ) -> None:
77
+ """Initialize a TableReader.
78
+
79
+ Args:
80
+ data: Either a RecordBatchReader or Table to read from
81
+ num_rows: Expected number of rows in the data. -1 indicates unknown
82
+ batch_size: Size of each batch to read. -1 means use default/reader's batch size
83
+ """
84
+ # Store the underlying reader and row count based on input type
85
+ if isinstance(data, pa.RecordBatchReader):
86
+ self._reader = data
87
+ self._num_rows = num_rows
88
+ else:
89
+ # Convert Table to RecordBatchReader for consistent interface
90
+ self._reader = data.to_reader()
91
+ self._num_rows = data.num_rows
92
+
93
+ # Configuration for batch reading
94
+ self._batch_size = batch_size
95
+
96
+ # Internal state for handling custom batch sizes
97
+ self._remain: pa.RecordBatch | None = (
98
+ None # Stores partial batch from previous read
99
+ )
100
+ self._eof = False # Flag to indicate end of data
101
+
102
+ @property
103
+ def num_rows(self) -> int:
104
+ """Get the total number of rows in the table.
105
+
106
+ Returns:
107
+ Total number of rows, or -1 if unknown
108
+ """
109
+ return self._num_rows
110
+
111
+ @property
112
+ def schema(self) -> pa.Schema:
113
+ """Get the schema of the table.
114
+
115
+ Returns:
116
+ PyArrow Schema describing the table's columns and types
117
+ """
118
+ return self._reader.schema
119
+
120
+ def read_all(self) -> pa.Table:
121
+ """Read all remaining data as a Table.
122
+
123
+ This is a convenience method that reads all data from the reader
124
+ and returns it as a single PyArrow Table.
125
+
126
+ Returns:
127
+ Complete table containing all remaining data
128
+ """
129
+ return self._reader.read_all()
130
+
131
+ def read_next_batch(self) -> pa.RecordBatch:
132
+ """Read the next batch of records.
133
+
134
+ This method respects the configured batch size. If the native reader
135
+ returns batches larger than the configured size, this method will split
136
+ them appropriately. Any partial data from previous reads is included
137
+ in the returned batch.
138
+
139
+ Returns:
140
+ Next RecordBatch of data
141
+
142
+ Raises:
143
+ StopIteration: When no more data is available
144
+ """
145
+ # Check if we've reached end of file
146
+ if self._eof:
147
+ raise StopIteration
148
+
149
+ # Get the next batch using internal logic
150
+ batch = self._read_next_batch()
151
+
152
+ # Handle end of data
153
+ if batch is None:
154
+ self._eof = True
155
+ raise StopIteration
156
+
157
+ return batch
158
+
159
+ def _read_next_batch(self) -> pa.RecordBatch | None:
160
+ """Internal method to read and process the next batch.
161
+
162
+ This method handles the complex logic of:
163
+ - Using default batch size when none is specified
164
+ - Accumulating data from multiple native batches to reach the target size
165
+ - Splitting oversized batches and saving the remainder
166
+ - Converting between Table and RecordBatch formats as needed
167
+
168
+ Returns:
169
+ Next RecordBatch of the configured size, or None if no more data
170
+ """
171
+ # If no batch size specified, just return the reader's native batches
172
+ if self._batch_size <= 0:
173
+ try:
174
+ batch = self._reader.read_next_batch()
175
+ # Convert to RecordBatch if the reader returns a Table
176
+ if isinstance(batch, pa.Table) and batch.num_rows > 0:
177
+ return batch.to_batches()[0]
178
+ return batch
179
+ except StopIteration:
180
+ return None
181
+
182
+ # We have a custom batch size - need to accumulate/split batches
183
+ batches: list[pa.RecordBatch] = []
184
+ num_rows: int = 0
185
+
186
+ # First, include any remaining data from the previous read
187
+ if self._remain is not None:
188
+ num_rows = self._remain.num_rows
189
+ batches = [self._remain]
190
+ self._remain = None
191
+
192
+ # Keep reading until we have enough rows or run out of data
193
+ while num_rows < self._batch_size:
194
+ try:
195
+ batch = self._reader.read_next_batch()
196
+
197
+ # Handle the case where reader returns a Table instead of RecordBatch
198
+ if isinstance(batch, pa.Table):
199
+ if batch.num_rows > 0:
200
+ # Convert each batch from the Table
201
+ for rb in batch.to_batches():
202
+ num_rows += rb.num_rows
203
+ if rb.num_rows > 0: # Skip empty batches
204
+ batches.append(rb)
205
+ else:
206
+ # Already a RecordBatch
207
+ num_rows += batch.num_rows
208
+ if batch.num_rows > 0: # Skip empty batches
209
+ batches.append(batch)
210
+ except StopIteration:
211
+ # Mark EOF but continue processing what we have
212
+ self._eof = True
213
+ break
214
+
215
+ # If we didn't get any data, return None
216
+ if num_rows == 0:
217
+ return None
218
+
219
+ # Split the last batch if we have more rows than needed
220
+ if num_rows > self._batch_size:
221
+ last = batches[-1]
222
+ remain_size = num_rows - self._batch_size
223
+ last_size = last.num_rows - remain_size
224
+
225
+ # Keep only what we need from the last batch
226
+ batches[-1] = last.slice(0, last_size)
227
+ # Save the remainder for the next read
228
+ self._remain = last.slice(last_size, remain_size)
229
+
230
+ # Optimized path: if we only have one batch, return it directly
231
+ if len(batches) == 1:
232
+ return batches[0]
233
+
234
+ # Otherwise, combine all batches and return as a single RecordBatch
235
+ combined = pa.Table.from_batches(batches)
236
+ return combined.to_batches()[0]
237
+
238
+ def close(self) -> None:
239
+ """Close the reader and release all resources.
240
+
241
+ This method should be called when the reader is no longer needed.
242
+ It closes the underlying reader and clears internal state.
243
+ """
244
+ # Close the underlying reader
245
+ self._reader.close()
246
+ # Clear internal state
247
+ self._remain = None
248
+ self._eof = False
249
+
250
+
251
+ DEFAULT_BATCH_SIZE = 1_000_000
252
+
253
+
254
+ class TableSource(ABC):
255
+ """Abstract base class for lazy table operations.
256
+
257
+ Provides deferred execution for table operations to prevent OOM issues.
258
+ """
259
+
260
+ @abstractmethod
261
+ def register(
262
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
263
+ ) -> None: ...
264
+
265
+ @abstractmethod
266
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
267
+ """Read data as a stream of record batches."""
268
+ ...
269
+
270
+
271
+ class ParquetReader(pa.RecordBatchReader):
272
+ """A reader that implements the pa.RecordBatchReader interface for Parquet files."""
273
+
274
+ def __init__(self, source: Any, columns: list[str] | None = None):
275
+ import pyarrow.parquet as pq
276
+
277
+ file = pq.ParquetFile(source)
278
+
279
+ # Use schema_arrow to get the proper pa.Schema
280
+ if columns:
281
+ # Filter the schema to only include selected columns
282
+ fields = [
283
+ file.schema_arrow.field(col)
284
+ for col in columns
285
+ if col in file.schema_arrow.names
286
+ ]
287
+ schema = pa.schema(fields)
288
+ else:
289
+ schema = file.schema_arrow
290
+
291
+ self._file = file
292
+ self._schema = schema
293
+ self._cast = False
294
+ self._num_rows = int(file.metadata.num_rows)
295
+ self._iter = file.iter_batches(columns=columns)
296
+
297
+ @property
298
+ def num_rows(self) -> int:
299
+ return self._num_rows
300
+
301
+ @property
302
+ def schema(self) -> pa.Schema:
303
+ return self._schema
304
+
305
+ def cast(self, target_schema: pa.Schema) -> ParquetReader:
306
+ # Validate that the number of columns is the same
307
+ if len(target_schema) != len(self._schema):
308
+ raise ValueError(
309
+ f"Cannot cast schema: target schema has {len(target_schema)} columns, "
310
+ f"but current schema has {len(self._schema)} columns"
311
+ )
312
+
313
+ # Check if there are any changes in the schema
314
+ schema_changed = False
315
+ for i, (target_field, current_field) in enumerate(
316
+ zip(target_schema, self._schema, strict=True)
317
+ ):
318
+ # Check if field names are the same (allowing type changes)
319
+ if target_field.name != current_field.name:
320
+ raise ValueError(
321
+ f"Cannot cast schema: field name at position {i} differs. "
322
+ f"Current: '{current_field.name}', Target: '{target_field.name}'. "
323
+ f"Field names must match."
324
+ )
325
+ # Check if types are different
326
+ if target_field.type != current_field.type:
327
+ schema_changed = True
328
+
329
+ # Only set _cast if there are actual changes
330
+ if schema_changed:
331
+ self._schema = target_schema
332
+ self._cast = True
333
+
334
+ return self
335
+
336
+ def read_all(self) -> pa.Table:
337
+ batches = []
338
+ try:
339
+ while True:
340
+ batch = self.read_next_batch()
341
+ batches.append(batch)
342
+ except StopIteration:
343
+ pass
344
+ if batches:
345
+ return pa.Table.from_batches(batches)
346
+ return pa.Table.from_batches([])
347
+
348
+ def read_next_batch(self) -> pa.RecordBatch:
349
+ batch = next(self._iter)
350
+ if self._cast:
351
+ return batch.cast(self._schema)
352
+ else:
353
+ return batch
354
+
355
+ def close(self) -> None:
356
+ """Close the Parquet reader and release resources."""
357
+ self._file.close()
358
+
359
+
360
+ _type_mapping = {
361
+ elt.bool_: pa.bool_(),
362
+ elt.i8: pa.int8(),
363
+ elt.i16: pa.int16(),
364
+ elt.i32: pa.int32(),
365
+ elt.i64: pa.int64(),
366
+ elt.u8: pa.uint8(),
367
+ elt.u16: pa.uint16(),
368
+ elt.u32: pa.uint32(),
369
+ elt.u64: pa.uint64(),
370
+ elt.f16: pa.float16(),
371
+ elt.f32: pa.float32(),
372
+ elt.f64: pa.float64(),
373
+ elt.STRING: pa.string(),
374
+ elt.DATE: pa.date64(),
375
+ elt.TIME: pa.time32("ms"),
376
+ elt.TIMESTAMP: pa.timestamp("ms"),
377
+ elt.DECIMAL: pa.decimal128(38, 10),
378
+ elt.BINARY: pa.binary(),
379
+ elt.JSON: pa.json_(),
380
+ }
381
+
382
+
383
+ def _pa_schema(s: elt.TableType) -> pa.Schema:
384
+ fields = []
385
+ for k, v in s.schema.items():
386
+ if v not in _type_mapping:
387
+ raise ValueError(f"cannot convert to pyarrow type. type={v}, name={k}")
388
+ fields.append(pa.field(k, _type_mapping[v]))
389
+
390
+ return pa.schema(fields)
391
+
392
+
393
+ @dataclass
394
+ class FileTableSource(TableSource):
395
+ """Lazy table handle for file-based operations with streaming reads."""
396
+
397
+ path: str
398
+ format: str
399
+ schema: pa.Schema | None = None
400
+
401
+ def register(
402
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
403
+ ) -> None:
404
+ """Register the file as a view in DuckDB."""
405
+ func_name = ""
406
+ match self.format:
407
+ case "parquet":
408
+ func_name = "read_parquet"
409
+ case "csv":
410
+ func_name = "read_csv_auto"
411
+ case "json":
412
+ func_name = "read_json_auto"
413
+ case _:
414
+ raise ValueError(f"Unsupported format: {self.format}")
415
+
416
+ safe_path = self.path.replace("'", "''")
417
+ base_query = f"SELECT * FROM {func_name}('{safe_path}')"
418
+ if replace:
419
+ query = f"CREATE OR REPLACE VIEW {name} AS {base_query}"
420
+ else:
421
+ query = f"CREATE VIEW {name} AS {base_query}"
422
+ conn.execute(query)
423
+
424
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
425
+ """Create a streaming reader for the file."""
426
+ import pyarrow.csv as pa_csv
427
+ import pyarrow.json as pa_json
428
+
429
+ columns = self.schema.names if self.schema else None
430
+
431
+ reader = None
432
+ num_rows = -1
433
+ match self.format.lower():
434
+ case "parquet":
435
+ reader = ParquetReader(self.path, columns)
436
+ num_rows = reader.num_rows
437
+ case "csv":
438
+ read_options = pa_csv.ReadOptions(use_threads=True)
439
+ convert_options = pa_csv.ConvertOptions(
440
+ column_types=self.schema,
441
+ include_columns=columns,
442
+ )
443
+ reader = pa_csv.open_csv(
444
+ self.path,
445
+ read_options=read_options,
446
+ convert_options=convert_options,
447
+ )
448
+ case "json":
449
+ read_options = pa_json.ReadOptions(use_threads=True)
450
+ table = pa_json.read_json(self.path, read_options=read_options)
451
+ if columns:
452
+ table = table.select(columns)
453
+ reader = table.to_reader()
454
+ num_rows = table.num_rows
455
+ case _:
456
+ raise ValueError(f"Unsupported format: {self.format}")
457
+
458
+ if self.schema and self.schema != reader.schema:
459
+ reader = reader.cast(self.schema)
460
+
461
+ return TableReader(reader, num_rows=num_rows, batch_size=batch_size)
462
+
463
+
464
+ class DuckDBState:
465
+ def __init__(self, conn: duckdb.DuckDBPyConnection) -> None:
466
+ self.conn = conn
467
+ self.tables: dict[str, Any] = {}
468
+
469
+
470
+ @dataclass(frozen=True)
471
+ class QueryTableSource(TableSource):
472
+ """Handle for existing DuckDB relations (kept for compatibility)."""
473
+
474
+ relation: duckdb.DuckDBPyRelation
475
+ state: DuckDBState
476
+
477
+ def register(
478
+ self, conn: duckdb.DuckDBPyConnection, name: str, replace: bool = True
479
+ ) -> None:
480
+ self.relation.create_view(name, replace)
481
+
482
+ def open(self, batch_size: int = DEFAULT_BATCH_SIZE) -> TableReader:
483
+ """Read from the DuckDB relation."""
484
+ if batch_size <= 0:
485
+ batch_size = DEFAULT_BATCH_SIZE
486
+ reader = self.relation.arrow(batch_size)
487
+ return TableReader(reader)
488
+
489
+
36
490
  # =============================================================================
37
491
  # TableValue Wrapper
38
492
  # =============================================================================
39
493
 
40
494
 
41
495
  @serde.register_class
42
- class TableValue(WrapValue[pa.Table]):
496
+ class TableValue(WrapValue[pa.Table | TableSource]):
43
497
  """Runtime value wrapping a PyArrow Table.
44
498
 
45
499
  Provides serialization via Arrow IPC format (streaming).
@@ -48,27 +502,43 @@ class TableValue(WrapValue[pa.Table]):
48
502
 
49
503
  _serde_kind: ClassVar[str] = "table_impl.TableValue"
50
504
 
505
+ @property
506
+ def data(self) -> pa.Table:
507
+ """Get the underlying PyArrow Table data.
508
+
509
+ For lazy TableSource, this triggers a full read of the data and caches
510
+ the result in self._data. Subsequent calls will return the cached table.
511
+
512
+ Returns:
513
+ The PyArrow Table containing all data
514
+ """
515
+ if isinstance(self._data, TableSource):
516
+ source = self._data
517
+ with source.open() as reader:
518
+ self._data = reader.read_all()
519
+
520
+ return self._data
521
+
51
522
  # =========== Wrap/Unwrap ===========
52
523
 
53
- def _convert(self, data: Any) -> pa.Table:
54
- """Convert input data to pa.Table."""
524
+ def _convert(self, data: Any) -> pa.Table | TableSource:
525
+ """Convert input data to pa.Table or TableSource."""
55
526
  if isinstance(data, TableValue):
56
527
  return data.unwrap()
57
528
  if isinstance(data, pd.DataFrame):
58
- return pa.Table.from_pandas(data)
59
- if not isinstance(data, pa.Table):
60
- raise TypeError(f"Expected pa.Table or pd.DataFrame, got {type(data)}")
529
+ data = pa.Table.from_pandas(data)
530
+ if not isinstance(data, pa.Table | TableSource):
531
+ raise TypeError(f"Expected pa.Table or TableSource, got {type(data)}")
61
532
  return data
62
533
 
63
- # unwrap() is inherited from WrapValue
64
-
65
534
  # =========== Serialization ===========
66
535
 
67
536
  def to_json(self) -> dict[str, Any]:
68
537
  # Serialize using Arrow IPC streaming format
538
+ data = self.data
69
539
  sink = pa.BufferOutputStream()
70
- with pa.ipc.new_stream(sink, self._data.schema) as writer:
71
- writer.write_table(self._data)
540
+ with pa.ipc.new_stream(sink, data.schema) as writer:
541
+ writer.write_table(data)
72
542
  ipc_bytes = sink.getvalue().to_pybytes()
73
543
  return {"ipc": base64.b64encode(ipc_bytes).decode("ascii")}
74
544
 
@@ -81,7 +551,7 @@ class TableValue(WrapValue[pa.Table]):
81
551
 
82
552
 
83
553
  # Module-level helpers for convenience (delegate to class methods)
84
- def _wrap(val: pa.Table | pd.DataFrame | TableValue) -> TableValue:
554
+ def _wrap(val: pa.Table | pd.DataFrame | TableSource | TableValue) -> TableValue:
85
555
  """Wrap a table-like value into TableValue."""
86
556
  return TableValue.wrap(val)
87
557
 
@@ -89,13 +559,13 @@ def _wrap(val: pa.Table | pd.DataFrame | TableValue) -> TableValue:
89
559
  def _unwrap(val: TableValue | pa.Table | pd.DataFrame) -> pa.Table:
90
560
  """Unwrap TableValue to pa.Table, also accepts raw pa.Table/DataFrame."""
91
561
  if isinstance(val, TableValue):
92
- return val.unwrap()
562
+ return val.data
93
563
  if isinstance(val, pd.DataFrame):
94
564
  return pa.Table.from_pandas(val)
95
565
  if isinstance(val, pa.Table):
96
566
  return val
97
567
  # Handle RecordBatchReader from newer PyArrow versions
98
- if isinstance(val, pa.lib.RecordBatchReader):
568
+ if isinstance(val, pa.RecordBatchReader):
99
569
  return val.read_all()
100
570
  raise TypeError(
101
571
  f"Expected TableValue, pa.Table, pd.DataFrame, or RecordBatchReader, got {type(val)}"
@@ -117,18 +587,40 @@ def run_sql_impl(interpreter: Interpreter, op: Operation, *args: Any) -> TableVa
117
587
  if dialect != "duckdb":
118
588
  raise ValueError(f"Unsupported dialect: {dialect}")
119
589
 
120
- # Use in-memory DuckDB connection
121
- conn = duckdb.connect(":memory:")
590
+ state: DuckDBState | None = None
591
+ tables: list[TableValue] = []
592
+ for arg in args:
593
+ tbl = _wrap(arg)
594
+ tables.append(tbl)
595
+ data = tbl.unwrap()
596
+ if isinstance(data, QueryTableSource):
597
+ if state is None:
598
+ state = data.state
599
+ elif state != data.state:
600
+ raise ValueError("All tables must belong to the same DuckDB connection")
601
+
602
+ if state is None:
603
+ conn = duckdb.connect()
604
+ state = DuckDBState(conn)
122
605
 
123
- for name, arg in zip(table_names, args, strict=True):
124
- conn.register(name, _unwrap(arg))
125
-
126
- # Execute query and fetch result as Arrow table
127
606
  try:
128
- arrow_result = conn.execute(query).arrow()
129
- # In newer DuckDB versions, .arrow() returns RecordBatchReader
130
- res = arrow_result.read_all()
131
- return _wrap(res)
607
+ conn = state.conn
608
+ # register tables or create view
609
+ for name, tbl in zip(table_names, tables, strict=True):
610
+ data = tbl.unwrap()
611
+ if name in state.tables:
612
+ if state.tables[name] is not data:
613
+ # TODO: rename and rewrite sql??
614
+ raise ValueError(f"{name} has been registered.")
615
+ else:
616
+ state.tables[name] = data
617
+ if isinstance(data, TableSource):
618
+ data.register(state.conn, name)
619
+ else:
620
+ conn.register(name, data)
621
+
622
+ relation = conn.sql(query)
623
+ return _wrap(QueryTableSource(relation, state))
132
624
  except Exception as e:
133
625
  raise RuntimeError(f"Failed to execute SQL query: {query}") from e
134
626
 
@@ -204,8 +696,6 @@ def _infer_format(path: str, format_hint: str) -> str:
204
696
  return "csv"
205
697
  elif path_lower.endswith((".json", ".jsonl")):
206
698
  return "json"
207
- elif path_lower.endswith((".feather", ".arrow")):
208
- return "feather"
209
699
  else:
210
700
  # Default to parquet
211
701
  return "parquet"
@@ -215,58 +705,134 @@ def _infer_format(path: str, format_hint: str) -> str:
215
705
  def read_impl(interpreter: Interpreter, op: Operation) -> TableValue:
216
706
  """Read table from file.
217
707
 
218
- Supported formats: parquet, csv, json, feather
708
+ Supported formats: parquet, csv, json
219
709
  """
220
- import pyarrow.csv as pv_csv
221
- import pyarrow.json as pv_json
222
- import pyarrow.parquet as pq
710
+ import os
223
711
 
224
712
  path: str = op.attrs["path"]
713
+ schema: elt.TableType = op.attrs["schema"]
225
714
  format_hint: str = op.attrs.get("format", "auto")
226
-
227
715
  fmt = _infer_format(path, format_hint)
228
-
229
- if fmt == "parquet":
230
- return _wrap(pq.read_table(path))
231
- elif fmt == "csv":
232
- return _wrap(pv_csv.read_csv(path))
233
- elif fmt == "json":
234
- return _wrap(pv_json.read_json(path))
235
- elif fmt == "feather":
236
- import pyarrow.feather as feather
237
-
238
- return _wrap(feather.read_table(path))
239
- else:
240
- raise ValueError(f"Unsupported format: {fmt}")
716
+ if not os.path.exists(path):
717
+ raise FileNotFoundError(f"{path} not exists")
718
+
719
+ pa_schema = _pa_schema(schema) if schema else None
720
+ return _wrap(FileTableSource(path=path, format=fmt, schema=pa_schema))
721
+
722
+
723
+ class MultiTableReader(BatchReader):
724
+ def __init__(self, readers: list[TableReader]) -> None:
725
+ fields = {}
726
+ for r in readers:
727
+ for f in r.schema:
728
+ if f.name in fields:
729
+ raise ValueError(f"Field name conflict. {f.name}")
730
+ fields[f.name] = f
731
+
732
+ self._readers = readers
733
+ self._schema = pa.schema(list(fields.values()))
734
+
735
+ @property
736
+ def schema(self) -> pa.Schema:
737
+ return self._schema
738
+
739
+ def read_next_batch(self) -> pa.RecordBatch:
740
+ num_rows = -1
741
+ columns: list[pa.ChunkedArray] = []
742
+ for idx, r in enumerate(self._readers):
743
+ batch = r.read_next_batch()
744
+ if num_rows == -1:
745
+ num_rows = batch.num_rows
746
+ elif num_rows != batch.num_rows:
747
+ raise ValueError(
748
+ f"Batch {idx} has {batch.num_rows} rows, expected {num_rows}"
749
+ )
750
+ columns.extend(batch.columns)
751
+ return pa.RecordBatch.from_arrays(columns, names=self._schema.names)
752
+
753
+ def close(self) -> None:
754
+ for r in self._readers:
755
+ r.close()
241
756
 
242
757
 
243
758
  @table.write_p.def_impl
244
- def write_impl(interpreter: Interpreter, op: Operation, table_val: Any) -> None:
759
+ def write_impl(interpreter: Interpreter, op: Operation, *tables: TableValue) -> None:
245
760
  """Write table to file.
246
761
 
247
- Supported formats: parquet, csv, json, feather
762
+ Supported formats: parquet, csv, json
763
+
764
+ For LazyTable, performs streaming writes when supported.
765
+ For regular Tables, performs direct writes.
248
766
  """
249
- import pyarrow.csv as pv_csv
250
- import pyarrow.parquet as pq
767
+ import os
251
768
 
252
769
  path: str = op.attrs["path"]
253
770
  format_hint: str = op.attrs.get("format", "parquet")
254
771
 
255
772
  fmt = _infer_format(path, format_hint)
256
773
 
257
- tbl = _unwrap(table_val)
774
+ batch_size = DEFAULT_BATCH_SIZE if len(tables) > 1 else -1
775
+ readers: list[TableReader] = []
776
+ for t in tables:
777
+ data = t.unwrap()
778
+ readers.append(
779
+ data.open(batch_size)
780
+ if isinstance(data, TableSource)
781
+ else TableReader(data)
782
+ )
258
783
 
259
- if fmt == "parquet":
260
- pq.write_table(tbl, path)
261
- elif fmt == "csv":
262
- pv_csv.write_csv(tbl, path)
263
- elif fmt == "json":
264
- # PyArrow doesn't have direct JSON write, convert to pandas
265
- df = tbl.to_pandas()
266
- df.to_json(path, orient="records", lines=True)
267
- elif fmt == "feather":
268
- import pyarrow.feather as feather
269
-
270
- feather.write_feather(tbl, path)
271
- else:
272
- raise ValueError(f"Unsupported format: {fmt}")
784
+ reader: BatchReader = readers[0] if len(readers) == 1 else MultiTableReader(readers)
785
+
786
+ import pyarrow.csv as pa_csv
787
+ import pyarrow.parquet as pa_pq
788
+
789
+ @runtime_checkable
790
+ class BatchWriter(Protocol):
791
+ def write_batch(self, batch: pa.RecordBatch) -> None: ...
792
+ def close(self) -> None: ...
793
+
794
+ class JsonWriter(BatchWriter):
795
+ def __init__(self, path: str) -> None:
796
+ self._path = path
797
+ self._batches: list[pa.RecordBatch] = []
798
+
799
+ def write_batch(self, batch: pa.RecordBatch) -> None:
800
+ self._batches.append(batch)
801
+
802
+ def close(self) -> None:
803
+ # PyArrow doesn't have direct JSON write, convert to pandas
804
+ tbl = pa.Table.from_batches(self._batches)
805
+ df = tbl.to_pandas()
806
+ df.to_json(self._path, orient="records", lines=True)
807
+
808
+ def _safe_remove_file(path: str) -> None:
809
+ if os.path.exists(path):
810
+ try:
811
+ os.remove(path)
812
+ except Exception:
813
+ pass # Ignore cleanup errors
814
+
815
+ try:
816
+ match fmt:
817
+ case "parquet":
818
+ writer = pa_pq.ParquetWriter(path, reader.schema)
819
+ case "csv":
820
+ writer = pa_csv.CSVWriter(path, reader.schema)
821
+ case "json":
822
+ writer = JsonWriter(path)
823
+ case _:
824
+ raise ValueError(f"Unsupported format: {fmt}")
825
+ except Exception as e:
826
+ reader.close()
827
+ _safe_remove_file(path)
828
+ raise e
829
+
830
+ try:
831
+ for batch in reader:
832
+ writer.write_batch(batch)
833
+ except Exception as e:
834
+ _safe_remove_file(path)
835
+ raise e
836
+ finally:
837
+ reader.close()
838
+ writer.close()
@@ -16,7 +16,7 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Any
19
+ from typing import Any, cast
20
20
 
21
21
  import mplang.v2.edsl as el
22
22
  import mplang.v2.edsl.typing as elt
@@ -305,32 +305,54 @@ def read(
305
305
 
306
306
 
307
307
  @write_p.def_abstract_eval
308
- def _write_ae(table_type: elt.TableType, *, path: str, format: str) -> elt.TableType:
308
+ def _write_ae(in_types: list[elt.BaseType], *, path: str, format: str) -> elt.TableType:
309
309
  """Infer output type for table.write.
310
310
 
311
311
  Args:
312
- table_type: Input table's type
312
+ in_types: Input table's type
313
313
  path: File path to write to
314
314
  format: Output format ("csv", "parquet")
315
315
 
316
316
  Returns:
317
- The input table type (passthrough)
317
+ The input table type
318
318
 
319
319
  Raises:
320
320
  TypeError: If input is not a TableType
321
321
  ValueError: If path is empty or format is invalid
322
322
  """
323
+
324
+ if not in_types:
325
+ raise ValueError(
326
+ f"write requires at least one input table, got {len(in_types)}"
327
+ )
328
+
329
+ # Verify all inputs are TableType
330
+ for i, t in enumerate(in_types):
331
+ if not isinstance(t, elt.TableType):
332
+ raise TypeError(f"Input {i} is not TableType: {type(t)}")
333
+
334
+ table_types = cast(list[elt.TableType], in_types)
335
+ columns = {}
336
+ for table_type in table_types:
337
+ for col_name in table_type.schema:
338
+ if col_name in columns:
339
+ raise ValueError(
340
+ f"Duplicate column name '{col_name}' found across tables. "
341
+ f"When writing multiple tables, column names must be unique."
342
+ )
343
+ columns.update(table_type.schema)
344
+
323
345
  if not isinstance(path, str) or not path:
324
346
  raise ValueError("path must be a non-empty string")
325
- if not isinstance(table_type, elt.TableType):
326
- raise TypeError(f"Expected TableType input, got {type(table_type).__name__}")
327
- if format not in ("csv", "parquet"):
328
- raise ValueError(f"format must be 'csv' or 'parquet', got {format!r}")
329
- return table_type
347
+ if format not in ("auto", "parquet", "csv", "json"):
348
+ raise ValueError(
349
+ f"format must be in ['auto', 'parquet', 'csv', 'json'], got {format!r}"
350
+ )
351
+ return elt.TableType(columns)
330
352
 
331
353
 
332
354
  def write(
333
- table: el.Object | Any,
355
+ tables: el.Object | list[el.Object] | Any,
334
356
  path: str,
335
357
  *,
336
358
  format: str = "parquet",
@@ -359,9 +381,14 @@ def write(
359
381
  >>> table.write(result, "/data/output.parquet")
360
382
  """
361
383
  # Auto-wrap runtime values
362
- if not isinstance(table, el.Object):
363
- table = constant(table)
364
- return write_p.bind(table, path=path, format=format) # type: ignore[no-any-return]
384
+ if not isinstance(tables, list):
385
+ tables = [tables]
386
+
387
+ for idx, tbl in enumerate(tables):
388
+ if not isinstance(tbl, el.Object):
389
+ tables[idx] = constant(tbl)
390
+
391
+ return write_p.bind(*tables, path=path, format=format) # type: ignore[no-any-return]
365
392
 
366
393
 
367
394
  __all__ = [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev263
3
+ Version: 0.1.dev264
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -211,7 +211,7 @@ Requires-Dist: cryptography>=43.0.0
211
211
  Requires-Dist: duckdb>=1.0.0
212
212
  Requires-Dist: fastapi
213
213
  Requires-Dist: flax>=0.12.0
214
- Requires-Dist: httpx>=0.27.0
214
+ Requires-Dist: httpx<1.0.0,>=0.27.0
215
215
  Requires-Dist: jax[cpu]==0.8.0
216
216
  Requires-Dist: lightphe<0.1.0,>=0.0.15
217
217
  Requires-Dist: numpy>=2.0.0
@@ -90,7 +90,7 @@ mplang/v2/backends/simp_design.md,sha256=CXvfxrvV1TmKlFm8IbKTbcHHwLl6AhwlY_cNqMd
90
90
  mplang/v2/backends/spu_impl.py,sha256=nDmpntXMKlFhaOUMXAOO_-RZTzqGLsgxEvwJuVA6h1g,9047
91
91
  mplang/v2/backends/spu_state.py,sha256=wj876IvNPhKyWISN6WwKBYoaDQFFJ8jemdJUVeH5IfA,4144
92
92
  mplang/v2/backends/store_impl.py,sha256=RyhADTNsnnNnwsatAMr7eeewXkVXtfNWA1oFiLXg8H0,2222
93
- mplang/v2/backends/table_impl.py,sha256=7W6Zm3bYrDx-6OTHsJe_SlVjxoDDJXNw5qfBGGBbE4U,8759
93
+ mplang/v2/backends/table_impl.py,sha256=Qmd-Z_PLjSbDngWkHz0wc6VykoGHfS2-rCOk1aWudws,27566
94
94
  mplang/v2/backends/tee_impl.py,sha256=Gp-vqqJPtEMNqP7y68tLhL3a-EW3BQwpo_qCJOSHqKs,7044
95
95
  mplang/v2/backends/tensor_impl.py,sha256=8f9f4-_e-m4JWGZSbXLmSSHcgPykRBc1sAYrA3OIxEg,18906
96
96
  mplang/v2/backends/simp_driver/__init__.py,sha256=ahOPYYvtFVwqxiFxkpSNP8BCTao_MfCXmtt5zsMaJxg,1258
@@ -114,7 +114,7 @@ mplang/v2/dialects/phe.py,sha256=PkehfF2NVBOu05zXITZ87yl-YQa4hwLs7zmUPbk2XhY,228
114
114
  mplang/v2/dialects/simp.py,sha256=ON7iegkHp3um5UX8V4Y5I-fGgFJ3YVwmFsXsleiqqUE,32869
115
115
  mplang/v2/dialects/spu.py,sha256=3JO-D394TKNH2VdFDRp5ohmG0uOcOHEs_ivFHbMZIgA,11385
116
116
  mplang/v2/dialects/store.py,sha256=RqUBzMAgtEMBmdT8axV5lVCv1hp5w0ZZM0Tu4iOZt-c,2114
117
- mplang/v2/dialects/table.py,sha256=ax9Yjvcb8jJ8fqNJodMQ_mrS8tf-xECHQFvUKUWPp70,12714
117
+ mplang/v2/dialects/table.py,sha256=jwNKHhpTRnpZVu_UhXGHKRAV0ekI8nXl5lLHa5PpxTE,13543
118
118
  mplang/v2/dialects/tee.py,sha256=oj_G8ebhtuz9_HarK8rKoaJNJ9ZkRbqcIxhp3m0xsjQ,10129
119
119
  mplang/v2/dialects/tensor.py,sha256=VVIlWtSHpeYFwGuKw7yWxwMQ_a35XJ-2ardeBed2HL8,39900
120
120
  mplang/v2/edsl/README.md,sha256=viflvdRojOa6Xk_UMRPqpuPGXcPGmdlv2-XR6LO7B58,7592
@@ -170,8 +170,8 @@ mplang/v2/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07
170
170
  mplang/v2/runtime/interpreter.py,sha256=UzrM5oepka6H0YKRZncNXhsuwKVm4pliG5J92fFRZMI,32300
171
171
  mplang/v2/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
172
172
  mplang/v2/runtime/value.py,sha256=CMOxElJP78v7pjasPhEpbxWbSgB2KsLbpPmzz0mQX0E,4317
173
- mplang_nightly-0.1.dev263.dist-info/METADATA,sha256=8dpzKpue2cMj5yUJ7WMNf6NnUkUSiPFVAQ1k3Sjhj2g,16768
174
- mplang_nightly-0.1.dev263.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
175
- mplang_nightly-0.1.dev263.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
176
- mplang_nightly-0.1.dev263.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
177
- mplang_nightly-0.1.dev263.dist-info/RECORD,,
173
+ mplang_nightly-0.1.dev264.dist-info/METADATA,sha256=lTGfgTEgk6Ptaf1GEFsmMwixspGR58Tb2SrGlT6vJKM,16775
174
+ mplang_nightly-0.1.dev264.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
175
+ mplang_nightly-0.1.dev264.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
176
+ mplang_nightly-0.1.dev264.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
177
+ mplang_nightly-0.1.dev264.dist-info/RECORD,,