kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +52 -52
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +753 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +81 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +313 -245
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +49 -29
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kumoapi.typing import Dtype
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SourceColumn:
|
|
8
|
+
name: str
|
|
9
|
+
dtype: Dtype | None
|
|
10
|
+
is_primary_key: bool
|
|
11
|
+
is_unique_key: bool
|
|
12
|
+
is_nullable: bool
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SourceForeignKey:
|
|
17
|
+
name: str
|
|
18
|
+
dst_table: str
|
|
19
|
+
primary_key: str
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Literal
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
from kumoai.experimental.rfm.base import (
|
|
9
|
+
LocalExpression,
|
|
10
|
+
Sampler,
|
|
11
|
+
SamplerOutput,
|
|
12
|
+
SourceColumn,
|
|
13
|
+
)
|
|
14
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from kumoai.experimental.rfm import Graph
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SQLSampler(Sampler):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
graph: 'Graph',
|
|
24
|
+
verbose: bool | ProgressLogger = True,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
27
|
+
|
|
28
|
+
self._source_name_dict: dict[str, str] = {
|
|
29
|
+
table.name: table._quoted_source_name
|
|
30
|
+
for table in graph.tables.values()
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
34
|
+
for table in graph.tables.values():
|
|
35
|
+
self._source_table_dict[table.name] = {}
|
|
36
|
+
for column in table.columns:
|
|
37
|
+
if not column.is_source:
|
|
38
|
+
continue
|
|
39
|
+
src_column = table._source_column_dict[column.name]
|
|
40
|
+
self._source_table_dict[table.name][column.name] = src_column
|
|
41
|
+
|
|
42
|
+
self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
|
|
43
|
+
for table in graph.tables.values():
|
|
44
|
+
self._table_dtype_dict[table.name] = {}
|
|
45
|
+
for column in table.columns:
|
|
46
|
+
self._table_dtype_dict[table.name][column.name] = column.dtype
|
|
47
|
+
|
|
48
|
+
self._table_column_ref_dict: dict[str, dict[str, str]] = {}
|
|
49
|
+
self._table_column_proj_dict: dict[str, dict[str, str]] = {}
|
|
50
|
+
for table in graph.tables.values():
|
|
51
|
+
column_ref_dict: dict[str, str] = {}
|
|
52
|
+
column_proj_dict: dict[str, str] = {}
|
|
53
|
+
for column in table.columns:
|
|
54
|
+
if column.expr is not None:
|
|
55
|
+
assert isinstance(column.expr, LocalExpression)
|
|
56
|
+
column_ref_dict[column.name] = column.expr.value
|
|
57
|
+
column_proj_dict[column.name] = (
|
|
58
|
+
f'{column.expr} AS {quote_ident(column.name)}')
|
|
59
|
+
else:
|
|
60
|
+
column_ref_dict[column.name] = quote_ident(column.name)
|
|
61
|
+
column_proj_dict[column.name] = quote_ident(column.name)
|
|
62
|
+
self._table_column_ref_dict[table.name] = column_ref_dict
|
|
63
|
+
self._table_column_proj_dict[table.name] = column_proj_dict
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def source_name_dict(self) -> dict[str, str]:
|
|
67
|
+
r"""The source table names for all tables in the graph."""
|
|
68
|
+
return self._source_name_dict
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
72
|
+
r"""The source column information for all tables in the graph."""
|
|
73
|
+
return self._source_table_dict
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
|
|
77
|
+
r"""The data types for all columns in all tables in the graph."""
|
|
78
|
+
return self._table_dtype_dict
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
|
|
82
|
+
r"""The SQL reference expression for all columns in all tables in the
|
|
83
|
+
graph.
|
|
84
|
+
"""
|
|
85
|
+
return self._table_column_ref_dict
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
|
|
89
|
+
r"""The SQL projection expressions for all columns in all tables in the
|
|
90
|
+
graph.
|
|
91
|
+
"""
|
|
92
|
+
return self._table_column_proj_dict
|
|
93
|
+
|
|
94
|
+
def _sample_subgraph(
|
|
95
|
+
self,
|
|
96
|
+
entity_table_name: str,
|
|
97
|
+
entity_pkey: pd.Series,
|
|
98
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
99
|
+
columns_dict: dict[str, set[str]],
|
|
100
|
+
num_neighbors: list[int],
|
|
101
|
+
) -> SamplerOutput:
|
|
102
|
+
|
|
103
|
+
df, batch = self._by_pkey(
|
|
104
|
+
table_name=entity_table_name,
|
|
105
|
+
pkey=entity_pkey,
|
|
106
|
+
columns=columns_dict[entity_table_name],
|
|
107
|
+
)
|
|
108
|
+
if len(batch) != len(entity_pkey):
|
|
109
|
+
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
110
|
+
mask[batch] = False
|
|
111
|
+
raise KeyError(f"The primary keys "
|
|
112
|
+
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
113
|
+
f"in the '{entity_table_name}' table")
|
|
114
|
+
|
|
115
|
+
perm = batch.argsort()
|
|
116
|
+
batch = batch[perm]
|
|
117
|
+
df = df.iloc[perm].reset_index(drop=True)
|
|
118
|
+
|
|
119
|
+
if not isinstance(anchor_time, pd.Series):
|
|
120
|
+
time_column = self.time_column_dict[entity_table_name]
|
|
121
|
+
anchor_time = df[time_column]
|
|
122
|
+
|
|
123
|
+
return SamplerOutput(
|
|
124
|
+
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
125
|
+
df_dict={entity_table_name: df},
|
|
126
|
+
inverse_dict={},
|
|
127
|
+
batch_dict={entity_table_name: batch},
|
|
128
|
+
num_sampled_nodes_dict={entity_table_name: [len(batch)]},
|
|
129
|
+
row_dict={},
|
|
130
|
+
col_dict={},
|
|
131
|
+
num_sampled_edges_dict={},
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Abstract Methods ########################################################
|
|
135
|
+
|
|
136
|
+
@abstractmethod
|
|
137
|
+
def _by_pkey(
|
|
138
|
+
self,
|
|
139
|
+
table_name: str,
|
|
140
|
+
pkey: pd.Series,
|
|
141
|
+
columns: set[str],
|
|
142
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
143
|
+
pass
|