kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,184 @@
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.model_plan import RunMode
6
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
+ from kumoapi.typing import Stype
8
+
9
+ import kumoai.kumolib as kumolib
10
+ from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
+ from kumoai.experimental.rfm.utils import normalize_text
12
+
13
+
14
+ class LocalGraphSampler:
15
+ def __init__(self, graph_store: LocalGraphStore) -> None:
16
+ self._graph_store = graph_store
17
+ self._sampler = kumolib.NeighborSampler(
18
+ self._graph_store.node_types,
19
+ self._graph_store.edge_types,
20
+ {
21
+ '__'.join(edge_type): colptr
22
+ for edge_type, colptr in self._graph_store.colptr_dict.items()
23
+ },
24
+ {
25
+ '__'.join(edge_type): row
26
+ for edge_type, row in self._graph_store.row_dict.items()
27
+ },
28
+ self._graph_store.time_dict,
29
+ )
30
+
31
+ def __call__(
32
+ self,
33
+ entity_table_names: Tuple[str, ...],
34
+ node: np.ndarray,
35
+ time: np.ndarray,
36
+ run_mode: RunMode,
37
+ num_neighbors: List[int],
38
+ exclude_cols_dict: Dict[str, List[str]],
39
+ ) -> Subgraph:
40
+
41
+ (
42
+ row_dict,
43
+ col_dict,
44
+ node_dict,
45
+ batch_dict,
46
+ num_sampled_nodes_dict,
47
+ num_sampled_edges_dict,
48
+ ) = self._sampler.sample(
49
+ {
50
+ '__'.join(edge_type): num_neighbors
51
+ for edge_type in self._graph_store.edge_types
52
+ },
53
+ {}, # time interval based sampling
54
+ entity_table_names[0],
55
+ node,
56
+ time // 1000**3, # nanoseconds to seconds
57
+ )
58
+
59
+ table_dict: Dict[str, Table] = {}
60
+ for table_name, node in node_dict.items():
61
+ batch = batch_dict[table_name]
62
+
63
+ if len(node) == 0:
64
+ continue
65
+
66
+ df = self._graph_store.df_dict[table_name]
67
+
68
+ num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
69
+ stype_dict = { # Exclude target columns:
70
+ column_name: stype
71
+ for column_name, stype in
72
+ self._graph_store.stype_dict[table_name].items()
73
+ if column_name not in exclude_cols_dict.get(table_name, [])
74
+ }
75
+ primary_key: Optional[str] = None
76
+ if table_name in entity_table_names:
77
+ primary_key = self._graph_store.pkey_name_dict.get(table_name)
78
+
79
+ columns: List[str] = []
80
+ if table_name in entity_table_names:
81
+ columns += [self._graph_store.pkey_name_dict[table_name]]
82
+ columns += list(stype_dict.keys())
83
+
84
+ if len(columns) == 0:
85
+ table_dict[table_name] = Table(
86
+ df=pd.DataFrame(index=range(len(node))),
87
+ row=None,
88
+ batch=batch,
89
+ num_sampled_nodes=num_sampled_nodes,
90
+ stype_dict=stype_dict,
91
+ primary_key=primary_key,
92
+ )
93
+ continue
94
+
95
+ row: Optional[np.ndarray] = None
96
+ if table_name in self._graph_store.end_time_column_dict:
97
+ # Set end time to NaT for all values greater than anchor time:
98
+ df = df.iloc[node].reset_index(drop=True)
99
+ col_name = self._graph_store.end_time_column_dict[table_name]
100
+ ser = df[col_name]
101
+ value = ser.astype('datetime64[ns]').astype(int).to_numpy()
102
+ mask = value > time[batch]
103
+ df.loc[mask, col_name] = pd.NaT
104
+ else:
105
+ # Only store unique rows in `df` above a certain threshold:
106
+ unique_node, inverse = np.unique(node, return_inverse=True)
107
+ if len(node) > 1.05 * len(unique_node):
108
+ df = df.iloc[unique_node].reset_index(drop=True)
109
+ row = inverse
110
+ else:
111
+ df = df.iloc[node].reset_index(drop=True)
112
+
113
+ # Filter data frame to minimal set of columns:
114
+ df = df[columns]
115
+
116
+ # Normalize text (if not already pre-processed):
117
+ for column_name, stype in stype_dict.items():
118
+ if stype == Stype.text:
119
+ df[column_name] = normalize_text(df[column_name])
120
+
121
+ table_dict[table_name] = Table(
122
+ df=df,
123
+ row=row,
124
+ batch=batch,
125
+ num_sampled_nodes=num_sampled_nodes,
126
+ stype_dict=stype_dict,
127
+ primary_key=primary_key,
128
+ )
129
+
130
+ link_dict: Dict[Tuple[str, str, str], Link] = {}
131
+ for edge_type in self._graph_store.edge_types:
132
+ edge_type_str = '__'.join(edge_type)
133
+
134
+ row = row_dict[edge_type_str]
135
+ col = col_dict[edge_type_str]
136
+
137
+ if len(row) == 0:
138
+ continue
139
+
140
+ # Do not store reverse edge type if it is a replica:
141
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
142
+ rev_edge_type_str = '__'.join(rev_edge_type)
143
+ if (rev_edge_type in link_dict
144
+ and np.array_equal(row, col_dict[rev_edge_type_str])
145
+ and np.array_equal(col, row_dict[rev_edge_type_str])):
146
+ link = Link(
147
+ layout=EdgeLayout.REV,
148
+ row=None,
149
+ col=None,
150
+ num_sampled_edges=(
151
+ num_sampled_edges_dict[edge_type_str].tolist()),
152
+ )
153
+ link_dict[edge_type] = link
154
+ continue
155
+
156
+ layout = EdgeLayout.COO
157
+ if np.array_equal(row, np.arange(len(row))):
158
+ row = None
159
+ if np.array_equal(col, np.arange(len(col))):
160
+ col = None
161
+
162
+ # Store in compressed representation if more efficient:
163
+ num_cols = table_dict[edge_type[2]].num_rows
164
+ if col is not None and len(col) > num_cols + 1:
165
+ layout = EdgeLayout.CSC
166
+ colcount = np.bincount(col, minlength=num_cols)
167
+ col = np.empty(num_cols + 1, dtype=col.dtype)
168
+ col[0] = 0
169
+ np.cumsum(colcount, out=col[1:])
170
+
171
+ link = Link(
172
+ layout=layout,
173
+ row=row,
174
+ col=col,
175
+ num_sampled_edges=(
176
+ num_sampled_edges_dict[edge_type_str].tolist()),
177
+ )
178
+ link_dict[edge_type] = link
179
+
180
+ return Subgraph(
181
+ anchor_time=time,
182
+ table_dict=table_dict,
183
+ link_dict=link_dict,
184
+ )
@@ -0,0 +1,359 @@
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from kumoapi.rfm.context import Subgraph
7
+ from kumoapi.typing import Stype
8
+
9
+ from kumoai.experimental.rfm import LocalGraph
10
+ from kumoai.experimental.rfm.utils import normalize_text
11
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger
12
+
13
+ try:
14
+ import torch
15
+ WITH_TORCH = True
16
+ except ImportError:
17
+ WITH_TORCH = False
18
+
19
+
20
+ class LocalGraphStore:
21
+ def __init__(
22
+ self,
23
+ graph: LocalGraph,
24
+ preprocess: bool = False,
25
+ verbose: Union[bool, ProgressLogger] = True,
26
+ ) -> None:
27
+
28
+ if not isinstance(verbose, ProgressLogger):
29
+ verbose = InteractiveProgressLogger(
30
+ "Materializing graph",
31
+ verbose=verbose,
32
+ )
33
+
34
+ with verbose as logger:
35
+ self.df_dict, self.mask_dict = self.sanitize(graph, preprocess)
36
+ self.stype_dict = self.get_stype_dict(graph)
37
+ logger.log("Sanitized input data")
38
+
39
+ self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
40
+ num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
41
+ if num_pkeys > 1:
42
+ logger.log(f"Collected primary keys from {num_pkeys} tables")
43
+ else:
44
+ logger.log(f"Collected primary key from {num_pkeys} table")
45
+
46
+ (
47
+ self.time_column_dict,
48
+ self.end_time_column_dict,
49
+ self.time_dict,
50
+ self.min_time,
51
+ self.max_time,
52
+ ) = self.get_time_data(graph)
53
+ if self.max_time != pd.Timestamp.min:
54
+ logger.log(f"Identified temporal graph from "
55
+ f"{self.min_time.date()} to {self.max_time.date()}")
56
+ else:
57
+ logger.log("Identified static graph without timestamps")
58
+
59
+ self.row_dict, self.colptr_dict = self.get_csc(graph)
60
+ num_nodes = sum(len(df) for df in self.df_dict.values())
61
+ num_edges = sum(len(row) for row in self.row_dict.values())
62
+ logger.log(f"Created graph with {num_nodes:,} nodes and "
63
+ f"{num_edges:,} edges")
64
+
65
+ @property
66
+ def node_types(self) -> List[str]:
67
+ return list(self.df_dict.keys())
68
+
69
+ @property
70
+ def edge_types(self) -> List[Tuple[str, str, str]]:
71
+ return list(self.row_dict.keys())
72
+
73
+ def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
74
+ r"""Returns the node ID given primary keys.
75
+
76
+ Args:
77
+ table_name: The table name.
78
+ pkey: The primary keys to receive node IDs for.
79
+ """
80
+ if table_name not in self.df_dict.keys():
81
+ raise KeyError(f"Table '{table_name}' does not exist")
82
+
83
+ if table_name not in self.pkey_map_dict.keys():
84
+ raise ValueError(f"Table '{table_name}' does not have a primary "
85
+ f"key")
86
+
87
+ if len(pkey) == 0:
88
+ raise KeyError(f"No primary keys passed for table '{table_name}'")
89
+
90
+ pkey_map = self.pkey_map_dict[table_name]
91
+
92
+ try:
93
+ pkey = pkey.astype(type(pkey_map.index[0]))
94
+ except ValueError as e:
95
+ raise ValueError(f"Could not cast primary keys "
96
+ f"{pkey.tolist()} to the expected data "
97
+ f"type '{pkey_map.index.dtype}'") from e
98
+
99
+ try:
100
+ return pkey_map.loc[pkey]['arange'].to_numpy()
101
+ except KeyError as e:
102
+ missing = ~np.isin(pkey, pkey_map.index)
103
+ raise KeyError(f"The primary keys {pkey[missing].tolist()} do "
104
+ f"not exist in the '{table_name}' table") from e
105
+
106
+ def sanitize(
107
+ self,
108
+ graph: LocalGraph,
109
+ preprocess: bool = False,
110
+ ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
111
+ r"""Sanitizes raw data according to table schema definition:
112
+
113
+ In particular, it:
114
+ * converts timestamp data to `pd.Datetime`
115
+ * drops timezone information from timestamps
116
+ * drops duplicate primary keys
117
+ * removes rows with missing primary keys or time values
118
+
119
+ If ``preprocess`` is set to ``True``, it will additionally pre-process
120
+ data for faster model processing. In particular, it:
121
+ * tokenizes any text column that is not a foreign key
122
+ """
123
+ df_dict: Dict[str, pd.DataFrame] = {
124
+ table_name: table._data.copy(deep=False).reset_index(drop=True)
125
+ for table_name, table in graph.tables.items()
126
+ }
127
+
128
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
129
+
130
+ mask_dict: Dict[str, np.ndarray] = {}
131
+ for table in graph.tables.values():
132
+ for col in table.columns:
133
+ if col.stype == Stype.timestamp:
134
+ ser = df_dict[table.name][col.name]
135
+ if not pd.api.types.is_datetime64_any_dtype(ser):
136
+ with warnings.catch_warnings():
137
+ warnings.filterwarnings(
138
+ 'ignore',
139
+ message='Could not infer format',
140
+ )
141
+ ser = pd.to_datetime(ser, errors='coerce')
142
+ df_dict[table.name][col.name] = ser
143
+ if isinstance(ser.dtype, pd.DatetimeTZDtype):
144
+ ser = ser.dt.tz_localize(None)
145
+ df_dict[table.name][col.name] = ser
146
+
147
+ # Normalize text in advance (but exclude foreign keys):
148
+ if (preprocess and col.stype == Stype.text
149
+ and (table.name, col.name) not in foreign_keys):
150
+ ser = df_dict[table.name][col.name]
151
+ df_dict[table.name][col.name] = normalize_text(ser)
152
+
153
+ mask: Optional[np.ndarray] = None
154
+ if table._time_column is not None:
155
+ ser = df_dict[table.name][table._time_column]
156
+ mask = ser.notna().to_numpy()
157
+
158
+ if table._primary_key is not None:
159
+ ser = df_dict[table.name][table._primary_key]
160
+ _mask = (~ser.duplicated().to_numpy()) & ser.notna().to_numpy()
161
+ mask = _mask if mask is None else (_mask & mask)
162
+
163
+ if mask is not None and not mask.all():
164
+ mask_dict[table.name] = mask
165
+
166
+ return df_dict, mask_dict
167
+
168
+ def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
169
+ stype_dict: Dict[str, Dict[str, Stype]] = {}
170
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
171
+ for table in graph.tables.values():
172
+ stype_dict[table.name] = {}
173
+ for column in table.columns:
174
+ if column == table.primary_key:
175
+ continue
176
+ if (table.name, column.name) in foreign_keys:
177
+ continue
178
+ stype_dict[table.name][column.name] = column.stype
179
+ return stype_dict
180
+
181
+ def get_pkey_data(
182
+ self,
183
+ graph: LocalGraph,
184
+ ) -> Tuple[
185
+ Dict[str, str],
186
+ Dict[str, pd.DataFrame],
187
+ ]:
188
+ pkey_name_dict: Dict[str, str] = {}
189
+ pkey_map_dict: Dict[str, pd.DataFrame] = {}
190
+
191
+ for table in graph.tables.values():
192
+ if table._primary_key is None:
193
+ continue
194
+
195
+ pkey_name_dict[table.name] = table._primary_key
196
+ pkey = self.df_dict[table.name][table._primary_key]
197
+ pkey_map = pd.DataFrame(
198
+ dict(arange=range(len(pkey))),
199
+ index=pkey,
200
+ )
201
+ if table.name in self.mask_dict:
202
+ pkey_map = pkey_map[self.mask_dict[table.name]]
203
+
204
+ if len(pkey_map) == 0:
205
+ error_msg = f"Found no valid rows in table '{table.name}'. "
206
+ if table.has_time_column():
207
+ error_msg += ("Please make sure that there exists valid "
208
+ "non-N/A primary key and time column pairs "
209
+ "in this table.")
210
+ else:
211
+ error_msg += ("Please make sure that there exists valid "
212
+ "non-N/A primary keys in this table.")
213
+ raise ValueError(error_msg)
214
+
215
+ pkey_map_dict[table.name] = pkey_map
216
+
217
+ return pkey_name_dict, pkey_map_dict
218
+
219
+ def get_time_data(
220
+ self,
221
+ graph: LocalGraph,
222
+ ) -> Tuple[
223
+ Dict[str, str],
224
+ Dict[str, str],
225
+ Dict[str, np.ndarray],
226
+ pd.Timestamp,
227
+ pd.Timestamp,
228
+ ]:
229
+ time_column_dict: Dict[str, str] = {}
230
+ end_time_column_dict: Dict[str, str] = {}
231
+ time_dict: Dict[str, np.ndarray] = {}
232
+ min_time = pd.Timestamp.max
233
+ max_time = pd.Timestamp.min
234
+ for table in graph.tables.values():
235
+ if table._end_time_column is not None:
236
+ end_time_column_dict[table.name] = table._end_time_column
237
+
238
+ if table._time_column is None:
239
+ continue
240
+
241
+ time = self.df_dict[table.name][table._time_column]
242
+ time_dict[table.name] = time.astype('datetime64[ns]').astype(
243
+ int).to_numpy() // 1000**3
244
+ time_column_dict[table.name] = table._time_column
245
+
246
+ if table.name in self.mask_dict.keys():
247
+ time = time[self.mask_dict[table.name]]
248
+ if len(time) > 0:
249
+ min_time = min(min_time, time.min())
250
+ max_time = max(max_time, time.max())
251
+
252
+ return (
253
+ time_column_dict,
254
+ end_time_column_dict,
255
+ time_dict,
256
+ min_time,
257
+ max_time,
258
+ )
259
+
260
+ def get_csc(
261
+ self,
262
+ graph: LocalGraph,
263
+ ) -> Tuple[
264
+ Dict[Tuple[str, str, str], np.ndarray],
265
+ Dict[Tuple[str, str, str], np.ndarray],
266
+ ]:
267
+ # A mapping from raw primary keys to node indices (0 to N-1):
268
+ map_dict: Dict[str, pd.CategoricalDtype] = {}
269
+ # A dictionary to manage offsets of node indices for invalid rows:
270
+ offset_dict: Dict[str, np.ndarray] = {}
271
+ for table_name in set(edge.dst_table for edge in graph.edges):
272
+ ser = self.df_dict[table_name][graph[table_name]._primary_key]
273
+ if table_name in self.mask_dict.keys():
274
+ mask = self.mask_dict[table_name]
275
+ ser = ser[mask]
276
+ offset_dict[table_name] = np.cumsum(~mask)[mask]
277
+ map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
278
+
279
+ # Build CSC graph representation:
280
+ row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
281
+ colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
282
+ for src_table, fkey, dst_table in graph.edges:
283
+ src_df = self.df_dict[src_table]
284
+ dst_df = self.df_dict[dst_table]
285
+
286
+ src = np.arange(len(src_df))
287
+ dst = src_df[fkey].astype(map_dict[dst_table]).cat.codes.to_numpy()
288
+ dst = dst.astype(int)
289
+ mask = dst >= 0
290
+ if dst_table in offset_dict.keys():
291
+ dst = dst + offset_dict[dst_table][dst]
292
+ if src_table in self.mask_dict.keys():
293
+ mask &= self.mask_dict[src_table]
294
+ src, dst = src[mask], dst[mask]
295
+
296
+ # Sort by destination/column (and time within neighborhoods):
297
+ # `lexsort` is expensive (especially in numpy) so avoid it if
298
+ # possible by grouping `time` and `node_id` together:
299
+ if src_table in self.time_dict:
300
+ src_time = self.time_dict[src_table][src]
301
+ min_time = int(src_time.min())
302
+ max_time = int(src_time.max())
303
+ offset = (max_time - min_time) + 1
304
+ if offset * len(dst_df) <= np.iinfo(np.int64).max:
305
+ index = dst * offset + (src_time - min_time)
306
+ perm = _argsort(index)
307
+ else: # Safe route to avoid `int64` overflow:
308
+ perm = _lexsort([src_time, dst])
309
+ else:
310
+ perm = _argsort(dst)
311
+
312
+ row, col = src[perm], dst[perm]
313
+
314
+ # Convert into compressed representation:
315
+ colcount = np.bincount(col, minlength=len(dst_df))
316
+ colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
317
+ colptr[0] = 0
318
+ np.cumsum(colcount, out=colptr[1:])
319
+ edge_type = (src_table, fkey, dst_table)
320
+ row_dict[edge_type] = row
321
+ colptr_dict[edge_type] = colptr
322
+
323
+ # Reverse connection - no sort and no time handling needed since
324
+ # the reverse mapping is 1-to-many.
325
+ row, col = dst, src
326
+ colcount = np.bincount(col, minlength=len(src_df))
327
+ colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
328
+ colptr[0] = 0
329
+ np.cumsum(colcount, out=colptr[1:])
330
+ edge_type = Subgraph.rev_edge_type(edge_type)
331
+ row_dict[edge_type] = row
332
+ colptr_dict[edge_type] = colptr
333
+
334
+ return row_dict, colptr_dict
335
+
336
+
337
+ def _argsort(input: np.ndarray) -> np.ndarray:
338
+ if not WITH_TORCH:
339
+ return np.argsort(input)
340
+ return torch.from_numpy(input).argsort().numpy()
341
+
342
+
343
+ def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
344
+ assert len(inputs) >= 1
345
+
346
+ if not WITH_TORCH:
347
+ return np.lexsort(inputs)
348
+
349
+ try:
350
+ out = torch.from_numpy(inputs[0]).argsort(stable=True)
351
+ except Exception:
352
+ return np.lexsort(inputs) # PyTorch<1.9 without `stable` support.
353
+
354
+ for input in inputs[1:]:
355
+ index = torch.from_numpy(input)[out]
356
+ index = index.argsort(stable=True)
357
+ out = out[index]
358
+
359
+ return out.numpy()