kumoai 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
- kumoai/experimental/rfm/backend/local/sampler.py +229 -45
- kumoai/experimental/rfm/backend/local/table.py +12 -2
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
- kumoai/experimental/rfm/backend/snow/table.py +35 -17
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
- kumoai/experimental/rfm/base/__init__.py +16 -5
- kumoai/experimental/rfm/base/sampler.py +538 -52
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +56 -0
- kumoai/experimental/rfm/base/table.py +12 -1
- kumoai/experimental/rfm/graph.py +26 -9
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +214 -151
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SQLSampler(Sampler):
|
|
11
|
+
def _sample_subgraph(
|
|
12
|
+
self,
|
|
13
|
+
entity_table_name: str,
|
|
14
|
+
entity_pkey: pd.Series,
|
|
15
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
16
|
+
columns_dict: dict[str, set[str]],
|
|
17
|
+
num_neighbors: list[int],
|
|
18
|
+
) -> SamplerOutput:
|
|
19
|
+
|
|
20
|
+
df, batch = self._by_pkey(
|
|
21
|
+
table_name=entity_table_name,
|
|
22
|
+
pkey=entity_pkey,
|
|
23
|
+
columns=columns_dict[entity_table_name],
|
|
24
|
+
)
|
|
25
|
+
if len(batch) != len(entity_pkey):
|
|
26
|
+
raise KeyError("Invalid primary keys") # TODO
|
|
27
|
+
|
|
28
|
+
perm = batch.argsort()
|
|
29
|
+
batch = batch[perm]
|
|
30
|
+
df = df.iloc[perm].reset_index(drop=True)
|
|
31
|
+
|
|
32
|
+
if not isinstance(anchor_time, pd.Series):
|
|
33
|
+
time_column = self.time_column_dict[entity_table_name]
|
|
34
|
+
anchor_time = df[time_column]
|
|
35
|
+
|
|
36
|
+
return SamplerOutput(
|
|
37
|
+
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
38
|
+
df_dict={entity_table_name: df},
|
|
39
|
+
inverse_dict={},
|
|
40
|
+
batch_dict={entity_table_name: batch},
|
|
41
|
+
num_sampled_nodes_dict={entity_table_name: [len(batch)]},
|
|
42
|
+
row_dict={},
|
|
43
|
+
col_dict={},
|
|
44
|
+
num_sampled_edges_dict={},
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Abstract Methods ########################################################
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def _by_pkey(
|
|
51
|
+
self,
|
|
52
|
+
table_name: str,
|
|
53
|
+
pkey: pd.Series,
|
|
54
|
+
columns: set[str],
|
|
55
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
56
|
+
pass
|
|
@@ -11,7 +11,12 @@ from kumoapi.typing import Stype
|
|
|
11
11
|
from typing_extensions import Self
|
|
12
12
|
|
|
13
13
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
14
|
-
from kumoai.experimental.rfm.base import
|
|
14
|
+
from kumoai.experimental.rfm.base import (
|
|
15
|
+
Column,
|
|
16
|
+
DataBackend,
|
|
17
|
+
SourceColumn,
|
|
18
|
+
SourceForeignKey,
|
|
19
|
+
)
|
|
15
20
|
from kumoai.experimental.rfm.infer import (
|
|
16
21
|
contains_categorical,
|
|
17
22
|
contains_id,
|
|
@@ -503,6 +508,12 @@ class Table(ABC):
|
|
|
503
508
|
|
|
504
509
|
# Abstract Methods ########################################################
|
|
505
510
|
|
|
511
|
+
@property
|
|
512
|
+
@abstractmethod
|
|
513
|
+
def backend(self) -> DataBackend:
|
|
514
|
+
r"""The data backend of this table."""
|
|
515
|
+
pass
|
|
516
|
+
|
|
506
517
|
@cached_property
|
|
507
518
|
def _source_column_dict(self) -> Dict[str, SourceColumn]:
|
|
508
519
|
return {col.name: col for col in self._get_source_columns()}
|
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -13,7 +13,7 @@ from kumoapi.typing import Stype
|
|
|
13
13
|
from typing_extensions import Self
|
|
14
14
|
|
|
15
15
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
-
from kumoai.experimental.rfm import Table
|
|
16
|
+
from kumoai.experimental.rfm.base import DataBackend, Table
|
|
17
17
|
from kumoai.graph import Edge
|
|
18
18
|
from kumoai.mixin import CastMixin
|
|
19
19
|
|
|
@@ -218,10 +218,12 @@ class Graph:
|
|
|
218
218
|
connect,
|
|
219
219
|
)
|
|
220
220
|
|
|
221
|
+
internal_connection = False
|
|
221
222
|
if not isinstance(connection, Connection):
|
|
222
223
|
connection = SqliteConnectionConfig._cast(connection)
|
|
223
224
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
225
|
connection = connect(connection.uri, **connection.kwargs)
|
|
226
|
+
internal_connection = True
|
|
225
227
|
assert isinstance(connection, Connection)
|
|
226
228
|
|
|
227
229
|
if table_names is None:
|
|
@@ -234,6 +236,9 @@ class Graph:
|
|
|
234
236
|
|
|
235
237
|
graph = cls(tables, edges=edges or [])
|
|
236
238
|
|
|
239
|
+
if internal_connection:
|
|
240
|
+
graph._connection = connection # type: ignore
|
|
241
|
+
|
|
237
242
|
if infer_metadata:
|
|
238
243
|
graph.infer_metadata(False)
|
|
239
244
|
|
|
@@ -394,7 +399,14 @@ class Graph:
|
|
|
394
399
|
|
|
395
400
|
return graph
|
|
396
401
|
|
|
397
|
-
#
|
|
402
|
+
# Backend #################################################################
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def backend(self) -> DataBackend | None:
|
|
406
|
+
backends = [table.backend for table in self._tables.values()]
|
|
407
|
+
return backends[0] if len(backends) > 0 else None
|
|
408
|
+
|
|
409
|
+
# Tables ##################################################################
|
|
398
410
|
|
|
399
411
|
def has_table(self, name: str) -> bool:
|
|
400
412
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -433,13 +445,10 @@ class Graph:
|
|
|
433
445
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
434
446
|
f"this graph; table names must be globally unique.")
|
|
435
447
|
|
|
436
|
-
if
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
f"'{table.__class__.__name__}' to this "
|
|
441
|
-
f"graph since other tables are of type "
|
|
442
|
-
f"'{cls.__name__}'.")
|
|
448
|
+
if self.backend is not None and table.backend != self.backend:
|
|
449
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
450
|
+
f"'{table.backend}' to this graph since other "
|
|
451
|
+
f"tables have backend '{self.backend}'.")
|
|
443
452
|
|
|
444
453
|
self._tables[table.name] = table
|
|
445
454
|
|
|
@@ -826,6 +835,10 @@ class Graph:
|
|
|
826
835
|
raise ValueError("At least one table needs to be added to the "
|
|
827
836
|
"graph")
|
|
828
837
|
|
|
838
|
+
backends = {table.backend for table in self._tables.values()}
|
|
839
|
+
if len(backends) != 1:
|
|
840
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
841
|
+
|
|
829
842
|
for edge in self.edges:
|
|
830
843
|
src_table, fkey, dst_table = edge
|
|
831
844
|
|
|
@@ -1063,3 +1076,7 @@ class Graph:
|
|
|
1063
1076
|
f' tables={tables},\n'
|
|
1064
1077
|
f' edges={edges},\n'
|
|
1065
1078
|
f')')
|
|
1079
|
+
|
|
1080
|
+
def __del__(self) -> None:
|
|
1081
|
+
if hasattr(self, '_connection'):
|
|
1082
|
+
self._connection.close()
|
|
@@ -134,7 +134,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
134
134
|
outs: List[pd.Series] = []
|
|
135
135
|
masks: List[np.ndarray] = []
|
|
136
136
|
for _ in range(num_forecasts):
|
|
137
|
-
anchor_target_time = anchor_time[target_batch]
|
|
137
|
+
anchor_target_time = anchor_time.iloc[target_batch]
|
|
138
138
|
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
139
139
|
|
|
140
140
|
time_filter_mask = (target_time <= anchor_target_time +
|