corvic-engine 0.3.0rc61__cp38-abi3-win_amd64.whl → 0.3.0rc63__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,8 +6,9 @@ import dataclasses
6
6
  import datetime
7
7
  import functools
8
8
  import math
9
- from collections.abc import MutableMapping
10
- from contextlib import nullcontext
9
+ from collections.abc import Callable, Mapping, MutableMapping
10
+ from contextlib import AbstractContextManager, ExitStack, nullcontext
11
+ from types import TracebackType
11
12
  from typing import Any, Final, cast
12
13
 
13
14
  import numpy as np
@@ -18,7 +19,7 @@ import pyarrow.parquet as pq
18
19
  import structlog
19
20
  from google.protobuf import json_format, struct_pb2
20
21
  from more_itertools import flatten
21
- from typing_extensions import deprecated
22
+ from typing_extensions import Self, deprecated
22
23
 
23
24
  from corvic import embed, embedding_metric, op_graph, sql
24
25
  from corvic.result import (
@@ -170,12 +171,30 @@ def _as_df(
170
171
  )
171
172
 
172
173
 
174
+ @dataclasses.dataclass(frozen=True)
175
+ class _LazyFrameWithMetrics:
176
+ data: pl.LazyFrame
177
+ metrics: dict[str, Any]
178
+
179
+ def apply(
180
+ self, lf_op: Callable[[pl.LazyFrame], pl.LazyFrame]
181
+ ) -> _LazyFrameWithMetrics:
182
+ return _LazyFrameWithMetrics(lf_op(self.data), self.metrics)
183
+
184
+ def with_data(self, data: pl.LazyFrame):
185
+ return _LazyFrameWithMetrics(data, self.metrics)
186
+
187
+
173
188
  @dataclasses.dataclass(frozen=True)
174
189
  class _SchemaAndBatches:
175
190
  schema: pa.Schema
176
191
  batches: list[pa.RecordBatch]
177
192
  metrics: dict[str, Any]
178
193
 
194
+ @classmethod
195
+ def from_lazy_frame_with_metrics(cls, lfm: _LazyFrameWithMetrics):
196
+ return cls.from_dataframe(lfm.data.collect(), lfm.metrics)
197
+
179
198
  def to_batch_reader(self):
180
199
  return pa.RecordBatchReader.from_batches(
181
200
  schema=self.schema,
@@ -209,16 +228,29 @@ class _SlicedTable:
209
228
 
210
229
 
211
230
  @dataclasses.dataclass
212
- class _InMemoryExecutionContext:
231
+ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContext"]):
213
232
  exec_context: ExecutionContext
214
233
  current_output_context: TableComputeContext | None = None
215
234
 
216
235
  # Using _SchemaAndBatches rather than a RecordBatchReader since the latter's
217
236
  # contract only guarantees one iteration and these might be accessed more than
218
237
  # once
219
- computed_batches_for_op_graph: dict[_SlicedTable, _SchemaAndBatches] = (
238
+ computed_batches_for_op_graph: dict[_SlicedTable, _LazyFrameWithMetrics] = (
220
239
  dataclasses.field(default_factory=dict)
221
240
  )
241
+ exit_stack: ExitStack = dataclasses.field(default_factory=ExitStack)
242
+
243
+ def __enter__(self) -> Self:
244
+ self.exit_stack = self.exit_stack.__enter__()
245
+ return self
246
+
247
+ def __exit__(
248
+ self,
249
+ __exc_type: type[BaseException] | None,
250
+ __exc_value: BaseException | None,
251
+ __traceback: TracebackType | None,
252
+ ) -> bool | None:
253
+ return self.exit_stack.__exit__(__exc_type, __exc_value, __traceback)
222
254
 
223
255
  @classmethod
224
256
  def count_source_op_uses(
@@ -313,13 +345,13 @@ class InMemoryExecutionResult(ExecutionResult):
313
345
  def make(
314
346
  cls,
315
347
  storage_manager: StorageManager,
316
- in_memory_context: _InMemoryExecutionContext,
348
+ computed_tables: Mapping[_SlicedTable, _SchemaAndBatches],
317
349
  context: ExecutionContext,
318
350
  ) -> InMemoryExecutionResult:
319
351
  tables = [
320
352
  InMemoryTableComputeResult(
321
353
  storage_manager,
322
- in_memory_context.computed_batches_for_op_graph[
354
+ computed_tables[
323
355
  _SlicedTable(
324
356
  table_context.table_op_graph,
325
357
  table_context.sql_output_slice_args,
@@ -362,70 +394,69 @@ class InMemoryExecutor(OpGraphExecutor):
362
394
 
363
395
  def _execute_read_from_parquet(
364
396
  self, op: op_graph.op.ReadFromParquet, context: _InMemoryExecutionContext
365
- ) -> Ok[_SchemaAndBatches]:
366
- batches: list[pa.RecordBatch] = []
367
- for blob_name in op.blob_names:
368
- with (
369
- self._storage_manager.blob_from_url(blob_name).open("rb") as stream,
370
- ):
371
- batches.extend(
372
- # reading files with pyarrow, then converting them to polars
373
- # can cause "ShapeError" bugs. That's why we're not reading this
374
- # using pyarrow.
375
- pl.read_parquet(
376
- source=stream,
377
- columns=op.arrow_schema.names,
378
- use_pyarrow=False,
379
- )
380
- .to_arrow()
381
- .to_batches()
397
+ ) -> Ok[_LazyFrameWithMetrics]:
398
+ data = cast(pl.DataFrame, pl.from_arrow(op.arrow_schema.empty_table()))
399
+ data = pl.scan_parquet(
400
+ [
401
+ context.exit_stack.enter_context(
402
+ self._storage_manager.blob_from_url(blob_name).open("rb")
382
403
  )
383
- return Ok(_SchemaAndBatches(op.arrow_schema, batches=batches, metrics={}))
404
+ for blob_name in op.blob_names
405
+ ],
406
+ schema=data.schema,
407
+ )
408
+ return Ok(_LazyFrameWithMetrics(data, metrics={}))
384
409
 
385
410
  def _execute_rollup_by_aggregation(
386
411
  self, op: op_graph.op.RollupByAggregation, context: _InMemoryExecutionContext
387
- ) -> Ok[_SchemaAndBatches]:
412
+ ) -> Ok[_LazyFrameWithMetrics]:
388
413
  raise NotImplementedError(
389
414
  "rollup by aggregation outside of sql not implemented"
390
415
  )
391
416
 
417
+ def _compute_source_then_apply(
418
+ self,
419
+ source: op_graph.Op,
420
+ lf_op: Callable[[pl.LazyFrame], pl.LazyFrame],
421
+ context: _InMemoryExecutionContext,
422
+ ):
423
+ return self._execute(source, context).map(
424
+ lambda source_lfm: source_lfm.apply(lf_op)
425
+ )
426
+
392
427
  def _execute_rename_columns(
393
428
  self, op: op_graph.op.RenameColumns, context: _InMemoryExecutionContext
394
429
  ):
395
- return self._execute(op.source, context).map(
396
- lambda source_batches: _SchemaAndBatches.from_dataframe(
397
- _as_df(source_batches).rename(dict(op.old_name_to_new)),
398
- source_batches.metrics,
399
- )
430
+ return self._compute_source_then_apply(
431
+ op.source, lambda lf: lf.rename(dict(op.old_name_to_new)), context
400
432
  )
401
433
 
402
434
  def _execute_select_columns(
403
435
  self, op: op_graph.op.SelectColumns, context: _InMemoryExecutionContext
404
436
  ):
405
- return self._execute(op.source, context).map(
406
- lambda source_batches: _SchemaAndBatches.from_dataframe(
407
- _as_df(source_batches).select(op.columns), source_batches.metrics
408
- )
437
+ return self._compute_source_then_apply(
438
+ op.source, lambda lf: lf.select(op.columns), context
409
439
  )
410
440
 
411
441
  def _execute_limit_rows(
412
442
  self, op: op_graph.op.LimitRows, context: _InMemoryExecutionContext
413
443
  ):
414
- return self._execute(op.source, context).map(
415
- lambda source_batches: _SchemaAndBatches.from_dataframe(
416
- _as_df(source_batches).limit(op.num_rows),
417
- source_batches.metrics,
418
- )
444
+ return self._compute_source_then_apply(
445
+ op.source, lambda lf: lf.limit(op.num_rows), context
446
+ )
447
+
448
+ def _execute_offset_rows(
449
+ self, op: op_graph.op.OffsetRows, context: _InMemoryExecutionContext
450
+ ):
451
+ return self._compute_source_then_apply(
452
+ op.source, lambda lf: lf.slice(op.num_rows), context
419
453
  )
420
454
 
421
455
  def _execute_order_by(
422
456
  self, op: op_graph.op.OrderBy, context: _InMemoryExecutionContext
423
457
  ):
424
- return self._execute(op.source, context).map(
425
- lambda source_batches: _SchemaAndBatches.from_dataframe(
426
- _as_df(source_batches).sort(op.columns, descending=op.desc),
427
- source_batches.metrics,
428
- )
458
+ return self._compute_source_then_apply(
459
+ op.source, lambda lf: lf.sort(op.columns, descending=op.desc), context
429
460
  )
430
461
 
431
462
  def _row_filter_literal_comparison_to_condition(
@@ -493,34 +524,31 @@ class InMemoryExecutor(OpGraphExecutor):
493
524
  def _execute_filter_rows(
494
525
  self, op: op_graph.op.FilterRows, context: _InMemoryExecutionContext
495
526
  ):
496
- match self._execute(op.source, context):
497
- case Ok(source_batches):
498
- return self._row_filter_to_condition(op.row_filter).map_or_else(
499
- lambda err: InternalError.from_(err),
500
- lambda row_filter: Ok(
501
- _SchemaAndBatches.from_dataframe(
502
- _as_df(source_batches).filter(row_filter),
503
- source_batches.metrics,
504
- )
505
- ),
506
- )
507
- case err:
508
- return err
527
+ match self._row_filter_to_condition(op.row_filter):
528
+ case op_graph.OpParseError() as err:
529
+ return InternalError.from_(err)
530
+ case Ok(row_filter):
531
+ pass
532
+ return self._compute_source_then_apply(
533
+ op.source, lambda lf: lf.filter(row_filter), context
534
+ )
509
535
 
510
536
  def _execute_embedding_metrics( # noqa: C901
511
537
  self, op: op_graph.op.EmbeddingMetrics, context: _InMemoryExecutionContext
512
538
  ):
513
539
  match self._execute(op.table, context):
514
- case Ok(source_batches):
540
+ case Ok(source_lfm):
515
541
  pass
516
542
  case err:
517
543
  return err
518
- embedding_df = _as_df(source_batches)
544
+ embedding_df = source_lfm.data.collect()
519
545
 
520
546
  if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
521
547
  # downstream consumers handle empty metadata by substituting their
522
548
  # own values
523
- return Ok(source_batches)
549
+ return Ok(
550
+ _LazyFrameWithMetrics(embedding_df.lazy(), metrics=source_lfm.metrics)
551
+ )
524
552
 
525
553
  # before it was configurable, this op assumed that the column's name was
526
554
  # this hardcoded name
@@ -531,7 +559,7 @@ class InMemoryExecutor(OpGraphExecutor):
531
559
  case InvalidArgumentError() as err:
532
560
  return InternalError.from_(err)
533
561
 
534
- metrics = source_batches.metrics.copy()
562
+ metrics = source_lfm.metrics.copy()
535
563
  match embedding_metric.ne_sum(embedding, normalize=True):
536
564
  case Ok(metric):
537
565
  metrics["ne_sum"] = metric
@@ -554,17 +582,17 @@ class InMemoryExecutor(OpGraphExecutor):
554
582
  metrics["stable_rank"] = metric
555
583
  case InvalidArgumentError() as err:
556
584
  _logger.warning("could not compute stable_rank", exc_info=str(err))
557
- return Ok(_SchemaAndBatches.from_dataframe(embedding_df, metrics=metrics))
585
+ return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), metrics=metrics))
558
586
 
559
587
  def _execute_embedding_coordinates(
560
588
  self, op: op_graph.op.EmbeddingCoordinates, context: _InMemoryExecutionContext
561
589
  ):
562
590
  match self._execute(op.table, context):
563
- case Ok(source_batches):
591
+ case Ok(source_lfm):
564
592
  pass
565
593
  case err:
566
594
  return err
567
- embedding_df = _as_df(source_batches)
595
+ embedding_df = source_lfm.data.collect()
568
596
 
569
597
  # before it was configurable, this op assumed that the column's name was
570
598
  # this hardcoded name
@@ -573,16 +601,14 @@ class InMemoryExecutor(OpGraphExecutor):
573
601
  # the neighbors of a point includes itself. That does mean, that an n_neighbors
574
602
  # value of less than 3 simply does not work
575
603
  if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
576
- coordinates_df = embedding_df.with_columns(
604
+ coordinates_df = embedding_df.lazy().with_columns(
577
605
  pl.Series(
578
606
  name=embedding_column_name,
579
607
  values=[[0.0] * op.n_components] * len(embedding_df),
580
608
  dtype=pl.List(pl.Float32),
581
609
  )
582
610
  )
583
- return Ok(
584
- _SchemaAndBatches.from_dataframe(coordinates_df, source_batches.metrics)
585
- )
611
+ return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
586
612
 
587
613
  match get_polars_embedding(embedding_df, embedding_column_name):
588
614
  case Ok(embedding):
@@ -598,39 +624,37 @@ class InMemoryExecutor(OpGraphExecutor):
598
624
  case InvalidArgumentError() as err:
599
625
  raise err
600
626
 
601
- coordinates_df = embedding_df.with_columns(
627
+ coordinates_df = embedding_df.lazy().with_columns(
602
628
  pl.Series(
603
629
  name=embedding_column_name,
604
630
  values=coordinates,
605
631
  dtype=pl.List(pl.Float32),
606
632
  )
607
633
  )
608
- return Ok(
609
- _SchemaAndBatches.from_dataframe(coordinates_df, source_batches.metrics)
610
- )
634
+ return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
611
635
 
612
636
  def _execute_distinct_rows(
613
637
  self, op: op_graph.op.DistinctRows, context: _InMemoryExecutionContext
614
638
  ):
615
639
  return self._execute(op.source, context).map(
616
- lambda source_batches: _SchemaAndBatches.from_dataframe(
617
- _as_df(source_batches).unique(), source_batches.metrics
640
+ lambda source_lfm: _LazyFrameWithMetrics(
641
+ source_lfm.data.unique(), source_lfm.metrics
618
642
  )
619
643
  )
620
644
 
621
645
  def _execute_join(self, op: op_graph.op.Join, context: _InMemoryExecutionContext):
622
646
  match self._execute(op.left_source, context):
623
- case Ok(left_batches):
647
+ case Ok(left_lfm):
624
648
  pass
625
649
  case err:
626
650
  return err
627
651
  match self._execute(op.right_source, context):
628
- case Ok(right_batches):
652
+ case Ok(right_lfm):
629
653
  pass
630
654
  case err:
631
655
  return err
632
- left_df = _as_df(left_batches)
633
- right_df = _as_df(right_batches)
656
+ left_lf = left_lfm.data
657
+ right_lf = right_lfm.data
634
658
 
635
659
  match op.how:
636
660
  case table_pb2.JOIN_TYPE_INNER:
@@ -641,32 +665,20 @@ class InMemoryExecutor(OpGraphExecutor):
641
665
  join_type = "inner"
642
666
 
643
667
  # in our join semantics we drop columns from the right source on conflict
644
- right_df = right_df.select(
668
+ right_lf = right_lf.select(
645
669
  [
646
670
  col
647
- for col in right_df.columns
648
- if col in op.right_join_columns or col not in left_df.columns
671
+ for col in right_lf.columns
672
+ if col in op.right_join_columns or col not in left_lf.columns
649
673
  ]
650
674
  )
651
- metrics = right_batches.metrics.copy()
652
- metrics.update(left_batches.metrics)
653
-
654
- # polars doesn't behave so well when one side is empty, just
655
- # compute the trivial empty join when the result is guaranteed
656
- # to be empty instead.
657
- if len(left_df) == 0 or len(right_df) == 0 and join_type == "inner":
658
- return Ok(
659
- _SchemaAndBatches(
660
- schema=op.schema.to_arrow(),
661
- batches=op.schema.to_arrow().empty_table().to_batches(),
662
- metrics=metrics,
663
- )
664
- )
675
+ metrics = right_lfm.metrics.copy()
676
+ metrics.update(left_lfm.metrics)
665
677
 
666
678
  return Ok(
667
- _SchemaAndBatches.from_dataframe(
668
- left_df.join(
669
- right_df,
679
+ _LazyFrameWithMetrics(
680
+ left_lf.join(
681
+ right_lf,
670
682
  left_on=op.left_join_columns,
671
683
  right_on=op.right_join_columns,
672
684
  how=join_type,
@@ -676,62 +688,47 @@ class InMemoryExecutor(OpGraphExecutor):
676
688
  )
677
689
 
678
690
  def _execute_empty(self, op: op_graph.op.Empty, context: _InMemoryExecutionContext):
679
- empty_table = pa.schema([]).empty_table()
680
- return Ok(
681
- _SchemaAndBatches(empty_table.schema, empty_table.to_batches(), metrics={})
682
- )
691
+ empty_table = cast(pl.DataFrame, pl.from_arrow(pa.schema([]).empty_table()))
692
+ return Ok(_LazyFrameWithMetrics(empty_table.lazy(), metrics={}))
683
693
 
684
694
  def _execute_concat(
685
695
  self, op: op_graph.op.Concat, context: _InMemoryExecutionContext
686
696
  ):
687
- source_batches = list[_SchemaAndBatches]()
697
+ source_lfms = list[_LazyFrameWithMetrics]()
688
698
  for table in op.tables:
689
699
  match self._execute(table, context):
690
700
  case Ok(batches):
691
- source_batches.append(batches)
701
+ source_lfms.append(batches)
692
702
  case err:
693
703
  return err
694
- dataframes = [_as_df(batches) for batches in source_batches]
704
+ data = pl.concat([lfm.data for lfm in source_lfms], how=op.how)
695
705
  metrics = dict[str, Any]()
696
- for batches in source_batches:
697
- metrics.update(batches.metrics)
698
- return Ok(
699
- _SchemaAndBatches.from_dataframe(
700
- pl.concat(dataframes, how=op.how), metrics=metrics
701
- )
702
- )
706
+ for lfm in source_lfms:
707
+ metrics.update(lfm.metrics)
708
+ return Ok(_LazyFrameWithMetrics(data, metrics=metrics))
703
709
 
704
710
  def _execute_unnest_struct(
705
711
  self, op: op_graph.op.UnnestStruct, context: _InMemoryExecutionContext
706
712
  ):
707
- return self._execute(op.source, context).map(
708
- lambda source_batches: _SchemaAndBatches.from_dataframe(
709
- _as_df(source_batches).unnest(op.struct_column_name),
710
- source_batches.metrics,
711
- )
713
+ return self._compute_source_then_apply(
714
+ op.source, lambda lf: lf.unnest(op.struct_column_name), context
712
715
  )
713
716
 
714
717
  def _execute_nest_into_struct(
715
718
  self, op: op_graph.op.NestIntoStruct, context: _InMemoryExecutionContext
716
719
  ):
717
- match self._execute(op.source, context):
718
- case Ok(source_batches):
719
- pass
720
- case err:
721
- return err
722
720
  non_struct_columns = [
723
- name
724
- for name in source_batches.schema.names
725
- if name not in op.column_names_to_nest
721
+ field.name
722
+ for field in op.source.schema
723
+ if field.name not in op.column_names_to_nest
726
724
  ]
727
- return Ok(
728
- _SchemaAndBatches.from_dataframe(
729
- _as_df(source_batches).select(
730
- *non_struct_columns,
731
- pl.struct(op.column_names_to_nest).alias(op.struct_column_name),
732
- ),
733
- source_batches.metrics,
734
- )
725
+ return self._compute_source_then_apply(
726
+ op.source,
727
+ lambda lf: lf.select(
728
+ *non_struct_columns,
729
+ pl.struct(op.column_names_to_nest).alias(op.struct_column_name),
730
+ ),
731
+ context,
735
732
  )
736
733
 
737
734
  def _execute_add_literal_column(
@@ -748,57 +745,49 @@ class InMemoryExecutor(OpGraphExecutor):
748
745
  else:
749
746
  column = pl.Series(name, literals).cast(dtype)
750
747
 
751
- def do_work(source_batches: _SchemaAndBatches):
752
- return _SchemaAndBatches.from_dataframe(
753
- _as_df(source_batches).with_columns(column),
754
- source_batches.metrics,
755
- )
756
-
757
- return self._execute(op.source, context).map(do_work)
748
+ return self._compute_source_then_apply(
749
+ op.source,
750
+ lambda lf: lf.with_columns(column),
751
+ context,
752
+ )
758
753
 
759
754
  def _execute_combine_columns(
760
755
  self, op: op_graph.op.CombineColumns, context: _InMemoryExecutionContext
761
756
  ):
762
- match self._execute(op.source, context):
763
- case Ok(source_batches):
764
- pass
765
- case err:
766
- return err
767
- source_df = _as_df(source_batches)
768
757
  match op.reduction:
769
- case op_graph.ConcatString():
758
+ case op_graph.ConcatString() as reduction:
770
759
  # if we do not ignore nulls then all concatenated rows that
771
760
  # have a single column that contain a null value will be output
772
761
  # as null.
773
- result_df = source_df.with_columns(
774
- pl.concat_str(
775
- [pl.col(col) for col in op.column_names],
776
- separator=op.reduction.separator,
777
- ignore_nulls=True,
778
- ).alias(op.combined_column_name)
779
- )
762
+ concat_expr = pl.concat_str(
763
+ [pl.col(col) for col in op.column_names],
764
+ separator=reduction.separator,
765
+ ignore_nulls=True,
766
+ ).alias(op.combined_column_name)
780
767
 
781
768
  case op_graph.ConcatList():
782
769
  if op.column_names:
783
- result_df = source_df.with_columns(
784
- pl.concat_list(*op.column_names).alias(op.combined_column_name)
770
+ concat_expr = pl.concat_list(*op.column_names).alias(
771
+ op.combined_column_name
785
772
  )
786
773
  else:
787
- result_df = source_df.with_columns(
788
- pl.Series(op.combined_column_name, [])
789
- )
774
+ concat_expr = pl.Series(op.combined_column_name, [])
790
775
 
791
- return Ok(_SchemaAndBatches.from_dataframe(result_df, source_batches.metrics))
776
+ return self._compute_source_then_apply(
777
+ op.source,
778
+ lambda lf: lf.with_columns(concat_expr),
779
+ context,
780
+ )
792
781
 
793
782
  def _execute_embed_column(
794
783
  self, op: op_graph.op.EmbedColumn, context: _InMemoryExecutionContext
795
784
  ):
796
785
  match self._execute(op.source, context):
797
- case Ok(source_batches):
786
+ case Ok(source_lfm):
798
787
  pass
799
788
  case err:
800
789
  return err
801
- source_df = _as_df(source_batches)
790
+ source_df = source_lfm.data.collect()
802
791
  to_embed = source_df[op.column_name].cast(pl.String())
803
792
 
804
793
  embed_context = EmbedTextContext(
@@ -815,17 +804,14 @@ class InMemoryExecutor(OpGraphExecutor):
815
804
  case InvalidArgumentError() | InternalError() as err:
816
805
  raise InternalError("Failed to embed column") from err
817
806
 
818
- result_df = source_df.with_columns(
819
- result.embeddings.alias(op.embedding_column_name)
820
- ).drop_nulls(op.embedding_column_name)
821
-
822
- return Ok(
823
- _SchemaAndBatches.from_dataframe(
824
- result_df,
825
- source_batches.metrics,
826
- )
807
+ result_df = (
808
+ source_df.lazy()
809
+ .with_columns(result.embeddings.alias(op.embedding_column_name))
810
+ .drop_nulls(op.embedding_column_name)
827
811
  )
828
812
 
813
+ return Ok(source_lfm.with_data(result_df))
814
+
829
815
  @staticmethod
830
816
  def get_cyclic_encoding(
831
817
  series: pl.Series,
@@ -942,12 +928,12 @@ class InMemoryExecutor(OpGraphExecutor):
942
928
  self, op: op_graph.op.EncodeColumns, context: _InMemoryExecutionContext
943
929
  ):
944
930
  match self._execute(op.source, context):
945
- case Ok(source_batches):
931
+ case Ok(source_lfm):
946
932
  pass
947
933
  case err:
948
934
  return err
949
- source_df = _as_df(source_batches)
950
- metrics = source_batches.metrics.copy()
935
+ source_df = source_lfm.data.collect()
936
+ metrics = source_lfm.metrics.copy()
951
937
  metric = metrics.get("one_hot_encoder", {})
952
938
  for encoder_arg in op.encoded_columns:
953
939
  to_encode = source_df[encoder_arg.column_name]
@@ -1074,8 +1060,8 @@ class InMemoryExecutor(OpGraphExecutor):
1074
1060
  )
1075
1061
  metrics["one_hot_encoder"] = metric
1076
1062
  return Ok(
1077
- _SchemaAndBatches.from_dataframe(
1078
- source_df,
1063
+ _LazyFrameWithMetrics(
1064
+ source_df.lazy(),
1079
1065
  metrics,
1080
1066
  )
1081
1067
  )
@@ -1112,43 +1098,40 @@ class InMemoryExecutor(OpGraphExecutor):
1112
1098
 
1113
1099
  metrics = dict[str, Any]()
1114
1100
 
1115
- edge_list_batches = list[_SchemaAndBatches]()
1101
+ edge_list_lfms = list[_LazyFrameWithMetrics]()
1116
1102
  for edge_list in op.edge_list_tables:
1117
1103
  match self._execute(edge_list.table, context):
1118
- case Ok(source_batches):
1119
- edge_list_batches.append(source_batches)
1104
+ case Ok(source_lfm):
1105
+ edge_list_lfms.append(source_lfm)
1120
1106
  case err:
1121
1107
  return err
1122
1108
 
1123
1109
  def edge_generator():
1124
- for edge_list, batches in zip(
1125
- op.edge_list_tables, edge_list_batches, strict=True
1126
- ):
1110
+ for edge_list, lfm in zip(op.edge_list_tables, edge_list_lfms, strict=True):
1127
1111
  start_column_name = edge_list.start_column_name
1128
1112
  end_column_name = edge_list.end_column_name
1129
1113
  start_column_type_name = entities_dtypes[start_column_name]
1130
1114
  end_column_type_name = entities_dtypes[end_column_name]
1131
- metrics.update(batches.metrics)
1132
- for batch in batches.batches:
1133
- yield (
1134
- _as_df(batch)
1135
- .with_columns(
1136
- pl.col(edge_list.start_column_name).alias(
1137
- f"start_id_{start_column_type_name}"
1138
- ),
1139
- pl.lit(edge_list.start_entity_name).alias("start_source"),
1140
- pl.col(edge_list.end_column_name).alias(
1141
- f"end_id_{end_column_type_name}"
1142
- ),
1143
- pl.lit(edge_list.end_entity_name).alias("end_source"),
1144
- )
1145
- .select(
1146
- f"start_id_{start_column_type_name}",
1147
- "start_source",
1148
- f"end_id_{end_column_type_name}",
1149
- "end_source",
1150
- )
1115
+ metrics.update(lfm.metrics)
1116
+ yield (
1117
+ lfm.data.with_columns(
1118
+ pl.col(edge_list.start_column_name).alias(
1119
+ f"start_id_{start_column_type_name}"
1120
+ ),
1121
+ pl.lit(edge_list.start_entity_name).alias("start_source"),
1122
+ pl.col(edge_list.end_column_name).alias(
1123
+ f"end_id_{end_column_type_name}"
1124
+ ),
1125
+ pl.lit(edge_list.end_entity_name).alias("end_source"),
1126
+ )
1127
+ .select(
1128
+ f"start_id_{start_column_type_name}",
1129
+ "start_source",
1130
+ f"end_id_{end_column_type_name}",
1131
+ "end_source",
1151
1132
  )
1133
+ .collect()
1134
+ )
1152
1135
 
1153
1136
  edges = pl.concat(
1154
1137
  [
@@ -1177,18 +1160,17 @@ class InMemoryExecutor(OpGraphExecutor):
1177
1160
  negative=op.negative,
1178
1161
  )
1179
1162
  n2v_runner.train(epochs=op.epochs)
1180
- return Ok(_SchemaAndBatches.from_dataframe(n2v_runner.wv.to_polars(), metrics))
1163
+ return Ok(_LazyFrameWithMetrics(n2v_runner.wv.to_polars().lazy(), metrics))
1181
1164
 
1182
1165
  def _execute_aggregate_columns(
1183
1166
  self, op: op_graph.op.AggregateColumns, context: _InMemoryExecutionContext
1184
1167
  ):
1185
1168
  match self._execute(op.source, context):
1186
- case Ok(source_batches):
1169
+ case Ok(source_lfm):
1187
1170
  pass
1188
1171
  case err:
1189
1172
  return err
1190
- source_df = _as_df(source_batches)
1191
- to_aggregate = source_df[op.column_names]
1173
+ to_aggregate = source_lfm.data.select(op.column_names)
1192
1174
 
1193
1175
  match op.aggregation:
1194
1176
  case op_graph.aggregation.Min():
@@ -1206,106 +1188,92 @@ class InMemoryExecutor(OpGraphExecutor):
1206
1188
  case op_graph.aggregation.NullCount():
1207
1189
  aggregate = to_aggregate.null_count()
1208
1190
 
1209
- return Ok(_SchemaAndBatches.from_dataframe(aggregate, metrics={}))
1191
+ return Ok(source_lfm.with_data(aggregate))
1210
1192
 
1211
1193
  def _execute_correlate_columns(
1212
1194
  self, op: op_graph.op.CorrelateColumns, context: _InMemoryExecutionContext
1213
1195
  ):
1214
1196
  match self._execute(op.source, context):
1215
- case Ok(source_batches):
1197
+ case Ok(source_lfm):
1216
1198
  pass
1217
1199
  case err:
1218
1200
  return err
1219
- source_df = _as_df(source_batches)
1201
+ source_df = source_lfm.data.collect()
1220
1202
  with np.errstate(invalid="ignore"):
1221
- corr_df = source_df[op.column_names].corr(dtype="float32")
1203
+ corr_df = source_df.select(op.column_names).corr(dtype="float32")
1222
1204
 
1223
- return Ok(
1224
- _SchemaAndBatches.from_dataframe(
1225
- corr_df,
1226
- metrics={},
1227
- )
1228
- )
1205
+ return Ok(source_lfm.with_data(corr_df.lazy()))
1229
1206
 
1230
1207
  def _execute_histogram_column(
1231
1208
  self, op: op_graph.op.HistogramColumn, context: _InMemoryExecutionContext
1232
1209
  ):
1233
- return self._execute(op.source, context).map(
1234
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1235
- _as_df(source_batches)[op.column_name]
1236
- .hist(include_category=False)
1237
- .rename(
1238
- {
1239
- "breakpoint": op.breakpoint_column_name,
1240
- "count": op.count_column_name,
1241
- }
1242
- ),
1243
- metrics={},
1244
- )
1210
+ return self._compute_source_then_apply(
1211
+ op.source,
1212
+ lambda lf: lf.collect()[op.column_name]
1213
+ .hist(include_category=False)
1214
+ .lazy()
1215
+ .rename(
1216
+ {
1217
+ "breakpoint": op.breakpoint_column_name,
1218
+ "count": op.count_column_name,
1219
+ }
1220
+ ),
1221
+ context,
1245
1222
  )
1246
1223
 
1247
1224
  def _execute_convert_column_to_string(
1248
1225
  self, op: op_graph.op.ConvertColumnToString, context: _InMemoryExecutionContext
1249
1226
  ):
1250
- match self._execute(op.source, context):
1251
- case Ok(source_batches):
1252
- pass
1253
- case err:
1254
- return err
1255
- source_df = _as_df(source_batches)
1256
- column = source_df[op.column_name]
1257
- if not column.dtype.is_nested():
1258
- source_df = source_df.with_columns(column.cast(pl.String(), strict=False))
1259
- elif isinstance(column.dtype, pl.Array | pl.List):
1260
- source_df = source_df.with_columns(
1261
- column.cast(pl.List(pl.String())).list.join(",")
1262
- )
1227
+ dtype = op.source.schema.to_polars()[op.column_name]
1228
+ if not dtype.is_nested():
1229
+ cast_expr = pl.col(op.column_name).cast(pl.String(), strict=False)
1230
+ elif isinstance(dtype, pl.Array | pl.List):
1231
+ cast_expr = pl.col(op.column_name).cast(pl.List(pl.String())).list.join(",")
1263
1232
  else:
1264
1233
  raise NotImplementedError(
1265
1234
  "converting struct columns to strings is not implemented"
1266
1235
  )
1267
- return Ok(
1268
- _SchemaAndBatches.from_dataframe(source_df, metrics=source_batches.metrics)
1236
+ return self._compute_source_then_apply(
1237
+ op.source, lambda lf: lf.collect().with_columns(cast_expr).lazy(), context
1269
1238
  )
1270
1239
 
1271
1240
  def _execute_add_row_index(
1272
1241
  self, op: op_graph.op.AddRowIndex, context: _InMemoryExecutionContext
1273
1242
  ):
1274
- return self._execute(op.source, context).map(
1275
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1276
- _as_df(source_batches)
1277
- .with_row_index(name=op.row_index_column_name, offset=op.offset)
1278
- .with_columns(pl.col(op.row_index_column_name).cast(pl.UInt64())),
1279
- metrics=source_batches.metrics,
1280
- )
1243
+ return self._compute_source_then_apply(
1244
+ op.source,
1245
+ lambda lf: lf.with_row_index(
1246
+ name=op.row_index_column_name, offset=op.offset
1247
+ ).with_columns(pl.col(op.row_index_column_name).cast(pl.UInt64())),
1248
+ context,
1281
1249
  )
1282
1250
 
1283
1251
  def _execute_output_csv(
1284
1252
  self, op: op_graph.op.OutputCsv, context: _InMemoryExecutionContext
1285
1253
  ):
1286
1254
  match self._execute(op.source, context):
1287
- case Ok(source_batches):
1255
+ case Ok(source_lfm):
1288
1256
  pass
1289
1257
  case err:
1290
1258
  return err
1291
- source_df = _as_df(source_batches)
1259
+ source_df = source_lfm.data.collect()
1292
1260
  source_df.write_csv(
1293
1261
  op.csv_url,
1294
1262
  quote_style="never",
1295
1263
  include_header=op.include_header,
1296
1264
  )
1297
- return Ok(source_batches)
1265
+ return Ok(source_lfm.with_data(source_df.lazy()))
1298
1266
 
1299
1267
  def _execute_truncate_list(
1300
1268
  self, op: op_graph.op.TruncateList, context: _InMemoryExecutionContext
1301
1269
  ):
1302
1270
  # TODO(Patrick): verify this approach works for arrays
1303
1271
  match self._execute(op.source, context):
1304
- case Ok(source_batches):
1272
+ case Ok(source_lfm):
1305
1273
  pass
1306
1274
  case err:
1307
1275
  return err
1308
- source_df = _as_df(source_batches)
1276
+ source_df = source_lfm.data.collect()
1309
1277
  if len(source_df):
1310
1278
  existing_length = get_polars_embedding_length(
1311
1279
  source_df, op.column_name
@@ -1326,6 +1294,7 @@ class InMemoryExecutor(OpGraphExecutor):
1326
1294
  else:
1327
1295
  return InternalError("unexpected type", cause="expected list or array type")
1328
1296
 
1297
+ source_df = source_df.lazy()
1329
1298
  if head_length < op.target_column_length:
1330
1299
  padding_length = op.target_column_length - head_length
1331
1300
  padding = [op.padding_value_as_py] * padding_length
@@ -1337,16 +1306,14 @@ class InMemoryExecutor(OpGraphExecutor):
1337
1306
  .list.to_array(width=op.target_column_length)
1338
1307
  .cast(pl.List(inner_type))
1339
1308
  )
1340
- return Ok(
1341
- _SchemaAndBatches.from_dataframe(source_df, metrics=source_batches.metrics)
1342
- )
1309
+ return Ok(source_lfm.with_data(source_df))
1343
1310
 
1344
1311
  def _execute_union(self, op: op_graph.op.Union, context: _InMemoryExecutionContext):
1345
- sources = list[_SchemaAndBatches]()
1312
+ sources = list[_LazyFrameWithMetrics]()
1346
1313
  for source in op.sources():
1347
1314
  match self._execute(source, context):
1348
- case Ok(source_df):
1349
- sources.append(source_df)
1315
+ case Ok(source_lfm):
1316
+ sources.append(source_lfm)
1350
1317
  case err:
1351
1318
  return err
1352
1319
 
@@ -1354,20 +1321,20 @@ class InMemoryExecutor(OpGraphExecutor):
1354
1321
  for src in sources:
1355
1322
  metrics.update(src.metrics)
1356
1323
 
1357
- result_df = pl.concat((_as_df(src) for src in sources), how="vertical_relaxed")
1324
+ result_lf = pl.concat((src.data for src in sources), how="vertical_relaxed")
1358
1325
  if op.distinct:
1359
- result_df = result_df.unique()
1360
- return Ok(_SchemaAndBatches.from_dataframe(result_df, metrics=metrics))
1326
+ result_lf = result_lf.unique()
1327
+ return Ok(_LazyFrameWithMetrics(result_lf, metrics=metrics))
1361
1328
 
1362
1329
  def _execute_embed_image_column(
1363
1330
  self, op: op_graph.op.EmbedImageColumn, context: _InMemoryExecutionContext
1364
1331
  ):
1365
1332
  match self._execute(op.source, context):
1366
- case Ok(source_batches):
1333
+ case Ok(source_lfm):
1367
1334
  pass
1368
1335
  case err:
1369
1336
  return err
1370
- source_df = _as_df(source_batches)
1337
+ source_df = source_lfm.data.collect()
1371
1338
  to_embed = source_df[op.column_name].cast(pl.Binary())
1372
1339
 
1373
1340
  embed_context = EmbedImageContext(
@@ -1382,14 +1349,12 @@ class InMemoryExecutor(OpGraphExecutor):
1382
1349
  case InvalidArgumentError() | InternalError() as err:
1383
1350
  raise InternalError("Failed to embed column") from err
1384
1351
 
1385
- result_df = source_df.with_columns(
1386
- result.embeddings.alias(op.embedding_column_name)
1387
- ).drop_nulls(op.embedding_column_name)
1388
-
1389
1352
  return Ok(
1390
- _SchemaAndBatches.from_dataframe(
1391
- result_df,
1392
- source_batches.metrics,
1353
+ _LazyFrameWithMetrics(
1354
+ source_df.lazy()
1355
+ .with_columns(result.embeddings.alias(op.embedding_column_name))
1356
+ .drop_nulls(op.embedding_column_name),
1357
+ source_lfm.metrics,
1393
1358
  )
1394
1359
  )
1395
1360
 
@@ -1397,13 +1362,15 @@ class InMemoryExecutor(OpGraphExecutor):
1397
1362
  self, op: op_graph.op.AddDecisionTreeSummary, context: _InMemoryExecutionContext
1398
1363
  ):
1399
1364
  match self._execute(op.source, context):
1400
- case Ok(source_batches):
1365
+ case Ok(source_lfm):
1401
1366
  pass
1402
1367
  case err:
1403
1368
  return err
1404
1369
 
1405
- df_input = _as_df(source_batches)
1406
- dataframe = df_input[list({*op.feature_column_names, op.label_column_name})]
1370
+ df_input = source_lfm.data.collect()
1371
+ dataframe = df_input.select(
1372
+ list({*op.feature_column_names, op.label_column_name})
1373
+ )
1407
1374
  boolean_columns = [
1408
1375
  name
1409
1376
  for name, dtype in dataframe.schema.items()
@@ -1463,36 +1430,33 @@ class InMemoryExecutor(OpGraphExecutor):
1463
1430
  )
1464
1431
  tree_str = tree_str.replace(f"{boolean_column} > 0.50", boolean_column)
1465
1432
 
1466
- metrics = source_batches.metrics.copy()
1433
+ metrics = source_lfm.metrics.copy()
1467
1434
  metrics[op.output_metric_key] = table_pb2.DecisionTreeSummary(
1468
1435
  text=tree_str, graphviz=tree_graphviz
1469
1436
  )
1470
- return Ok(_SchemaAndBatches.from_dataframe(df_input, metrics=metrics))
1437
+ return Ok(_LazyFrameWithMetrics(df_input.lazy(), metrics=metrics))
1471
1438
 
1472
1439
  def _execute_unnest_list(
1473
1440
  self, op: op_graph.op.UnnestList, context: _InMemoryExecutionContext
1474
1441
  ):
1475
- return self._execute(op.source, context).map(
1476
- lambda source_batches: _SchemaAndBatches.from_dataframe(
1477
- _as_df(source_batches)
1478
- .with_columns(
1479
- pl.col(op.list_column_name).list.get(i).alias(column_name)
1480
- for i, column_name in enumerate(op.column_names)
1481
- )
1482
- .drop(op.list_column_name),
1483
- source_batches.metrics,
1484
- )
1442
+ return self._compute_source_then_apply(
1443
+ op.source,
1444
+ lambda lf: lf.with_columns(
1445
+ pl.col(op.list_column_name).list.get(i).alias(column_name)
1446
+ for i, column_name in enumerate(op.column_names)
1447
+ ).drop(op.list_column_name),
1448
+ context,
1485
1449
  )
1486
1450
 
1487
1451
  def _execute_sample_rows(
1488
1452
  self, op: op_graph.op.SampleRows, context: _InMemoryExecutionContext
1489
1453
  ):
1490
1454
  match self._execute(op.source, context):
1491
- case Ok(source_batches):
1455
+ case Ok(source_lfm):
1492
1456
  pass
1493
1457
  case err:
1494
1458
  return err
1495
- source_df = _as_df(source_batches)
1459
+ source_df = source_lfm.data.collect()
1496
1460
  n = min(op.num_rows, source_df.shape[0])
1497
1461
  sample_strategy = op.sample_strategy
1498
1462
  match sample_strategy:
@@ -1503,9 +1467,9 @@ class InMemoryExecutor(OpGraphExecutor):
1503
1467
  )
1504
1468
 
1505
1469
  return Ok(
1506
- _SchemaAndBatches.from_dataframe(
1507
- result_df,
1508
- source_batches.metrics,
1470
+ _LazyFrameWithMetrics(
1471
+ result_df.lazy(),
1472
+ source_lfm.metrics,
1509
1473
  )
1510
1474
  )
1511
1475
 
@@ -1513,15 +1477,16 @@ class InMemoryExecutor(OpGraphExecutor):
1513
1477
  self, op: op_graph.op.DescribeColumns, context: _InMemoryExecutionContext
1514
1478
  ):
1515
1479
  match self._execute(op.source, context):
1516
- case Ok(source_batches):
1480
+ case Ok(source_lfm):
1517
1481
  pass
1518
1482
  case err:
1519
1483
  return err
1520
- source_df = _as_df(source_batches)
1484
+ source_df = source_lfm.data.collect()
1521
1485
  return Ok(
1522
- _SchemaAndBatches.from_dataframe(
1523
- source_df.describe().rename({"statistic": op.statistic_column_name}),
1524
- source_batches.metrics,
1486
+ source_lfm.with_data(
1487
+ source_df.describe()
1488
+ .lazy()
1489
+ .rename({"statistic": op.statistic_column_name})
1525
1490
  )
1526
1491
  )
1527
1492
 
@@ -1542,7 +1507,7 @@ class InMemoryExecutor(OpGraphExecutor):
1542
1507
  op: op_graph.Op,
1543
1508
  context: _InMemoryExecutionContext,
1544
1509
  ) -> (
1545
- Ok[_SchemaAndBatches]
1510
+ Ok[_LazyFrameWithMetrics]
1546
1511
  | InternalError
1547
1512
  | ResourceExhaustedError
1548
1513
  | InvalidArgumentError
@@ -1562,13 +1527,12 @@ class InMemoryExecutor(OpGraphExecutor):
1562
1527
  return InternalError.from_(err)
1563
1528
  case sql.NoRowsError() as err:
1564
1529
  return Ok(
1565
- _SchemaAndBatches.from_dataframe(
1530
+ _LazyFrameWithMetrics(
1566
1531
  cast(
1567
1532
  pl.DataFrame,
1568
1533
  pl.from_arrow(expected_schema.empty_table()),
1569
- ),
1534
+ ).lazy(),
1570
1535
  metrics={},
1571
- expected_schema=expected_schema,
1572
1536
  )
1573
1537
  )
1574
1538
  case Ok(query):
@@ -1576,10 +1540,9 @@ class InMemoryExecutor(OpGraphExecutor):
1576
1540
  return self._staging_db.run_select_query(
1577
1541
  query, expected_schema, context.current_slice_args
1578
1542
  ).map(
1579
- lambda rbr: _SchemaAndBatches.from_dataframe(
1580
- _as_df(rbr, expected_schema),
1543
+ lambda rbr: _LazyFrameWithMetrics(
1544
+ _as_df(rbr, expected_schema).lazy(),
1581
1545
  metrics={},
1582
- expected_schema=expected_schema,
1583
1546
  )
1584
1547
  )
1585
1548
 
@@ -1600,6 +1563,8 @@ class InMemoryExecutor(OpGraphExecutor):
1600
1563
  return self._execute_select_columns(op, context)
1601
1564
  case op_graph.op.LimitRows():
1602
1565
  return self._execute_limit_rows(op, context)
1566
+ case op_graph.op.OffsetRows():
1567
+ return self._execute_offset_rows(op, context)
1603
1568
  case op_graph.op.OrderBy():
1604
1569
  return self._execute_order_by(op, context)
1605
1570
  case op_graph.op.FilterRows():
@@ -1669,7 +1634,7 @@ class InMemoryExecutor(OpGraphExecutor):
1669
1634
  op: op_graph.Op,
1670
1635
  context: _InMemoryExecutionContext,
1671
1636
  ) -> (
1672
- Ok[_SchemaAndBatches]
1637
+ Ok[_LazyFrameWithMetrics]
1673
1638
  | InternalError
1674
1639
  | ResourceExhaustedError
1675
1640
  | InvalidArgumentError
@@ -1701,11 +1666,11 @@ class InMemoryExecutor(OpGraphExecutor):
1701
1666
 
1702
1667
  try:
1703
1668
  _logger.info("starting op execution")
1704
- maybe_batches = self._do_execute(op=op, context=context)
1669
+ maybe_lfm = self._do_execute(op=op, context=context)
1705
1670
  finally:
1706
1671
  _logger.info("op execution complete")
1707
- match maybe_batches:
1708
- case Ok(batches):
1672
+ match maybe_lfm:
1673
+ case Ok(lfm):
1709
1674
  pass
1710
1675
  case err:
1711
1676
  if span:
@@ -1716,8 +1681,12 @@ class InMemoryExecutor(OpGraphExecutor):
1716
1681
  sliced_table in context.output_tables
1717
1682
  or sliced_table in context.reused_tables
1718
1683
  ):
1719
- context.computed_batches_for_op_graph[sliced_table] = batches
1720
- return Ok(batches)
1684
+ # collect the lazy frame since it will be re-used to avoid
1685
+ # re-computation
1686
+ dataframe = lfm.data.collect()
1687
+ lfm = _LazyFrameWithMetrics(dataframe.lazy(), lfm.metrics)
1688
+ context.computed_batches_for_op_graph[sliced_table] = lfm
1689
+ return Ok(lfm)
1721
1690
 
1722
1691
  def execute(
1723
1692
  self, context: ExecutionContext
@@ -1727,22 +1696,26 @@ class InMemoryExecutor(OpGraphExecutor):
1727
1696
  | InternalError
1728
1697
  | ResourceExhaustedError
1729
1698
  ):
1730
- in_memory_context = _InMemoryExecutionContext(context)
1731
-
1732
- for table_context in context.tables_to_compute:
1733
- in_memory_context.current_output_context = table_context
1734
- sliced_table = _SlicedTable(
1735
- table_context.table_op_graph, table_context.sql_output_slice_args
1736
- )
1737
- if sliced_table not in in_memory_context.computed_batches_for_op_graph:
1738
- match self._execute(sliced_table.op_graph, in_memory_context):
1739
- case Ok():
1740
- pass
1741
- case err:
1742
- return err
1699
+ with _InMemoryExecutionContext(context) as in_memory_context:
1700
+ for table_context in context.tables_to_compute:
1701
+ in_memory_context.current_output_context = table_context
1702
+ sliced_table = _SlicedTable(
1703
+ table_context.table_op_graph, table_context.sql_output_slice_args
1704
+ )
1705
+ if sliced_table not in in_memory_context.computed_batches_for_op_graph:
1706
+ match self._execute(sliced_table.op_graph, in_memory_context):
1707
+ case Ok():
1708
+ pass
1709
+ case err:
1710
+ return err
1711
+ args_lfm_iterator = in_memory_context.computed_batches_for_op_graph.items()
1712
+ computed_tables = {
1713
+ slice_args: _SchemaAndBatches.from_lazy_frame_with_metrics(lfm)
1714
+ for slice_args, lfm in args_lfm_iterator
1715
+ }
1743
1716
 
1744
1717
  return Ok(
1745
1718
  InMemoryExecutionResult.make(
1746
- self._storage_manager, in_memory_context, context
1719
+ self._storage_manager, computed_tables, context
1747
1720
  )
1748
1721
  )