corvic-engine 0.3.0rc66__cp38-abi3-win_amd64.whl → 0.3.0rc68__cp38-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- corvic/engine/_native.pyd +0 -0
- corvic/model/_base_model.py +3 -4
- corvic/model/_completion_model.py +2 -4
- corvic/model/_feature_view.py +5 -6
- corvic/model/_pipeline.py +1 -2
- corvic/model/_resource.py +1 -2
- corvic/model/_source.py +1 -2
- corvic/model/_space.py +1 -2
- corvic/orm/base.py +4 -5
- corvic/orm/ids.py +1 -2
- corvic/orm/mixins.py +18 -9
- corvic/pa_scalar/_temporal.py +1 -1
- corvic/result/__init__.py +1 -2
- corvic/sql/parse_ops.py +5 -1
- corvic/system/_column_encoding.py +215 -0
- corvic/system/_embedder.py +24 -2
- corvic/system/_image_embedder.py +38 -0
- corvic/system/_planner.py +6 -3
- corvic/system/_text_embedder.py +21 -0
- corvic/system/client.py +2 -1
- corvic/system/in_memory_executor.py +503 -507
- corvic/system/op_graph_executor.py +7 -3
- corvic/system/storage.py +1 -3
- corvic/table/table.py +5 -5
- {corvic_engine-0.3.0rc66.dist-info → corvic_engine-0.3.0rc68.dist-info}/METADATA +3 -4
- {corvic_engine-0.3.0rc66.dist-info → corvic_engine-0.3.0rc68.dist-info}/RECORD +28 -27
- {corvic_engine-0.3.0rc66.dist-info → corvic_engine-0.3.0rc68.dist-info}/WHEEL +0 -0
- {corvic_engine-0.3.0rc66.dist-info → corvic_engine-0.3.0rc68.dist-info}/licenses/LICENSE +0 -0
@@ -2,14 +2,15 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import asyncio
|
5
6
|
import dataclasses
|
6
7
|
import datetime
|
7
8
|
import functools
|
8
|
-
import
|
9
|
-
from
|
9
|
+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
10
11
|
from contextlib import AbstractContextManager, ExitStack, nullcontext
|
11
12
|
from types import TracebackType
|
12
|
-
from typing import Any, Final, cast
|
13
|
+
from typing import Any, Final, ParamSpec, Self, TypeVar, cast
|
13
14
|
|
14
15
|
import numpy as np
|
15
16
|
import polars as pl
|
@@ -19,8 +20,9 @@ import pyarrow.parquet as pq
|
|
19
20
|
import structlog
|
20
21
|
from google.protobuf import json_format, struct_pb2
|
21
22
|
from more_itertools import flatten
|
22
|
-
from typing_extensions import
|
23
|
+
from typing_extensions import deprecated
|
23
24
|
|
25
|
+
import corvic.system._column_encoding as column_encoding
|
24
26
|
from corvic import embed, embedding_metric, op_graph, sql
|
25
27
|
from corvic.result import (
|
26
28
|
InternalError,
|
@@ -49,48 +51,51 @@ from corvic_generated.orm.v1 import table_pb2
|
|
49
51
|
|
50
52
|
_logger = structlog.get_logger()
|
51
53
|
|
52
|
-
"""Reference and Maximum number of years for normalizing year in Datetime encoder"""
|
53
|
-
REFERENCE_YEAR: Final = 1900
|
54
|
-
MAX_NUMBER_OF_YEARS: Final = 200
|
55
|
-
|
56
54
|
_MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY: Final = 3
|
57
55
|
|
58
56
|
|
57
|
+
def _collect_and_apply(
|
58
|
+
data: pl.LazyFrame, func: Callable[[pl.DataFrame], _R]
|
59
|
+
) -> tuple[pl.DataFrame, _R]:
|
60
|
+
data_df = data.collect()
|
61
|
+
return data_df, func(data_df)
|
62
|
+
|
63
|
+
|
59
64
|
def get_polars_embedding_length(
|
60
|
-
|
65
|
+
embedding_series: pl.Series,
|
61
66
|
) -> Ok[int] | InvalidArgumentError:
|
62
|
-
outer_type =
|
67
|
+
outer_type = embedding_series.dtype
|
63
68
|
if isinstance(outer_type, pl.Array):
|
64
69
|
return Ok(outer_type.shape[0])
|
65
70
|
if not isinstance(outer_type, pl.List):
|
66
71
|
return InvalidArgumentError("invalid embedding datatype", dtype=str(outer_type))
|
67
|
-
if len(
|
72
|
+
if len(embedding_series) == 0:
|
68
73
|
return InvalidArgumentError(
|
69
74
|
"cannot infer embedding length for empty embedding set"
|
70
75
|
)
|
71
|
-
embedding_length = len(
|
76
|
+
embedding_length = len(embedding_series[0])
|
72
77
|
if embedding_length < 1:
|
73
78
|
return InvalidArgumentError("invalid embedding length", length=embedding_length)
|
74
79
|
return Ok(embedding_length)
|
75
80
|
|
76
81
|
|
77
82
|
def get_polars_embedding(
|
78
|
-
|
83
|
+
embedding_series: pl.Series,
|
79
84
|
) -> Ok[np.ndarray[Any, Any]] | InvalidArgumentError:
|
80
|
-
outer_type =
|
85
|
+
outer_type = embedding_series.dtype
|
81
86
|
if isinstance(outer_type, pl.Array):
|
82
|
-
return Ok(
|
87
|
+
return Ok(embedding_series.to_numpy())
|
83
88
|
if not isinstance(outer_type, pl.List):
|
84
89
|
return InvalidArgumentError("invalid embedding datatype", dtype=str(outer_type))
|
85
|
-
match get_polars_embedding_length(
|
90
|
+
match get_polars_embedding_length(embedding_series):
|
86
91
|
case Ok(embedding_length):
|
87
92
|
pass
|
88
93
|
case InvalidArgumentError() as err:
|
89
94
|
return err
|
90
95
|
return Ok(
|
91
|
-
|
92
|
-
|
93
|
-
.to_numpy()
|
96
|
+
embedding_series.cast(
|
97
|
+
pl.Array(inner=outer_type.inner, shape=embedding_length)
|
98
|
+
).to_numpy()
|
94
99
|
)
|
95
100
|
|
96
101
|
|
@@ -192,8 +197,15 @@ class _SchemaAndBatches:
|
|
192
197
|
metrics: dict[str, Any]
|
193
198
|
|
194
199
|
@classmethod
|
195
|
-
def from_lazy_frame_with_metrics(
|
196
|
-
|
200
|
+
async def from_lazy_frame_with_metrics(
|
201
|
+
cls, lfm: _LazyFrameWithMetrics, worker_threads: ThreadPoolExecutor | None
|
202
|
+
):
|
203
|
+
return cls.from_dataframe(
|
204
|
+
await asyncio.get_running_loop().run_in_executor(
|
205
|
+
worker_threads, lfm.data.collect
|
206
|
+
),
|
207
|
+
lfm.metrics,
|
208
|
+
)
|
197
209
|
|
198
210
|
def to_batch_reader(self):
|
199
211
|
return pa.RecordBatchReader.from_batches(
|
@@ -221,6 +233,10 @@ class _SchemaAndBatches:
|
|
221
233
|
return cls(schema, table.to_batches(), metrics)
|
222
234
|
|
223
235
|
|
236
|
+
_P = ParamSpec("_P")
|
237
|
+
_R = TypeVar("_R")
|
238
|
+
|
239
|
+
|
224
240
|
@dataclasses.dataclass(frozen=True)
|
225
241
|
class _SlicedTable:
|
226
242
|
op_graph: op_graph.Op
|
@@ -230,6 +246,7 @@ class _SlicedTable:
|
|
230
246
|
@dataclasses.dataclass
|
231
247
|
class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContext"]):
|
232
248
|
exec_context: ExecutionContext
|
249
|
+
worker_threads: ThreadPoolExecutor | None
|
233
250
|
current_output_context: TableComputeContext | None = None
|
234
251
|
|
235
252
|
# Using _SchemaAndBatches rather than a RecordBatchReader since the latter's
|
@@ -239,6 +256,7 @@ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContex
|
|
239
256
|
dataclasses.field(default_factory=dict)
|
240
257
|
)
|
241
258
|
exit_stack: ExitStack = dataclasses.field(default_factory=ExitStack)
|
259
|
+
lock: asyncio.Lock = dataclasses.field(default_factory=asyncio.Lock)
|
242
260
|
|
243
261
|
def __enter__(self) -> Self:
|
244
262
|
self.exit_stack = self.exit_stack.__enter__()
|
@@ -252,6 +270,15 @@ class _InMemoryExecutionContext(AbstractContextManager["_InMemoryExecutionContex
|
|
252
270
|
) -> bool | None:
|
253
271
|
return self.exit_stack.__exit__(__exc_type, __exc_value, __traceback)
|
254
272
|
|
273
|
+
async def run_on_worker(
|
274
|
+
self, func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
|
275
|
+
) -> _R:
|
276
|
+
# lock here because polars operations aren't guaranteed to be independent
|
277
|
+
async with self.lock:
|
278
|
+
return await asyncio.get_running_loop().run_in_executor(
|
279
|
+
self.worker_threads, lambda: func(*args, **kwargs)
|
280
|
+
)
|
281
|
+
|
255
282
|
@classmethod
|
256
283
|
def count_source_op_uses(
|
257
284
|
cls,
|
@@ -407,55 +434,55 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
407
434
|
)
|
408
435
|
return Ok(_LazyFrameWithMetrics(data, metrics={}))
|
409
436
|
|
410
|
-
def _execute_rollup_by_aggregation(
|
437
|
+
async def _execute_rollup_by_aggregation(
|
411
438
|
self, op: op_graph.op.RollupByAggregation, context: _InMemoryExecutionContext
|
412
439
|
) -> Ok[_LazyFrameWithMetrics]:
|
413
440
|
raise NotImplementedError(
|
414
441
|
"rollup by aggregation outside of sql not implemented"
|
415
442
|
)
|
416
443
|
|
417
|
-
def _compute_source_then_apply(
|
444
|
+
async def _compute_source_then_apply(
|
418
445
|
self,
|
419
446
|
source: op_graph.Op,
|
420
447
|
lf_op: Callable[[pl.LazyFrame], pl.LazyFrame],
|
421
448
|
context: _InMemoryExecutionContext,
|
422
449
|
):
|
423
|
-
return self._execute(source, context).map(
|
450
|
+
return (await self._execute(source, context)).map(
|
424
451
|
lambda source_lfm: source_lfm.apply(lf_op)
|
425
452
|
)
|
426
453
|
|
427
|
-
def _execute_rename_columns(
|
454
|
+
async def _execute_rename_columns(
|
428
455
|
self, op: op_graph.op.RenameColumns, context: _InMemoryExecutionContext
|
429
456
|
):
|
430
|
-
return self._compute_source_then_apply(
|
457
|
+
return await self._compute_source_then_apply(
|
431
458
|
op.source, lambda lf: lf.rename(dict(op.old_name_to_new)), context
|
432
459
|
)
|
433
460
|
|
434
|
-
def _execute_select_columns(
|
461
|
+
async def _execute_select_columns(
|
435
462
|
self, op: op_graph.op.SelectColumns, context: _InMemoryExecutionContext
|
436
463
|
):
|
437
|
-
return self._compute_source_then_apply(
|
464
|
+
return await self._compute_source_then_apply(
|
438
465
|
op.source, lambda lf: lf.select(op.columns), context
|
439
466
|
)
|
440
467
|
|
441
|
-
def _execute_limit_rows(
|
468
|
+
async def _execute_limit_rows(
|
442
469
|
self, op: op_graph.op.LimitRows, context: _InMemoryExecutionContext
|
443
470
|
):
|
444
|
-
return self._compute_source_then_apply(
|
471
|
+
return await self._compute_source_then_apply(
|
445
472
|
op.source, lambda lf: lf.limit(op.num_rows), context
|
446
473
|
)
|
447
474
|
|
448
|
-
def _execute_offset_rows(
|
475
|
+
async def _execute_offset_rows(
|
449
476
|
self, op: op_graph.op.OffsetRows, context: _InMemoryExecutionContext
|
450
477
|
):
|
451
|
-
return self._compute_source_then_apply(
|
478
|
+
return await self._compute_source_then_apply(
|
452
479
|
op.source, lambda lf: lf.slice(op.num_rows), context
|
453
480
|
)
|
454
481
|
|
455
|
-
def _execute_order_by(
|
482
|
+
async def _execute_order_by(
|
456
483
|
self, op: op_graph.op.OrderBy, context: _InMemoryExecutionContext
|
457
484
|
):
|
458
|
-
return self._compute_source_then_apply(
|
485
|
+
return await self._compute_source_then_apply(
|
459
486
|
op.source, lambda lf: lf.sort(op.columns, descending=op.desc), context
|
460
487
|
)
|
461
488
|
|
@@ -536,7 +563,7 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
536
563
|
case op_graph.row_filter.CombineFilters():
|
537
564
|
return self._row_filter_combination_to_condition(row_filter)
|
538
565
|
|
539
|
-
def _execute_filter_rows(
|
566
|
+
async def _execute_filter_rows(
|
540
567
|
self, op: op_graph.op.FilterRows, context: _InMemoryExecutionContext
|
541
568
|
):
|
542
569
|
match self._row_filter_to_condition(op.row_filter):
|
@@ -544,78 +571,129 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
544
571
|
return InternalError.from_(err)
|
545
572
|
case Ok(row_filter):
|
546
573
|
pass
|
547
|
-
return self._compute_source_then_apply(
|
574
|
+
return await self._compute_source_then_apply(
|
548
575
|
op.source, lambda lf: lf.filter(row_filter), context
|
549
576
|
)
|
550
577
|
|
551
|
-
def
|
578
|
+
def _get_embedding_column_if(
|
579
|
+
self,
|
580
|
+
data: pl.LazyFrame,
|
581
|
+
embedding_column_name: str,
|
582
|
+
pred: Callable[[pl.DataFrame], bool],
|
583
|
+
) -> Ok[tuple[pl.DataFrame, np.ndarray[Any, Any] | None]] | InvalidArgumentError:
|
584
|
+
data_df = data.collect()
|
585
|
+
if pred(data_df):
|
586
|
+
match get_polars_embedding(data_df[embedding_column_name]):
|
587
|
+
case InvalidArgumentError() as err:
|
588
|
+
return err
|
589
|
+
case Ok(embeddings):
|
590
|
+
pass
|
591
|
+
else:
|
592
|
+
embeddings = None
|
593
|
+
return Ok((data_df, embeddings))
|
594
|
+
|
595
|
+
async def _execute_embedding_metrics( # noqa: C901
|
552
596
|
self, op: op_graph.op.EmbeddingMetrics, context: _InMemoryExecutionContext
|
553
597
|
):
|
554
|
-
match self._execute(op.table, context):
|
555
|
-
case Ok(source_lfm):
|
556
|
-
pass
|
557
|
-
case err:
|
558
|
-
return err
|
559
|
-
embedding_df = source_lfm.data.collect()
|
560
|
-
|
561
|
-
if len(embedding_df) < _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY:
|
562
|
-
# downstream consumers handle empty metadata by substituting their
|
563
|
-
# own values
|
564
|
-
return Ok(
|
565
|
-
_LazyFrameWithMetrics(embedding_df.lazy(), metrics=source_lfm.metrics)
|
566
|
-
)
|
567
|
-
|
568
598
|
# before it was configurable, this op assumed that the column's name was
|
569
599
|
# this hardcoded name
|
570
600
|
embedding_column_name = op.embedding_column_name or "embedding"
|
571
|
-
|
572
|
-
|
601
|
+
|
602
|
+
match await self._execute(op.table, context):
|
603
|
+
case Ok(source_lfm):
|
573
604
|
pass
|
605
|
+
case err:
|
606
|
+
return err
|
607
|
+
match await context.run_on_worker(
|
608
|
+
self._get_embedding_column_if,
|
609
|
+
source_lfm.data,
|
610
|
+
embedding_column_name,
|
611
|
+
lambda df: len(df) >= _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY,
|
612
|
+
):
|
574
613
|
case InvalidArgumentError() as err:
|
575
|
-
return
|
614
|
+
return err
|
615
|
+
case Ok(result):
|
616
|
+
embedding_df, embeddings = result
|
617
|
+
|
618
|
+
if embeddings is None:
|
619
|
+
return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), source_lfm.metrics))
|
576
620
|
|
577
621
|
metrics = source_lfm.metrics.copy()
|
578
|
-
|
622
|
+
async with asyncio.TaskGroup() as tg:
|
623
|
+
ne_sum = tg.create_task(
|
624
|
+
context.run_on_worker(
|
625
|
+
embedding_metric.ne_sum, embeddings, normalize=True
|
626
|
+
)
|
627
|
+
)
|
628
|
+
condition_number = tg.create_task(
|
629
|
+
context.run_on_worker(
|
630
|
+
embedding_metric.condition_number, embeddings, normalize=True
|
631
|
+
)
|
632
|
+
)
|
633
|
+
rcondition_number = tg.create_task(
|
634
|
+
context.run_on_worker(
|
635
|
+
embedding_metric.rcondition_number, embeddings, normalize=True
|
636
|
+
)
|
637
|
+
)
|
638
|
+
stable_rank = tg.create_task(
|
639
|
+
context.run_on_worker(
|
640
|
+
embedding_metric.stable_rank, embeddings, normalize=True
|
641
|
+
)
|
642
|
+
)
|
643
|
+
|
644
|
+
match ne_sum.result():
|
579
645
|
case Ok(metric):
|
580
646
|
metrics["ne_sum"] = metric
|
581
647
|
case InvalidArgumentError() as err:
|
582
648
|
_logger.warning("could not compute ne_sum", exc_info=str(err))
|
583
|
-
|
649
|
+
|
650
|
+
match condition_number.result():
|
584
651
|
case Ok(metric):
|
585
652
|
metrics["condition_number"] = metric
|
586
653
|
case InvalidArgumentError() as err:
|
587
654
|
_logger.warning("could not compute condition_number", exc_info=str(err))
|
588
|
-
|
655
|
+
|
656
|
+
match rcondition_number.result():
|
589
657
|
case Ok(metric):
|
590
658
|
metrics["rcondition_number"] = metric
|
591
659
|
case InvalidArgumentError() as err:
|
592
660
|
_logger.warning(
|
593
661
|
"could not compute rcondition_number", exc_info=str(err)
|
594
662
|
)
|
595
|
-
match
|
663
|
+
match stable_rank.result():
|
596
664
|
case Ok(metric):
|
597
665
|
metrics["stable_rank"] = metric
|
598
666
|
case InvalidArgumentError() as err:
|
599
667
|
_logger.warning("could not compute stable_rank", exc_info=str(err))
|
668
|
+
|
600
669
|
return Ok(_LazyFrameWithMetrics(embedding_df.lazy(), metrics=metrics))
|
601
670
|
|
602
|
-
def _execute_embedding_coordinates(
|
671
|
+
async def _execute_embedding_coordinates(
|
603
672
|
self, op: op_graph.op.EmbeddingCoordinates, context: _InMemoryExecutionContext
|
604
673
|
):
|
605
|
-
match self._execute(op.table, context):
|
674
|
+
match await self._execute(op.table, context):
|
606
675
|
case Ok(source_lfm):
|
607
676
|
pass
|
608
677
|
case err:
|
609
678
|
return err
|
610
|
-
embedding_df = source_lfm.data.collect()
|
611
679
|
|
612
680
|
# before it was configurable, this op assumed that the column's name was
|
613
681
|
# this hardcoded name
|
614
682
|
embedding_column_name = op.embedding_column_name or "embedding"
|
683
|
+
match await context.run_on_worker(
|
684
|
+
self._get_embedding_column_if,
|
685
|
+
source_lfm.data,
|
686
|
+
embedding_column_name,
|
687
|
+
lambda df: len(df) >= _MIN_EMBEDDINGS_FOR_EMBEDDINGS_SUMMARY,
|
688
|
+
):
|
689
|
+
case InvalidArgumentError() as err:
|
690
|
+
return err
|
691
|
+
case Ok(result):
|
692
|
+
embedding_df, embeddings = result
|
615
693
|
|
616
694
|
# the neighbors of a point includes itself. That does mean, that an n_neighbors
|
617
695
|
# value of less than 3 simply does not work
|
618
|
-
if
|
696
|
+
if embeddings is None:
|
619
697
|
coordinates_df = embedding_df.lazy().with_columns(
|
620
698
|
pl.Series(
|
621
699
|
name=embedding_column_name,
|
@@ -625,14 +703,11 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
625
703
|
)
|
626
704
|
return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
|
627
705
|
|
628
|
-
match
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
match self._dimension_reducer.reduce_dimensions(
|
635
|
-
embedding, op.n_components, op.metric
|
706
|
+
match await context.run_on_worker(
|
707
|
+
self._dimension_reducer.reduce_dimensions,
|
708
|
+
embeddings,
|
709
|
+
op.n_components,
|
710
|
+
op.metric,
|
636
711
|
):
|
637
712
|
case Ok(coordinates):
|
638
713
|
pass
|
@@ -648,26 +723,37 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
648
723
|
)
|
649
724
|
return Ok(_LazyFrameWithMetrics(coordinates_df, source_lfm.metrics))
|
650
725
|
|
651
|
-
def _execute_distinct_rows(
|
726
|
+
async def _execute_distinct_rows(
|
652
727
|
self, op: op_graph.op.DistinctRows, context: _InMemoryExecutionContext
|
653
728
|
):
|
654
|
-
return self.
|
655
|
-
lambda
|
656
|
-
source_lfm.data.unique(), source_lfm.metrics
|
657
|
-
)
|
729
|
+
return await self._compute_source_then_apply(
|
730
|
+
op.source, lambda source: source.unique(), context
|
658
731
|
)
|
659
732
|
|
660
|
-
def _execute_join(
|
661
|
-
|
733
|
+
async def _execute_join(
|
734
|
+
self, op: op_graph.op.Join, context: _InMemoryExecutionContext
|
735
|
+
):
|
736
|
+
async with asyncio.TaskGroup() as tg:
|
737
|
+
left_task = tg.create_task(self._execute(op.left_source, context))
|
738
|
+
right_task = tg.create_task(self._execute(op.right_source, context))
|
739
|
+
match left_task.result():
|
740
|
+
case (
|
741
|
+
InternalError()
|
742
|
+
| ResourceExhaustedError()
|
743
|
+
| InvalidArgumentError() as err
|
744
|
+
):
|
745
|
+
return err
|
662
746
|
case Ok(left_lfm):
|
663
747
|
pass
|
664
|
-
|
748
|
+
match right_task.result():
|
749
|
+
case (
|
750
|
+
InternalError()
|
751
|
+
| ResourceExhaustedError()
|
752
|
+
| InvalidArgumentError() as err
|
753
|
+
):
|
665
754
|
return err
|
666
|
-
match self._execute(op.right_source, context):
|
667
755
|
case Ok(right_lfm):
|
668
756
|
pass
|
669
|
-
case err:
|
670
|
-
return err
|
671
757
|
left_lf = left_lfm.data
|
672
758
|
right_lf = right_lfm.data
|
673
759
|
|
@@ -702,20 +788,31 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
702
788
|
)
|
703
789
|
)
|
704
790
|
|
705
|
-
def _execute_empty(
|
791
|
+
async def _execute_empty(
|
792
|
+
self, op: op_graph.op.Empty, context: _InMemoryExecutionContext
|
793
|
+
):
|
706
794
|
empty_table = cast(pl.DataFrame, pl.from_arrow(pa.schema([]).empty_table()))
|
707
795
|
return Ok(_LazyFrameWithMetrics(empty_table.lazy(), metrics={}))
|
708
796
|
|
709
|
-
def _execute_concat(
|
797
|
+
async def _execute_concat(
|
710
798
|
self, op: op_graph.op.Concat, context: _InMemoryExecutionContext
|
711
799
|
):
|
800
|
+
async with asyncio.TaskGroup() as tg:
|
801
|
+
tasks = [
|
802
|
+
tg.create_task(self._execute(table, context)) for table in op.tables
|
803
|
+
]
|
804
|
+
|
712
805
|
source_lfms = list[_LazyFrameWithMetrics]()
|
713
|
-
for
|
714
|
-
match
|
715
|
-
case
|
716
|
-
|
717
|
-
|
806
|
+
for task in tasks:
|
807
|
+
match task.result():
|
808
|
+
case (
|
809
|
+
InternalError()
|
810
|
+
| ResourceExhaustedError()
|
811
|
+
| InvalidArgumentError() as err
|
812
|
+
):
|
718
813
|
return err
|
814
|
+
case Ok(lfm):
|
815
|
+
source_lfms.append(lfm)
|
719
816
|
data = pl.concat([lfm.data for lfm in source_lfms], how=op.how)
|
720
817
|
metrics = dict[str, Any]()
|
721
818
|
for lfm in source_lfms:
|
@@ -794,16 +891,19 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
794
891
|
context,
|
795
892
|
)
|
796
893
|
|
797
|
-
def _execute_embed_column(
|
894
|
+
async def _execute_embed_column(
|
798
895
|
self, op: op_graph.op.EmbedColumn, context: _InMemoryExecutionContext
|
799
896
|
):
|
800
|
-
match self._execute(op.source, context):
|
897
|
+
match await self._execute(op.source, context):
|
801
898
|
case Ok(source_lfm):
|
802
899
|
pass
|
803
900
|
case err:
|
804
901
|
return err
|
805
|
-
source_df =
|
806
|
-
|
902
|
+
source_df, to_embed = await context.run_on_worker(
|
903
|
+
_collect_and_apply,
|
904
|
+
source_lfm.data,
|
905
|
+
lambda df: df[op.column_name].cast(pl.String),
|
906
|
+
)
|
807
907
|
|
808
908
|
embed_context = EmbedTextContext(
|
809
909
|
inputs=to_embed,
|
@@ -813,112 +913,19 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
813
913
|
expected_coordinate_bitwidth=op.expected_coordinate_bitwidth,
|
814
914
|
room_id=context.exec_context.room_id,
|
815
915
|
)
|
816
|
-
match self._text_embedder.
|
916
|
+
match await self._text_embedder.aembed(embed_context, context.worker_threads):
|
817
917
|
case Ok(result):
|
818
918
|
pass
|
819
919
|
case InvalidArgumentError() | InternalError() as err:
|
820
920
|
raise InternalError("Failed to embed column") from err
|
821
921
|
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
return Ok(source_lfm.with_data(result_df))
|
829
|
-
|
830
|
-
@staticmethod
|
831
|
-
def get_cyclic_encoding(
|
832
|
-
series: pl.Series,
|
833
|
-
period: int,
|
834
|
-
) -> tuple[pl.Series, pl.Series]:
|
835
|
-
sine_series = (2 * math.pi * series / period).sin().alias(f"{series.name}_sine")
|
836
|
-
cosine_series = (
|
837
|
-
(2 * math.pi * series / period).cos().alias(f"{series.name}_cosine")
|
922
|
+
return Ok(
|
923
|
+
source_lfm.with_data(
|
924
|
+
source_df.lazy()
|
925
|
+
.with_columns(result.embeddings.alias(op.embedding_column_name))
|
926
|
+
.drop_nulls(op.embedding_column_name)
|
927
|
+
)
|
838
928
|
)
|
839
|
-
return sine_series, cosine_series
|
840
|
-
|
841
|
-
@staticmethod
|
842
|
-
def encode_datetime(series: pl.Series) -> pl.Series:
|
843
|
-
match series.dtype:
|
844
|
-
case pl.Date | pl.Time:
|
845
|
-
pass
|
846
|
-
case pl.Datetime:
|
847
|
-
series = series.dt.replace_time_zone("UTC")
|
848
|
-
case _:
|
849
|
-
raise ValueError("Invalid arguments, expected a datetime series")
|
850
|
-
|
851
|
-
if series.is_null().all():
|
852
|
-
zero_vector = pl.zeros(11, dtype=pl.Float32, eager=True)
|
853
|
-
return pl.Series([zero_vector] * len(series), dtype=pl.List(pl.Float32))
|
854
|
-
|
855
|
-
n = len(series)
|
856
|
-
year_norm = pl.zeros(n, dtype=pl.Float32, eager=True).alias("year")
|
857
|
-
month_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_sine")
|
858
|
-
month_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_cosine")
|
859
|
-
day_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_sine")
|
860
|
-
day_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_cosine")
|
861
|
-
hour_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_sine")
|
862
|
-
hour_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_cosine")
|
863
|
-
minute_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_sine")
|
864
|
-
minute_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_cosine")
|
865
|
-
second_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_sine")
|
866
|
-
second_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_cosine")
|
867
|
-
|
868
|
-
if series.dtype in [pl.Date, pl.Datetime]:
|
869
|
-
try:
|
870
|
-
year = series.dt.year().cast(pl.Float32).alias("year")
|
871
|
-
month = series.dt.month().cast(pl.Float32).alias("month")
|
872
|
-
day = series.dt.day().cast(pl.Float32).alias("day")
|
873
|
-
|
874
|
-
year_norm = (year - REFERENCE_YEAR) / MAX_NUMBER_OF_YEARS
|
875
|
-
month_sine, month_cosine = InMemoryExecutor.get_cyclic_encoding(
|
876
|
-
month, 12
|
877
|
-
)
|
878
|
-
day_sine, day_cosine = InMemoryExecutor.get_cyclic_encoding(day, 31)
|
879
|
-
except pl.exceptions.PanicException as e:
|
880
|
-
_logger.exception("Error extracting datetime", exc_info=e)
|
881
|
-
|
882
|
-
if series.dtype in [pl.Time, pl.Datetime]:
|
883
|
-
try:
|
884
|
-
hour = series.dt.hour().cast(pl.Float32).alias("hour")
|
885
|
-
minute = series.dt.minute().cast(pl.Float32).alias("minute")
|
886
|
-
second = series.dt.second().cast(pl.Float32).alias("second")
|
887
|
-
|
888
|
-
hour_sine, hour_cosine = InMemoryExecutor.get_cyclic_encoding(hour, 24)
|
889
|
-
minute_sine, minute_cosine = InMemoryExecutor.get_cyclic_encoding(
|
890
|
-
minute, 60
|
891
|
-
)
|
892
|
-
second_sine, second_cosine = InMemoryExecutor.get_cyclic_encoding(
|
893
|
-
second, 60
|
894
|
-
)
|
895
|
-
except pl.exceptions.PanicException as e:
|
896
|
-
_logger.exception("Error extracting datetime", exc_info=e)
|
897
|
-
|
898
|
-
return pl.DataFrame(
|
899
|
-
[
|
900
|
-
year_norm.fill_null(0.0),
|
901
|
-
month_sine.fill_null(0.0),
|
902
|
-
month_cosine.fill_null(0.0),
|
903
|
-
day_sine.fill_null(0.0),
|
904
|
-
day_cosine.fill_null(0.0),
|
905
|
-
hour_sine.fill_null(0.0),
|
906
|
-
hour_cosine.fill_null(0.0),
|
907
|
-
minute_sine.fill_null(0.0),
|
908
|
-
minute_cosine.fill_null(0.0),
|
909
|
-
second_sine.fill_null(0.0),
|
910
|
-
second_cosine.fill_null(0.0),
|
911
|
-
]
|
912
|
-
).select(pl.concat_list(pl.all()).alias(series.name))[series.name]
|
913
|
-
|
914
|
-
@staticmethod
|
915
|
-
def encode_duration(series: pl.Series) -> pl.Series:
|
916
|
-
if series.dtype != pl.Duration:
|
917
|
-
raise ValueError("Invalid arguments, expected a duration series")
|
918
|
-
if series.is_null().all():
|
919
|
-
return pl.zeros(len(series), dtype=pl.Float32, eager=True)
|
920
|
-
|
921
|
-
return series.dt.total_seconds().cast(pl.Float32).fill_null(0.0)
|
922
929
|
|
923
930
|
@staticmethod
|
924
931
|
def encode_text(series: pl.Series) -> pl.Series:
|
@@ -939,132 +946,71 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
939
946
|
pl.List(pl.Float32),
|
940
947
|
)
|
941
948
|
|
942
|
-
def
|
949
|
+
def _encode_column( # noqa: C901
|
950
|
+
self, to_encode: pl.Series, encoder: op_graph.Encoder
|
951
|
+
) -> tuple[pl.Series, list[str] | None]:
|
952
|
+
match encoder:
|
953
|
+
case op_graph.encoder.OneHotEncoder():
|
954
|
+
return column_encoding.encode_one_hot(to_encode)
|
955
|
+
|
956
|
+
case op_graph.encoder.MinMaxScaler():
|
957
|
+
return column_encoding.encode_min_max_scale(
|
958
|
+
to_encode, encoder.feature_range_min, encoder.feature_range_max
|
959
|
+
), None
|
960
|
+
|
961
|
+
case op_graph.encoder.LabelBinarizer():
|
962
|
+
return column_encoding.encode_label_boolean(
|
963
|
+
to_encode, encoder.neg_label, encoder.pos_label
|
964
|
+
), None
|
965
|
+
|
966
|
+
case op_graph.encoder.LabelEncoder():
|
967
|
+
return column_encoding.encode_label(
|
968
|
+
to_encode, normalize=encoder.normalize
|
969
|
+
), None
|
970
|
+
|
971
|
+
case op_graph.encoder.KBinsDiscretizer():
|
972
|
+
return column_encoding.encode_kbins(
|
973
|
+
to_encode, encoder.n_bins, encoder.encode_method, encoder.strategy
|
974
|
+
), None
|
975
|
+
|
976
|
+
case op_graph.encoder.Binarizer():
|
977
|
+
return column_encoding.encode_boolean(
|
978
|
+
to_encode, encoder.threshold
|
979
|
+
), None
|
980
|
+
|
981
|
+
case op_graph.encoder.MaxAbsScaler():
|
982
|
+
return column_encoding.encode_max_abs_scale(to_encode), None
|
983
|
+
|
984
|
+
case op_graph.encoder.StandardScaler():
|
985
|
+
return column_encoding.encode_standard_scale(
|
986
|
+
to_encode, with_mean=encoder.with_mean, with_std=encoder.with_std
|
987
|
+
), None
|
988
|
+
|
989
|
+
case op_graph.encoder.TimestampEncoder():
|
990
|
+
if to_encode.dtype == pl.datatypes.Duration:
|
991
|
+
return column_encoding.encode_duration(to_encode), None
|
992
|
+
return column_encoding.encode_datetime(to_encode), None
|
993
|
+
|
994
|
+
case op_graph.encoder.TextEncoder():
|
995
|
+
return self.encode_text(to_encode), None
|
996
|
+
|
997
|
+
async def _execute_encode_columns(
|
943
998
|
self, op: op_graph.op.EncodeColumns, context: _InMemoryExecutionContext
|
944
999
|
):
|
945
|
-
match self._execute(op.source, context):
|
1000
|
+
match await self._execute(op.source, context):
|
946
1001
|
case Ok(source_lfm):
|
947
1002
|
pass
|
948
1003
|
case err:
|
949
1004
|
return err
|
950
|
-
source_df = source_lfm.data.collect
|
1005
|
+
source_df = await context.run_on_worker(source_lfm.data.collect)
|
951
1006
|
metrics = source_lfm.metrics.copy()
|
952
1007
|
metric = metrics.get("one_hot_encoder", {})
|
953
1008
|
for encoder_arg in op.encoded_columns:
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
encoded = encoded.select(
|
960
|
-
pl.concat_list(pl.all())
|
961
|
-
.alias(encoder_arg.encoded_column_name)
|
962
|
-
.cast(pl.List(pl.Boolean))
|
963
|
-
)
|
964
|
-
|
965
|
-
case op_graph.encoder.MinMaxScaler():
|
966
|
-
from sklearn.preprocessing import MinMaxScaler
|
967
|
-
|
968
|
-
encoder = MinMaxScaler(
|
969
|
-
feature_range=(
|
970
|
-
encoder_arg.encoder.feature_range_min,
|
971
|
-
encoder_arg.encoder.feature_range_max,
|
972
|
-
)
|
973
|
-
)
|
974
|
-
encoded = encoder.fit_transform(
|
975
|
-
to_encode.to_numpy().reshape(-1, 1)
|
976
|
-
).flatten()
|
977
|
-
|
978
|
-
case op_graph.encoder.LabelBinarizer():
|
979
|
-
from sklearn.preprocessing import LabelBinarizer
|
980
|
-
|
981
|
-
encoder = LabelBinarizer(
|
982
|
-
neg_label=encoder_arg.encoder.neg_label,
|
983
|
-
pos_label=encoder_arg.encoder.pos_label,
|
984
|
-
)
|
985
|
-
encoded = encoder.fit_transform(to_encode.to_numpy().reshape(-1))
|
986
|
-
|
987
|
-
case op_graph.encoder.LabelEncoder():
|
988
|
-
from sklearn.preprocessing import LabelEncoder
|
989
|
-
|
990
|
-
encoder = LabelEncoder()
|
991
|
-
encoded = encoder.fit_transform(
|
992
|
-
to_encode.to_numpy().reshape(-1)
|
993
|
-
).flatten()
|
994
|
-
# `classes_` is only set after fit,
|
995
|
-
# Creating custom typestubs will not solve this typing issue.
|
996
|
-
if encoder_arg.encoder.normalize and hasattr(encoder, "classes_"):
|
997
|
-
classes_ = cast(list[int], encoder.classes_) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
|
998
|
-
max_class: int = len(classes_) - 1
|
999
|
-
if max_class > 0:
|
1000
|
-
encoded = encoded.astype(np.float64)
|
1001
|
-
encoded /= max_class
|
1002
|
-
|
1003
|
-
case op_graph.encoder.KBinsDiscretizer():
|
1004
|
-
from sklearn.preprocessing import KBinsDiscretizer
|
1005
|
-
|
1006
|
-
encoder = KBinsDiscretizer(
|
1007
|
-
n_bins=encoder_arg.encoder.n_bins,
|
1008
|
-
encode=encoder_arg.encoder.encode_method,
|
1009
|
-
strategy=encoder_arg.encoder.strategy,
|
1010
|
-
dtype=np.float32,
|
1011
|
-
)
|
1012
|
-
encoded = encoder.fit_transform(
|
1013
|
-
to_encode.to_numpy().reshape(-1, 1)
|
1014
|
-
).flatten()
|
1015
|
-
|
1016
|
-
case op_graph.encoder.Binarizer():
|
1017
|
-
from sklearn.preprocessing import Binarizer
|
1018
|
-
|
1019
|
-
encoder = Binarizer(
|
1020
|
-
threshold=encoder_arg.encoder.threshold,
|
1021
|
-
)
|
1022
|
-
encoded = encoder.fit_transform(
|
1023
|
-
to_encode.to_numpy().reshape(-1, 1)
|
1024
|
-
).flatten()
|
1025
|
-
|
1026
|
-
case op_graph.encoder.MaxAbsScaler():
|
1027
|
-
from sklearn.preprocessing import MaxAbsScaler
|
1028
|
-
|
1029
|
-
encoder = MaxAbsScaler()
|
1030
|
-
try:
|
1031
|
-
encoded = encoder.fit_transform(
|
1032
|
-
np.nan_to_num(to_encode.to_numpy()).reshape(-1, 1)
|
1033
|
-
).flatten()
|
1034
|
-
except ValueError:
|
1035
|
-
encoded = np.array([])
|
1036
|
-
|
1037
|
-
case op_graph.encoder.StandardScaler():
|
1038
|
-
from sklearn.preprocessing import StandardScaler
|
1039
|
-
|
1040
|
-
encoder = StandardScaler(
|
1041
|
-
with_mean=encoder_arg.encoder.with_mean,
|
1042
|
-
with_std=encoder_arg.encoder.with_std,
|
1043
|
-
)
|
1044
|
-
encoded = encoder.fit_transform(
|
1045
|
-
to_encode.to_numpy().reshape(-1, 1)
|
1046
|
-
).flatten()
|
1047
|
-
|
1048
|
-
case op_graph.encoder.TimestampEncoder():
|
1049
|
-
if to_encode.dtype == pl.datatypes.Duration:
|
1050
|
-
encoded = self.encode_duration(to_encode)
|
1051
|
-
else:
|
1052
|
-
encoded = self.encode_datetime(to_encode)
|
1053
|
-
source_df = source_df.with_columns(
|
1054
|
-
encoded.rename(encoder_arg.encoded_column_name).cast(
|
1055
|
-
encoder_arg.encoder.output_dtype
|
1056
|
-
)
|
1057
|
-
)
|
1058
|
-
continue
|
1059
|
-
|
1060
|
-
case op_graph.encoder.TextEncoder():
|
1061
|
-
encoded = self.encode_text(to_encode)
|
1062
|
-
source_df = source_df.with_columns(
|
1063
|
-
encoded.rename(encoder_arg.encoded_column_name).cast(
|
1064
|
-
encoder_arg.encoder.output_dtype
|
1065
|
-
)
|
1066
|
-
)
|
1067
|
-
continue
|
1009
|
+
encoded, one_hot_columns = self._encode_column(
|
1010
|
+
source_df[encoder_arg.column_name], encoder_arg.encoder
|
1011
|
+
)
|
1012
|
+
if one_hot_columns is not None:
|
1013
|
+
metric[encoder_arg.column_name] = one_hot_columns
|
1068
1014
|
|
1069
1015
|
source_df = source_df.with_columns(
|
1070
1016
|
pl.Series(
|
@@ -1081,7 +1027,7 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1081
1027
|
)
|
1082
1028
|
)
|
1083
1029
|
|
1084
|
-
def _execute_embed_node2vec_from_edge_lists(
|
1030
|
+
async def _execute_embed_node2vec_from_edge_lists(
|
1085
1031
|
self,
|
1086
1032
|
op: op_graph.op.EmbedNode2vecFromEdgeLists,
|
1087
1033
|
context: _InMemoryExecutionContext,
|
@@ -1115,7 +1061,7 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1115
1061
|
|
1116
1062
|
edge_list_lfms = list[_LazyFrameWithMetrics]()
|
1117
1063
|
for edge_list in op.edge_list_tables:
|
1118
|
-
match self._execute(edge_list.table, context):
|
1064
|
+
match await self._execute(edge_list.table, context):
|
1119
1065
|
case Ok(source_lfm):
|
1120
1066
|
edge_list_lfms.append(source_lfm)
|
1121
1067
|
case err:
|
@@ -1129,58 +1075,65 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1129
1075
|
end_column_type_name = entities_dtypes[end_column_name]
|
1130
1076
|
metrics.update(lfm.metrics)
|
1131
1077
|
yield (
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1078
|
+
context.run_on_worker(
|
1079
|
+
lfm.data.with_columns(
|
1080
|
+
pl.col(edge_list.start_column_name).alias(
|
1081
|
+
f"start_id_{start_column_type_name}"
|
1082
|
+
),
|
1083
|
+
pl.lit(edge_list.start_entity_name).alias("start_source"),
|
1084
|
+
pl.col(edge_list.end_column_name).alias(
|
1085
|
+
f"end_id_{end_column_type_name}"
|
1086
|
+
),
|
1087
|
+
pl.lit(edge_list.end_entity_name).alias("end_source"),
|
1088
|
+
)
|
1089
|
+
.select(
|
1090
|
+
f"start_id_{start_column_type_name}",
|
1091
|
+
"start_source",
|
1092
|
+
f"end_id_{end_column_type_name}",
|
1093
|
+
"end_source",
|
1094
|
+
)
|
1095
|
+
.collect
|
1147
1096
|
)
|
1148
|
-
.collect()
|
1149
1097
|
)
|
1150
1098
|
|
1151
|
-
|
1152
|
-
[
|
1153
|
-
empty_edges_table,
|
1154
|
-
*(edge_list for edge_list in edge_generator()),
|
1155
|
-
],
|
1156
|
-
rechunk=False,
|
1157
|
-
how="diagonal",
|
1158
|
-
)
|
1099
|
+
async with asyncio.TaskGroup() as tg:
|
1100
|
+
edge_tasks = [tg.create_task(edge_list) for edge_list in edge_generator()]
|
1159
1101
|
|
1160
|
-
|
1161
|
-
edges=
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1102
|
+
def run_n2v():
|
1103
|
+
edges = pl.concat(
|
1104
|
+
[
|
1105
|
+
empty_edges_table,
|
1106
|
+
*(task.result() for task in edge_tasks),
|
1107
|
+
],
|
1108
|
+
rechunk=False,
|
1109
|
+
how="diagonal",
|
1110
|
+
)
|
1111
|
+
n2v_space = embed.Space(
|
1112
|
+
edges=edges,
|
1113
|
+
start_id_column_names=start_id_column_names,
|
1114
|
+
end_id_column_names=end_id_column_names,
|
1115
|
+
directed=True,
|
1116
|
+
)
|
1117
|
+
n2v_runner = embed.Node2Vec(
|
1118
|
+
space=n2v_space,
|
1119
|
+
dim=op.ndim,
|
1120
|
+
walk_length=op.walk_length,
|
1121
|
+
window=op.window,
|
1122
|
+
p=op.p,
|
1123
|
+
q=op.q,
|
1124
|
+
alpha=op.alpha,
|
1125
|
+
min_alpha=op.min_alpha,
|
1126
|
+
negative=op.negative,
|
1127
|
+
)
|
1128
|
+
n2v_runner.train(epochs=op.epochs)
|
1129
|
+
return n2v_runner.wv.to_polars().lazy()
|
1130
|
+
|
1131
|
+
return Ok(_LazyFrameWithMetrics(await context.run_on_worker(run_n2v), metrics))
|
1179
1132
|
|
1180
|
-
def _execute_aggregate_columns(
|
1133
|
+
async def _execute_aggregate_columns(
|
1181
1134
|
self, op: op_graph.op.AggregateColumns, context: _InMemoryExecutionContext
|
1182
1135
|
):
|
1183
|
-
match self._execute(op.source, context):
|
1136
|
+
match await self._execute(op.source, context):
|
1184
1137
|
case Ok(source_lfm):
|
1185
1138
|
pass
|
1186
1139
|
case err:
|
@@ -1205,38 +1158,48 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1205
1158
|
|
1206
1159
|
return Ok(source_lfm.with_data(aggregate))
|
1207
1160
|
|
1208
|
-
def _execute_correlate_columns(
|
1161
|
+
async def _execute_correlate_columns(
|
1209
1162
|
self, op: op_graph.op.CorrelateColumns, context: _InMemoryExecutionContext
|
1210
1163
|
):
|
1211
|
-
match self._execute(op.source, context):
|
1164
|
+
match await self._execute(op.source, context):
|
1212
1165
|
case Ok(source_lfm):
|
1213
1166
|
pass
|
1214
1167
|
case err:
|
1215
1168
|
return err
|
1216
|
-
source_df = source_lfm.data.collect()
|
1217
|
-
with np.errstate(invalid="ignore"):
|
1218
|
-
corr_df = source_df.select(op.column_names).corr(dtype="float32")
|
1219
1169
|
|
1170
|
+
def correlate(df: pl.DataFrame):
|
1171
|
+
with np.errstate(invalid="ignore"):
|
1172
|
+
return df.select(op.column_names).corr(dtype="float32")
|
1173
|
+
|
1174
|
+
_, corr_df = await context.run_on_worker(
|
1175
|
+
_collect_and_apply, source_lfm.data, correlate
|
1176
|
+
)
|
1220
1177
|
return Ok(source_lfm.with_data(corr_df.lazy()))
|
1221
1178
|
|
1222
|
-
def _execute_histogram_column(
|
1179
|
+
async def _execute_histogram_column(
|
1223
1180
|
self, op: op_graph.op.HistogramColumn, context: _InMemoryExecutionContext
|
1224
1181
|
):
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1182
|
+
match await self._execute(op.source, context):
|
1183
|
+
case Ok(source_lfm):
|
1184
|
+
pass
|
1185
|
+
case err:
|
1186
|
+
return err
|
1187
|
+
|
1188
|
+
_, result_df = await context.run_on_worker(
|
1189
|
+
_collect_and_apply,
|
1190
|
+
source_lfm.data,
|
1191
|
+
lambda df: df[op.column_name]
|
1228
1192
|
.hist(include_category=False)
|
1229
|
-
.lazy()
|
1230
1193
|
.rename(
|
1231
1194
|
{
|
1232
1195
|
"breakpoint": op.breakpoint_column_name,
|
1233
1196
|
"count": op.count_column_name,
|
1234
1197
|
}
|
1235
1198
|
),
|
1236
|
-
context,
|
1237
1199
|
)
|
1200
|
+
return Ok(source_lfm.with_data(result_df.lazy()))
|
1238
1201
|
|
1239
|
-
def _execute_convert_column_to_string(
|
1202
|
+
async def _execute_convert_column_to_string(
|
1240
1203
|
self, op: op_graph.op.ConvertColumnToString, context: _InMemoryExecutionContext
|
1241
1204
|
):
|
1242
1205
|
dtype = op.source.schema.to_polars()[op.column_name]
|
@@ -1248,14 +1211,15 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1248
1211
|
raise NotImplementedError(
|
1249
1212
|
"converting struct columns to strings is not implemented"
|
1250
1213
|
)
|
1251
|
-
|
1252
|
-
|
1214
|
+
|
1215
|
+
return await self._compute_source_then_apply(
|
1216
|
+
op.source, lambda lf: lf.with_columns(cast_expr), context
|
1253
1217
|
)
|
1254
1218
|
|
1255
|
-
def _execute_add_row_index(
|
1219
|
+
async def _execute_add_row_index(
|
1256
1220
|
self, op: op_graph.op.AddRowIndex, context: _InMemoryExecutionContext
|
1257
1221
|
):
|
1258
|
-
return self._compute_source_then_apply(
|
1222
|
+
return await self._compute_source_then_apply(
|
1259
1223
|
op.source,
|
1260
1224
|
lambda lf: lf.with_row_index(
|
1261
1225
|
name=op.row_index_column_name, offset=op.offset
|
@@ -1263,70 +1227,76 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1263
1227
|
context,
|
1264
1228
|
)
|
1265
1229
|
|
1266
|
-
def _execute_output_csv(
|
1230
|
+
async def _execute_output_csv(
|
1267
1231
|
self, op: op_graph.op.OutputCsv, context: _InMemoryExecutionContext
|
1268
1232
|
):
|
1269
|
-
match self._execute(op.source, context):
|
1233
|
+
match await self._execute(op.source, context):
|
1270
1234
|
case Ok(source_lfm):
|
1271
1235
|
pass
|
1272
1236
|
case err:
|
1273
1237
|
return err
|
1274
|
-
source_df =
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1238
|
+
source_df, _ = await context.run_on_worker(
|
1239
|
+
_collect_and_apply,
|
1240
|
+
source_lfm.data,
|
1241
|
+
lambda df: df.write_csv(
|
1242
|
+
op.csv_url, quote_style="never", include_header=op.include_header
|
1243
|
+
),
|
1279
1244
|
)
|
1280
1245
|
return Ok(source_lfm.with_data(source_df.lazy()))
|
1281
1246
|
|
1282
|
-
def _execute_truncate_list(
|
1247
|
+
async def _execute_truncate_list(
|
1283
1248
|
self, op: op_graph.op.TruncateList, context: _InMemoryExecutionContext
|
1284
1249
|
):
|
1285
1250
|
# TODO(Patrick): verify this approach works for arrays
|
1286
|
-
match self._execute(op.source, context):
|
1251
|
+
match await self._execute(op.source, context):
|
1287
1252
|
case Ok(source_lfm):
|
1288
1253
|
pass
|
1289
1254
|
case err:
|
1290
1255
|
return err
|
1291
|
-
source_df =
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
existing_length = 0
|
1298
|
-
head_length = (
|
1299
|
-
op.target_column_length
|
1300
|
-
if existing_length >= op.target_column_length
|
1301
|
-
else existing_length
|
1302
|
-
)
|
1303
|
-
source_df = source_df.with_columns(
|
1304
|
-
pl.col(op.column_name).list.head(head_length)
|
1256
|
+
source_df, existing_length = await context.run_on_worker(
|
1257
|
+
_collect_and_apply,
|
1258
|
+
source_lfm.data,
|
1259
|
+
lambda df: get_polars_embedding_length(df[op.column_name]).unwrap_or_raise()
|
1260
|
+
if len(df)
|
1261
|
+
else 0,
|
1305
1262
|
)
|
1263
|
+
|
1306
1264
|
outer_type = source_df.schema[op.column_name]
|
1307
1265
|
if isinstance(outer_type, pl.Array | pl.List):
|
1308
1266
|
inner_type = outer_type.inner
|
1309
1267
|
else:
|
1310
1268
|
return InternalError("unexpected type", cause="expected list or array type")
|
1269
|
+
result = source_df.lazy()
|
1270
|
+
|
1271
|
+
head_length = (
|
1272
|
+
op.target_column_length
|
1273
|
+
if existing_length >= op.target_column_length
|
1274
|
+
else existing_length
|
1275
|
+
)
|
1276
|
+
result = result.with_columns(pl.col(op.column_name).list.head(head_length))
|
1311
1277
|
|
1312
|
-
source_df = source_df.lazy()
|
1313
1278
|
if head_length < op.target_column_length:
|
1314
1279
|
padding_length = op.target_column_length - head_length
|
1315
1280
|
padding = [op.padding_value_as_py] * padding_length
|
1316
|
-
|
1317
|
-
|
1318
|
-
)
|
1319
|
-
source_df = source_df.with_columns(
|
1281
|
+
result = result.with_columns(pl.col(op.column_name).list.concat(padding))
|
1282
|
+
result = result.with_columns(
|
1320
1283
|
pl.col(op.column_name)
|
1321
1284
|
.list.to_array(width=op.target_column_length)
|
1322
1285
|
.cast(pl.List(inner_type))
|
1323
1286
|
)
|
1324
|
-
return Ok(source_lfm.with_data(
|
1287
|
+
return Ok(source_lfm.with_data(result))
|
1325
1288
|
|
1326
|
-
def _execute_union(
|
1289
|
+
async def _execute_union(
|
1290
|
+
self, op: op_graph.op.Union, context: _InMemoryExecutionContext
|
1291
|
+
):
|
1292
|
+
async with asyncio.TaskGroup() as tg:
|
1293
|
+
source_taks = [
|
1294
|
+
tg.create_task(self._execute(source, context))
|
1295
|
+
for source in op.sources()
|
1296
|
+
]
|
1327
1297
|
sources = list[_LazyFrameWithMetrics]()
|
1328
|
-
for
|
1329
|
-
match
|
1298
|
+
for task in source_taks:
|
1299
|
+
match task.result():
|
1330
1300
|
case Ok(source_lfm):
|
1331
1301
|
sources.append(source_lfm)
|
1332
1302
|
case err:
|
@@ -1341,16 +1311,19 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1341
1311
|
result_lf = result_lf.unique()
|
1342
1312
|
return Ok(_LazyFrameWithMetrics(result_lf, metrics=metrics))
|
1343
1313
|
|
1344
|
-
def _execute_embed_image_column(
|
1314
|
+
async def _execute_embed_image_column(
|
1345
1315
|
self, op: op_graph.op.EmbedImageColumn, context: _InMemoryExecutionContext
|
1346
1316
|
):
|
1347
|
-
match self._execute(op.source, context):
|
1317
|
+
match await self._execute(op.source, context):
|
1348
1318
|
case Ok(source_lfm):
|
1349
1319
|
pass
|
1350
1320
|
case err:
|
1351
1321
|
return err
|
1352
|
-
source_df =
|
1353
|
-
|
1322
|
+
source_df, to_embed = await context.run_on_worker(
|
1323
|
+
_collect_and_apply,
|
1324
|
+
source_lfm.data,
|
1325
|
+
lambda df: df[op.column_name].cast(pl.Binary()),
|
1326
|
+
)
|
1354
1327
|
|
1355
1328
|
embed_context = EmbedImageContext(
|
1356
1329
|
inputs=to_embed,
|
@@ -1358,43 +1331,38 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1358
1331
|
expected_vector_length=op.expected_vector_length,
|
1359
1332
|
expected_coordinate_bitwidth=op.expected_coordinate_bitwidth,
|
1360
1333
|
)
|
1361
|
-
match self._image_embedder.
|
1334
|
+
match await self._image_embedder.aembed(embed_context, context.worker_threads):
|
1362
1335
|
case Ok(result):
|
1363
1336
|
pass
|
1364
1337
|
case InvalidArgumentError() | InternalError() as err:
|
1365
1338
|
raise InternalError("Failed to embed column") from err
|
1366
1339
|
|
1367
1340
|
return Ok(
|
1368
|
-
|
1341
|
+
source_lfm.with_data(
|
1369
1342
|
source_df.lazy()
|
1370
1343
|
.with_columns(result.embeddings.alias(op.embedding_column_name))
|
1371
|
-
.drop_nulls(op.embedding_column_name)
|
1372
|
-
source_lfm.metrics,
|
1344
|
+
.drop_nulls(op.embedding_column_name)
|
1373
1345
|
)
|
1374
1346
|
)
|
1375
1347
|
|
1376
|
-
def
|
1377
|
-
self,
|
1348
|
+
def _compute_decision_tree_summary(
|
1349
|
+
self,
|
1350
|
+
data: pl.DataFrame,
|
1351
|
+
feature_column_names: Sequence[str],
|
1352
|
+
label_column_name: str,
|
1353
|
+
max_depth: int,
|
1354
|
+
class_names: Sequence[str] | None,
|
1378
1355
|
):
|
1379
|
-
|
1380
|
-
case Ok(source_lfm):
|
1381
|
-
pass
|
1382
|
-
case err:
|
1383
|
-
return err
|
1384
|
-
|
1385
|
-
df_input = source_lfm.data.collect()
|
1386
|
-
dataframe = df_input.select(
|
1387
|
-
list({*op.feature_column_names, op.label_column_name})
|
1388
|
-
)
|
1356
|
+
dataframe = data.select(list({*feature_column_names, label_column_name}))
|
1389
1357
|
boolean_columns = [
|
1390
1358
|
name
|
1391
1359
|
for name, dtype in dataframe.schema.items()
|
1392
|
-
if dtype == pl.Boolean() and name in
|
1360
|
+
if dtype == pl.Boolean() and name in feature_column_names
|
1393
1361
|
]
|
1394
1362
|
|
1395
1363
|
# Drop Nan and Null and infinite rows as not supported by decision tree
|
1396
1364
|
dataframe = dataframe.with_columns(
|
1397
|
-
*[pl.col(col).cast(pl.Float32) for col in
|
1365
|
+
*[pl.col(col).cast(pl.Float32) for col in feature_column_names]
|
1398
1366
|
)
|
1399
1367
|
dataframe = dataframe.drop_nans().drop_nulls()
|
1400
1368
|
try:
|
@@ -1407,9 +1375,8 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1407
1375
|
return InvalidArgumentError(
|
1408
1376
|
"a minimum of 1 sample is required by DecisionTreeClassifier"
|
1409
1377
|
)
|
1410
|
-
features = dataframe[
|
1411
|
-
classes = dataframe[
|
1412
|
-
max_depth = op.max_depth
|
1378
|
+
features = dataframe[feature_column_names]
|
1379
|
+
classes = dataframe[label_column_name]
|
1413
1380
|
|
1414
1381
|
from sklearn.tree import DecisionTreeClassifier, export_graphviz, export_text
|
1415
1382
|
from sklearn.utils.multiclass import check_classification_targets
|
@@ -1427,15 +1394,15 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1427
1394
|
|
1428
1395
|
tree_str = export_text(
|
1429
1396
|
decision_tree=decision_tree,
|
1430
|
-
feature_names=
|
1431
|
-
class_names=
|
1397
|
+
feature_names=feature_column_names,
|
1398
|
+
class_names=class_names,
|
1432
1399
|
max_depth=max_depth,
|
1433
1400
|
)
|
1434
1401
|
|
1435
1402
|
tree_graphviz = export_graphviz(
|
1436
1403
|
decision_tree=decision_tree,
|
1437
|
-
feature_names=
|
1438
|
-
class_names=
|
1404
|
+
feature_names=feature_column_names,
|
1405
|
+
class_names=class_names,
|
1439
1406
|
max_depth=max_depth,
|
1440
1407
|
)
|
1441
1408
|
|
@@ -1445,16 +1412,41 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1445
1412
|
)
|
1446
1413
|
tree_str = tree_str.replace(f"{boolean_column} > 0.50", boolean_column)
|
1447
1414
|
|
1415
|
+
return Ok(table_pb2.DecisionTreeSummary(text=tree_str, graphviz=tree_graphviz))
|
1416
|
+
|
1417
|
+
async def _execute_add_decision_tree_summary(
|
1418
|
+
self, op: op_graph.op.AddDecisionTreeSummary, context: _InMemoryExecutionContext
|
1419
|
+
):
|
1420
|
+
match await self._execute(op.source, context):
|
1421
|
+
case Ok(source_lfm):
|
1422
|
+
pass
|
1423
|
+
case err:
|
1424
|
+
return err
|
1448
1425
|
metrics = source_lfm.metrics.copy()
|
1449
|
-
|
1450
|
-
|
1426
|
+
source_df, summary_result = await context.run_on_worker(
|
1427
|
+
_collect_and_apply,
|
1428
|
+
source_lfm.data,
|
1429
|
+
lambda df: self._compute_decision_tree_summary(
|
1430
|
+
df,
|
1431
|
+
op.feature_column_names,
|
1432
|
+
op.label_column_name,
|
1433
|
+
op.max_depth,
|
1434
|
+
op.classes_names,
|
1435
|
+
),
|
1451
1436
|
)
|
1452
|
-
return Ok(_LazyFrameWithMetrics(df_input.lazy(), metrics=metrics))
|
1453
1437
|
|
1454
|
-
|
1438
|
+
match summary_result:
|
1439
|
+
case InvalidArgumentError() | InternalError() as err:
|
1440
|
+
return err
|
1441
|
+
case Ok(tree_summary):
|
1442
|
+
metrics[op.output_metric_key] = tree_summary
|
1443
|
+
|
1444
|
+
return Ok(_LazyFrameWithMetrics(source_df.lazy(), metrics=metrics))
|
1445
|
+
|
1446
|
+
async def _execute_unnest_list(
|
1455
1447
|
self, op: op_graph.op.UnnestList, context: _InMemoryExecutionContext
|
1456
1448
|
):
|
1457
|
-
return self._compute_source_then_apply(
|
1449
|
+
return await self._compute_source_then_apply(
|
1458
1450
|
op.source,
|
1459
1451
|
lambda lf: lf.with_columns(
|
1460
1452
|
pl.col(op.list_column_name).list.get(i).alias(column_name)
|
@@ -1463,47 +1455,42 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1463
1455
|
context,
|
1464
1456
|
)
|
1465
1457
|
|
1466
|
-
def _execute_sample_rows(
|
1458
|
+
async def _execute_sample_rows(
|
1467
1459
|
self, op: op_graph.op.SampleRows, context: _InMemoryExecutionContext
|
1468
1460
|
):
|
1469
|
-
match self._execute(op.source, context):
|
1461
|
+
match await self._execute(op.source, context):
|
1470
1462
|
case Ok(source_lfm):
|
1471
1463
|
pass
|
1472
1464
|
case err:
|
1473
1465
|
return err
|
1474
|
-
source_df = source_lfm.data.collect()
|
1475
|
-
n = min(op.num_rows, source_df.shape[0])
|
1476
|
-
sample_strategy = op.sample_strategy
|
1477
|
-
match sample_strategy:
|
1478
|
-
case op_graph.sample_strategy.UniformRandom():
|
1479
|
-
result_df = source_df.sample(
|
1480
|
-
n=n,
|
1481
|
-
seed=sample_strategy.seed,
|
1482
|
-
)
|
1483
1466
|
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1467
|
+
def sample(df: pl.DataFrame):
|
1468
|
+
match op.sample_strategy:
|
1469
|
+
case op_graph.sample_strategy.UniformRandom():
|
1470
|
+
return df.sample(
|
1471
|
+
min(op.num_rows, df.shape[0]), seed=op.sample_strategy.seed
|
1472
|
+
)
|
1473
|
+
|
1474
|
+
_, result_df = await context.run_on_worker(
|
1475
|
+
_collect_and_apply, source_lfm.data, sample
|
1489
1476
|
)
|
1490
1477
|
|
1491
|
-
|
1478
|
+
return Ok(_LazyFrameWithMetrics(result_df.lazy(), source_lfm.metrics))
|
1479
|
+
|
1480
|
+
async def _execute_describe_columns(
|
1492
1481
|
self, op: op_graph.op.DescribeColumns, context: _InMemoryExecutionContext
|
1493
1482
|
):
|
1494
|
-
match self._execute(op.source, context):
|
1483
|
+
match await self._execute(op.source, context):
|
1495
1484
|
case Ok(source_lfm):
|
1496
1485
|
pass
|
1497
1486
|
case err:
|
1498
1487
|
return err
|
1499
|
-
|
1500
|
-
|
1501
|
-
source_lfm.
|
1502
|
-
|
1503
|
-
.lazy()
|
1504
|
-
.rename({"statistic": op.statistic_column_name})
|
1505
|
-
)
|
1488
|
+
_, result_df = await context.run_on_worker(
|
1489
|
+
_collect_and_apply,
|
1490
|
+
source_lfm.data,
|
1491
|
+
lambda df: df.describe().rename({"statistic": op.statistic_column_name}),
|
1506
1492
|
)
|
1493
|
+
return Ok(source_lfm.with_data(result_df.lazy()))
|
1507
1494
|
|
1508
1495
|
def _has_partially_computed_data(
|
1509
1496
|
self, op: op_graph.Op, context: _InMemoryExecutionContext
|
@@ -1517,7 +1504,7 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1517
1504
|
for sub_source in flatten(source.sources() for source in op.sources())
|
1518
1505
|
)
|
1519
1506
|
|
1520
|
-
def _do_execute( # noqa: C901
|
1507
|
+
async def _do_execute( # noqa: C901
|
1521
1508
|
self,
|
1522
1509
|
op: op_graph.Op,
|
1523
1510
|
context: _InMemoryExecutionContext,
|
@@ -1552,8 +1539,13 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1552
1539
|
)
|
1553
1540
|
case Ok(query):
|
1554
1541
|
pass
|
1555
|
-
return
|
1556
|
-
|
1542
|
+
return (
|
1543
|
+
await context.run_on_worker(
|
1544
|
+
self._staging_db.run_select_query,
|
1545
|
+
query,
|
1546
|
+
expected_schema,
|
1547
|
+
context.current_slice_args,
|
1548
|
+
)
|
1557
1549
|
).map(
|
1558
1550
|
lambda rbr: _LazyFrameWithMetrics(
|
1559
1551
|
_as_df(rbr, expected_schema).lazy(),
|
@@ -1571,80 +1563,80 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1571
1563
|
case op_graph.op.ReadFromParquet():
|
1572
1564
|
return self._execute_read_from_parquet(op, context)
|
1573
1565
|
case op_graph.op.RenameColumns():
|
1574
|
-
return self._execute_rename_columns(op, context)
|
1566
|
+
return await self._execute_rename_columns(op, context)
|
1575
1567
|
case op_graph.op.Join():
|
1576
|
-
return self._execute_join(op, context)
|
1568
|
+
return await self._execute_join(op, context)
|
1577
1569
|
case op_graph.op.SelectColumns():
|
1578
|
-
return self._execute_select_columns(op, context)
|
1570
|
+
return await self._execute_select_columns(op, context)
|
1579
1571
|
case op_graph.op.LimitRows():
|
1580
|
-
return self._execute_limit_rows(op, context)
|
1572
|
+
return await self._execute_limit_rows(op, context)
|
1581
1573
|
case op_graph.op.OffsetRows():
|
1582
|
-
return self._execute_offset_rows(op, context)
|
1574
|
+
return await self._execute_offset_rows(op, context)
|
1583
1575
|
case op_graph.op.OrderBy():
|
1584
|
-
return self._execute_order_by(op, context)
|
1576
|
+
return await self._execute_order_by(op, context)
|
1585
1577
|
case op_graph.op.FilterRows():
|
1586
|
-
return self._execute_filter_rows(op, context)
|
1578
|
+
return await self._execute_filter_rows(op, context)
|
1587
1579
|
case op_graph.op.DistinctRows():
|
1588
|
-
return self._execute_distinct_rows(op, context)
|
1580
|
+
return await self._execute_distinct_rows(op, context)
|
1589
1581
|
case (
|
1590
1582
|
op_graph.op.SetMetadata()
|
1591
1583
|
| op_graph.op.UpdateMetadata()
|
1592
1584
|
| op_graph.op.RemoveFromMetadata()
|
1593
1585
|
| op_graph.op.UpdateFeatureTypes()
|
1594
1586
|
):
|
1595
|
-
return self._execute(op.source, context)
|
1587
|
+
return await self._execute(op.source, context)
|
1596
1588
|
case op_graph.op.EmbeddingMetrics() as op:
|
1597
|
-
return self._execute_embedding_metrics(op, context)
|
1589
|
+
return await self._execute_embedding_metrics(op, context)
|
1598
1590
|
case op_graph.op.EmbeddingCoordinates():
|
1599
|
-
return self._execute_embedding_coordinates(op, context)
|
1591
|
+
return await self._execute_embedding_coordinates(op, context)
|
1600
1592
|
case op_graph.op.RollupByAggregation() as op:
|
1601
|
-
return self._execute_rollup_by_aggregation(op, context)
|
1593
|
+
return await self._execute_rollup_by_aggregation(op, context)
|
1602
1594
|
case op_graph.op.Empty():
|
1603
|
-
return self._execute_empty(op, context)
|
1595
|
+
return await self._execute_empty(op, context)
|
1604
1596
|
case op_graph.op.EmbedNode2vecFromEdgeLists():
|
1605
|
-
return self._execute_embed_node2vec_from_edge_lists(op, context)
|
1597
|
+
return await self._execute_embed_node2vec_from_edge_lists(op, context)
|
1606
1598
|
case op_graph.op.Concat():
|
1607
|
-
return self._execute_concat(op, context)
|
1599
|
+
return await self._execute_concat(op, context)
|
1608
1600
|
case op_graph.op.UnnestStruct():
|
1609
|
-
return self._execute_unnest_struct(op, context)
|
1601
|
+
return await self._execute_unnest_struct(op, context)
|
1610
1602
|
case op_graph.op.NestIntoStruct():
|
1611
|
-
return self._execute_nest_into_struct(op, context)
|
1603
|
+
return await self._execute_nest_into_struct(op, context)
|
1612
1604
|
case op_graph.op.AddLiteralColumn():
|
1613
|
-
return self._execute_add_literal_column(op, context)
|
1605
|
+
return await self._execute_add_literal_column(op, context)
|
1614
1606
|
case op_graph.op.CombineColumns():
|
1615
|
-
return self._execute_combine_columns(op, context)
|
1607
|
+
return await self._execute_combine_columns(op, context)
|
1616
1608
|
case op_graph.op.EmbedColumn():
|
1617
|
-
return self._execute_embed_column(op, context)
|
1609
|
+
return await self._execute_embed_column(op, context)
|
1618
1610
|
case op_graph.op.EncodeColumns():
|
1619
|
-
return self._execute_encode_columns(op, context)
|
1611
|
+
return await self._execute_encode_columns(op, context)
|
1620
1612
|
case op_graph.op.AggregateColumns():
|
1621
|
-
return self._execute_aggregate_columns(op, context)
|
1613
|
+
return await self._execute_aggregate_columns(op, context)
|
1622
1614
|
case op_graph.op.CorrelateColumns():
|
1623
|
-
return self._execute_correlate_columns(op, context)
|
1615
|
+
return await self._execute_correlate_columns(op, context)
|
1624
1616
|
case op_graph.op.HistogramColumn():
|
1625
|
-
return self._execute_histogram_column(op, context)
|
1617
|
+
return await self._execute_histogram_column(op, context)
|
1626
1618
|
case op_graph.op.ConvertColumnToString():
|
1627
|
-
return self._execute_convert_column_to_string(op, context)
|
1619
|
+
return await self._execute_convert_column_to_string(op, context)
|
1628
1620
|
case op_graph.op.AddRowIndex():
|
1629
|
-
return self._execute_add_row_index(op, context)
|
1621
|
+
return await self._execute_add_row_index(op, context)
|
1630
1622
|
case op_graph.op.OutputCsv():
|
1631
|
-
return self._execute_output_csv(op, context)
|
1623
|
+
return await self._execute_output_csv(op, context)
|
1632
1624
|
case op_graph.op.TruncateList():
|
1633
|
-
return self._execute_truncate_list(op, context)
|
1625
|
+
return await self._execute_truncate_list(op, context)
|
1634
1626
|
case op_graph.op.Union():
|
1635
|
-
return self._execute_union(op, context)
|
1627
|
+
return await self._execute_union(op, context)
|
1636
1628
|
case op_graph.op.EmbedImageColumn():
|
1637
|
-
return self._execute_embed_image_column(op, context)
|
1629
|
+
return await self._execute_embed_image_column(op, context)
|
1638
1630
|
case op_graph.op.AddDecisionTreeSummary():
|
1639
|
-
return self._execute_add_decision_tree_summary(op, context)
|
1631
|
+
return await self._execute_add_decision_tree_summary(op, context)
|
1640
1632
|
case op_graph.op.UnnestList():
|
1641
|
-
return self._execute_unnest_list(op, context)
|
1633
|
+
return await self._execute_unnest_list(op, context)
|
1642
1634
|
case op_graph.op.SampleRows():
|
1643
|
-
return self._execute_sample_rows(op, context)
|
1635
|
+
return await self._execute_sample_rows(op, context)
|
1644
1636
|
case op_graph.op.DescribeColumns():
|
1645
|
-
return self._execute_describe_columns(op, context)
|
1637
|
+
return await self._execute_describe_columns(op, context)
|
1646
1638
|
|
1647
|
-
def _execute(
|
1639
|
+
async def _execute(
|
1648
1640
|
self,
|
1649
1641
|
op: op_graph.Op,
|
1650
1642
|
context: _InMemoryExecutionContext,
|
@@ -1681,7 +1673,7 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1681
1673
|
|
1682
1674
|
try:
|
1683
1675
|
_logger.info("starting op execution")
|
1684
|
-
maybe_lfm = self._do_execute(op=op, context=context)
|
1676
|
+
maybe_lfm = await self._do_execute(op=op, context=context)
|
1685
1677
|
finally:
|
1686
1678
|
_logger.info("op execution complete")
|
1687
1679
|
match maybe_lfm:
|
@@ -1703,29 +1695,33 @@ class InMemoryExecutor(OpGraphExecutor):
|
|
1703
1695
|
context.computed_batches_for_op_graph[sliced_table] = lfm
|
1704
1696
|
return Ok(lfm)
|
1705
1697
|
|
1706
|
-
def execute(
|
1707
|
-
self,
|
1698
|
+
async def execute(
|
1699
|
+
self,
|
1700
|
+
context: ExecutionContext,
|
1701
|
+
worker_threads: ThreadPoolExecutor | None = None,
|
1708
1702
|
) -> (
|
1709
1703
|
Ok[ExecutionResult]
|
1710
1704
|
| InvalidArgumentError
|
1711
1705
|
| InternalError
|
1712
1706
|
| ResourceExhaustedError
|
1713
1707
|
):
|
1714
|
-
with _InMemoryExecutionContext(context) as in_memory_context:
|
1708
|
+
with _InMemoryExecutionContext(context, worker_threads) as in_memory_context:
|
1715
1709
|
for table_context in context.tables_to_compute:
|
1716
1710
|
in_memory_context.current_output_context = table_context
|
1717
1711
|
sliced_table = _SlicedTable(
|
1718
1712
|
table_context.table_op_graph, table_context.sql_output_slice_args
|
1719
1713
|
)
|
1720
1714
|
if sliced_table not in in_memory_context.computed_batches_for_op_graph:
|
1721
|
-
match self._execute(sliced_table.op_graph, in_memory_context):
|
1715
|
+
match await self._execute(sliced_table.op_graph, in_memory_context):
|
1722
1716
|
case Ok():
|
1723
1717
|
pass
|
1724
1718
|
case err:
|
1725
1719
|
return err
|
1726
1720
|
args_lfm_iterator = in_memory_context.computed_batches_for_op_graph.items()
|
1727
1721
|
computed_tables = {
|
1728
|
-
slice_args: _SchemaAndBatches.from_lazy_frame_with_metrics(
|
1722
|
+
slice_args: await _SchemaAndBatches.from_lazy_frame_with_metrics(
|
1723
|
+
lfm, worker_threads
|
1724
|
+
)
|
1729
1725
|
for slice_args, lfm in args_lfm_iterator
|
1730
1726
|
}
|
1731
1727
|
|