kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import kumoai.kumolib # noqa: F401
|
|
3
|
+
except Exception as e:
|
|
4
|
+
import platform
|
|
5
|
+
|
|
6
|
+
_msg = f"""RFM is not supported in your environment.
|
|
7
|
+
|
|
8
|
+
💻 Your Environment:
|
|
9
|
+
Python version: {platform.python_version()}
|
|
10
|
+
Operating system: {platform.system()}
|
|
11
|
+
CPU architecture: {platform.machine()}
|
|
12
|
+
glibc version: {platform.libc_ver()[1]}
|
|
13
|
+
|
|
14
|
+
✅ Supported Environments:
|
|
15
|
+
* Python versions: 3.10, 3.11, 3.12, 3.13
|
|
16
|
+
* Operating systems and CPU architectures:
|
|
17
|
+
* Linux (x86_64)
|
|
18
|
+
* macOS (arm64)
|
|
19
|
+
* Windows (x86_64)
|
|
20
|
+
* glibc versions: >=2.28
|
|
21
|
+
|
|
22
|
+
❌ Unsupported Environments:
|
|
23
|
+
* Python versions: 3.8, 3.9, 3.14
|
|
24
|
+
* Operating systems and CPU architectures:
|
|
25
|
+
* Linux (arm64)
|
|
26
|
+
* macOS (x86_64)
|
|
27
|
+
* Windows (arm64)
|
|
28
|
+
* glibc versions: <2.28
|
|
29
|
+
|
|
30
|
+
Please create a feature request at 'https://github.com/kumo-ai/kumo-rfm'."""
|
|
31
|
+
|
|
32
|
+
raise RuntimeError(_msg) from e
|
|
33
|
+
|
|
34
|
+
from .table import LocalTable
|
|
35
|
+
from .graph_store import LocalGraphStore
|
|
36
|
+
from .sampler import LocalSampler
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
'LocalTable',
|
|
40
|
+
'LocalGraphStore',
|
|
41
|
+
'LocalSampler',
|
|
42
|
+
]
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.rfm.context import Subgraph
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
9
|
+
from kumoai.utils import ProgressLogger
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import torch
|
|
13
|
+
WITH_TORCH = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
WITH_TORCH = False
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from kumoai.experimental.rfm import Graph
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LocalGraphStore:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
graph: 'Graph',
|
|
25
|
+
verbose: bool | ProgressLogger = True,
|
|
26
|
+
) -> None:
|
|
27
|
+
|
|
28
|
+
if not isinstance(verbose, ProgressLogger):
|
|
29
|
+
verbose = ProgressLogger.default(
|
|
30
|
+
msg="Materializing graph",
|
|
31
|
+
verbose=verbose,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
with verbose as logger:
|
|
35
|
+
self.df_dict, self.mask_dict = self.sanitize(graph)
|
|
36
|
+
logger.log("Sanitized input data")
|
|
37
|
+
|
|
38
|
+
self.pkey_map_dict = self.get_pkey_map_dict(graph)
|
|
39
|
+
num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
|
|
40
|
+
if num_pkeys > 1:
|
|
41
|
+
logger.log(f"Collected primary keys from {num_pkeys} tables")
|
|
42
|
+
else:
|
|
43
|
+
logger.log(f"Collected primary key from {num_pkeys} table")
|
|
44
|
+
|
|
45
|
+
self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
|
|
46
|
+
if len(self.min_max_time_dict) > 0:
|
|
47
|
+
min_time = min(t for t, _ in self.min_max_time_dict.values())
|
|
48
|
+
max_time = max(t for _, t in self.min_max_time_dict.values())
|
|
49
|
+
logger.log(f"Identified temporal graph from "
|
|
50
|
+
f"{min_time.date()} to {max_time.date()}")
|
|
51
|
+
else:
|
|
52
|
+
logger.log("Identified static graph without timestamps")
|
|
53
|
+
|
|
54
|
+
self.row_dict, self.colptr_dict = self.get_csc(graph)
|
|
55
|
+
num_nodes = sum(len(df) for df in self.df_dict.values())
|
|
56
|
+
num_edges = sum(len(row) for row in self.row_dict.values())
|
|
57
|
+
logger.log(f"Created graph with {num_nodes:,} nodes and "
|
|
58
|
+
f"{num_edges:,} edges")
|
|
59
|
+
|
|
60
|
+
def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
|
|
61
|
+
r"""Returns the node ID given primary keys.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
table_name: The table name.
|
|
65
|
+
pkey: The primary keys to receive node IDs for.
|
|
66
|
+
"""
|
|
67
|
+
if table_name not in self.df_dict.keys():
|
|
68
|
+
raise KeyError(f"Table '{table_name}' does not exist")
|
|
69
|
+
|
|
70
|
+
if table_name not in self.pkey_map_dict.keys():
|
|
71
|
+
raise ValueError(f"Table '{table_name}' does not have a primary "
|
|
72
|
+
f"key")
|
|
73
|
+
|
|
74
|
+
if len(pkey) == 0:
|
|
75
|
+
raise KeyError(f"No primary keys passed for table '{table_name}'")
|
|
76
|
+
|
|
77
|
+
pkey_map = self.pkey_map_dict[table_name]
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
pkey = pkey.astype(type(pkey_map.index[0]))
|
|
81
|
+
except ValueError as e:
|
|
82
|
+
raise ValueError(f"Could not cast primary keys "
|
|
83
|
+
f"{pkey.tolist()} to the expected data "
|
|
84
|
+
f"type '{pkey_map.index.dtype}'") from e
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
return pkey_map.loc[pkey]['arange'].to_numpy()
|
|
88
|
+
except KeyError as e:
|
|
89
|
+
missing = ~np.isin(pkey, pkey_map.index)
|
|
90
|
+
raise KeyError(f"The primary keys {pkey[missing].tolist()} do "
|
|
91
|
+
f"not exist in the '{table_name}' table") from e
|
|
92
|
+
|
|
93
|
+
def sanitize(
|
|
94
|
+
self,
|
|
95
|
+
graph: 'Graph',
|
|
96
|
+
) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
|
|
97
|
+
r"""Sanitizes raw data according to table schema definition:
|
|
98
|
+
|
|
99
|
+
In particular, it:
|
|
100
|
+
* converts timestamp data to `pd.Datetime`
|
|
101
|
+
* drops timezone information from timestamps
|
|
102
|
+
* drops duplicate primary keys
|
|
103
|
+
* removes rows with missing primary keys or time values
|
|
104
|
+
"""
|
|
105
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
106
|
+
for table_name, table in graph.tables.items():
|
|
107
|
+
assert isinstance(table, LocalTable)
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
mask_dict: dict[str, np.ndarray] = {}
|
|
121
|
+
for table in graph.tables.values():
|
|
122
|
+
mask: np.ndarray | None = None
|
|
123
|
+
if table._time_column is not None:
|
|
124
|
+
ser = df_dict[table.name][table._time_column]
|
|
125
|
+
mask = ser.notna().to_numpy()
|
|
126
|
+
|
|
127
|
+
if table._primary_key is not None:
|
|
128
|
+
ser = df_dict[table.name][table._primary_key]
|
|
129
|
+
_mask = (~ser.duplicated().to_numpy()) & ser.notna().to_numpy()
|
|
130
|
+
mask = _mask if mask is None else (_mask & mask)
|
|
131
|
+
|
|
132
|
+
if mask is not None and not mask.all():
|
|
133
|
+
mask_dict[table.name] = mask
|
|
134
|
+
|
|
135
|
+
return df_dict, mask_dict
|
|
136
|
+
|
|
137
|
+
def get_pkey_map_dict(
|
|
138
|
+
self,
|
|
139
|
+
graph: 'Graph',
|
|
140
|
+
) -> dict[str, pd.DataFrame]:
|
|
141
|
+
pkey_map_dict: dict[str, pd.DataFrame] = {}
|
|
142
|
+
|
|
143
|
+
for table in graph.tables.values():
|
|
144
|
+
if table._primary_key is None:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
pkey = self.df_dict[table.name][table._primary_key]
|
|
148
|
+
pkey_map = pd.DataFrame(
|
|
149
|
+
dict(arange=range(len(pkey))),
|
|
150
|
+
index=pkey,
|
|
151
|
+
)
|
|
152
|
+
if table.name in self.mask_dict:
|
|
153
|
+
pkey_map = pkey_map[self.mask_dict[table.name]]
|
|
154
|
+
|
|
155
|
+
if len(pkey_map) == 0:
|
|
156
|
+
error_msg = f"Found no valid rows in table '{table.name}'. "
|
|
157
|
+
if table.has_time_column():
|
|
158
|
+
error_msg += ("Please make sure that there exists valid "
|
|
159
|
+
"non-N/A primary key and time column pairs "
|
|
160
|
+
"in this table.")
|
|
161
|
+
else:
|
|
162
|
+
error_msg += ("Please make sure that there exists valid "
|
|
163
|
+
"non-N/A primary keys in this table.")
|
|
164
|
+
raise ValueError(error_msg)
|
|
165
|
+
|
|
166
|
+
pkey_map_dict[table.name] = pkey_map
|
|
167
|
+
|
|
168
|
+
return pkey_map_dict
|
|
169
|
+
|
|
170
|
+
def get_time_data(
|
|
171
|
+
self,
|
|
172
|
+
graph: 'Graph',
|
|
173
|
+
) -> tuple[
|
|
174
|
+
dict[str, np.ndarray],
|
|
175
|
+
dict[str, tuple[pd.Timestamp, pd.Timestamp]],
|
|
176
|
+
]:
|
|
177
|
+
time_dict: dict[str, np.ndarray] = {}
|
|
178
|
+
min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
179
|
+
for table in graph.tables.values():
|
|
180
|
+
if table._time_column is None:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
time = self.df_dict[table.name][table._time_column]
|
|
184
|
+
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
185
|
+
|
|
186
|
+
if table.name in self.mask_dict.keys():
|
|
187
|
+
time = time[self.mask_dict[table.name]]
|
|
188
|
+
if len(time) > 0:
|
|
189
|
+
min_max_time_dict[table.name] = (time.min(), time.max())
|
|
190
|
+
else:
|
|
191
|
+
min_max_time_dict[table.name] = (
|
|
192
|
+
pd.Timestamp.max,
|
|
193
|
+
pd.Timestamp.min,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return time_dict, min_max_time_dict
|
|
197
|
+
|
|
198
|
+
def get_csc(
|
|
199
|
+
self,
|
|
200
|
+
graph: 'Graph',
|
|
201
|
+
) -> tuple[
|
|
202
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
203
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
204
|
+
]:
|
|
205
|
+
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
206
|
+
map_dict: dict[str, pd.CategoricalDtype] = {}
|
|
207
|
+
# A dictionary to manage offsets of node indices for invalid rows:
|
|
208
|
+
offset_dict: dict[str, np.ndarray] = {}
|
|
209
|
+
for table_name in {edge.dst_table for edge in graph.edges}:
|
|
210
|
+
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
211
|
+
if table_name in self.mask_dict.keys():
|
|
212
|
+
mask = self.mask_dict[table_name]
|
|
213
|
+
ser = ser[mask]
|
|
214
|
+
offset_dict[table_name] = np.cumsum(~mask)[mask]
|
|
215
|
+
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
216
|
+
|
|
217
|
+
# Build CSC graph representation:
|
|
218
|
+
row_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
219
|
+
colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
220
|
+
for src_table, fkey, dst_table in graph.edges:
|
|
221
|
+
src_df = self.df_dict[src_table]
|
|
222
|
+
dst_df = self.df_dict[dst_table]
|
|
223
|
+
|
|
224
|
+
src = np.arange(len(src_df))
|
|
225
|
+
dst = src_df[fkey].astype(map_dict[dst_table]).cat.codes.to_numpy()
|
|
226
|
+
dst = dst.astype(int)
|
|
227
|
+
mask = dst >= 0
|
|
228
|
+
if dst_table in offset_dict.keys():
|
|
229
|
+
dst = dst + offset_dict[dst_table][dst]
|
|
230
|
+
if src_table in self.mask_dict.keys():
|
|
231
|
+
mask &= self.mask_dict[src_table]
|
|
232
|
+
src, dst = src[mask], dst[mask]
|
|
233
|
+
|
|
234
|
+
# Sort by destination/column (and time within neighborhoods):
|
|
235
|
+
# `lexsort` is expensive (especially in numpy) so avoid it if
|
|
236
|
+
# possible by grouping `time` and `node_id` together:
|
|
237
|
+
if src_table in self.time_dict:
|
|
238
|
+
src_time = self.time_dict[src_table][src]
|
|
239
|
+
min_time = int(src_time.min())
|
|
240
|
+
max_time = int(src_time.max())
|
|
241
|
+
offset = (max_time - min_time) + 1
|
|
242
|
+
if offset * len(dst_df) <= np.iinfo(np.int64).max:
|
|
243
|
+
index = dst * offset + (src_time - min_time)
|
|
244
|
+
perm = _argsort(index)
|
|
245
|
+
else: # Safe route to avoid `int64` overflow:
|
|
246
|
+
perm = _lexsort([src_time, dst])
|
|
247
|
+
else:
|
|
248
|
+
perm = _argsort(dst)
|
|
249
|
+
|
|
250
|
+
row, col = src[perm], dst[perm]
|
|
251
|
+
|
|
252
|
+
# Convert into compressed representation:
|
|
253
|
+
colcount = np.bincount(col, minlength=len(dst_df))
|
|
254
|
+
colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
|
|
255
|
+
colptr[0] = 0
|
|
256
|
+
np.cumsum(colcount, out=colptr[1:])
|
|
257
|
+
edge_type = (src_table, fkey, dst_table)
|
|
258
|
+
row_dict[edge_type] = row
|
|
259
|
+
colptr_dict[edge_type] = colptr
|
|
260
|
+
|
|
261
|
+
# Reverse connection - no sort and no time handling needed since
|
|
262
|
+
# the reverse mapping is 1-to-many.
|
|
263
|
+
row, col = dst, src
|
|
264
|
+
colcount = np.bincount(col, minlength=len(src_df))
|
|
265
|
+
colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
|
|
266
|
+
colptr[0] = 0
|
|
267
|
+
np.cumsum(colcount, out=colptr[1:])
|
|
268
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
269
|
+
row_dict[edge_type] = row
|
|
270
|
+
colptr_dict[edge_type] = colptr
|
|
271
|
+
|
|
272
|
+
return row_dict, colptr_dict
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _argsort(input: np.ndarray) -> np.ndarray:
|
|
276
|
+
if not WITH_TORCH:
|
|
277
|
+
return np.argsort(input)
|
|
278
|
+
return torch.from_numpy(input).argsort().numpy()
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
|
|
282
|
+
assert len(inputs) >= 1
|
|
283
|
+
|
|
284
|
+
if not WITH_TORCH:
|
|
285
|
+
return np.lexsort(inputs)
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
out = torch.from_numpy(inputs[0]).argsort(stable=True)
|
|
289
|
+
except Exception:
|
|
290
|
+
return np.lexsort(inputs) # PyTorch<1.9 without `stable` support.
|
|
291
|
+
|
|
292
|
+
for input in inputs[1:]:
|
|
293
|
+
index = torch.from_numpy(input)[out]
|
|
294
|
+
index = index.argsort(stable=True)
|
|
295
|
+
out = out[index]
|
|
296
|
+
|
|
297
|
+
return out.numpy()
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
8
|
+
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
9
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
10
|
+
from kumoai.utils import ProgressLogger
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from kumoai.experimental.rfm import Graph
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalSampler(Sampler):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
graph: 'Graph',
|
|
20
|
+
verbose: bool | ProgressLogger = True,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
23
|
+
|
|
24
|
+
import kumoai.kumolib as kumolib
|
|
25
|
+
|
|
26
|
+
self._graph_store = LocalGraphStore(graph, verbose)
|
|
27
|
+
self._graph_sampler = kumolib.NeighborSampler(
|
|
28
|
+
list(self.table_stype_dict.keys()),
|
|
29
|
+
self.edge_types,
|
|
30
|
+
{
|
|
31
|
+
'__'.join(edge_type): colptr
|
|
32
|
+
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
'__'.join(edge_type): row
|
|
36
|
+
for edge_type, row in self._graph_store.row_dict.items()
|
|
37
|
+
},
|
|
38
|
+
self._graph_store.time_dict,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def _get_min_max_time_dict(
|
|
42
|
+
self,
|
|
43
|
+
table_names: list[str],
|
|
44
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
45
|
+
return {
|
|
46
|
+
key: value
|
|
47
|
+
for key, value in self._graph_store.min_max_time_dict.items()
|
|
48
|
+
if key in table_names
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def _sample_subgraph(
|
|
52
|
+
self,
|
|
53
|
+
entity_table_name: str,
|
|
54
|
+
entity_pkey: pd.Series,
|
|
55
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
56
|
+
columns_dict: dict[str, set[str]],
|
|
57
|
+
num_neighbors: list[int],
|
|
58
|
+
) -> SamplerOutput:
|
|
59
|
+
|
|
60
|
+
index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
|
|
61
|
+
|
|
62
|
+
if isinstance(anchor_time, pd.Series):
|
|
63
|
+
time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
|
|
64
|
+
else:
|
|
65
|
+
assert anchor_time == 'entity'
|
|
66
|
+
time = self._graph_store.time_dict[entity_table_name][index]
|
|
67
|
+
|
|
68
|
+
(
|
|
69
|
+
row_dict,
|
|
70
|
+
col_dict,
|
|
71
|
+
node_dict,
|
|
72
|
+
batch_dict,
|
|
73
|
+
num_sampled_nodes_dict,
|
|
74
|
+
num_sampled_edges_dict,
|
|
75
|
+
) = self._graph_sampler.sample(
|
|
76
|
+
{
|
|
77
|
+
'__'.join(edge_type): num_neighbors
|
|
78
|
+
for edge_type in self.edge_types
|
|
79
|
+
},
|
|
80
|
+
{},
|
|
81
|
+
entity_table_name,
|
|
82
|
+
index,
|
|
83
|
+
time,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
87
|
+
inverse_dict: dict[str, np.ndarray] = {}
|
|
88
|
+
for table_name, node in node_dict.items():
|
|
89
|
+
df = self._graph_store.df_dict[table_name]
|
|
90
|
+
columns = columns_dict[table_name]
|
|
91
|
+
if self.end_time_column_dict.get(table_name, None) in columns:
|
|
92
|
+
df = df.iloc[node]
|
|
93
|
+
elif len(columns) == 0:
|
|
94
|
+
df = df.iloc[node]
|
|
95
|
+
else:
|
|
96
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
97
|
+
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
98
|
+
if len(node) > 1.05 * len(unique_node):
|
|
99
|
+
df = df.iloc[unique_node]
|
|
100
|
+
inverse_dict[table_name] = inverse
|
|
101
|
+
else:
|
|
102
|
+
df = df.iloc[node]
|
|
103
|
+
df = df.reset_index(drop=True)
|
|
104
|
+
df = df[list(columns)]
|
|
105
|
+
df_dict[table_name] = df
|
|
106
|
+
|
|
107
|
+
num_sampled_nodes_dict = {
|
|
108
|
+
table_name: num_sampled_nodes.tolist()
|
|
109
|
+
for table_name, num_sampled_nodes in
|
|
110
|
+
num_sampled_nodes_dict.items()
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
row_dict = {
|
|
114
|
+
edge_type: row_dict['__'.join(edge_type)]
|
|
115
|
+
for edge_type in self.edge_types
|
|
116
|
+
}
|
|
117
|
+
col_dict = {
|
|
118
|
+
edge_type: col_dict['__'.join(edge_type)]
|
|
119
|
+
for edge_type in self.edge_types
|
|
120
|
+
}
|
|
121
|
+
num_sampled_edges_dict = {
|
|
122
|
+
edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
|
|
123
|
+
for edge_type in self.edge_types
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return SamplerOutput(
|
|
127
|
+
anchor_time=time * 1000**3, # to nanoseconds
|
|
128
|
+
df_dict=df_dict,
|
|
129
|
+
inverse_dict=inverse_dict,
|
|
130
|
+
batch_dict=batch_dict,
|
|
131
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
132
|
+
row_dict=row_dict,
|
|
133
|
+
col_dict=col_dict,
|
|
134
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _sample_entity_table(
|
|
138
|
+
self,
|
|
139
|
+
table_name: str,
|
|
140
|
+
columns: set[str],
|
|
141
|
+
num_rows: int,
|
|
142
|
+
random_seed: int | None = None,
|
|
143
|
+
) -> pd.DataFrame:
|
|
144
|
+
pkey_map = self._graph_store.pkey_map_dict[table_name]
|
|
145
|
+
if len(pkey_map) > num_rows:
|
|
146
|
+
pkey_map = pkey_map.sample(
|
|
147
|
+
n=num_rows,
|
|
148
|
+
random_state=random_seed,
|
|
149
|
+
ignore_index=True,
|
|
150
|
+
)
|
|
151
|
+
df = self._graph_store.df_dict[table_name]
|
|
152
|
+
df = df.iloc[pkey_map['arange']][list(columns)]
|
|
153
|
+
return df
|
|
154
|
+
|
|
155
|
+
def _sample_target(
|
|
156
|
+
self,
|
|
157
|
+
query: ValidatedPredictiveQuery,
|
|
158
|
+
entity_df: pd.DataFrame,
|
|
159
|
+
train_index: np.ndarray,
|
|
160
|
+
train_time: pd.Series,
|
|
161
|
+
num_train_examples: int,
|
|
162
|
+
test_index: np.ndarray,
|
|
163
|
+
test_time: pd.Series,
|
|
164
|
+
num_test_examples: int,
|
|
165
|
+
columns_dict: dict[str, set[str]],
|
|
166
|
+
time_offset_dict: dict[
|
|
167
|
+
tuple[str, str, str],
|
|
168
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
169
|
+
],
|
|
170
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
171
|
+
|
|
172
|
+
train_y, train_mask = self._sample_target_set(
|
|
173
|
+
query=query,
|
|
174
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
175
|
+
index=train_index,
|
|
176
|
+
anchor_time=train_time,
|
|
177
|
+
num_examples=num_train_examples,
|
|
178
|
+
columns_dict=columns_dict,
|
|
179
|
+
time_offset_dict=time_offset_dict,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
test_y, test_mask = self._sample_target_set(
|
|
183
|
+
query=query,
|
|
184
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
185
|
+
index=test_index,
|
|
186
|
+
anchor_time=test_time,
|
|
187
|
+
num_examples=num_test_examples,
|
|
188
|
+
columns_dict=columns_dict,
|
|
189
|
+
time_offset_dict=time_offset_dict,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return train_y, train_mask, test_y, test_mask
|
|
193
|
+
|
|
194
|
+
# Helper Methods ##########################################################
|
|
195
|
+
|
|
196
|
+
def _sample_target_set(
|
|
197
|
+
self,
|
|
198
|
+
query: ValidatedPredictiveQuery,
|
|
199
|
+
pkey: pd.Series,
|
|
200
|
+
index: np.ndarray,
|
|
201
|
+
anchor_time: pd.Series,
|
|
202
|
+
num_examples: int,
|
|
203
|
+
columns_dict: dict[str, set[str]],
|
|
204
|
+
time_offset_dict: dict[
|
|
205
|
+
tuple[str, str, str],
|
|
206
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
207
|
+
],
|
|
208
|
+
batch_size: int = 10_000,
|
|
209
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
210
|
+
|
|
211
|
+
num_hops = 1 if len(time_offset_dict) > 0 else 0
|
|
212
|
+
num_neighbors_dict: dict[str, list[int]] = {}
|
|
213
|
+
unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
|
|
214
|
+
for edge_type, (start, end) in time_offset_dict.items():
|
|
215
|
+
unix_time_offset_dict['__'.join(edge_type)] = [[
|
|
216
|
+
date_offset_to_seconds(start) if start is not None else None,
|
|
217
|
+
date_offset_to_seconds(end),
|
|
218
|
+
]]
|
|
219
|
+
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
220
|
+
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
221
|
+
|
|
222
|
+
count = 0
|
|
223
|
+
ys: list[pd.Series] = []
|
|
224
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
225
|
+
for start in range(0, len(index), batch_size):
|
|
226
|
+
subset = pkey.iloc[index[start:start + batch_size]]
|
|
227
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
228
|
+
|
|
229
|
+
_, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
|
|
230
|
+
num_neighbors_dict,
|
|
231
|
+
unix_time_offset_dict,
|
|
232
|
+
query.entity_table,
|
|
233
|
+
self._graph_store.get_node_id(query.entity_table, subset),
|
|
234
|
+
time.astype(int).to_numpy() // 1000**3, # to seconds
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
feat_dict: dict[str, pd.DataFrame] = {}
|
|
238
|
+
time_dict: dict[str, pd.Series] = {}
|
|
239
|
+
for table_name, columns in columns_dict.items():
|
|
240
|
+
df = self._graph_store.df_dict[table_name]
|
|
241
|
+
df = df.iloc[node_dict[table_name]].reset_index(drop=True)
|
|
242
|
+
df = df[list(columns)]
|
|
243
|
+
feat_dict[table_name] = df
|
|
244
|
+
|
|
245
|
+
time_column = self.time_column_dict.get(table_name)
|
|
246
|
+
if time_column in columns:
|
|
247
|
+
time_dict[table_name] = df[time_column]
|
|
248
|
+
|
|
249
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
250
|
+
query=query,
|
|
251
|
+
feat_dict=feat_dict,
|
|
252
|
+
time_dict=time_dict,
|
|
253
|
+
batch_dict=batch_dict,
|
|
254
|
+
anchor_time=time,
|
|
255
|
+
num_forecasts=query.num_forecasts,
|
|
256
|
+
)
|
|
257
|
+
ys.append(y)
|
|
258
|
+
mask[start:start + batch_size] = _mask
|
|
259
|
+
|
|
260
|
+
count += len(y)
|
|
261
|
+
if count >= num_examples:
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
if len(ys) == 0:
|
|
265
|
+
y = pd.Series([], dtype=float)
|
|
266
|
+
elif len(ys) == 1:
|
|
267
|
+
y = ys[0]
|
|
268
|
+
else:
|
|
269
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
270
|
+
|
|
271
|
+
return y, mask
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Helper Functions ############################################################
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
278
|
+
r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
|
|
279
|
+
|
|
280
|
+
.. note::
|
|
281
|
+
We are conservative and take months and years as their maximum value.
|
|
282
|
+
Additional values are then dropped in label computation where we know
|
|
283
|
+
the actual dates.
|
|
284
|
+
"""
|
|
285
|
+
MAX_DAYS_IN_MONTH = 31
|
|
286
|
+
MAX_DAYS_IN_YEAR = 366
|
|
287
|
+
|
|
288
|
+
SECONDS_IN_MINUTE = 60
|
|
289
|
+
SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
|
|
290
|
+
SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
|
|
291
|
+
|
|
292
|
+
total_sec = 0
|
|
293
|
+
multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
|
|
294
|
+
|
|
295
|
+
for attr, value in offset.__dict__.items():
|
|
296
|
+
if value is None or value == 0:
|
|
297
|
+
continue
|
|
298
|
+
scaled_value = value * multiplier
|
|
299
|
+
if attr == 'years':
|
|
300
|
+
total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
|
|
301
|
+
elif attr == 'months':
|
|
302
|
+
total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
|
|
303
|
+
elif attr == 'days':
|
|
304
|
+
total_sec += scaled_value * SECONDS_IN_DAY
|
|
305
|
+
elif attr == 'hours':
|
|
306
|
+
total_sec += scaled_value * SECONDS_IN_HOUR
|
|
307
|
+
elif attr == 'minutes':
|
|
308
|
+
total_sec += scaled_value * SECONDS_IN_MINUTE
|
|
309
|
+
elif attr == 'seconds':
|
|
310
|
+
total_sec += scaled_value
|
|
311
|
+
|
|
312
|
+
return total_sec
|