kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +294 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +221 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +447 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +203 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1775 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +67 -0
- kumoai/experimental/rfm/authenticate.py +433 -0
- kumoai/experimental/rfm/infer/__init__.py +11 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/local_graph.py +810 -0
- kumoai/experimental/rfm/local_graph_sampler.py +184 -0
- kumoai/experimental/rfm/local_graph_store.py +359 -0
- kumoai/experimental/rfm/local_pquery_driver.py +689 -0
- kumoai/experimental/rfm/local_table.py +545 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +1130 -0
- kumoai/experimental/rfm/utils.py +344 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +637 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +123 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +10 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +177 -0
- kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
- kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
- kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
- kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.model_plan import RunMode
|
|
6
|
+
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
|
+
from kumoapi.typing import Stype
|
|
8
|
+
|
|
9
|
+
import kumoai.kumolib as kumolib
|
|
10
|
+
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
+
from kumoai.experimental.rfm.utils import normalize_text
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LocalGraphSampler:
|
|
15
|
+
def __init__(self, graph_store: LocalGraphStore) -> None:
|
|
16
|
+
self._graph_store = graph_store
|
|
17
|
+
self._sampler = kumolib.NeighborSampler(
|
|
18
|
+
self._graph_store.node_types,
|
|
19
|
+
self._graph_store.edge_types,
|
|
20
|
+
{
|
|
21
|
+
'__'.join(edge_type): colptr
|
|
22
|
+
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
23
|
+
},
|
|
24
|
+
{
|
|
25
|
+
'__'.join(edge_type): row
|
|
26
|
+
for edge_type, row in self._graph_store.row_dict.items()
|
|
27
|
+
},
|
|
28
|
+
self._graph_store.time_dict,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def __call__(
|
|
32
|
+
self,
|
|
33
|
+
entity_table_names: Tuple[str, ...],
|
|
34
|
+
node: np.ndarray,
|
|
35
|
+
time: np.ndarray,
|
|
36
|
+
run_mode: RunMode,
|
|
37
|
+
num_neighbors: List[int],
|
|
38
|
+
exclude_cols_dict: Dict[str, List[str]],
|
|
39
|
+
) -> Subgraph:
|
|
40
|
+
|
|
41
|
+
(
|
|
42
|
+
row_dict,
|
|
43
|
+
col_dict,
|
|
44
|
+
node_dict,
|
|
45
|
+
batch_dict,
|
|
46
|
+
num_sampled_nodes_dict,
|
|
47
|
+
num_sampled_edges_dict,
|
|
48
|
+
) = self._sampler.sample(
|
|
49
|
+
{
|
|
50
|
+
'__'.join(edge_type): num_neighbors
|
|
51
|
+
for edge_type in self._graph_store.edge_types
|
|
52
|
+
},
|
|
53
|
+
{}, # time interval based sampling
|
|
54
|
+
entity_table_names[0],
|
|
55
|
+
node,
|
|
56
|
+
time // 1000**3, # nanoseconds to seconds
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
table_dict: Dict[str, Table] = {}
|
|
60
|
+
for table_name, node in node_dict.items():
|
|
61
|
+
batch = batch_dict[table_name]
|
|
62
|
+
|
|
63
|
+
if len(node) == 0:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
df = self._graph_store.df_dict[table_name]
|
|
67
|
+
|
|
68
|
+
num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
|
|
69
|
+
stype_dict = { # Exclude target columns:
|
|
70
|
+
column_name: stype
|
|
71
|
+
for column_name, stype in
|
|
72
|
+
self._graph_store.stype_dict[table_name].items()
|
|
73
|
+
if column_name not in exclude_cols_dict.get(table_name, [])
|
|
74
|
+
}
|
|
75
|
+
primary_key: Optional[str] = None
|
|
76
|
+
if table_name in entity_table_names:
|
|
77
|
+
primary_key = self._graph_store.pkey_name_dict.get(table_name)
|
|
78
|
+
|
|
79
|
+
columns: List[str] = []
|
|
80
|
+
if table_name in entity_table_names:
|
|
81
|
+
columns += [self._graph_store.pkey_name_dict[table_name]]
|
|
82
|
+
columns += list(stype_dict.keys())
|
|
83
|
+
|
|
84
|
+
if len(columns) == 0:
|
|
85
|
+
table_dict[table_name] = Table(
|
|
86
|
+
df=pd.DataFrame(index=range(len(node))),
|
|
87
|
+
row=None,
|
|
88
|
+
batch=batch,
|
|
89
|
+
num_sampled_nodes=num_sampled_nodes,
|
|
90
|
+
stype_dict=stype_dict,
|
|
91
|
+
primary_key=primary_key,
|
|
92
|
+
)
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
row: Optional[np.ndarray] = None
|
|
96
|
+
if table_name in self._graph_store.end_time_column_dict:
|
|
97
|
+
# Set end time to NaT for all values greater than anchor time:
|
|
98
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
99
|
+
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
100
|
+
ser = df[col_name]
|
|
101
|
+
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
102
|
+
mask = value > time[batch]
|
|
103
|
+
df.loc[mask, col_name] = pd.NaT
|
|
104
|
+
else:
|
|
105
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
106
|
+
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
107
|
+
if len(node) > 1.05 * len(unique_node):
|
|
108
|
+
df = df.iloc[unique_node].reset_index(drop=True)
|
|
109
|
+
row = inverse
|
|
110
|
+
else:
|
|
111
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
112
|
+
|
|
113
|
+
# Filter data frame to minimal set of columns:
|
|
114
|
+
df = df[columns]
|
|
115
|
+
|
|
116
|
+
# Normalize text (if not already pre-processed):
|
|
117
|
+
for column_name, stype in stype_dict.items():
|
|
118
|
+
if stype == Stype.text:
|
|
119
|
+
df[column_name] = normalize_text(df[column_name])
|
|
120
|
+
|
|
121
|
+
table_dict[table_name] = Table(
|
|
122
|
+
df=df,
|
|
123
|
+
row=row,
|
|
124
|
+
batch=batch,
|
|
125
|
+
num_sampled_nodes=num_sampled_nodes,
|
|
126
|
+
stype_dict=stype_dict,
|
|
127
|
+
primary_key=primary_key,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
link_dict: Dict[Tuple[str, str, str], Link] = {}
|
|
131
|
+
for edge_type in self._graph_store.edge_types:
|
|
132
|
+
edge_type_str = '__'.join(edge_type)
|
|
133
|
+
|
|
134
|
+
row = row_dict[edge_type_str]
|
|
135
|
+
col = col_dict[edge_type_str]
|
|
136
|
+
|
|
137
|
+
if len(row) == 0:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# Do not store reverse edge type if it is a replica:
|
|
141
|
+
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
142
|
+
rev_edge_type_str = '__'.join(rev_edge_type)
|
|
143
|
+
if (rev_edge_type in link_dict
|
|
144
|
+
and np.array_equal(row, col_dict[rev_edge_type_str])
|
|
145
|
+
and np.array_equal(col, row_dict[rev_edge_type_str])):
|
|
146
|
+
link = Link(
|
|
147
|
+
layout=EdgeLayout.REV,
|
|
148
|
+
row=None,
|
|
149
|
+
col=None,
|
|
150
|
+
num_sampled_edges=(
|
|
151
|
+
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
152
|
+
)
|
|
153
|
+
link_dict[edge_type] = link
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
layout = EdgeLayout.COO
|
|
157
|
+
if np.array_equal(row, np.arange(len(row))):
|
|
158
|
+
row = None
|
|
159
|
+
if np.array_equal(col, np.arange(len(col))):
|
|
160
|
+
col = None
|
|
161
|
+
|
|
162
|
+
# Store in compressed representation if more efficient:
|
|
163
|
+
num_cols = table_dict[edge_type[2]].num_rows
|
|
164
|
+
if col is not None and len(col) > num_cols + 1:
|
|
165
|
+
layout = EdgeLayout.CSC
|
|
166
|
+
colcount = np.bincount(col, minlength=num_cols)
|
|
167
|
+
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
168
|
+
col[0] = 0
|
|
169
|
+
np.cumsum(colcount, out=col[1:])
|
|
170
|
+
|
|
171
|
+
link = Link(
|
|
172
|
+
layout=layout,
|
|
173
|
+
row=row,
|
|
174
|
+
col=col,
|
|
175
|
+
num_sampled_edges=(
|
|
176
|
+
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
177
|
+
)
|
|
178
|
+
link_dict[edge_type] = link
|
|
179
|
+
|
|
180
|
+
return Subgraph(
|
|
181
|
+
anchor_time=time,
|
|
182
|
+
table_dict=table_dict,
|
|
183
|
+
link_dict=link_dict,
|
|
184
|
+
)
|
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from kumoapi.rfm.context import Subgraph
|
|
7
|
+
from kumoapi.typing import Stype
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import LocalGraph
|
|
10
|
+
from kumoai.experimental.rfm.utils import normalize_text
|
|
11
|
+
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import torch
|
|
15
|
+
WITH_TORCH = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
WITH_TORCH = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LocalGraphStore:
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
graph: LocalGraph,
|
|
24
|
+
preprocess: bool = False,
|
|
25
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
26
|
+
) -> None:
|
|
27
|
+
|
|
28
|
+
if not isinstance(verbose, ProgressLogger):
|
|
29
|
+
verbose = InteractiveProgressLogger(
|
|
30
|
+
"Materializing graph",
|
|
31
|
+
verbose=verbose,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
with verbose as logger:
|
|
35
|
+
self.df_dict, self.mask_dict = self.sanitize(graph, preprocess)
|
|
36
|
+
self.stype_dict = self.get_stype_dict(graph)
|
|
37
|
+
logger.log("Sanitized input data")
|
|
38
|
+
|
|
39
|
+
self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
|
|
40
|
+
num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
|
|
41
|
+
if num_pkeys > 1:
|
|
42
|
+
logger.log(f"Collected primary keys from {num_pkeys} tables")
|
|
43
|
+
else:
|
|
44
|
+
logger.log(f"Collected primary key from {num_pkeys} table")
|
|
45
|
+
|
|
46
|
+
(
|
|
47
|
+
self.time_column_dict,
|
|
48
|
+
self.end_time_column_dict,
|
|
49
|
+
self.time_dict,
|
|
50
|
+
self.min_time,
|
|
51
|
+
self.max_time,
|
|
52
|
+
) = self.get_time_data(graph)
|
|
53
|
+
if self.max_time != pd.Timestamp.min:
|
|
54
|
+
logger.log(f"Identified temporal graph from "
|
|
55
|
+
f"{self.min_time.date()} to {self.max_time.date()}")
|
|
56
|
+
else:
|
|
57
|
+
logger.log("Identified static graph without timestamps")
|
|
58
|
+
|
|
59
|
+
self.row_dict, self.colptr_dict = self.get_csc(graph)
|
|
60
|
+
num_nodes = sum(len(df) for df in self.df_dict.values())
|
|
61
|
+
num_edges = sum(len(row) for row in self.row_dict.values())
|
|
62
|
+
logger.log(f"Created graph with {num_nodes:,} nodes and "
|
|
63
|
+
f"{num_edges:,} edges")
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def node_types(self) -> List[str]:
|
|
67
|
+
return list(self.df_dict.keys())
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def edge_types(self) -> List[Tuple[str, str, str]]:
|
|
71
|
+
return list(self.row_dict.keys())
|
|
72
|
+
|
|
73
|
+
def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
|
|
74
|
+
r"""Returns the node ID given primary keys.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
table_name: The table name.
|
|
78
|
+
pkey: The primary keys to receive node IDs for.
|
|
79
|
+
"""
|
|
80
|
+
if table_name not in self.df_dict.keys():
|
|
81
|
+
raise KeyError(f"Table '{table_name}' does not exist")
|
|
82
|
+
|
|
83
|
+
if table_name not in self.pkey_map_dict.keys():
|
|
84
|
+
raise ValueError(f"Table '{table_name}' does not have a primary "
|
|
85
|
+
f"key")
|
|
86
|
+
|
|
87
|
+
if len(pkey) == 0:
|
|
88
|
+
raise KeyError(f"No primary keys passed for table '{table_name}'")
|
|
89
|
+
|
|
90
|
+
pkey_map = self.pkey_map_dict[table_name]
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
pkey = pkey.astype(type(pkey_map.index[0]))
|
|
94
|
+
except ValueError as e:
|
|
95
|
+
raise ValueError(f"Could not cast primary keys "
|
|
96
|
+
f"{pkey.tolist()} to the expected data "
|
|
97
|
+
f"type '{pkey_map.index.dtype}'") from e
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
return pkey_map.loc[pkey]['arange'].to_numpy()
|
|
101
|
+
except KeyError as e:
|
|
102
|
+
missing = ~np.isin(pkey, pkey_map.index)
|
|
103
|
+
raise KeyError(f"The primary keys {pkey[missing].tolist()} do "
|
|
104
|
+
f"not exist in the '{table_name}' table") from e
|
|
105
|
+
|
|
106
|
+
def sanitize(
|
|
107
|
+
self,
|
|
108
|
+
graph: LocalGraph,
|
|
109
|
+
preprocess: bool = False,
|
|
110
|
+
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
|
|
111
|
+
r"""Sanitizes raw data according to table schema definition:
|
|
112
|
+
|
|
113
|
+
In particular, it:
|
|
114
|
+
* converts timestamp data to `pd.Datetime`
|
|
115
|
+
* drops timezone information from timestamps
|
|
116
|
+
* drops duplicate primary keys
|
|
117
|
+
* removes rows with missing primary keys or time values
|
|
118
|
+
|
|
119
|
+
If ``preprocess`` is set to ``True``, it will additionally pre-process
|
|
120
|
+
data for faster model processing. In particular, it:
|
|
121
|
+
* tokenizes any text column that is not a foreign key
|
|
122
|
+
"""
|
|
123
|
+
df_dict: Dict[str, pd.DataFrame] = {
|
|
124
|
+
table_name: table._data.copy(deep=False).reset_index(drop=True)
|
|
125
|
+
for table_name, table in graph.tables.items()
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
129
|
+
|
|
130
|
+
mask_dict: Dict[str, np.ndarray] = {}
|
|
131
|
+
for table in graph.tables.values():
|
|
132
|
+
for col in table.columns:
|
|
133
|
+
if col.stype == Stype.timestamp:
|
|
134
|
+
ser = df_dict[table.name][col.name]
|
|
135
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
136
|
+
with warnings.catch_warnings():
|
|
137
|
+
warnings.filterwarnings(
|
|
138
|
+
'ignore',
|
|
139
|
+
message='Could not infer format',
|
|
140
|
+
)
|
|
141
|
+
ser = pd.to_datetime(ser, errors='coerce')
|
|
142
|
+
df_dict[table.name][col.name] = ser
|
|
143
|
+
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
144
|
+
ser = ser.dt.tz_localize(None)
|
|
145
|
+
df_dict[table.name][col.name] = ser
|
|
146
|
+
|
|
147
|
+
# Normalize text in advance (but exclude foreign keys):
|
|
148
|
+
if (preprocess and col.stype == Stype.text
|
|
149
|
+
and (table.name, col.name) not in foreign_keys):
|
|
150
|
+
ser = df_dict[table.name][col.name]
|
|
151
|
+
df_dict[table.name][col.name] = normalize_text(ser)
|
|
152
|
+
|
|
153
|
+
mask: Optional[np.ndarray] = None
|
|
154
|
+
if table._time_column is not None:
|
|
155
|
+
ser = df_dict[table.name][table._time_column]
|
|
156
|
+
mask = ser.notna().to_numpy()
|
|
157
|
+
|
|
158
|
+
if table._primary_key is not None:
|
|
159
|
+
ser = df_dict[table.name][table._primary_key]
|
|
160
|
+
_mask = (~ser.duplicated().to_numpy()) & ser.notna().to_numpy()
|
|
161
|
+
mask = _mask if mask is None else (_mask & mask)
|
|
162
|
+
|
|
163
|
+
if mask is not None and not mask.all():
|
|
164
|
+
mask_dict[table.name] = mask
|
|
165
|
+
|
|
166
|
+
return df_dict, mask_dict
|
|
167
|
+
|
|
168
|
+
def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
|
|
169
|
+
stype_dict: Dict[str, Dict[str, Stype]] = {}
|
|
170
|
+
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
171
|
+
for table in graph.tables.values():
|
|
172
|
+
stype_dict[table.name] = {}
|
|
173
|
+
for column in table.columns:
|
|
174
|
+
if column == table.primary_key:
|
|
175
|
+
continue
|
|
176
|
+
if (table.name, column.name) in foreign_keys:
|
|
177
|
+
continue
|
|
178
|
+
stype_dict[table.name][column.name] = column.stype
|
|
179
|
+
return stype_dict
|
|
180
|
+
|
|
181
|
+
def get_pkey_data(
|
|
182
|
+
self,
|
|
183
|
+
graph: LocalGraph,
|
|
184
|
+
) -> Tuple[
|
|
185
|
+
Dict[str, str],
|
|
186
|
+
Dict[str, pd.DataFrame],
|
|
187
|
+
]:
|
|
188
|
+
pkey_name_dict: Dict[str, str] = {}
|
|
189
|
+
pkey_map_dict: Dict[str, pd.DataFrame] = {}
|
|
190
|
+
|
|
191
|
+
for table in graph.tables.values():
|
|
192
|
+
if table._primary_key is None:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
pkey_name_dict[table.name] = table._primary_key
|
|
196
|
+
pkey = self.df_dict[table.name][table._primary_key]
|
|
197
|
+
pkey_map = pd.DataFrame(
|
|
198
|
+
dict(arange=range(len(pkey))),
|
|
199
|
+
index=pkey,
|
|
200
|
+
)
|
|
201
|
+
if table.name in self.mask_dict:
|
|
202
|
+
pkey_map = pkey_map[self.mask_dict[table.name]]
|
|
203
|
+
|
|
204
|
+
if len(pkey_map) == 0:
|
|
205
|
+
error_msg = f"Found no valid rows in table '{table.name}'. "
|
|
206
|
+
if table.has_time_column():
|
|
207
|
+
error_msg += ("Please make sure that there exists valid "
|
|
208
|
+
"non-N/A primary key and time column pairs "
|
|
209
|
+
"in this table.")
|
|
210
|
+
else:
|
|
211
|
+
error_msg += ("Please make sure that there exists valid "
|
|
212
|
+
"non-N/A primary keys in this table.")
|
|
213
|
+
raise ValueError(error_msg)
|
|
214
|
+
|
|
215
|
+
pkey_map_dict[table.name] = pkey_map
|
|
216
|
+
|
|
217
|
+
return pkey_name_dict, pkey_map_dict
|
|
218
|
+
|
|
219
|
+
def get_time_data(
|
|
220
|
+
self,
|
|
221
|
+
graph: LocalGraph,
|
|
222
|
+
) -> Tuple[
|
|
223
|
+
Dict[str, str],
|
|
224
|
+
Dict[str, str],
|
|
225
|
+
Dict[str, np.ndarray],
|
|
226
|
+
pd.Timestamp,
|
|
227
|
+
pd.Timestamp,
|
|
228
|
+
]:
|
|
229
|
+
time_column_dict: Dict[str, str] = {}
|
|
230
|
+
end_time_column_dict: Dict[str, str] = {}
|
|
231
|
+
time_dict: Dict[str, np.ndarray] = {}
|
|
232
|
+
min_time = pd.Timestamp.max
|
|
233
|
+
max_time = pd.Timestamp.min
|
|
234
|
+
for table in graph.tables.values():
|
|
235
|
+
if table._end_time_column is not None:
|
|
236
|
+
end_time_column_dict[table.name] = table._end_time_column
|
|
237
|
+
|
|
238
|
+
if table._time_column is None:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
time = self.df_dict[table.name][table._time_column]
|
|
242
|
+
time_dict[table.name] = time.astype('datetime64[ns]').astype(
|
|
243
|
+
int).to_numpy() // 1000**3
|
|
244
|
+
time_column_dict[table.name] = table._time_column
|
|
245
|
+
|
|
246
|
+
if table.name in self.mask_dict.keys():
|
|
247
|
+
time = time[self.mask_dict[table.name]]
|
|
248
|
+
if len(time) > 0:
|
|
249
|
+
min_time = min(min_time, time.min())
|
|
250
|
+
max_time = max(max_time, time.max())
|
|
251
|
+
|
|
252
|
+
return (
|
|
253
|
+
time_column_dict,
|
|
254
|
+
end_time_column_dict,
|
|
255
|
+
time_dict,
|
|
256
|
+
min_time,
|
|
257
|
+
max_time,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def get_csc(
|
|
261
|
+
self,
|
|
262
|
+
graph: LocalGraph,
|
|
263
|
+
) -> Tuple[
|
|
264
|
+
Dict[Tuple[str, str, str], np.ndarray],
|
|
265
|
+
Dict[Tuple[str, str, str], np.ndarray],
|
|
266
|
+
]:
|
|
267
|
+
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
268
|
+
map_dict: Dict[str, pd.CategoricalDtype] = {}
|
|
269
|
+
# A dictionary to manage offsets of node indices for invalid rows:
|
|
270
|
+
offset_dict: Dict[str, np.ndarray] = {}
|
|
271
|
+
for table_name in set(edge.dst_table for edge in graph.edges):
|
|
272
|
+
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
273
|
+
if table_name in self.mask_dict.keys():
|
|
274
|
+
mask = self.mask_dict[table_name]
|
|
275
|
+
ser = ser[mask]
|
|
276
|
+
offset_dict[table_name] = np.cumsum(~mask)[mask]
|
|
277
|
+
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
278
|
+
|
|
279
|
+
# Build CSC graph representation:
|
|
280
|
+
row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
|
|
281
|
+
colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
|
|
282
|
+
for src_table, fkey, dst_table in graph.edges:
|
|
283
|
+
src_df = self.df_dict[src_table]
|
|
284
|
+
dst_df = self.df_dict[dst_table]
|
|
285
|
+
|
|
286
|
+
src = np.arange(len(src_df))
|
|
287
|
+
dst = src_df[fkey].astype(map_dict[dst_table]).cat.codes.to_numpy()
|
|
288
|
+
dst = dst.astype(int)
|
|
289
|
+
mask = dst >= 0
|
|
290
|
+
if dst_table in offset_dict.keys():
|
|
291
|
+
dst = dst + offset_dict[dst_table][dst]
|
|
292
|
+
if src_table in self.mask_dict.keys():
|
|
293
|
+
mask &= self.mask_dict[src_table]
|
|
294
|
+
src, dst = src[mask], dst[mask]
|
|
295
|
+
|
|
296
|
+
# Sort by destination/column (and time within neighborhoods):
|
|
297
|
+
# `lexsort` is expensive (especially in numpy) so avoid it if
|
|
298
|
+
# possible by grouping `time` and `node_id` together:
|
|
299
|
+
if src_table in self.time_dict:
|
|
300
|
+
src_time = self.time_dict[src_table][src]
|
|
301
|
+
min_time = int(src_time.min())
|
|
302
|
+
max_time = int(src_time.max())
|
|
303
|
+
offset = (max_time - min_time) + 1
|
|
304
|
+
if offset * len(dst_df) <= np.iinfo(np.int64).max:
|
|
305
|
+
index = dst * offset + (src_time - min_time)
|
|
306
|
+
perm = _argsort(index)
|
|
307
|
+
else: # Safe route to avoid `int64` overflow:
|
|
308
|
+
perm = _lexsort([src_time, dst])
|
|
309
|
+
else:
|
|
310
|
+
perm = _argsort(dst)
|
|
311
|
+
|
|
312
|
+
row, col = src[perm], dst[perm]
|
|
313
|
+
|
|
314
|
+
# Convert into compressed representation:
|
|
315
|
+
colcount = np.bincount(col, minlength=len(dst_df))
|
|
316
|
+
colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
|
|
317
|
+
colptr[0] = 0
|
|
318
|
+
np.cumsum(colcount, out=colptr[1:])
|
|
319
|
+
edge_type = (src_table, fkey, dst_table)
|
|
320
|
+
row_dict[edge_type] = row
|
|
321
|
+
colptr_dict[edge_type] = colptr
|
|
322
|
+
|
|
323
|
+
# Reverse connection - no sort and no time handling needed since
|
|
324
|
+
# the reverse mapping is 1-to-many.
|
|
325
|
+
row, col = dst, src
|
|
326
|
+
colcount = np.bincount(col, minlength=len(src_df))
|
|
327
|
+
colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
|
|
328
|
+
colptr[0] = 0
|
|
329
|
+
np.cumsum(colcount, out=colptr[1:])
|
|
330
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
331
|
+
row_dict[edge_type] = row
|
|
332
|
+
colptr_dict[edge_type] = colptr
|
|
333
|
+
|
|
334
|
+
return row_dict, colptr_dict
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _argsort(input: np.ndarray) -> np.ndarray:
|
|
338
|
+
if not WITH_TORCH:
|
|
339
|
+
return np.argsort(input)
|
|
340
|
+
return torch.from_numpy(input).argsort().numpy()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
|
|
344
|
+
assert len(inputs) >= 1
|
|
345
|
+
|
|
346
|
+
if not WITH_TORCH:
|
|
347
|
+
return np.lexsort(inputs)
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
out = torch.from_numpy(inputs[0]).argsort(stable=True)
|
|
351
|
+
except Exception:
|
|
352
|
+
return np.lexsort(inputs) # PyTorch<1.9 without `stable` support.
|
|
353
|
+
|
|
354
|
+
for input in inputs[1:]:
|
|
355
|
+
index = torch.from_numpy(input)[out]
|
|
356
|
+
index = index.argsort(stable=True)
|
|
357
|
+
out = out[index]
|
|
358
|
+
|
|
359
|
+
return out.numpy()
|