corvic-engine 0.3.0rc67__cp38-abi3-win_amd64.whl → 0.3.0rc68__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,14 +2,15 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
5
6
  import dataclasses
6
7
  import datetime
7
8
  import functools
8
- import math
9
- from collections.abc import Callable, Mapping, MutableMapping
9
+ from collections.abc import Callable, Mapping, MutableMapping, Sequence
10
+ from concurrent.futures import ThreadPoolExecutor
10
11
  from contextlib import AbstractContextManager, ExitStack, nullcontext
11
12
  from types import TracebackType
12
- from typing import Any, Final, cast
13
+ from typing import Any, Final, ParamSpec, Self, TypeVar, cast
13
14
 
14
15
  import numpy as np
15
16
  import polars as pl
@@ -19,8 +20,9 @@ import pyarrow.parquet as pq
19
20
  import structlog
20
21
  from google.protobuf import json_format, struct_pb2
21
22
  from more_itertools import flatten
22
- from typing_extensions import Self, deprecated
23
+ from typing_extensions import deprecated
23
24
 
25
+ import corvic.system._column_encoding as column_encoding
24
26
  from corvic import embed, embedding_metric, op_graph, sql
25
27
  from corvic.result import (
26
28
  InternalError,
@@ -49,48 +51,51 @@ from corvic_generated.orm.v1 import table_pb2
49
51
 
50
52
  _logger = structlog.get_logger()
51
53
 
52
- """Reference and Maximum number of years for normalizing year in Datetime encoder"""
53
- REFERENCE_YEAR: Final = 1900
54
- MAX_NUMBER_OF_YEARS: Final = 200
55
-
56
54
  _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY: Final = 3
57
55
 
58
56
 
57
+ def _collect_and_apply(
58
+ data: pl.LazyFrame, func: Callable[[pl.DataFrame], _R]
59
+ ) -> tuple[pl.DataFrame, _R]:
60
+ data_df = data.collect()
61
+ return data_df, func(data_df)
62
+
63
+
59
64
  def get_polars_embedding_length(
60
- embedding_df: pl.DataFrame, embedding_column_name: str
65
+ embedding_series: pl.Series,
61
66
  ) -> Ok[int] | InvalidArgumentError:
62
- outer_type = embedding_df.schema[embedding_column_name]
67
+ outer_type = embedding_series.dtype
63
68
  if isinstance(outer_type, pl.Array):
64
69
  return Ok(outer_type.shape[0])
65
70
  if not isinstance(outer_type, pl.List):
66
71
  return InvalidArgumentError("invalid embedding datatype", dtype=str(outer_type))
67
- if len(embedding_df[embedding_column_name]) == 0:
72
+ if len(embedding_series) == 0:
68
73
  return InvalidArgumentError(
69
74
  "cannot infer embedding length for empty embedding set"
70
75
  )
71
- embedding_length = len(embedding_df[embedding_column_name][0])
76
+ embedding_length = len(embedding_series[0])
72
77
  if embedding_length < 1:
73
78
  return InvalidArgumentError("invalid embedding length", length=embedding_length)
74
79
  return Ok(embedding_length)
75
80
 
76
81
 
77
82
  def get_polars_embedding(
78
- embedding_df: pl.DataFrame, embedding_column_name: str
83
+ embedding_series: pl.Series,
79
84
  ) -> Ok[np.ndarray[Any, Any]] | InvalidArgumentError:
80
- outer_type = embedding_df.schema[embedding_column_name]
85
+ outer_type = embedding_series.dtype
81
86
  if isinstance(outer_type, pl.Array):
82
- return Ok(embedding_df[embedding_column_name].to_numpy())
87
+ return Ok(embedding_series.to_numpy())
83
88
  if not isinstance(outer_type, pl.List):
84
89
  return InvalidArgumentError("invalid embedding datatype", dtype=str(outer_type))
85
- match get_polars_embedding_length(embedding_df, embedding_column_name):
90
+ match get_polars_embedding_length(embedding_series):
86
91
  case Ok(embedding_length):
87
92
  pass
88
93
  case InvalidArgumentError() as err:
89
94
  return err
90
95
  return Ok(
91
- embedding_df[embedding_column_name]
92
- .cast(pl.Array(inner=outer_type.inner, shape=embedding_length))
93
- .to_numpy()
96
+ embedding_series.cast(
97
+ pl.Array(inner=outer_type.inner, shape=embedding_length)
98
+ ).to_numpy()
94
99
  )
95
100
 
96
101
 
@@ -192,8 +197,15 @@ class _SchemaAndBatches:
192
197
  metrics: dict[str, Any]
193
198
 
194
199
  @classmethod
195
- def from_lazy_frame_with_metrics(cls, lfm: _LazyFrameWithMetrics):
196
- return cls.from_dataframe(lfm.data.collect(), lfm.metrics)
200
+ async def from_lazy_frame_with_metrics(
201
+ cls, lfm: _LazyFrameWithMetrics, worker_threads: ThreadPoolExecutor | None
202
+ ):
203
+ return cls.from_dataframe(
204
+ await asyncio.get_running_loop().run_in_executor(
205
+ worker_threads, lfm.data.collect
206
+ ),
207
+ lfm.metrics,
208
+ )
197
209
 
198
210
  def to_batch_reader(self):
199
211
  return pa.RecordBatchReader.from_batches(
@@ -221,6 +233,10 @@ class _SchemaAndBatches:
221
233
  return cls(schema, table.to_batches(), metrics)
222
234
 
223
235
 
236
+ _P = ParamSpec("_P")
237
+ _R = TypeVar("_R")
238
+
239
+
224
240
  @dataclasses.dataclass(frozen=True)
225
241
  class _SlicedTable:
226
242
  op_graph: op_graph.Op
@@ -230,6 +246,7 @@ class _SlicedTable:
230
246
  @dataclasses.dataclass
231
247
  class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContext"]):
232
248
  exec_context: ExecutionContext
249
+ worker_threads: ThreadPoolExecutor | None
233
250
  current_output_context: TableComputeContext | None = None
234
251
 
235
252
  # Using _SchemaAndBatches rather than a RecordBatchReader since the latter's
@@ -239,6 +256,7 @@ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContex
239
256
  dataclasses.field(default_factory=dict)
240
257
  )
241
258
  exit_stack: ExitStack = dataclasses.field(default_factory=ExitStack)
259
+ lock: asyncio.Lock = dataclasses.field(default_factory=asyncio.Lock)
242
260
 
243
261
  def __enter__(self) -> Self:
244
262
  self.exit_stack = self.exit_stack.__enter__()
@@ -252,6 +270,15 @@ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContex
252
270
  ) -> bool | None:
253
271
  return self.exit_stack.__exit__(__exc_type, __exc_value, __traceback)
254
272
 
273
+ async def run_on_worker(
274
+ self, func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
275
+ ) -> _R:
276
+ # lock here because polars operations aren't guaranteed to be independent
277
+ async with self.lock:
278
+ return await asyncio.get_running_loop().run_in_executor(
279
+ self.worker_threads, lambda: func(*args, **kwargs)
280
+ )
281
+
255
282
  @classmethod
256
283
  def count_source_op_uses(
257
284
  cls,
@@ -407,55 +434,55 @@ class InMemoryExecutor(OpGraphExecutor):
407
434
  )
408
435
  return Ok(_LazyFrameWithMetrics(data, metrics={}))
409
436
 
410
- def _execute_rollup_by_aggregation(
437
+ async def _execute_rollup_by_aggregation(
411
438
  self, op: op_graph.op.RollupByAggregation, context: _InMemoryExecutionContext
412
439
  ) -> Ok[_LazyFrameWithMetrics]:
413
440
  raise NotImplementedError(
414
441
  "rollup by aggregation outside of sql not implemented"
415
442
  )
416
443
 
417
- def _compute_source_then_apply(
444
+ async def _compute_source_then_apply(
418
445
  self,
419
446
  source: op_graph.Op,
420
447
  lf_op: Callable[[pl.LazyFrame], pl.LazyFrame],
421
448
  context: _InMemoryExecutionContext,
422
449
  ):
423
- return self._execute(source, context).map(
450
+ return (await self._execute(source, context)).map(
424
451
  lambda source_lfm: source_lfm.apply(lf_op)
425
452
  )
426
453
 
427
- def _execute_rename_columns(
454
+ async def _execute_rename_columns(
428
455
  self, op: op_graph.op.RenameColumns, context: _InMemoryExecutionContext
429
456
  ):
430
- return self._compute_source_then_apply(
457
+ return await self._compute_source_then_apply(
431
458
  op.source, lambda lf: lf.rename(dict(op.old_name_to_new)), context
432
459
  )
433
460
 
434
- def _execute_select_columns(
461
+ async def _execute_select_columns(
435
462
  self, op: op_graph.op.SelectColumns, context: _InMemoryExecutionContext
436
463
  ):
437
- return self._compute_source_then_apply(
464
+ return await self._compute_source_then_apply(
438
465
  op.source, lambda lf: lf.select(op.columns), context
439
466
  )
440
467
 
441
- def _execute_limit_rows(
468
+ async def _execute_limit_rows(
442
469
  self, op: op_graph.op.LimitRows, context: _InMemoryExecutionContext
443
470
  ):
444
- return self._compute_source_then_apply(
471
+ return await self._compute_source_then_apply(
445
472
  op.source, lambda lf: lf.limit(op.num_rows), context
446
473
  )
447
474
 
448
- def _execute_offset_rows(
475
+ async def _execute_offset_rows(
449
476
  self, op: op_graph.op.OffsetRows, context: _InMemoryExecutionContext
450
477
  ):
451
- return self._compute_source_then_apply(
478
+ return await self._compute_source_then_apply(
452
479
  op.source, lambda lf: lf.slice(op.num_rows), context
453
480
  )
454
481
 
455
- def _execute_order_by(
482
+ async def _execute_order_by(
456
483
  self, op: op_graph.op.OrderBy, context: _InMemoryExecutionContext
457
484
  ):
458
- return self._compute_source_then_apply(
485
+ return await self._compute_source_then_apply(
459
486
  op.source, lambda lf: lf.sort(op.columns, descending=op.desc), context
460
487
  )
461
488
 
@@ -536,7 +563,7 @@ class InMemoryExecutor(OpGraphExecutor):
536
563
  case op_graph.row_filter.CombineFilters():
537
564
  return self._row_filter_combination_to_condition(row_filter)
538
565
 
539
- def _execute_filter_rows(
566
+ async def _execute_filter_rows(
540
567
  self, op: op_graph.op.FilterRows, context: _InMemoryExecutionContext
541
568
  ):
542
569
  match self._row_filter_to_condition(op.row_filter):
@@ -544,78 +571,129 @@ class InMemoryExecutor(OpGraphExecutor):
544
571
  return InternalError.from_(err)
545
572
  case Ok(row_filter):
546
573
  pass
547
- return self._compute_source_then_apply(
574
+ return await self._compute_source_then_apply(
548
575
  op.source, lambda lf: lf.filter(row_filter), context
549
576
  )
550
577
 
551
- def _execute_embedding_metrics( # noqa: C901
578
+ def _get_embedding_column_if(
579
+ self,
580
+ data: pl.LazyFrame,
581
+ embedding_column_name: str,
582
+ pred: Callable[[pl.DataFrame], bool],
583
+ ) -> Ok[tuple[pl.DataFrame, np.ndarray[Any, Any] | None]] | InvalidArgumentError:
584
+ data_df = data.collect()
585
+ if pred(data_df):
586
+ match get_polars_embedding(data_df[embedding_column_name]):
587
+ case InvalidArgumentError() as err:
588
+ return err
589
+ case Ok(embeddings):
590
+ pass
591
+ else:
592
+ embeddings = None
593
+ return Ok((data_df, embeddings))
594
+
595
+ async def _execute_embedding_metrics( # noqa: C901
552
596
  self, op: op_graph.op.EmbeddingMetrics, context: _InMemoryExecutionContext
553
597
  ):
554
- match self._execute(op.table, context):
555
- case Ok(source_lfm):
556
- pass
557
- case err:
558
- return err
559
- embedding_df = source_lfm.data.collect()
560
-
561
- if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
562
- # downstream consumers handle empty metadata by substituting their
563
- # own values
564
- return Ok(
565
- _LazyFrameWithMetrics(embedding_df.lazy(), metrics=source_lfm.metrics)
566
- )
567
-
568
598
  # before it was configurable, this op assumed that the column's name was
569
599
  # this hardcoded name
570
600
  embedding_column_name = op.embedding_column_name or "embedding"
571
- match get_polars_embedding(embedding_df, embedding_column_name):
572
- case Ok(embedding):
601
+
602
+ match await self._execute(op.table, context):
603
+ case Ok(source_lfm):
573
604
  pass
605
+ case err:
606
+ return err
607
+ match await context.run_on_worker(
608
+ self._get_embedding_column_if,
609
+ source_lfm.data,
610
+ embedding_column_name,
611
+ lambda df: len(df) >= _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY,
612
+ ):
574
613
  case InvalidArgumentError() as err:
575
- return InternalError.from_(err)
614
+ return err
615
+ case Ok(result):
616
+ embedding_df, embeddings = result
617
+
618
+ if embeddings is None:
619
+ return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), source_lfm.metrics))
576
620
 
577
621
  metrics = source_lfm.metrics.copy()
578
- match embedding_metric.ne_sum(embedding, normalize=True):
622
+ async with asyncio.TaskGroup() as tg:
623
+ ne_sum = tg.create_task(
624
+ context.run_on_worker(
625
+ embedding_metric.ne_sum, embeddings, normalize=True
626
+ )
627
+ )
628
+ condition_number = tg.create_task(
629
+ context.run_on_worker(
630
+ embedding_metric.condition_number, embeddings, normalize=True
631
+ )
632
+ )
633
+ rcondition_number = tg.create_task(
634
+ context.run_on_worker(
635
+ embedding_metric.rcondition_number, embeddings, normalize=True
636
+ )
637
+ )
638
+ stable_rank = tg.create_task(
639
+ context.run_on_worker(
640
+ embedding_metric.stable_rank, embeddings, normalize=True
641
+ )
642
+ )
643
+
644
+ match ne_sum.result():
579
645
  case Ok(metric):
580
646
  metrics["ne_sum"] = metric
581
647
  case InvalidArgumentError() as err:
582
648
  _logger.warning("could not compute ne_sum", exc_info=str(err))
583
- match embedding_metric.condition_number(embedding, normalize=True):
649
+
650
+ match condition_number.result():
584
651
  case Ok(metric):
585
652
  metrics["condition_number"] = metric
586
653
  case InvalidArgumentError() as err:
587
654
  _logger.warning("could not compute condition_number", exc_info=str(err))
588
- match embedding_metric.rcondition_number(embedding, normalize=True):
655
+
656
+ match rcondition_number.result():
589
657
  case Ok(metric):
590
658
  metrics["rcondition_number"] = metric
591
659
  case InvalidArgumentError() as err:
592
660
  _logger.warning(
593
661
  "could not compute rcondition_number", exc_info=str(err)
594
662
  )
595
- match embedding_metric.stable_rank(embedding, normalize=True):
663
+ match stable_rank.result():
596
664
  case Ok(metric):
597
665
  metrics["stable_rank"] = metric
598
666
  case InvalidArgumentError() as err:
599
667
  _logger.warning("could not compute stable_rank", exc_info=str(err))
668
+
600
669
  return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), metrics=metrics))
601
670
 
602
- def _execute_embedding_coordinates(
671
+ async def _execute_embedding_coordinates(
603
672
  self, op: op_graph.op.EmbeddingCoordinates, context: _InMemoryExecutionContext
604
673
  ):
605
- match self._execute(op.table, context):
674
+ match await self._execute(op.table, context):
606
675
  case Ok(source_lfm):
607
676
  pass
608
677
  case err:
609
678
  return err
610
- embedding_df = source_lfm.data.collect()
611
679
 
612
680
  # before it was configurable, this op assumed that the column's name was
613
681
  # this hardcoded name
614
682
  embedding_column_name = op.embedding_column_name or "embedding"
683
+ match await context.run_on_worker(
684
+ self._get_embedding_column_if,
685
+ source_lfm.data,
686
+ embedding_column_name,
687
+ lambda df: len(df) >= _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY,
688
+ ):
689
+ case InvalidArgumentError() as err:
690
+ return err
691
+ case Ok(result):
692
+ embedding_df, embeddings = result
615
693
 
616
694
  # the neighbors of a point includes itself. That does mean, that an n_neighbors
617
695
  # value of less than 3 simply does not work
618
- if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
696
+ if embeddings is None:
619
697
  coordinates_df = embedding_df.lazy().with_columns(
620
698
  pl.Series(
621
699
  name=embedding_column_name,
@@ -625,14 +703,11 @@ class InMemoryExecutor(OpGraphExecutor):
625
703
  )
626
704
  return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
627
705
 
628
- match get_polars_embedding(embedding_df, embedding_column_name):
629
- case Ok(embedding):
630
- pass
631
- case InvalidArgumentError() as err:
632
- raise err
633
-
634
- match self._dimension_reducer.reduce_dimensions(
635
- embedding, op.n_components, op.metric
706
+ match await context.run_on_worker(
707
+ self._dimension_reducer.reduce_dimensions,
708
+ embeddings,
709
+ op.n_components,
710
+ op.metric,
636
711
  ):
637
712
  case Ok(coordinates):
638
713
  pass
@@ -648,26 +723,37 @@ class InMemoryExecutor(OpGraphExecutor):
648
723
  )
649
724
  return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
650
725
 
651
- def _execute_distinct_rows(
726
+ async def _execute_distinct_rows(
652
727
  self, op: op_graph.op.DistinctRows, context: _InMemoryExecutionContext
653
728
  ):
654
- return self._execute(op.source, context).map(
655
- lambda source_lfm: _LazyFrameWithMetrics(
656
- source_lfm.data.unique(), source_lfm.metrics
657
- )
729
+ return await self._compute_source_then_apply(
730
+ op.source, lambda source: source.unique(), context
658
731
  )
659
732
 
660
- def _execute_join(self, op: op_graph.op.Join, context: _InMemoryExecutionContext):
661
- match self._execute(op.left_source, context):
733
+ async def _execute_join(
734
+ self, op: op_graph.op.Join, context: _InMemoryExecutionContext
735
+ ):
736
+ async with asyncio.TaskGroup() as tg:
737
+ left_task = tg.create_task(self._execute(op.left_source, context))
738
+ right_task = tg.create_task(self._execute(op.right_source, context))
739
+ match left_task.result():
740
+ case (
741
+ InternalError()
742
+ | ResourceExhaustedError()
743
+ | InvalidArgumentError() as err
744
+ ):
745
+ return err
662
746
  case Ok(left_lfm):
663
747
  pass
664
- case err:
748
+ match right_task.result():
749
+ case (
750
+ InternalError()
751
+ | ResourceExhaustedError()
752
+ | InvalidArgumentError() as err
753
+ ):
665
754
  return err
666
- match self._execute(op.right_source, context):
667
755
  case Ok(right_lfm):
668
756
  pass
669
- case err:
670
- return err
671
757
  left_lf = left_lfm.data
672
758
  right_lf = right_lfm.data
673
759
 
@@ -702,20 +788,31 @@ class InMemoryExecutor(OpGraphExecutor):
702
788
  )
703
789
  )
704
790
 
705
- def _execute_empty(self, op: op_graph.op.Empty, context: _InMemoryExecutionContext):
791
+ async def _execute_empty(
792
+ self, op: op_graph.op.Empty, context: _InMemoryExecutionContext
793
+ ):
706
794
  empty_table = cast(pl.DataFrame, pl.from_arrow(pa.schema([]).empty_table()))
707
795
  return Ok(_LazyFrameWithMetrics(empty_table.lazy(), metrics={}))
708
796
 
709
- def _execute_concat(
797
+ async def _execute_concat(
710
798
  self, op: op_graph.op.Concat, context: _InMemoryExecutionContext
711
799
  ):
800
+ async with asyncio.TaskGroup() as tg:
801
+ tasks = [
802
+ tg.create_task(self._execute(table, context)) for table in op.tables
803
+ ]
804
+
712
805
  source_lfms = list[_LazyFrameWithMetrics]()
713
- for table in op.tables:
714
- match self._execute(table, context):
715
- case Ok(batches):
716
- source_lfms.append(batches)
717
- case err:
806
+ for task in tasks:
807
+ match task.result():
808
+ case (
809
+ InternalError()
810
+ | ResourceExhaustedError()
811
+ | InvalidArgumentError() as err
812
+ ):
718
813
  return err
814
+ case Ok(lfm):
815
+ source_lfms.append(lfm)
719
816
  data = pl.concat([lfm.data for lfm in source_lfms], how=op.how)
720
817
  metrics = dict[str, Any]()
721
818
  for lfm in source_lfms:
@@ -794,16 +891,19 @@ class InMemoryExecutor(OpGraphExecutor):
794
891
  context,
795
892
  )
796
893
 
797
- def _execute_embed_column(
894
+ async def _execute_embed_column(
798
895
  self, op: op_graph.op.EmbedColumn, context: _InMemoryExecutionContext
799
896
  ):
800
- match self._execute(op.source, context):
897
+ match await self._execute(op.source, context):
801
898
  case Ok(source_lfm):
802
899
  pass
803
900
  case err:
804
901
  return err
805
- source_df = source_lfm.data.collect()
806
- to_embed = source_df[op.column_name].cast(pl.String())
902
+ source_df, to_embed = await context.run_on_worker(
903
+ _collect_and_apply,
904
+ source_lfm.data,
905
+ lambda df: df[op.column_name].cast(pl.String),
906
+ )
807
907
 
808
908
  embed_context = EmbedTextContext(
809
909
  inputs=to_embed,
@@ -813,112 +913,19 @@ class InMemoryExecutor(OpGraphExecutor):
813
913
  expected_coordinate_bitwidth=op.expected_coordinate_bitwidth,
814
914
  room_id=context.exec_context.room_id,
815
915
  )
816
- match self._text_embedder.embed(embed_context):
916
+ match await self._text_embedder.aembed(embed_context, context.worker_threads):
817
917
  case Ok(result):
818
918
  pass
819
919
  case InvalidArgumentError() | InternalError() as err:
820
920
  raise InternalError("Failed to embed column") from err
821
921
 
822
- result_df = (
823
- source_df.lazy()
824
- .with_columns(result.embeddings.alias(op.embedding_column_name))
825
- .drop_nulls(op.embedding_column_name)
826
- )
827
-
828
- return Ok(source_lfm.with_data(result_df))
829
-
830
- @staticmethod
831
- def get_cyclic_encoding(
832
- series: pl.Series,
833
- period: int,
834
- ) -> tuple[pl.Series, pl.Series]:
835
- sine_series = (2 * math.pi * series / period).sin().alias(f"{series.name}_sine")
836
- cosine_series = (
837
- (2 * math.pi * series / period).cos().alias(f"{series.name}_cosine")
922
+ return Ok(
923
+ source_lfm.with_data(
924
+ source_df.lazy()
925
+ .with_columns(result.embeddings.alias(op.embedding_column_name))
926
+ .drop_nulls(op.embedding_column_name)
927
+ )
838
928
  )
839
- return sine_series, cosine_series
840
-
841
- @staticmethod
842
- def encode_datetime(series: pl.Series) -> pl.Series:
843
- match series.dtype:
844
- case pl.Date | pl.Time:
845
- pass
846
- case pl.Datetime:
847
- series = series.dt.replace_time_zone("UTC")
848
- case _:
849
- raise ValueError("Invalid arguments, expected a datetime series")
850
-
851
- if series.is_null().all():
852
- zero_vector = pl.zeros(11, dtype=pl.Float32, eager=True)
853
- return pl.Series([zero_vector] * len(series), dtype=pl.List(pl.Float32))
854
-
855
- n = len(series)
856
- year_norm = pl.zeros(n, dtype=pl.Float32, eager=True).alias("year")
857
- month_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_sine")
858
- month_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_cosine")
859
- day_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_sine")
860
- day_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_cosine")
861
- hour_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_sine")
862
- hour_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_cosine")
863
- minute_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_sine")
864
- minute_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_cosine")
865
- second_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_sine")
866
- second_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_cosine")
867
-
868
- if series.dtype in [pl.Date, pl.Datetime]:
869
- try:
870
- year = series.dt.year().cast(pl.Float32).alias("year")
871
- month = series.dt.month().cast(pl.Float32).alias("month")
872
- day = series.dt.day().cast(pl.Float32).alias("day")
873
-
874
- year_norm = (year - REFERENCE_YEAR) / MAX_NUMBER_OF_YEARS
875
- month_sine, month_cosine = InMemoryExecutor.get_cyclic_encoding(
876
- month, 12
877
- )
878
- day_sine, day_cosine = InMemoryExecutor.get_cyclic_encoding(day, 31)
879
- except pl.exceptions.PanicException as e:
880
- _logger.exception("Error extracting datetime", exc_info=e)
881
-
882
- if series.dtype in [pl.Time, pl.Datetime]:
883
- try:
884
- hour = series.dt.hour().cast(pl.Float32).alias("hour")
885
- minute = series.dt.minute().cast(pl.Float32).alias("minute")
886
- second = series.dt.second().cast(pl.Float32).alias("second")
887
-
888
- hour_sine, hour_cosine = InMemoryExecutor.get_cyclic_encoding(hour, 24)
889
- minute_sine, minute_cosine = InMemoryExecutor.get_cyclic_encoding(
890
- minute, 60
891
- )
892
- second_sine, second_cosine = InMemoryExecutor.get_cyclic_encoding(
893
- second, 60
894
- )
895
- except pl.exceptions.PanicException as e:
896
- _logger.exception("Error extracting datetime", exc_info=e)
897
-
898
- return pl.DataFrame(
899
- [
900
- year_norm.fill_null(0.0),
901
- month_sine.fill_null(0.0),
902
- month_cosine.fill_null(0.0),
903
- day_sine.fill_null(0.0),
904
- day_cosine.fill_null(0.0),
905
- hour_sine.fill_null(0.0),
906
- hour_cosine.fill_null(0.0),
907
- minute_sine.fill_null(0.0),
908
- minute_cosine.fill_null(0.0),
909
- second_sine.fill_null(0.0),
910
- second_cosine.fill_null(0.0),
911
- ]
912
- ).select(pl.concat_list(pl.all()).alias(series.name))[series.name]
913
-
914
- @staticmethod
915
- def encode_duration(series: pl.Series) -> pl.Series:
916
- if series.dtype != pl.Duration:
917
- raise ValueError("Invalid arguments, expected a duration series")
918
- if series.is_null().all():
919
- return pl.zeros(len(series), dtype=pl.Float32, eager=True)
920
-
921
- return series.dt.total_seconds().cast(pl.Float32).fill_null(0.0)
922
929
 
923
930
  @staticmethod
924
931
  def encode_text(series: pl.Series) -> pl.Series:
@@ -939,132 +946,71 @@ class InMemoryExecutor(OpGraphExecutor):
939
946
  pl.List(pl.Float32),
940
947
  )
941
948
 
942
- def _execute_encode_columns( # noqa: C901, PLR0915
949
+ def _encode_column( # noqa: C901
950
+ self, to_encode: pl.Series, encoder: op_graph.Encoder
951
+ ) -> tuple[pl.Series, list[str] | None]:
952
+ match encoder:
953
+ case op_graph.encoder.OneHotEncoder():
954
+ return column_encoding.encode_one_hot(to_encode)
955
+
956
+ case op_graph.encoder.MinMaxScaler():
957
+ return column_encoding.encode_min_max_scale(
958
+ to_encode, encoder.feature_range_min, encoder.feature_range_max
959
+ ), None
960
+
961
+ case op_graph.encoder.LabelBinarizer():
962
+ return column_encoding.encode_label_boolean(
963
+ to_encode, encoder.neg_label, encoder.pos_label
964
+ ), None
965
+
966
+ case op_graph.encoder.LabelEncoder():
967
+ return column_encoding.encode_label(
968
+ to_encode, normalize=encoder.normalize
969
+ ), None
970
+
971
+ case op_graph.encoder.KBinsDiscretizer():
972
+ return column_encoding.encode_kbins(
973
+ to_encode, encoder.n_bins, encoder.encode_method, encoder.strategy
974
+ ), None
975
+
976
+ case op_graph.encoder.Binarizer():
977
+ return column_encoding.encode_boolean(
978
+ to_encode, encoder.threshold
979
+ ), None
980
+
981
+ case op_graph.encoder.MaxAbsScaler():
982
+ return column_encoding.encode_max_abs_scale(to_encode), None
983
+
984
+ case op_graph.encoder.StandardScaler():
985
+ return column_encoding.encode_standard_scale(
986
+ to_encode, with_mean=encoder.with_mean, with_std=encoder.with_std
987
+ ), None
988
+
989
+ case op_graph.encoder.TimestampEncoder():
990
+ if to_encode.dtype == pl.datatypes.Duration:
991
+ return column_encoding.encode_duration(to_encode), None
992
+ return column_encoding.encode_datetime(to_encode), None
993
+
994
+ case op_graph.encoder.TextEncoder():
995
+ return self.encode_text(to_encode), None
996
+
997
+ async def _execute_encode_columns(
943
998
  self, op: op_graph.op.EncodeColumns, context: _InMemoryExecutionContext
944
999
  ):
945
- match self._execute(op.source, context):
1000
+ match await self._execute(op.source, context):
946
1001
  case Ok(source_lfm):
947
1002
  pass
948
1003
  case err:
949
1004
  return err
950
- source_df = source_lfm.data.collect()
1005
+ source_df = await context.run_on_worker(source_lfm.data.collect)
951
1006
  metrics = source_lfm.metrics.copy()
952
1007
  metric = metrics.get("one_hot_encoder", {})
953
1008
  for encoder_arg in op.encoded_columns:
954
- to_encode = source_df[encoder_arg.column_name]
955
- match encoder_arg.encoder:
956
- case op_graph.encoder.OneHotEncoder():
957
- encoded = to_encode.to_dummies()
958
- metric[encoder_arg.column_name] = encoded.columns
959
- encoded = encoded.select(
960
- pl.concat_list(pl.all())
961
- .alias(encoder_arg.encoded_column_name)
962
- .cast(pl.List(pl.Boolean))
963
- )
964
-
965
- case op_graph.encoder.MinMaxScaler():
966
- from sklearn.preprocessing import MinMaxScaler
967
-
968
- encoder = MinMaxScaler(
969
- feature_range=(
970
- encoder_arg.encoder.feature_range_min,
971
- encoder_arg.encoder.feature_range_max,
972
- )
973
- )
974
- encoded = encoder.fit_transform(
975
- to_encode.to_numpy().reshape(-1, 1)
976
- ).flatten()
977
-
978
- case op_graph.encoder.LabelBinarizer():
979
- from sklearn.preprocessing import LabelBinarizer
980
-
981
- encoder = LabelBinarizer(
982
- neg_label=encoder_arg.encoder.neg_label,
983
- pos_label=encoder_arg.encoder.pos_label,
984
- )
985
- encoded = encoder.fit_transform(to_encode.to_numpy().reshape(-1))
986
-
987
- case op_graph.encoder.LabelEncoder():
988
- from sklearn.preprocessing import LabelEncoder
989
-
990
- encoder = LabelEncoder()
991
- encoded = encoder.fit_transform(
992
- to_encode.to_numpy().reshape(-1)
993
- ).flatten()
994
- # `classes_` is only set after fit,
995
- # Creating custom typestubs will not solve this typing issue.
996
- if encoder_arg.encoder.normalize and hasattr(encoder, "classes_"):
997
- classes_ = cast(list[int], encoder.classes_) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
998
- max_class: int = len(classes_) - 1
999
- if max_class > 0:
1000
- encoded = encoded.astype(np.float64)
1001
- encoded /= max_class
1002
-
1003
- case op_graph.encoder.KBinsDiscretizer():
1004
- from sklearn.preprocessing import KBinsDiscretizer
1005
-
1006
- encoder = KBinsDiscretizer(
1007
- n_bins=encoder_arg.encoder.n_bins,
1008
- encode=encoder_arg.encoder.encode_method,
1009
- strategy=encoder_arg.encoder.strategy,
1010
- dtype=np.float32,
1011
- )
1012
- encoded = encoder.fit_transform(
1013
- to_encode.to_numpy().reshape(-1, 1)
1014
- ).flatten()
1015
-
1016
- case op_graph.encoder.Binarizer():
1017
- from sklearn.preprocessing import Binarizer
1018
-
1019
- encoder = Binarizer(
1020
- threshold=encoder_arg.encoder.threshold,
1021
- )
1022
- encoded = encoder.fit_transform(
1023
- to_encode.to_numpy().reshape(-1, 1)
1024
- ).flatten()
1025
-
1026
- case op_graph.encoder.MaxAbsScaler():
1027
- from sklearn.preprocessing import MaxAbsScaler
1028
-
1029
- encoder = MaxAbsScaler()
1030
- try:
1031
- encoded = encoder.fit_transform(
1032
- np.nan_to_num(to_encode.to_numpy()).reshape(-1, 1)
1033
- ).flatten()
1034
- except ValueError:
1035
- encoded = np.array([])
1036
-
1037
- case op_graph.encoder.StandardScaler():
1038
- from sklearn.preprocessing import StandardScaler
1039
-
1040
- encoder = StandardScaler(
1041
- with_mean=encoder_arg.encoder.with_mean,
1042
- with_std=encoder_arg.encoder.with_std,
1043
- )
1044
- encoded = encoder.fit_transform(
1045
- to_encode.to_numpy().reshape(-1, 1)
1046
- ).flatten()
1047
-
1048
- case op_graph.encoder.TimestampEncoder():
1049
- if to_encode.dtype == pl.datatypes.Duration:
1050
- encoded = self.encode_duration(to_encode)
1051
- else:
1052
- encoded = self.encode_datetime(to_encode)
1053
- source_df = source_df.with_columns(
1054
- encoded.rename(encoder_arg.encoded_column_name).cast(
1055
- encoder_arg.encoder.output_dtype
1056
- )
1057
- )
1058
- continue
1059
-
1060
- case op_graph.encoder.TextEncoder():
1061
- encoded = self.encode_text(to_encode)
1062
- source_df = source_df.with_columns(
1063
- encoded.rename(encoder_arg.encoded_column_name).cast(
1064
- encoder_arg.encoder.output_dtype
1065
- )
1066
- )
1067
- continue
1009
+ encoded, one_hot_columns = self._encode_column(
1010
+ source_df[encoder_arg.column_name], encoder_arg.encoder
1011
+ )
1012
+ if one_hot_columns is not None:
1013
+ metric[encoder_arg.column_name] = one_hot_columns
1068
1014
 
1069
1015
  source_df = source_df.with_columns(
1070
1016
  pl.Series(
@@ -1081,7 +1027,7 @@ class InMemoryExecutor(OpGraphExecutor):
1081
1027
  )
1082
1028
  )
1083
1029
 
1084
- def _execute_embed_node2vec_from_edge_lists(
1030
+ async def _execute_embed_node2vec_from_edge_lists(
1085
1031
  self,
1086
1032
  op: op_graph.op.EmbedNode2vecFromEdgeLists,
1087
1033
  context: _InMemoryExecutionContext,
@@ -1115,7 +1061,7 @@ class InMemoryExecutor(OpGraphExecutor):
1115
1061
 
1116
1062
  edge_list_lfms = list[_LazyFrameWithMetrics]()
1117
1063
  for edge_list in op.edge_list_tables:
1118
- match self._execute(edge_list.table, context):
1064
+ match await self._execute(edge_list.table, context):
1119
1065
  case Ok(source_lfm):
1120
1066
  edge_list_lfms.append(source_lfm)
1121
1067
  case err:
@@ -1129,58 +1075,65 @@ class InMemoryExecutor(OpGraphExecutor):
1129
1075
  end_column_type_name = entities_dtypes[end_column_name]
1130
1076
  metrics.update(lfm.metrics)
1131
1077
  yield (
1132
- lfm.data.with_columns(
1133
- pl.col(edge_list.start_column_name).alias(
1134
- f"start_id_{start_column_type_name}"
1135
- ),
1136
- pl.lit(edge_list.start_entity_name).alias("start_source"),
1137
- pl.col(edge_list.end_column_name).alias(
1138
- f"end_id_{end_column_type_name}"
1139
- ),
1140
- pl.lit(edge_list.end_entity_name).alias("end_source"),
1141
- )
1142
- .select(
1143
- f"start_id_{start_column_type_name}",
1144
- "start_source",
1145
- f"end_id_{end_column_type_name}",
1146
- "end_source",
1078
+ context.run_on_worker(
1079
+ lfm.data.with_columns(
1080
+ pl.col(edge_list.start_column_name).alias(
1081
+ f"start_id_{start_column_type_name}"
1082
+ ),
1083
+ pl.lit(edge_list.start_entity_name).alias("start_source"),
1084
+ pl.col(edge_list.end_column_name).alias(
1085
+ f"end_id_{end_column_type_name}"
1086
+ ),
1087
+ pl.lit(edge_list.end_entity_name).alias("end_source"),
1088
+ )
1089
+ .select(
1090
+ f"start_id_{start_column_type_name}",
1091
+ "start_source",
1092
+ f"end_id_{end_column_type_name}",
1093
+ "end_source",
1094
+ )
1095
+ .collect
1147
1096
  )
1148
- .collect()
1149
1097
  )
1150
1098
 
1151
- edges = pl.concat(
1152
- [
1153
- empty_edges_table,
1154
- *(edge_list for edge_list in edge_generator()),
1155
- ],
1156
- rechunk=False,
1157
- how="diagonal",
1158
- )
1099
+ async with asyncio.TaskGroup() as tg:
1100
+ edge_tasks = [tg.create_task(edge_list) for edge_list in edge_generator()]
1159
1101
 
1160
- n2v_space = embed.Space(
1161
- edges=edges,
1162
- start_id_column_names=start_id_column_names,
1163
- end_id_column_names=end_id_column_names,
1164
- directed=True,
1165
- )
1166
- n2v_runner = embed.Node2Vec(
1167
- space=n2v_space,
1168
- dim=op.ndim,
1169
- walk_length=op.walk_length,
1170
- window=op.window,
1171
- p=op.p,
1172
- q=op.q,
1173
- alpha=op.alpha,
1174
- min_alpha=op.min_alpha,
1175
- negative=op.negative,
1176
- )
1177
- n2v_runner.train(epochs=op.epochs)
1178
- return Ok(_LazyFrameWithMetrics(n2v_runner.wv.to_polars().lazy(), metrics))
1102
+ def run_n2v():
1103
+ edges = pl.concat(
1104
+ [
1105
+ empty_edges_table,
1106
+ *(task.result() for task in edge_tasks),
1107
+ ],
1108
+ rechunk=False,
1109
+ how="diagonal",
1110
+ )
1111
+ n2v_space = embed.Space(
1112
+ edges=edges,
1113
+ start_id_column_names=start_id_column_names,
1114
+ end_id_column_names=end_id_column_names,
1115
+ directed=True,
1116
+ )
1117
+ n2v_runner = embed.Node2Vec(
1118
+ space=n2v_space,
1119
+ dim=op.ndim,
1120
+ walk_length=op.walk_length,
1121
+ window=op.window,
1122
+ p=op.p,
1123
+ q=op.q,
1124
+ alpha=op.alpha,
1125
+ min_alpha=op.min_alpha,
1126
+ negative=op.negative,
1127
+ )
1128
+ n2v_runner.train(epochs=op.epochs)
1129
+ return n2v_runner.wv.to_polars().lazy()
1130
+
1131
+ return Ok(_LazyFrameWithMetrics(await context.run_on_worker(run_n2v), metrics))
1179
1132
 
1180
- def _execute_aggregate_columns(
1133
+ async def _execute_aggregate_columns(
1181
1134
  self, op: op_graph.op.AggregateColumns, context: _InMemoryExecutionContext
1182
1135
  ):
1183
- match self._execute(op.source, context):
1136
+ match await self._execute(op.source, context):
1184
1137
  case Ok(source_lfm):
1185
1138
  pass
1186
1139
  case err:
@@ -1205,38 +1158,48 @@ class InMemoryExecutor(OpGraphExecutor):
1205
1158
 
1206
1159
  return Ok(source_lfm.with_data(aggregate))
1207
1160
 
1208
- def _execute_correlate_columns(
1161
+ async def _execute_correlate_columns(
1209
1162
  self, op: op_graph.op.CorrelateColumns, context: _InMemoryExecutionContext
1210
1163
  ):
1211
- match self._execute(op.source, context):
1164
+ match await self._execute(op.source, context):
1212
1165
  case Ok(source_lfm):
1213
1166
  pass
1214
1167
  case err:
1215
1168
  return err
1216
- source_df = source_lfm.data.collect()
1217
- with np.errstate(invalid="ignore"):
1218
- corr_df = source_df.select(op.column_names).corr(dtype="float32")
1219
1169
 
1170
+ def correlate(df: pl.DataFrame):
1171
+ with np.errstate(invalid="ignore"):
1172
+ return df.select(op.column_names).corr(dtype="float32")
1173
+
1174
+ _, corr_df = await context.run_on_worker(
1175
+ _collect_and_apply, source_lfm.data, correlate
1176
+ )
1220
1177
  return Ok(source_lfm.with_data(corr_df.lazy()))
1221
1178
 
1222
- def _execute_histogram_column(
1179
+ async def _execute_histogram_column(
1223
1180
  self, op: op_graph.op.HistogramColumn, context: _InMemoryExecutionContext
1224
1181
  ):
1225
- return self._compute_source_then_apply(
1226
- op.source,
1227
- lambda lf: lf.collect()[op.column_name]
1182
+ match await self._execute(op.source, context):
1183
+ case Ok(source_lfm):
1184
+ pass
1185
+ case err:
1186
+ return err
1187
+
1188
+ _, result_df = await context.run_on_worker(
1189
+ _collect_and_apply,
1190
+ source_lfm.data,
1191
+ lambda df: df[op.column_name]
1228
1192
  .hist(include_category=False)
1229
- .lazy()
1230
1193
  .rename(
1231
1194
  {
1232
1195
  "breakpoint": op.breakpoint_column_name,
1233
1196
  "count": op.count_column_name,
1234
1197
  }
1235
1198
  ),
1236
- context,
1237
1199
  )
1200
+ return Ok(source_lfm.with_data(result_df.lazy()))
1238
1201
 
1239
- def _execute_convert_column_to_string(
1202
+ async def _execute_convert_column_to_string(
1240
1203
  self, op: op_graph.op.ConvertColumnToString, context: _InMemoryExecutionContext
1241
1204
  ):
1242
1205
  dtype = op.source.schema.to_polars()[op.column_name]
@@ -1248,14 +1211,15 @@ class InMemoryExecutor(OpGraphExecutor):
1248
1211
  raise NotImplementedError(
1249
1212
  "converting struct columns to strings is not implemented"
1250
1213
  )
1251
- return self._compute_source_then_apply(
1252
- op.source, lambda lf: lf.collect().with_columns(cast_expr).lazy(), context
1214
+
1215
+ return await self._compute_source_then_apply(
1216
+ op.source, lambda lf: lf.with_columns(cast_expr), context
1253
1217
  )
1254
1218
 
1255
- def _execute_add_row_index(
1219
+ async def _execute_add_row_index(
1256
1220
  self, op: op_graph.op.AddRowIndex, context: _InMemoryExecutionContext
1257
1221
  ):
1258
- return self._compute_source_then_apply(
1222
+ return await self._compute_source_then_apply(
1259
1223
  op.source,
1260
1224
  lambda lf: lf.with_row_index(
1261
1225
  name=op.row_index_column_name, offset=op.offset
@@ -1263,70 +1227,76 @@ class InMemoryExecutor(OpGraphExecutor):
1263
1227
  context,
1264
1228
  )
1265
1229
 
1266
- def _execute_output_csv(
1230
+ async def _execute_output_csv(
1267
1231
  self, op: op_graph.op.OutputCsv, context: _InMemoryExecutionContext
1268
1232
  ):
1269
- match self._execute(op.source, context):
1233
+ match await self._execute(op.source, context):
1270
1234
  case Ok(source_lfm):
1271
1235
  pass
1272
1236
  case err:
1273
1237
  return err
1274
- source_df = source_lfm.data.collect()
1275
- source_df.write_csv(
1276
- op.csv_url,
1277
- quote_style="never",
1278
- include_header=op.include_header,
1238
+ source_df, _ = await context.run_on_worker(
1239
+ _collect_and_apply,
1240
+ source_lfm.data,
1241
+ lambda df: df.write_csv(
1242
+ op.csv_url, quote_style="never", include_header=op.include_header
1243
+ ),
1279
1244
  )
1280
1245
  return Ok(source_lfm.with_data(source_df.lazy()))
1281
1246
 
1282
- def _execute_truncate_list(
1247
+ async def _execute_truncate_list(
1283
1248
  self, op: op_graph.op.TruncateList, context: _InMemoryExecutionContext
1284
1249
  ):
1285
1250
  # TODO(Patrick): verify this approach works for arrays
1286
- match self._execute(op.source, context):
1251
+ match await self._execute(op.source, context):
1287
1252
  case Ok(source_lfm):
1288
1253
  pass
1289
1254
  case err:
1290
1255
  return err
1291
- source_df = source_lfm.data.collect()
1292
- if len(source_df):
1293
- existing_length = get_polars_embedding_length(
1294
- source_df, op.column_name
1295
- ).unwrap_or_raise()
1296
- else:
1297
- existing_length = 0
1298
- head_length = (
1299
- op.target_column_length
1300
- if existing_length >= op.target_column_length
1301
- else existing_length
1302
- )
1303
- source_df = source_df.with_columns(
1304
- pl.col(op.column_name).list.head(head_length)
1256
+ source_df, existing_length = await context.run_on_worker(
1257
+ _collect_and_apply,
1258
+ source_lfm.data,
1259
+ lambda df: get_polars_embedding_length(df[op.column_name]).unwrap_or_raise()
1260
+ if len(df)
1261
+ else 0,
1305
1262
  )
1263
+
1306
1264
  outer_type = source_df.schema[op.column_name]
1307
1265
  if isinstance(outer_type, pl.Array | pl.List):
1308
1266
  inner_type = outer_type.inner
1309
1267
  else:
1310
1268
  return InternalError("unexpected type", cause="expected list or array type")
1269
+ result = source_df.lazy()
1270
+
1271
+ head_length = (
1272
+ op.target_column_length
1273
+ if existing_length >= op.target_column_length
1274
+ else existing_length
1275
+ )
1276
+ result = result.with_columns(pl.col(op.column_name).list.head(head_length))
1311
1277
 
1312
- source_df = source_df.lazy()
1313
1278
  if head_length < op.target_column_length:
1314
1279
  padding_length = op.target_column_length - head_length
1315
1280
  padding = [op.padding_value_as_py] * padding_length
1316
- source_df = source_df.with_columns(
1317
- pl.col(op.column_name).list.concat(padding)
1318
- )
1319
- source_df = source_df.with_columns(
1281
+ result = result.with_columns(pl.col(op.column_name).list.concat(padding))
1282
+ result = result.with_columns(
1320
1283
  pl.col(op.column_name)
1321
1284
  .list.to_array(width=op.target_column_length)
1322
1285
  .cast(pl.List(inner_type))
1323
1286
  )
1324
- return Ok(source_lfm.with_data(source_df))
1287
+ return Ok(source_lfm.with_data(result))
1325
1288
 
1326
- def _execute_union(self, op: op_graph.op.Union, context: _InMemoryExecutionContext):
1289
+ async def _execute_union(
1290
+ self, op: op_graph.op.Union, context: _InMemoryExecutionContext
1291
+ ):
1292
+ async with asyncio.TaskGroup() as tg:
1293
+ source_taks = [
1294
+ tg.create_task(self._execute(source, context))
1295
+ for source in op.sources()
1296
+ ]
1327
1297
  sources = list[_LazyFrameWithMetrics]()
1328
- for source in op.sources():
1329
- match self._execute(source, context):
1298
+ for task in source_taks:
1299
+ match task.result():
1330
1300
  case Ok(source_lfm):
1331
1301
  sources.append(source_lfm)
1332
1302
  case err:
@@ -1341,16 +1311,19 @@ class InMemoryExecutor(OpGraphExecutor):
1341
1311
  result_lf = result_lf.unique()
1342
1312
  return Ok(_LazyFrameWithMetrics(result_lf, metrics=metrics))
1343
1313
 
1344
- def _execute_embed_image_column(
1314
+ async def _execute_embed_image_column(
1345
1315
  self, op: op_graph.op.EmbedImageColumn, context: _InMemoryExecutionContext
1346
1316
  ):
1347
- match self._execute(op.source, context):
1317
+ match await self._execute(op.source, context):
1348
1318
  case Ok(source_lfm):
1349
1319
  pass
1350
1320
  case err:
1351
1321
  return err
1352
- source_df = source_lfm.data.collect()
1353
- to_embed = source_df[op.column_name].cast(pl.Binary())
1322
+ source_df, to_embed = await context.run_on_worker(
1323
+ _collect_and_apply,
1324
+ source_lfm.data,
1325
+ lambda df: df[op.column_name].cast(pl.Binary()),
1326
+ )
1354
1327
 
1355
1328
  embed_context = EmbedImageContext(
1356
1329
  inputs=to_embed,
@@ -1358,43 +1331,38 @@ class InMemoryExecutor(OpGraphExecutor):
1358
1331
  expected_vector_length=op.expected_vector_length,
1359
1332
  expected_coordinate_bitwidth=op.expected_coordinate_bitwidth,
1360
1333
  )
1361
- match self._image_embedder.embed(embed_context):
1334
+ match await self._image_embedder.aembed(embed_context, context.worker_threads):
1362
1335
  case Ok(result):
1363
1336
  pass
1364
1337
  case InvalidArgumentError() | InternalError() as err:
1365
1338
  raise InternalError("Failed to embed column") from err
1366
1339
 
1367
1340
  return Ok(
1368
- _LazyFrameWithMetrics(
1341
+ source_lfm.with_data(
1369
1342
  source_df.lazy()
1370
1343
  .with_columns(result.embeddings.alias(op.embedding_column_name))
1371
- .drop_nulls(op.embedding_column_name),
1372
- source_lfm.metrics,
1344
+ .drop_nulls(op.embedding_column_name)
1373
1345
  )
1374
1346
  )
1375
1347
 
1376
- def _execute_add_decision_tree_summary(
1377
- self, op: op_graph.op.AddDecisionTreeSummary, context: _InMemoryExecutionContext
1348
+ def _compute_decision_tree_summary(
1349
+ self,
1350
+ data: pl.DataFrame,
1351
+ feature_column_names: Sequence[str],
1352
+ label_column_name: str,
1353
+ max_depth: int,
1354
+ class_names: Sequence[str] | None,
1378
1355
  ):
1379
- match self._execute(op.source, context):
1380
- case Ok(source_lfm):
1381
- pass
1382
- case err:
1383
- return err
1384
-
1385
- df_input = source_lfm.data.collect()
1386
- dataframe = df_input.select(
1387
- list({*op.feature_column_names, op.label_column_name})
1388
- )
1356
+ dataframe = data.select(list({*feature_column_names, label_column_name}))
1389
1357
  boolean_columns = [
1390
1358
  name
1391
1359
  for name, dtype in dataframe.schema.items()
1392
- if dtype == pl.Boolean() and name in op.feature_column_names
1360
+ if dtype == pl.Boolean() and name in feature_column_names
1393
1361
  ]
1394
1362
 
1395
1363
  # Drop Nan and Null and infinite rows as not supported by decision tree
1396
1364
  dataframe = dataframe.with_columns(
1397
- *[pl.col(col).cast(pl.Float32) for col in op.feature_column_names]
1365
+ *[pl.col(col).cast(pl.Float32) for col in feature_column_names]
1398
1366
  )
1399
1367
  dataframe = dataframe.drop_nans().drop_nulls()
1400
1368
  try:
@@ -1407,9 +1375,8 @@ class InMemoryExecutor(OpGraphExecutor):
1407
1375
  return InvalidArgumentError(
1408
1376
  "a minimum of 1 sample is required by DecisionTreeClassifier"
1409
1377
  )
1410
- features = dataframe[op.feature_column_names]
1411
- classes = dataframe[op.label_column_name]
1412
- max_depth = op.max_depth
1378
+ features = dataframe[feature_column_names]
1379
+ classes = dataframe[label_column_name]
1413
1380
 
1414
1381
  from sklearn.tree import DecisionTreeClassifier, export_graphviz, export_text
1415
1382
  from sklearn.utils.multiclass import check_classification_targets
@@ -1427,15 +1394,15 @@ class InMemoryExecutor(OpGraphExecutor):
1427
1394
 
1428
1395
  tree_str = export_text(
1429
1396
  decision_tree=decision_tree,
1430
- feature_names=op.feature_column_names,
1431
- class_names=op.classes_names,
1397
+ feature_names=feature_column_names,
1398
+ class_names=class_names,
1432
1399
  max_depth=max_depth,
1433
1400
  )
1434
1401
 
1435
1402
  tree_graphviz = export_graphviz(
1436
1403
  decision_tree=decision_tree,
1437
- feature_names=op.feature_column_names,
1438
- class_names=op.classes_names,
1404
+ feature_names=feature_column_names,
1405
+ class_names=class_names,
1439
1406
  max_depth=max_depth,
1440
1407
  )
1441
1408
 
@@ -1445,16 +1412,41 @@ class InMemoryExecutor(OpGraphExecutor):
1445
1412
  )
1446
1413
  tree_str = tree_str.replace(f"{boolean_column} > 0.50", boolean_column)
1447
1414
 
1415
+ return Ok(table_pb2.DecisionTreeSummary(text=tree_str, graphviz=tree_graphviz))
1416
+
1417
+ async def _execute_add_decision_tree_summary(
1418
+ self, op: op_graph.op.AddDecisionTreeSummary, context: _InMemoryExecutionContext
1419
+ ):
1420
+ match await self._execute(op.source, context):
1421
+ case Ok(source_lfm):
1422
+ pass
1423
+ case err:
1424
+ return err
1448
1425
  metrics = source_lfm.metrics.copy()
1449
- metrics[op.output_metric_key] = table_pb2.DecisionTreeSummary(
1450
- text=tree_str, graphviz=tree_graphviz
1426
+ source_df, summary_result = await context.run_on_worker(
1427
+ _collect_and_apply,
1428
+ source_lfm.data,
1429
+ lambda df: self._compute_decision_tree_summary(
1430
+ df,
1431
+ op.feature_column_names,
1432
+ op.label_column_name,
1433
+ op.max_depth,
1434
+ op.classes_names,
1435
+ ),
1451
1436
  )
1452
- return Ok(_LazyFrameWithMetrics(df_input.lazy(), metrics=metrics))
1453
1437
 
1454
- def _execute_unnest_list(
1438
+ match summary_result:
1439
+ case InvalidArgumentError() | InternalError() as err:
1440
+ return err
1441
+ case Ok(tree_summary):
1442
+ metrics[op.output_metric_key] = tree_summary
1443
+
1444
+ return Ok(_LazyFrameWithMetrics(source_df.lazy(), metrics=metrics))
1445
+
1446
+ async def _execute_unnest_list(
1455
1447
  self, op: op_graph.op.UnnestList, context: _InMemoryExecutionContext
1456
1448
  ):
1457
- return self._compute_source_then_apply(
1449
+ return await self._compute_source_then_apply(
1458
1450
  op.source,
1459
1451
  lambda lf: lf.with_columns(
1460
1452
  pl.col(op.list_column_name).list.get(i).alias(column_name)
@@ -1463,47 +1455,42 @@ class InMemoryExecutor(OpGraphExecutor):
1463
1455
  context,
1464
1456
  )
1465
1457
 
1466
- def _execute_sample_rows(
1458
+ async def _execute_sample_rows(
1467
1459
  self, op: op_graph.op.SampleRows, context: _InMemoryExecutionContext
1468
1460
  ):
1469
- match self._execute(op.source, context):
1461
+ match await self._execute(op.source, context):
1470
1462
  case Ok(source_lfm):
1471
1463
  pass
1472
1464
  case err:
1473
1465
  return err
1474
- source_df = source_lfm.data.collect()
1475
- n = min(op.num_rows, source_df.shape[0])
1476
- sample_strategy = op.sample_strategy
1477
- match sample_strategy:
1478
- case op_graph.sample_strategy.UniformRandom():
1479
- result_df = source_df.sample(
1480
- n=n,
1481
- seed=sample_strategy.seed,
1482
- )
1483
1466
 
1484
- return Ok(
1485
- _LazyFrameWithMetrics(
1486
- result_df.lazy(),
1487
- source_lfm.metrics,
1488
- )
1467
+ def sample(df: pl.DataFrame):
1468
+ match op.sample_strategy:
1469
+ case op_graph.sample_strategy.UniformRandom():
1470
+ return df.sample(
1471
+ min(op.num_rows, df.shape[0]), seed=op.sample_strategy.seed
1472
+ )
1473
+
1474
+ _, result_df = await context.run_on_worker(
1475
+ _collect_and_apply, source_lfm.data, sample
1489
1476
  )
1490
1477
 
1491
- def _execute_describe_columns(
1478
+ return Ok(_LazyFrameWithMetrics(result_df.lazy(), source_lfm.metrics))
1479
+
1480
+ async def _execute_describe_columns(
1492
1481
  self, op: op_graph.op.DescribeColumns, context: _InMemoryExecutionContext
1493
1482
  ):
1494
- match self._execute(op.source, context):
1483
+ match await self._execute(op.source, context):
1495
1484
  case Ok(source_lfm):
1496
1485
  pass
1497
1486
  case err:
1498
1487
  return err
1499
- source_df = source_lfm.data.collect()
1500
- return Ok(
1501
- source_lfm.with_data(
1502
- source_df.describe()
1503
- .lazy()
1504
- .rename({"statistic": op.statistic_column_name})
1505
- )
1488
+ _, result_df = await context.run_on_worker(
1489
+ _collect_and_apply,
1490
+ source_lfm.data,
1491
+ lambda df: df.describe().rename({"statistic": op.statistic_column_name}),
1506
1492
  )
1493
+ return Ok(source_lfm.with_data(result_df.lazy()))
1507
1494
 
1508
1495
  def _has_partially_computed_data(
1509
1496
  self, op: op_graph.Op, context: _InMemoryExecutionContext
@@ -1517,7 +1504,7 @@ class InMemoryExecutor(OpGraphExecutor):
1517
1504
  for sub_source in flatten(source.sources() for source in op.sources())
1518
1505
  )
1519
1506
 
1520
- def _do_execute( # noqa: C901
1507
+ async def _do_execute( # noqa: C901
1521
1508
  self,
1522
1509
  op: op_graph.Op,
1523
1510
  context: _InMemoryExecutionContext,
@@ -1552,8 +1539,13 @@ class InMemoryExecutor(OpGraphExecutor):
1552
1539
  )
1553
1540
  case Ok(query):
1554
1541
  pass
1555
- return self._staging_db.run_select_query(
1556
- query, expected_schema, context.current_slice_args
1542
+ return (
1543
+ await context.run_on_worker(
1544
+ self._staging_db.run_select_query,
1545
+ query,
1546
+ expected_schema,
1547
+ context.current_slice_args,
1548
+ )
1557
1549
  ).map(
1558
1550
  lambda rbr: _LazyFrameWithMetrics(
1559
1551
  _as_df(rbr, expected_schema).lazy(),
@@ -1571,80 +1563,80 @@ class InMemoryExecutor(OpGraphExecutor):
1571
1563
  case op_graph.op.ReadFromParquet():
1572
1564
  return self._execute_read_from_parquet(op, context)
1573
1565
  case op_graph.op.RenameColumns():
1574
- return self._execute_rename_columns(op, context)
1566
+ return await self._execute_rename_columns(op, context)
1575
1567
  case op_graph.op.Join():
1576
- return self._execute_join(op, context)
1568
+ return await self._execute_join(op, context)
1577
1569
  case op_graph.op.SelectColumns():
1578
- return self._execute_select_columns(op, context)
1570
+ return await self._execute_select_columns(op, context)
1579
1571
  case op_graph.op.LimitRows():
1580
- return self._execute_limit_rows(op, context)
1572
+ return await self._execute_limit_rows(op, context)
1581
1573
  case op_graph.op.OffsetRows():
1582
- return self._execute_offset_rows(op, context)
1574
+ return await self._execute_offset_rows(op, context)
1583
1575
  case op_graph.op.OrderBy():
1584
- return self._execute_order_by(op, context)
1576
+ return await self._execute_order_by(op, context)
1585
1577
  case op_graph.op.FilterRows():
1586
- return self._execute_filter_rows(op, context)
1578
+ return await self._execute_filter_rows(op, context)
1587
1579
  case op_graph.op.DistinctRows():
1588
- return self._execute_distinct_rows(op, context)
1580
+ return await self._execute_distinct_rows(op, context)
1589
1581
  case (
1590
1582
  op_graph.op.SetMetadata()
1591
1583
  | op_graph.op.UpdateMetadata()
1592
1584
  | op_graph.op.RemoveFromMetadata()
1593
1585
  | op_graph.op.UpdateFeatureTypes()
1594
1586
  ):
1595
- return self._execute(op.source, context)
1587
+ return await self._execute(op.source, context)
1596
1588
  case op_graph.op.EmbeddingMetrics() as op:
1597
- return self._execute_embedding_metrics(op, context)
1589
+ return await self._execute_embedding_metrics(op, context)
1598
1590
  case op_graph.op.EmbeddingCoordinates():
1599
- return self._execute_embedding_coordinates(op, context)
1591
+ return await self._execute_embedding_coordinates(op, context)
1600
1592
  case op_graph.op.RollupByAggregation() as op:
1601
- return self._execute_rollup_by_aggregation(op, context)
1593
+ return await self._execute_rollup_by_aggregation(op, context)
1602
1594
  case op_graph.op.Empty():
1603
- return self._execute_empty(op, context)
1595
+ return await self._execute_empty(op, context)
1604
1596
  case op_graph.op.EmbedNode2vecFromEdgeLists():
1605
- return self._execute_embed_node2vec_from_edge_lists(op, context)
1597
+ return await self._execute_embed_node2vec_from_edge_lists(op, context)
1606
1598
  case op_graph.op.Concat():
1607
- return self._execute_concat(op, context)
1599
+ return await self._execute_concat(op, context)
1608
1600
  case op_graph.op.UnnestStruct():
1609
- return self._execute_unnest_struct(op, context)
1601
+ return await self._execute_unnest_struct(op, context)
1610
1602
  case op_graph.op.NestIntoStruct():
1611
- return self._execute_nest_into_struct(op, context)
1603
+ return await self._execute_nest_into_struct(op, context)
1612
1604
  case op_graph.op.AddLiteralColumn():
1613
- return self._execute_add_literal_column(op, context)
1605
+ return await self._execute_add_literal_column(op, context)
1614
1606
  case op_graph.op.CombineColumns():
1615
- return self._execute_combine_columns(op, context)
1607
+ return await self._execute_combine_columns(op, context)
1616
1608
  case op_graph.op.EmbedColumn():
1617
- return self._execute_embed_column(op, context)
1609
+ return await self._execute_embed_column(op, context)
1618
1610
  case op_graph.op.EncodeColumns():
1619
- return self._execute_encode_columns(op, context)
1611
+ return await self._execute_encode_columns(op, context)
1620
1612
  case op_graph.op.AggregateColumns():
1621
- return self._execute_aggregate_columns(op, context)
1613
+ return await self._execute_aggregate_columns(op, context)
1622
1614
  case op_graph.op.CorrelateColumns():
1623
- return self._execute_correlate_columns(op, context)
1615
+ return await self._execute_correlate_columns(op, context)
1624
1616
  case op_graph.op.HistogramColumn():
1625
- return self._execute_histogram_column(op, context)
1617
+ return await self._execute_histogram_column(op, context)
1626
1618
  case op_graph.op.ConvertColumnToString():
1627
- return self._execute_convert_column_to_string(op, context)
1619
+ return await self._execute_convert_column_to_string(op, context)
1628
1620
  case op_graph.op.AddRowIndex():
1629
- return self._execute_add_row_index(op, context)
1621
+ return await self._execute_add_row_index(op, context)
1630
1622
  case op_graph.op.OutputCsv():
1631
- return self._execute_output_csv(op, context)
1623
+ return await self._execute_output_csv(op, context)
1632
1624
  case op_graph.op.TruncateList():
1633
- return self._execute_truncate_list(op, context)
1625
+ return await self._execute_truncate_list(op, context)
1634
1626
  case op_graph.op.Union():
1635
- return self._execute_union(op, context)
1627
+ return await self._execute_union(op, context)
1636
1628
  case op_graph.op.EmbedImageColumn():
1637
- return self._execute_embed_image_column(op, context)
1629
+ return await self._execute_embed_image_column(op, context)
1638
1630
  case op_graph.op.AddDecisionTreeSummary():
1639
- return self._execute_add_decision_tree_summary(op, context)
1631
+ return await self._execute_add_decision_tree_summary(op, context)
1640
1632
  case op_graph.op.UnnestList():
1641
- return self._execute_unnest_list(op, context)
1633
+ return await self._execute_unnest_list(op, context)
1642
1634
  case op_graph.op.SampleRows():
1643
- return self._execute_sample_rows(op, context)
1635
+ return await self._execute_sample_rows(op, context)
1644
1636
  case op_graph.op.DescribeColumns():
1645
- return self._execute_describe_columns(op, context)
1637
+ return await self._execute_describe_columns(op, context)
1646
1638
 
1647
- def _execute(
1639
+ async def _execute(
1648
1640
  self,
1649
1641
  op: op_graph.Op,
1650
1642
  context: _InMemoryExecutionContext,
@@ -1681,7 +1673,7 @@ class InMemoryExecutor(OpGraphExecutor):
1681
1673
 
1682
1674
  try:
1683
1675
  _logger.info("starting op execution")
1684
- maybe_lfm = self._do_execute(op=op, context=context)
1676
+ maybe_lfm = await self._do_execute(op=op, context=context)
1685
1677
  finally:
1686
1678
  _logger.info("op execution complete")
1687
1679
  match maybe_lfm:
@@ -1703,29 +1695,33 @@ class InMemoryExecutor(OpGraphExecutor):
1703
1695
  context.computed_batches_for_op_graph[sliced_table] = lfm
1704
1696
  return Ok(lfm)
1705
1697
 
1706
- def execute(
1707
- self, context: ExecutionContext
1698
+ async def execute(
1699
+ self,
1700
+ context: ExecutionContext,
1701
+ worker_threads: ThreadPoolExecutor | None = None,
1708
1702
  ) -> (
1709
1703
  Ok[ExecutionResult]
1710
1704
  | InvalidArgumentError
1711
1705
  | InternalError
1712
1706
  | ResourceExhaustedError
1713
1707
  ):
1714
- with _InMemoryExecutionContext(context) as in_memory_context:
1708
+ with _InMemoryExecutionContext(context, worker_threads) as in_memory_context:
1715
1709
  for table_context in context.tables_to_compute:
1716
1710
  in_memory_context.current_output_context = table_context
1717
1711
  sliced_table = _SlicedTable(
1718
1712
  table_context.table_op_graph, table_context.sql_output_slice_args
1719
1713
  )
1720
1714
  if sliced_table not in in_memory_context.computed_batches_for_op_graph:
1721
- match self._execute(sliced_table.op_graph, in_memory_context):
1715
+ match await self._execute(sliced_table.op_graph, in_memory_context):
1722
1716
  case Ok():
1723
1717
  pass
1724
1718
  case err:
1725
1719
  return err
1726
1720
  args_lfm_iterator = in_memory_context.computed_batches_for_op_graph.items()
1727
1721
  computed_tables = {
1728
- slice_args: _SchemaAndBatches.from_lazy_frame_with_metrics(lfm)
1722
+ slice_args: await _SchemaAndBatches.from_lazy_frame_with_metrics(
1723
+ lfm, worker_threads
1724
+ )
1729
1725
  for slice_args, lfm in args_lfm_iterator
1730
1726
  }
1731
1727