corvic-engine 0.3.0rc62__cp38-abi3-win_amd64.whl → 0.3.0rc64__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. buf/validate/validate_pb2.py +415 -698
  2. buf/validate/validate_pb2.pyi +173 -362
  3. buf/validate/validate_pb2_grpc.py +1 -1
  4. buf/validate/validate_pb2_grpc.pyi +6 -10
  5. corvic/engine/_native.pyd +0 -0
  6. corvic/system/_embedder.py +31 -8
  7. corvic/system/_image_embedder.py +33 -12
  8. corvic/system/in_memory_executor.py +292 -331
  9. corvic/system_sqlite/staging.py +17 -9
  10. {corvic_engine-0.3.0rc62.dist-info → corvic_engine-0.3.0rc64.dist-info}/METADATA +1 -1
  11. {corvic_engine-0.3.0rc62.dist-info → corvic_engine-0.3.0rc64.dist-info}/RECORD +27 -35
  12. corvic_generated/feature/v1/experiment_pb2.py +2 -2
  13. corvic_generated/feature/v1/space_pb2.py +2 -2
  14. corvic_generated/feature/v2/feature_view_pb2.py +2 -2
  15. corvic_generated/feature/v2/space_pb2.py +5 -5
  16. corvic_generated/ingest/v2/pipeline_pb2.py +24 -22
  17. corvic_generated/ingest/v2/pipeline_pb2.pyi +4 -2
  18. corvic_generated/ingest/v2/resource_pb2.py +4 -4
  19. corvic_generated/ingest/v2/room_pb2.py +31 -31
  20. corvic_generated/ingest/v2/room_pb2.pyi +4 -2
  21. corvic_generated/ingest/v2/source_pb2.py +4 -4
  22. corvic_generated/ingest/v2/table_pb2.py +3 -3
  23. corvic_generated/orm/v1/agent_pb2.py +2 -2
  24. corvic_generated/orm/v1/agent_pb2.pyi +6 -0
  25. corvic_generated/orm/v1/table_pb2.py +2 -2
  26. buf/validate/expression_pb2.py +0 -37
  27. buf/validate/expression_pb2.pyi +0 -52
  28. buf/validate/expression_pb2_grpc.py +0 -4
  29. buf/validate/expression_pb2_grpc.pyi +0 -34
  30. buf/validate/priv/private_pb2.py +0 -37
  31. buf/validate/priv/private_pb2.pyi +0 -37
  32. buf/validate/priv/private_pb2_grpc.py +0 -4
  33. buf/validate/priv/private_pb2_grpc.pyi +0 -34
  34. {corvic_engine-0.3.0rc62.dist-info → corvic_engine-0.3.0rc64.dist-info}/WHEEL +0 -0
  35. {corvic_engine-0.3.0rc62.dist-info → corvic_engine-0.3.0rc64.dist-info}/licenses/LICENSE +0 -0
@@ -6,8 +6,9 @@ import dataclasses
6
6
  import datetime
7
7
  import functools
8
8
  import math
9
- from collections.abc import MutableMapping
10
- from contextlib import nullcontext
9
+ from collections.abc import Callable, Mapping, MutableMapping
10
+ from contextlib import AbstractContextManager, ExitStack, nullcontext
11
+ from types import TracebackType
11
12
  from typing import Any, Final, cast
12
13
 
13
14
  import numpy as np
@@ -18,7 +19,7 @@ import pyarrow.parquet as pq
18
19
  import structlog
19
20
  from google.protobuf import json_format, struct_pb2
20
21
  from more_itertools import flatten
21
- from typing_extensions import deprecated
22
+ from typing_extensions import Self, deprecated
22
23
 
23
24
  from corvic import embed, embedding_metric, op_graph, sql
24
25
  from corvic.result import (
@@ -170,12 +171,30 @@ def _as_df(
170
171
  )
171
172
 
172
173
 
174
+ @dataclasses.dataclass(frozen=True)
175
+ class _LazyFrameWithMetrics:
176
+ data: pl.LazyFrame
177
+ metrics: dict[str, Any]
178
+
179
+ def apply(
180
+ self, lf_op: Callable[[pl.LazyFrame], pl.LazyFrame]
181
+ ) -> _LazyFrameWithMetrics:
182
+ return _LazyFrameWithMetrics(lf_op(self.data), self.metrics)
183
+
184
+ def with_data(self, data: pl.LazyFrame):
185
+ return _LazyFrameWithMetrics(data, self.metrics)
186
+
187
+
173
188
  @dataclasses.dataclass(frozen=True)
174
189
  class _SchemaAndBatches:
175
190
  schema: pa.Schema
176
191
  batches: list[pa.RecordBatch]
177
192
  metrics: dict[str, Any]
178
193
 
194
+ @classmethod
195
+ def from_lazy_frame_with_metrics(cls, lfm: _LazyFrameWithMetrics):
196
+ return cls.from_dataframe(lfm.data.collect(), lfm.metrics)
197
+
179
198
  def to_batch_reader(self):
180
199
  return pa.RecordBatchReader.from_batches(
181
200
  schema=self.schema,
@@ -209,16 +228,29 @@ class _SlicedTable:
209
228
 
210
229
 
211
230
  @dataclasses.dataclass
212
- class _InMemoryExecutionContext:
231
+ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContext"]):
213
232
  exec_context: ExecutionContext
214
233
  current_output_context: TableComputeContext | None = None
215
234
 
216
235
  # Using _SchemaAndBatches rather than a RecordBatchReader since the latter's
217
236
  # contract only guarantees one iteration and these might be accessed more than
218
237
  # once
219
- computed_batches_for_op_graph: dict[_SlicedTable, _SchemaAndBatches] = (
238
+ computed_batches_for_op_graph: dict[_SlicedTable, _LazyFrameWithMetrics] = (
220
239
  dataclasses.field(default_factory=dict)
221
240
  )
241
+ exit_stack: ExitStack = dataclasses.field(default_factory=ExitStack)
242
+
243
+ def __enter__(self) -> Self:
244
+ self.exit_stack = self.exit_stack.__enter__()
245
+ return self
246
+
247
+ def __exit__(
248
+ self,
249
+ __exc_type: type[BaseException] | None,
250
+ __exc_value: BaseException | None,
251
+ __traceback: TracebackType | None,
252
+ ) -> bool | None:
253
+ return self.exit_stack.__exit__(__exc_type, __exc_value, __traceback)
222
254
 
223
255
  @classmethod
224
256
  def count_source_op_uses(
@@ -313,13 +345,13 @@ class InMemoryExecutionResult(ExecutionResult):
313
345
  def make(
314
346
  cls,
315
347
  storage_manager: StorageManager,
316
- in_memory_context: _InMemoryExecutionContext,
348
+ computed_tables: Mapping[_SlicedTable, _SchemaAndBatches],
317
349
  context: ExecutionContext,
318
350
  ) -> InMemoryExecutionResult:
319
351
  tables = [
320
352
  InMemoryTableComputeResult(
321
353
  storage_manager,
322
- in_memory_context.computed_batches_for_op_graph[
354
+ computed_tables[
323
355
  _SlicedTable(
324
356
  table_context.table_op_graph,
325
357
  table_context.sql_output_slice_args,
@@ -362,80 +394,69 @@ class InMemoryExecutor(OpGraphExecutor):
362
394
 
363
395
  def _execute_read_from_parquet(
364
396
  self, op: op_graph.op.ReadFromParquet, context: _InMemoryExecutionContext
365
- ) -> Ok[_SchemaAndBatches]:
366
- batches: list[pa.RecordBatch] = []
367
- for blob_name in op.blob_names:
368
- with (
369
- self._storage_manager.blob_from_url(blob_name).open("rb") as stream,
370
- ):
371
- batches.extend(
372
- # reading files with pyarrow, then converting them to polars
373
- # can cause "ShapeError" bugs. That's why we're not reading this
374
- # using pyarrow.
375
- pl.read_parquet(
376
- source=stream,
377
- columns=op.arrow_schema.names,
378
- use_pyarrow=False,
379
- )
380
- .to_arrow()
381
- .to_batches()
397
+ ) -> Ok[_LazyFrameWithMetrics]:
398
+ data = cast(pl.DataFrame, pl.from_arrow(op.arrow_schema.empty_table()))
399
+ data = pl.scan_parquet(
400
+ [
401
+ context.exit_stack.enter_context(
402
+ self._storage_manager.blob_from_url(blob_name).open("rb")
382
403
  )
383
- return Ok(_SchemaAndBatches(op.arrow_schema, batches=batches, metrics={}))
404
+ for blob_name in op.blob_names
405
+ ],
406
+ schema=data.schema,
407
+ )
408
+ return Ok(_LazyFrameWithMetrics(data, metrics={}))
384
409
 
385
410
  def _execute_rollup_by_aggregation(
386
411
  self, op: op_graph.op.RollupByAggregation, context: _InMemoryExecutionContext
387
- ) -> Ok[_SchemaAndBatches]:
412
+ ) -> Ok[_LazyFrameWithMetrics]:
388
413
  raise NotImplementedError(
389
414
  "rollup by aggregation outside of sql not implemented"
390
415
  )
391
416
 
417
+ def _compute_source_then_apply(
418
+ self,
419
+ source: op_graph.Op,
420
+ lf_op: Callable[[pl.LazyFrame], pl.LazyFrame],
421
+ context: _InMemoryExecutionContext,
422
+ ):
423
+ return self._execute(source, context).map(
424
+ lambda source_lfm: source_lfm.apply(lf_op)
425
+ )
426
+
392
427
  def _execute_rename_columns(
393
428
  self, op: op_graph.op.RenameColumns, context: _InMemoryExecutionContext
394
429
  ):
395
- return self._execute(op.source, context).map(
396
- lambda source_batches: _SchemaAndBatches.from_dataframe(
397
- _as_df(source_batches).rename(dict(op.old_name_to_new)),
398
- source_batches.metrics,
399
- )
430
+ return self._compute_source_then_apply(
431
+ op.source, lambda lf: lf.rename(dict(op.old_name_to_new)), context
400
432
  )
401
433
 
402
434
  def _execute_select_columns(
403
435
  self, op: op_graph.op.SelectColumns, context: _InMemoryExecutionContext
404
436
  ):
405
- return self._execute(op.source, context).map(
406
- lambda source_batches: _SchemaAndBatches.from_dataframe(
407
- _as_df(source_batches).select(op.columns), source_batches.metrics
408
- )
437
+ return self._compute_source_then_apply(
438
+ op.source, lambda lf: lf.select(op.columns), context
409
439
  )
410
440
 
411
441
  def _execute_limit_rows(
412
442
  self, op: op_graph.op.LimitRows, context: _InMemoryExecutionContext
413
443
  ):
414
- return self._execute(op.source, context).map(
415
- lambda source_batches: _SchemaAndBatches.from_dataframe(
416
- _as_df(source_batches).limit(op.num_rows),
417
- source_batches.metrics,
418
- )
444
+ return self._compute_source_then_apply(
445
+ op.source, lambda lf: lf.limit(op.num_rows), context
419
446
  )
420
447
 
421
448
  def _execute_offset_rows(
422
449
  self, op: op_graph.op.OffsetRows, context: _InMemoryExecutionContext
423
450
  ):
424
- return self._execute(op.source, context).map(
425
- lambda source_batches: _SchemaAndBatches.from_dataframe(
426
- _as_df(source_batches).slice(op.num_rows),
427
- source_batches.metrics,
428
- )
451
+ return self._compute_source_then_apply(
452
+ op.source, lambda lf: lf.slice(op.num_rows), context
429
453
  )
430
454
 
431
455
  def _execute_order_by(
432
456
  self, op: op_graph.op.OrderBy, context: _InMemoryExecutionContext
433
457
  ):
434
- return self._execute(op.source, context).map(
435
- lambda source_batches: _SchemaAndBatches.from_dataframe(
436
- _as_df(source_batches).sort(op.columns, descending=op.desc),
437
- source_batches.metrics,
438
- )
458
+ return self._compute_source_then_apply(
459
+ op.source, lambda lf: lf.sort(op.columns, descending=op.desc), context
439
460
  )
440
461
 
441
462
  def _row_filter_literal_comparison_to_condition(
@@ -503,34 +524,31 @@ class InMemoryExecutor(OpGraphExecutor):
503
524
  def _execute_filter_rows(
504
525
  self, op: op_graph.op.FilterRows, context: _InMemoryExecutionContext
505
526
  ):
506
- match self._execute(op.source, context):
507
- case Ok(source_batches):
508
- return self._row_filter_to_condition(op.row_filter).map_or_else(
509
- lambda err: InternalError.from_(err),
510
- lambda row_filter: Ok(
511
- _SchemaAndBatches.from_dataframe(
512
- _as_df(source_batches).filter(row_filter),
513
- source_batches.metrics,
514
- )
515
- ),
516
- )
517
- case err:
518
- return err
527
+ match self._row_filter_to_condition(op.row_filter):
528
+ case op_graph.OpParseError() as err:
529
+ return InternalError.from_(err)
530
+ case Ok(row_filter):
531
+ pass
532
+ return self._compute_source_then_apply(
533
+ op.source, lambda lf: lf.filter(row_filter), context
534
+ )
519
535
 
520
536
  def _execute_embedding_metrics( # noqa: C901
521
537
  self, op: op_graph.op.EmbeddingMetrics, context: _InMemoryExecutionContext
522
538
  ):
523
539
  match self._execute(op.table, context):
524
- case Ok(source_batches):
540
+ case Ok(source_lfm):
525
541
  pass
526
542
  case err:
527
543
  return err
528
- embedding_df = _as_df(source_batches)
544
+ embedding_df = source_lfm.data.collect()
529
545
 
530
546
  if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
531
547
  # downstream consumers handle empty metadata by substituting their
532
548
  # own values
533
- return Ok(source_batches)
549
+ return Ok(
550
+ _LazyFrameWithMetrics(embedding_df.lazy(), metrics=source_lfm.metrics)
551
+ )
534
552
 
535
553
  # before it was configurable, this op assumed that the column's name was
536
554
  # this hardcoded name
@@ -541,7 +559,7 @@ class InMemoryExecutor(OpGraphExecutor):
541
559
  case InvalidArgumentError() as err:
542
560
  return InternalError.from_(err)
543
561
 
544
- metrics = source_batches.metrics.copy()
562
+ metrics = source_lfm.metrics.copy()
545
563
  match embedding_metric.ne_sum(embedding, normalize=True):
546
564
  case Ok(metric):
547
565
  metrics["ne_sum"] = metric
@@ -564,17 +582,17 @@ class InMemoryExecutor(OpGraphExecutor):
564
582
  metrics["stable_rank"] = metric
565
583
  case InvalidArgumentError() as err:
566
584
  _logger.warning("could not compute stable_rank", exc_info=str(err))
567
- return Ok(_SchemaAndBatches.from_dataframe(embedding_df, metrics=metrics))
585
+ return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), metrics=metrics))
568
586
 
569
587
  def _execute_embedding_coordinates(
570
588
  self, op: op_graph.op.EmbeddingCoordinates, context: _InMemoryExecutionContext
571
589
  ):
572
590
  match self._execute(op.table, context):
573
- case Ok(source_batches):
591
+ case Ok(source_lfm):
574
592
  pass
575
593
  case err:
576
594
  return err
577
- embedding_df = _as_df(source_batches)
595
+ embedding_df = source_lfm.data.collect()
578
596
 
579
597
  # before it was configurable, this op assumed that the column's name was
580
598
  # this hardcoded name
@@ -583,16 +601,14 @@ class InMemoryExecutor(OpGraphExecutor):
583
601
  # the neighbors of a point includes itself. That does mean, that an n_neighbors
584
602
  # value of less than 3 simply does not work
585
603
  if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
586
- coordinates_df = embedding_df.with_columns(
604
+ coordinates_df = embedding_df.lazy().with_columns(
587
605
  pl.Series(
588
606
  name=embedding_column_name,
589
607
  values=[[0.0] * op.n_components] * len(embedding_df),
590
608
  dtype=pl.List(pl.Float32),
591
609
  )
592
610
  )
593
- return Ok(
594
- _SchemaAndBatches.from_dataframe(coordinates_df, source_batches.metrics)
595
- )
611
+ return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
596
612
 
597
613
  match get_polars_embedding(embedding_df, embedding_column_name):
598
614
  case Ok(embedding):
@@ -608,39 +624,37 @@ class InMemoryExecutor(OpGraphExecutor):
608
624
  case InvalidArgumentError() as err:
609
625
  raise err
610
626
 
611
- coordinates_df = embedding_df.with_columns(
627
+ coordinates_df = embedding_df.lazy().with_columns(
612
628
  pl.Series(
613
629
  name=embedding_column_name,
614
630
  values=coordinates,
615
631
  dtype=pl.List(pl.Float32),
616
632
  )
617
633
  )
618
- return Ok(
619
- _SchemaAndBatches.from_dataframe(coordinates_df, source_batches.metrics)
620
- )
634
+ return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
621
635
 
622
636
  def _execute_distinct_rows(
623
637
  self, op: op_graph.op.DistinctRows, context: _InMemoryExecutionContext
624
638
  ):
625
639
  return self._execute(op.source, context).map(
626
- lambda source_batches: _SchemaAndBatches.from_dataframe(
627
- _as_df(source_batches).unique(), source_batches.metrics
640
+ lambda source_lfm: _LazyFrameWithMetrics(
641
+ source_lfm.data.unique(), source_lfm.metrics
628
642
  )
629
643
  )
630
644
 
631
645
  def _execute_join(self, op: op_graph.op.Join, context: _InMemoryExecutionContext):
632
646
  match self._execute(op.left_source, context):
633
- case Ok(left_batches):
647
+ case Ok(left_lfm):
634
648
  pass
635
649
  case err:
636
650
  return err
637
651
  match self._execute(op.right_source, context):
638
- case Ok(right_batches):
652
+ case Ok(right_lfm):
639
653
  pass
640
654
  case err:
641
655
  return err
642
- left_df = _as_df(left_batches)
643
- right_df = _as_df(right_batches)
656
+ left_lf = left_lfm.data
657
+ right_lf = right_lfm.data
644
658
 
645
659
  match op.how:
646
660
  case table_pb2.JOIN_TYPE_INNER:
@@ -651,32 +665,20 @@ class InMemoryExecutor(OpGraphExecutor):
651
665
  join_type = "inner"
652
666
 
653
667
  # in our join semantics we drop columns from the right source on conflict
654
- right_df = right_df.select(
668
+ right_lf = right_lf.select(
655
669
  [
656
670
  col
657
- for col in right_df.columns
658
- if col in op.right_join_columns or col not in left_df.columns
671
+ for col in right_lf.columns
672
+ if col in op.right_join_columns or col not in left_lf.columns
659
673
  ]
660
674
  )
661
- metrics = right_batches.metrics.copy()
662
- metrics.update(left_batches.metrics)
663
-
664
- # polars doesn't behave so well when one side is empty, just
665
- # compute the trivial empty join when the result is guaranteed
666
- # to be empty instead.
667
- if len(left_df) == 0 or len(right_df) == 0 and join_type == "inner":
668
- return Ok(
669
- _SchemaAndBatches(
670
- schema=op.schema.to_arrow(),
671
- batches=op.schema.to_arrow().empty_table().to_batches(),
672
- metrics=metrics,
673
- )
674
- )
675
+ metrics = right_lfm.metrics.copy()
676
+ metrics.update(left_lfm.metrics)
675
677
 
676
678
  return Ok(
677
- _SchemaAndBatches.from_dataframe(
678
- left_df.join(
679
- right_df,
679
+ _LazyFrameWithMetrics(
680
+ left_lf.join(
681
+ right_lf,
680
682
  left_on=op.left_join_columns,
681
683
  right_on=op.right_join_columns,
682
684
  how=join_type,
@@ -686,62 +688,47 @@ class InMemoryExecutor(OpGraphExecutor):
686
688
  )
687
689
 
688
690
  def _execute_empty(self, op: op_graph.op.Empty, context: _InMemoryExecutionContext):
689
- empty_table = pa.schema([]).empty_table()
690
- return Ok(
691
- _SchemaAndBatches(empty_table.schema, empty_table.to_batches(), metrics={})
692
- )
691
+ empty_table = cast(pl.DataFrame, pl.from_arrow(pa.schema([]).empty_table()))
692
+ return Ok(_LazyFrameWithMetrics(empty_table.lazy(), metrics={}))
693
693
 
694
694
  def _execute_concat(
695
695
  self, op: op_graph.op.Concat, context: _InMemoryExecutionContext
696
696
  ):
697
- source_batches = list[_SchemaAndBatches]()
697
+ source_lfms = list[_LazyFrameWithMetrics]()
698
698
  for table in op.tables:
699
699
  match self._execute(table, context):
700
700
  case Ok(batches):
701
- source_batches.append(batches)
701
+ source_lfms.append(batches)
702
702
  case err:
703
703
  return err
704
- dataframes = [_as_df(batches) for batches in source_batches]
704
+ data = pl.concat([lfm.data for lfm in source_lfms], how=op.how)
705
705
  metrics = dict[str, Any]()
706
- for batches in source_batches:
707
- metrics.update(batches.metrics)
708
- return Ok(
709
- _SchemaAndBatches.from_dataframe(
710
- pl.concat(dataframes, how=op.how), metrics=metrics
711
- )
712
- )
706
+ for lfm in source_lfms:
707
+ metrics.update(lfm.metrics)
708
+ return Ok(_LazyFrameWithMetrics(data, metrics=metrics))
713
709
 
714
710
  def _execute_unnest_struct(
715
711
  self, op: op_graph.op.UnnestStruct, context: _InMemoryExecutionContext
716
712
  ):
717
- return self._execute(op.source, context).map(
718
- lambda source_batches: _SchemaAndBatches.from_dataframe(
719
- _as_df(source_batches).unnest(op.struct_column_name),
720
- source_batches.metrics,
721
- )
713
+ return self._compute_source_then_apply(
714
+ op.source, lambda lf: lf.unnest(op.struct_column_name), context
722
715
  )
723
716
 
724
717
  def _execute_nest_into_struct(
725
718
  self, op: op_graph.op.NestIntoStruct, context: _InMemoryExecutionContext
726
719
  ):
727
- match self._execute(op.source, context):
728
- case Ok(source_batches):
729
- pass
730
- case err:
731
- return err
732
720
  non_struct_columns = [
733
- name
734
- for name in source_batches.schema.names
735
- if name not in op.column_names_to_nest
721
+ field.name
722
+ for field in op.source.schema
723
+ if field.name not in op.column_names_to_nest
736
724
  ]
737
- return Ok(
738
- _SchemaAndBatches.from_dataframe(
739
- _as_df(source_batches).select(
740
- *non_struct_columns,
741
- pl.struct(op.column_names_to_nest).alias(op.struct_column_name),
742
- ),
743
- source_batches.metrics,
744
- )
725
+ return self._compute_source_then_apply(
726
+ op.source,
727
+ lambda lf: lf.select(
728
+ *non_struct_columns,
729
+ pl.struct(op.column_names_to_nest).alias(op.struct_column_name),
730
+ ),
731
+ context,
745
732
  )
746
733
 
747
734
  def _execute_add_literal_column(
@@ -758,57 +745,49 @@ class InMemoryExecutor(OpGraphExecutor):
758
745
  else:
759
746
  column = pl.Series(name, literals).cast(dtype)
760
747
 
761
- def do_work(source_batches: _SchemaAndBatches):
762
- return _SchemaAndBatches.from_dataframe(
763
- _as_df(source_batches).with_columns(column),
764
- source_batches.metrics,
765
- )
766
-
767
- return self._execute(op.source, context).map(do_work)
748
+ return self._compute_source_then_apply(
749
+ op.source,
750
+ lambda lf: lf.with_columns(column),
751
+ context,
752
+ )
768
753
 
769
754
  def _execute_combine_columns(
770
755
  self, op: op_graph.op.CombineColumns, context: _InMemoryExecutionContext
771
756
  ):
772
- match self._execute(op.source, context):
773
- case Ok(source_batches):
774
- pass
775
- case err:
776
- return err
777
- source_df = _as_df(source_batches)
778
757
  match op.reduction:
779
- case op_graph.ConcatString():
758
+ case op_graph.ConcatString() as reduction:
780
759
  # if we do not ignore nulls then all concatenated rows that
781
760
  # have a single column that contain a null value will be output
782
761
  # as null.
783
- result_df = source_df.with_columns(
784
- pl.concat_str(
785
- [pl.col(col) for col in op.column_names],
786
- separator=op.reduction.separator,
787
- ignore_nulls=True,
788
- ).alias(op.combined_column_name)
789
- )
762
+ concat_expr = pl.concat_str(
763
+ [pl.col(col) for col in op.column_names],
764
+ separator=reduction.separator,
765
+ ignore_nulls=True,
766
+ ).alias(op.combined_column_name)
790
767
 
791
768
  case op_graph.ConcatList():
792
769
  if op.column_names:
793
- result_df = source_df.with_columns(
794
- pl.concat_list(*op.column_names).alias(op.combined_column_name)
770
+ concat_expr = pl.concat_list(*op.column_names).alias(
771
+ op.combined_column_name
795
772
  )
796
773
  else:
797
- result_df = source_df.with_columns(
798
- pl.Series(op.combined_column_name, [])
799
- )
774
+ concat_expr = pl.Series(op.combined_column_name, [])
800
775
 
801
- return Ok(_SchemaAndBatches.from_dataframe(result_df, source_batches.metrics))
776
+ return self._compute_source_then_apply(
777
+ op.source,
778
+ lambda lf: lf.with_columns(concat_expr),
779
+ context,
780
+ )
802
781
 
803
782
  def _execute_embed_column(
804
783
  self, op: op_graph.op.EmbedColumn, context: _InMemoryExecutionContext
805
784
  ):
806
785
  match self._execute(op.source, context):
807
- case Ok(source_batches):
786
+ case Ok(source_lfm):
808
787
  pass
809
788
  case err:
810
789
  return err
811
- source_df = _as_df(source_batches)
790
+ source_df = source_lfm.data.collect()
812
791
  to_embed = source_df[op.column_name].cast(pl.String())
813
792
 
814
793
  embed_context = EmbedTextContext(
@@ -825,17 +804,14 @@ class InMemoryExecutor(OpGraphExecutor):
825
804
  case InvalidArgumentError() | InternalError() as err:
826
805
  raise InternalError("Failed to embed column") from err
827
806
 
828
- result_df = source_df.with_columns(
829
- result.embeddings.alias(op.embedding_column_name)
830
- ).drop_nulls(op.embedding_column_name)
831
-
832
- return Ok(
833
- _SchemaAndBatches.from_dataframe(
834
- result_df,
835
- source_batches.metrics,
836
- )
807
+ result_df = (
808
+ source_df.lazy()
809
+ .with_columns(result.embeddings.alias(op.embedding_column_name))
810
+ .drop_nulls(op.embedding_column_name)
837
811
  )
838
812
 
813
+ return Ok(source_lfm.with_data(result_df))
814
+
839
815
  @staticmethod
840
816
  def get_cyclic_encoding(
841
817
  series: pl.Series,
@@ -952,12 +928,12 @@ class InMemoryExecutor(OpGraphExecutor):
952
928
  self, op: op_graph.op.EncodeColumns, context: _InMemoryExecutionContext
953
929
  ):
954
930
  match self._execute(op.source, context):
955
- case Ok(source_batches):
931
+ case Ok(source_lfm):
956
932
  pass
957
933
  case err:
958
934
  return err
959
- source_df = _as_df(source_batches)
960
- metrics = source_batches.metrics.copy()
935
+ source_df = source_lfm.data.collect()
936
+ metrics = source_lfm.metrics.copy()
961
937
  metric = metrics.get("one_hot_encoder", {})
962
938
  for encoder_arg in op.encoded_columns:
963
939
  to_encode = source_df[encoder_arg.column_name]
@@ -1084,8 +1060,8 @@ class InMemoryExecutor(OpGraphExecutor):
1084
1060
  )
1085
1061
  metrics["one_hot_encoder"] = metric
1086
1062
  return Ok(
1087
- _SchemaAndBatches.from_dataframe(
1088
- source_df,
1063
+ _LazyFrameWithMetrics(
1064
+ source_df.lazy(),
1089
1065
  metrics,
1090
1066
  )
1091
1067
  )
@@ -1122,43 +1098,40 @@ class InMemoryExecutor(OpGraphExecutor):
1122
1098
 
1123
1099
  metrics = dict[str, Any]()
1124
1100
 
1125
- edge_list_batches = list[_SchemaAndBatches]()
1101
+ edge_list_lfms = list[_LazyFrameWithMetrics]()
1126
1102
  for edge_list in op.edge_list_tables:
1127
1103
  match self._execute(edge_list.table, context):
1128
- case Ok(source_batches):
1129
- edge_list_batches.append(source_batches)
1104
+ case Ok(source_lfm):
1105
+ edge_list_lfms.append(source_lfm)
1130
1106
  case err:
1131
1107
  return err
1132
1108
 
1133
1109
  def edge_generator():
1134
- for edge_list, batches in zip(
1135
- op.edge_list_tables, edge_list_batches, strict=True
1136
- ):
1110
+ for edge_list, lfm in zip(op.edge_list_tables, edge_list_lfms, strict=True):
1137
1111
  start_column_name = edge_list.start_column_name
1138
1112
  end_column_name = edge_list.end_column_name
1139
1113
  start_column_type_name = entities_dtypes[start_column_name]
1140
1114
  end_column_type_name = entities_dtypes[end_column_name]
1141
- metrics.update(batches.metrics)
1142
- for batch in batches.batches:
1143
- yield (
1144
- _as_df(batch)
1145
- .with_columns(
1146
- pl.col(edge_list.start_column_name).alias(
1147
- f"start_id_{start_column_type_name}"
1148
- ),
1149
- pl.lit(edge_list.start_entity_name).alias("start_source"),
1150
- pl.col(edge_list.end_column_name).alias(
1151
- f"end_id_{end_column_type_name}"
1152
- ),
1153
- pl.lit(edge_list.end_entity_name).alias("end_source"),
1154
- )
1155
- .select(
1156
- f"start_id_{start_column_type_name}",
1157
- "start_source",
1158
- f"end_id_{end_column_type_name}",
1159
- "end_source",
1160
- )
1115
+ metrics.update(lfm.metrics)
1116
+ yield (
1117
+ lfm.data.with_columns(
1118
+ pl.col(edge_list.start_column_name).alias(
1119
+ f"start_id_{start_column_type_name}"
1120
+ ),
1121
+ pl.lit(edge_list.start_entity_name).alias("start_source"),
1122
+ pl.col(edge_list.end_column_name).alias(
1123
+ f"end_id_{end_column_type_name}"
1124
+ ),
1125
+ pl.lit(edge_list.end_entity_name).alias("end_source"),
1126
+ )
1127
+ .select(
1128
+ f"start_id_{start_column_type_name}",
1129
+ "start_source",
1130
+ f"end_id_{end_column_type_name}",
1131
+ "end_source",
1161
1132
  )
1133
+ .collect()
1134
+ )
1162
1135
 
1163
1136
  edges = pl.concat(
1164
1137
  [
@@ -1187,18 +1160,17 @@ class InMemoryExecutor(OpGraphExecutor):
1187
1160
  negative=op.negative,
1188
1161
  )
1189
1162
  n2v_runner.train(epochs=op.epochs)
1190
- return Ok(_SchemaAndBatches.from_dataframe(n2v_runner.wv.to_polars(), metrics))
1163
+ return Ok(_LazyFrameWithMetrics(n2v_runner.wv.to_polars().lazy(), metrics))
1191
1164
 
1192
1165
  def _execute_aggregate_columns(
1193
1166
  self, op: op_graph.op.AggregateColumns, context: _InMemoryExecutionContext
1194
1167
  ):
1195
1168
  match self._execute(op.source, context):
1196
- case Ok(source_batches):
1169
+ case Ok(source_lfm):
1197
1170
  pass
1198
1171
  case err:
1199
1172
  return err
1200
- source_df = _as_df(source_batches)
1201
- to_aggregate = source_df[op.column_names]
1173
+ to_aggregate = source_lfm.data.select(op.column_names)
1202
1174
 
1203
1175
  match op.aggregation:
1204
1176
  case op_graph.aggregation.Min():
@@ -1216,106 +1188,92 @@ class InMemoryExecutor(OpGraphExecutor):
1216
1188
  case op_graph.aggregation.NullCount():
1217
1189
  aggregate = to_aggregate.null_count()
1218
1190
 
1219
- return Ok(_SchemaAndBatches.from_dataframe(aggregate, metrics={}))
1191
+ return Ok(source_lfm.with_data(aggregate))
1220
1192
 
1221
1193
  def _execute_correlate_columns(
1222
1194
  self, op: op_graph.op.CorrelateColumns, context: _InMemoryExecutionContext
1223
1195
  ):
1224
1196
  match self._execute(op.source, context):
1225
- case Ok(source_batches):
1197
+ case Ok(source_lfm):
1226
1198
  pass
1227
1199
  case err:
1228
1200
  return err
1229
- source_df = _as_df(source_batches)
1201
+ source_df = source_lfm.data.collect()
1230
1202
  with np.errstate(invalid="ignore"):
1231
- corr_df = source_df[op.column_names].corr(dtype="float32")
1203
+ corr_df = source_df.select(op.column_names).corr(dtype="float32")
1232
1204
 
1233
- return Ok(
1234
- _SchemaAndBatches.from_dataframe(
1235
- corr_df,
1236
- metrics={},
1237
- )
1238
- )
1205
+ return Ok(source_lfm.with_data(corr_df.lazy()))
1239
1206
 
1240
1207
  def _execute_histogram_column(
1241
1208
  self, op: op_graph.op.HistogramColumn, context: _InMemoryExecutionContext
1242
1209
  ):
1243
- return self._execute(op.source, context).map(
1244
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1245
- _as_df(source_batches)[op.column_name]
1246
- .hist(include_category=False)
1247
- .rename(
1248
- {
1249
- "breakpoint": op.breakpoint_column_name,
1250
- "count": op.count_column_name,
1251
- }
1252
- ),
1253
- metrics={},
1254
- )
1210
+ return self._compute_source_then_apply(
1211
+ op.source,
1212
+ lambda lf: lf.collect()[op.column_name]
1213
+ .hist(include_category=False)
1214
+ .lazy()
1215
+ .rename(
1216
+ {
1217
+ "breakpoint": op.breakpoint_column_name,
1218
+ "count": op.count_column_name,
1219
+ }
1220
+ ),
1221
+ context,
1255
1222
  )
1256
1223
 
1257
1224
  def _execute_convert_column_to_string(
1258
1225
  self, op: op_graph.op.ConvertColumnToString, context: _InMemoryExecutionContext
1259
1226
  ):
1260
- match self._execute(op.source, context):
1261
- case Ok(source_batches):
1262
- pass
1263
- case err:
1264
- return err
1265
- source_df = _as_df(source_batches)
1266
- column = source_df[op.column_name]
1267
- if not column.dtype.is_nested():
1268
- source_df = source_df.with_columns(column.cast(pl.String(), strict=False))
1269
- elif isinstance(column.dtype, pl.Array | pl.List):
1270
- source_df = source_df.with_columns(
1271
- column.cast(pl.List(pl.String())).list.join(",")
1272
- )
1227
+ dtype = op.source.schema.to_polars()[op.column_name]
1228
+ if not dtype.is_nested():
1229
+ cast_expr = pl.col(op.column_name).cast(pl.String(), strict=False)
1230
+ elif isinstance(dtype, pl.Array | pl.List):
1231
+ cast_expr = pl.col(op.column_name).cast(pl.List(pl.String())).list.join(",")
1273
1232
  else:
1274
1233
  raise NotImplementedError(
1275
1234
  "converting struct columns to strings is not implemented"
1276
1235
  )
1277
- return Ok(
1278
- _SchemaAndBatches.from_dataframe(source_df, metrics=source_batches.metrics)
1236
+ return self._compute_source_then_apply(
1237
+ op.source, lambda lf: lf.collect().with_columns(cast_expr).lazy(), context
1279
1238
  )
1280
1239
 
1281
1240
  def _execute_add_row_index(
1282
1241
  self, op: op_graph.op.AddRowIndex, context: _InMemoryExecutionContext
1283
1242
  ):
1284
- return self._execute(op.source, context).map(
1285
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1286
- _as_df(source_batches)
1287
- .with_row_index(name=op.row_index_column_name, offset=op.offset)
1288
- .with_columns(pl.col(op.row_index_column_name).cast(pl.UInt64())),
1289
- metrics=source_batches.metrics,
1290
- )
1243
+ return self._compute_source_then_apply(
1244
+ op.source,
1245
+ lambda lf: lf.with_row_index(
1246
+ name=op.row_index_column_name, offset=op.offset
1247
+ ).with_columns(pl.col(op.row_index_column_name).cast(pl.UInt64())),
1248
+ context,
1291
1249
  )
1292
1250
 
1293
1251
  def _execute_output_csv(
1294
1252
  self, op: op_graph.op.OutputCsv, context: _InMemoryExecutionContext
1295
1253
  ):
1296
1254
  match self._execute(op.source, context):
1297
- case Ok(source_batches):
1255
+ case Ok(source_lfm):
1298
1256
  pass
1299
1257
  case err:
1300
1258
  return err
1301
- source_df = _as_df(source_batches)
1259
+ source_df = source_lfm.data.collect()
1302
1260
  source_df.write_csv(
1303
1261
  op.csv_url,
1304
1262
  quote_style="never",
1305
1263
  include_header=op.include_header,
1306
1264
  )
1307
- return Ok(source_batches)
1265
+ return Ok(source_lfm.with_data(source_df.lazy()))
1308
1266
 
1309
1267
  def _execute_truncate_list(
1310
1268
  self, op: op_graph.op.TruncateList, context: _InMemoryExecutionContext
1311
1269
  ):
1312
1270
  # TODO(Patrick): verify this approach works for arrays
1313
1271
  match self._execute(op.source, context):
1314
- case Ok(source_batches):
1272
+ case Ok(source_lfm):
1315
1273
  pass
1316
1274
  case err:
1317
1275
  return err
1318
- source_df = _as_df(source_batches)
1276
+ source_df = source_lfm.data.collect()
1319
1277
  if len(source_df):
1320
1278
  existing_length = get_polars_embedding_length(
1321
1279
  source_df, op.column_name
@@ -1336,6 +1294,7 @@ class InMemoryExecutor(OpGraphExecutor):
1336
1294
  else:
1337
1295
  return InternalError("unexpected type", cause="expected list or array type")
1338
1296
 
1297
+ source_df = source_df.lazy()
1339
1298
  if head_length < op.target_column_length:
1340
1299
  padding_length = op.target_column_length - head_length
1341
1300
  padding = [op.padding_value_as_py] * padding_length
@@ -1347,16 +1306,14 @@ class InMemoryExecutor(OpGraphExecutor):
1347
1306
  .list.to_array(width=op.target_column_length)
1348
1307
  .cast(pl.List(inner_type))
1349
1308
  )
1350
- return Ok(
1351
- _SchemaAndBatches.from_dataframe(source_df, metrics=source_batches.metrics)
1352
- )
1309
+ return Ok(source_lfm.with_data(source_df))
1353
1310
 
1354
1311
  def _execute_union(self, op: op_graph.op.Union, context: _InMemoryExecutionContext):
1355
- sources = list[_SchemaAndBatches]()
1312
+ sources = list[_LazyFrameWithMetrics]()
1356
1313
  for source in op.sources():
1357
1314
  match self._execute(source, context):
1358
- case Ok(source_df):
1359
- sources.append(source_df)
1315
+ case Ok(source_lfm):
1316
+ sources.append(source_lfm)
1360
1317
  case err:
1361
1318
  return err
1362
1319
 
@@ -1364,20 +1321,20 @@ class InMemoryExecutor(OpGraphExecutor):
1364
1321
  for src in sources:
1365
1322
  metrics.update(src.metrics)
1366
1323
 
1367
- result_df = pl.concat((_as_df(src) for src in sources), how="vertical_relaxed")
1324
+ result_lf = pl.concat((src.data for src in sources), how="vertical_relaxed")
1368
1325
  if op.distinct:
1369
- result_df = result_df.unique()
1370
- return Ok(_SchemaAndBatches.from_dataframe(result_df, metrics=metrics))
1326
+ result_lf = result_lf.unique()
1327
+ return Ok(_LazyFrameWithMetrics(result_lf, metrics=metrics))
1371
1328
 
1372
1329
  def _execute_embed_image_column(
1373
1330
  self, op: op_graph.op.EmbedImageColumn, context: _InMemoryExecutionContext
1374
1331
  ):
1375
1332
  match self._execute(op.source, context):
1376
- case Ok(source_batches):
1333
+ case Ok(source_lfm):
1377
1334
  pass
1378
1335
  case err:
1379
1336
  return err
1380
- source_df = _as_df(source_batches)
1337
+ source_df = source_lfm.data.collect()
1381
1338
  to_embed = source_df[op.column_name].cast(pl.Binary())
1382
1339
 
1383
1340
  embed_context = EmbedImageContext(
@@ -1392,14 +1349,12 @@ class InMemoryExecutor(OpGraphExecutor):
1392
1349
  case InvalidArgumentError() | InternalError() as err:
1393
1350
  raise InternalError("Failed to embed column") from err
1394
1351
 
1395
- result_df = source_df.with_columns(
1396
- result.embeddings.alias(op.embedding_column_name)
1397
- ).drop_nulls(op.embedding_column_name)
1398
-
1399
1352
  return Ok(
1400
- _SchemaAndBatches.from_dataframe(
1401
- result_df,
1402
- source_batches.metrics,
1353
+ _LazyFrameWithMetrics(
1354
+ source_df.lazy()
1355
+ .with_columns(result.embeddings.alias(op.embedding_column_name))
1356
+ .drop_nulls(op.embedding_column_name),
1357
+ source_lfm.metrics,
1403
1358
  )
1404
1359
  )
1405
1360
 
@@ -1407,13 +1362,15 @@ class InMemoryExecutor(OpGraphExecutor):
1407
1362
  self, op: op_graph.op.AddDecisionTreeSummary, context: _InMemoryExecutionContext
1408
1363
  ):
1409
1364
  match self._execute(op.source, context):
1410
- case Ok(source_batches):
1365
+ case Ok(source_lfm):
1411
1366
  pass
1412
1367
  case err:
1413
1368
  return err
1414
1369
 
1415
- df_input = _as_df(source_batches)
1416
- dataframe = df_input[list({*op.feature_column_names, op.label_column_name})]
1370
+ df_input = source_lfm.data.collect()
1371
+ dataframe = df_input.select(
1372
+ list({*op.feature_column_names, op.label_column_name})
1373
+ )
1417
1374
  boolean_columns = [
1418
1375
  name
1419
1376
  for name, dtype in dataframe.schema.items()
@@ -1473,36 +1430,33 @@ class InMemoryExecutor(OpGraphExecutor):
1473
1430
  )
1474
1431
  tree_str = tree_str.replace(f"{boolean_column} > 0.50", boolean_column)
1475
1432
 
1476
- metrics = source_batches.metrics.copy()
1433
+ metrics = source_lfm.metrics.copy()
1477
1434
  metrics[op.output_metric_key] = table_pb2.DecisionTreeSummary(
1478
1435
  text=tree_str, graphviz=tree_graphviz
1479
1436
  )
1480
- return Ok(_SchemaAndBatches.from_dataframe(df_input, metrics=metrics))
1437
+ return Ok(_LazyFrameWithMetrics(df_input.lazy(), metrics=metrics))
1481
1438
 
1482
1439
  def _execute_unnest_list(
1483
1440
  self, op: op_graph.op.UnnestList, context: _InMemoryExecutionContext
1484
1441
  ):
1485
- return self._execute(op.source, context).map(
1486
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1487
- _as_df(source_batches)
1488
- .with_columns(
1489
- pl.col(op.list_column_name).list.get(i).alias(column_name)
1490
- for i, column_name in enumerate(op.column_names)
1491
- )
1492
- .drop(op.list_column_name),
1493
- source_batches.metrics,
1494
- )
1442
+ return self._compute_source_then_apply(
1443
+ op.source,
1444
+ lambda lf: lf.with_columns(
1445
+ pl.col(op.list_column_name).list.get(i).alias(column_name)
1446
+ for i, column_name in enumerate(op.column_names)
1447
+ ).drop(op.list_column_name),
1448
+ context,
1495
1449
  )
1496
1450
 
1497
1451
  def _execute_sample_rows(
1498
1452
  self, op: op_graph.op.SampleRows, context: _InMemoryExecutionContext
1499
1453
  ):
1500
1454
  match self._execute(op.source, context):
1501
- case Ok(source_batches):
1455
+ case Ok(source_lfm):
1502
1456
  pass
1503
1457
  case err:
1504
1458
  return err
1505
- source_df = _as_df(source_batches)
1459
+ source_df = source_lfm.data.collect()
1506
1460
  n = min(op.num_rows, source_df.shape[0])
1507
1461
  sample_strategy = op.sample_strategy
1508
1462
  match sample_strategy:
@@ -1513,9 +1467,9 @@ class InMemoryExecutor(OpGraphExecutor):
1513
1467
  )
1514
1468
 
1515
1469
  return Ok(
1516
- _SchemaAndBatches.from_dataframe(
1517
- result_df,
1518
- source_batches.metrics,
1470
+ _LazyFrameWithMetrics(
1471
+ result_df.lazy(),
1472
+ source_lfm.metrics,
1519
1473
  )
1520
1474
  )
1521
1475
 
@@ -1523,15 +1477,16 @@ class InMemoryExecutor(OpGraphExecutor):
1523
1477
  self, op: op_graph.op.DescribeColumns, context: _InMemoryExecutionContext
1524
1478
  ):
1525
1479
  match self._execute(op.source, context):
1526
- case Ok(source_batches):
1480
+ case Ok(source_lfm):
1527
1481
  pass
1528
1482
  case err:
1529
1483
  return err
1530
- source_df = _as_df(source_batches)
1484
+ source_df = source_lfm.data.collect()
1531
1485
  return Ok(
1532
- _SchemaAndBatches.from_dataframe(
1533
- source_df.describe().rename({"statistic": op.statistic_column_name}),
1534
- source_batches.metrics,
1486
+ source_lfm.with_data(
1487
+ source_df.describe()
1488
+ .lazy()
1489
+ .rename({"statistic": op.statistic_column_name})
1535
1490
  )
1536
1491
  )
1537
1492
 
@@ -1552,7 +1507,7 @@ class InMemoryExecutor(OpGraphExecutor):
1552
1507
  op: op_graph.Op,
1553
1508
  context: _InMemoryExecutionContext,
1554
1509
  ) -> (
1555
- Ok[_SchemaAndBatches]
1510
+ Ok[_LazyFrameWithMetrics]
1556
1511
  | InternalError
1557
1512
  | ResourceExhaustedError
1558
1513
  | InvalidArgumentError
@@ -1572,13 +1527,12 @@ class InMemoryExecutor(OpGraphExecutor):
1572
1527
  return InternalError.from_(err)
1573
1528
  case sql.NoRowsError() as err:
1574
1529
  return Ok(
1575
- _SchemaAndBatches.from_dataframe(
1530
+ _LazyFrameWithMetrics(
1576
1531
  cast(
1577
1532
  pl.DataFrame,
1578
1533
  pl.from_arrow(expected_schema.empty_table()),
1579
- ),
1534
+ ).lazy(),
1580
1535
  metrics={},
1581
- expected_schema=expected_schema,
1582
1536
  )
1583
1537
  )
1584
1538
  case Ok(query):
@@ -1586,10 +1540,9 @@ class InMemoryExecutor(OpGraphExecutor):
1586
1540
  return self._staging_db.run_select_query(
1587
1541
  query, expected_schema, context.current_slice_args
1588
1542
  ).map(
1589
- lambda rbr: _SchemaAndBatches.from_dataframe(
1590
- _as_df(rbr, expected_schema),
1543
+ lambda rbr: _LazyFrameWithMetrics(
1544
+ _as_df(rbr, expected_schema).lazy(),
1591
1545
  metrics={},
1592
- expected_schema=expected_schema,
1593
1546
  )
1594
1547
  )
1595
1548
 
@@ -1681,7 +1634,7 @@ class InMemoryExecutor(OpGraphExecutor):
1681
1634
  op: op_graph.Op,
1682
1635
  context: _InMemoryExecutionContext,
1683
1636
  ) -> (
1684
- Ok[_SchemaAndBatches]
1637
+ Ok[_LazyFrameWithMetrics]
1685
1638
  | InternalError
1686
1639
  | ResourceExhaustedError
1687
1640
  | InvalidArgumentError
@@ -1713,11 +1666,11 @@ class InMemoryExecutor(OpGraphExecutor):
1713
1666
 
1714
1667
  try:
1715
1668
  _logger.info("starting op execution")
1716
- maybe_batches = self._do_execute(op=op, context=context)
1669
+ maybe_lfm = self._do_execute(op=op, context=context)
1717
1670
  finally:
1718
1671
  _logger.info("op execution complete")
1719
- match maybe_batches:
1720
- case Ok(batches):
1672
+ match maybe_lfm:
1673
+ case Ok(lfm):
1721
1674
  pass
1722
1675
  case err:
1723
1676
  if span:
@@ -1728,8 +1681,12 @@ class InMemoryExecutor(OpGraphExecutor):
1728
1681
  sliced_table in context.output_tables
1729
1682
  or sliced_table in context.reused_tables
1730
1683
  ):
1731
- context.computed_batches_for_op_graph[sliced_table] = batches
1732
- return Ok(batches)
1684
+ # collect the lazy frame since it will be re-used to avoid
1685
+ # re-computation
1686
+ dataframe = lfm.data.collect()
1687
+ lfm = _LazyFrameWithMetrics(dataframe.lazy(), lfm.metrics)
1688
+ context.computed_batches_for_op_graph[sliced_table] = lfm
1689
+ return Ok(lfm)
1733
1690
 
1734
1691
  def execute(
1735
1692
  self, context: ExecutionContext
@@ -1739,22 +1696,26 @@ class InMemoryExecutor(OpGraphExecutor):
1739
1696
  | InternalError
1740
1697
  | ResourceExhaustedError
1741
1698
  ):
1742
- in_memory_context = _InMemoryExecutionContext(context)
1743
-
1744
- for table_context in context.tables_to_compute:
1745
- in_memory_context.current_output_context = table_context
1746
- sliced_table = _SlicedTable(
1747
- table_context.table_op_graph, table_context.sql_output_slice_args
1748
- )
1749
- if sliced_table not in in_memory_context.computed_batches_for_op_graph:
1750
- match self._execute(sliced_table.op_graph, in_memory_context):
1751
- case Ok():
1752
- pass
1753
- case err:
1754
- return err
1699
+ with _InMemoryExecutionContext(context) as in_memory_context:
1700
+ for table_context in context.tables_to_compute:
1701
+ in_memory_context.current_output_context = table_context
1702
+ sliced_table = _SlicedTable(
1703
+ table_context.table_op_graph, table_context.sql_output_slice_args
1704
+ )
1705
+ if sliced_table not in in_memory_context.computed_batches_for_op_graph:
1706
+ match self._execute(sliced_table.op_graph, in_memory_context):
1707
+ case Ok():
1708
+ pass
1709
+ case err:
1710
+ return err
1711
+ args_lfm_iterator = in_memory_context.computed_batches_for_op_graph.items()
1712
+ computed_tables = {
1713
+ slice_args: _SchemaAndBatches.from_lazy_frame_with_metrics(lfm)
1714
+ for slice_args, lfm in args_lfm_iterator
1715
+ }
1755
1716
 
1756
1717
  return Ok(
1757
1718
  InMemoryExecutionResult.make(
1758
- self._storage_manager, in_memory_context, context
1719
+ self._storage_manager, computed_tables, context
1759
1720
  )
1760
1721
  )