kumoai 2.13.0.dev202512040651__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kumoai/__init__.py CHANGED
@@ -280,7 +280,19 @@ __all__ = [
280
280
  ]
281
281
 
282
282
 
283
+ def in_snowflake_notebook() -> bool:
284
+ try:
285
+ from snowflake.snowpark.context import get_active_session
286
+ import streamlit # noqa: F401
287
+ get_active_session()
288
+ return True
289
+ except Exception:
290
+ return False
291
+
292
+
283
293
  def in_notebook() -> bool:
294
+ if in_snowflake_notebook():
295
+ return True
284
296
  try:
285
297
  from IPython import get_ipython
286
298
  shell = get_ipython()
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.13.0.dev202512040651'
1
+ __version__ = '2.14.0.dev202512111731'
@@ -32,7 +32,11 @@ Please create a feature request at 'https://github.com/kumo-ai/kumo-rfm'."""
32
32
  raise RuntimeError(_msg) from e
33
33
 
34
34
  from .table import LocalTable
35
+ from .graph_store import LocalGraphStore
36
+ from .sampler import LocalSampler
35
37
 
36
38
  __all__ = [
37
39
  'LocalTable',
40
+ 'LocalGraphStore',
41
+ 'LocalSampler',
38
42
  ]
@@ -1,12 +1,12 @@
1
1
  import warnings
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import Graph, LocalTable
9
+ from kumoai.experimental.rfm.backend.local import LocalTable
10
10
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
11
11
 
12
12
  try:
@@ -15,11 +15,14 @@ try:
15
15
  except ImportError:
16
16
  WITH_TORCH = False
17
17
 
18
+ if TYPE_CHECKING:
19
+ from kumoai.experimental.rfm import Graph
20
+
18
21
 
19
22
  class LocalGraphStore:
20
23
  def __init__(
21
24
  self,
22
- graph: Graph,
25
+ graph: 'Graph',
23
26
  verbose: Union[bool, ProgressLogger] = True,
24
27
  ) -> None:
25
28
 
@@ -31,26 +34,21 @@ class LocalGraphStore:
31
34
 
32
35
  with verbose as logger:
33
36
  self.df_dict, self.mask_dict = self.sanitize(graph)
34
- self.stype_dict = self.get_stype_dict(graph)
35
37
  logger.log("Sanitized input data")
36
38
 
37
- self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
39
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
38
40
  num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
39
41
  if num_pkeys > 1:
40
42
  logger.log(f"Collected primary keys from {num_pkeys} tables")
41
43
  else:
42
44
  logger.log(f"Collected primary key from {num_pkeys} table")
43
45
 
44
- (
45
- self.time_column_dict,
46
- self.end_time_column_dict,
47
- self.time_dict,
48
- self.min_time,
49
- self.max_time,
50
- ) = self.get_time_data(graph)
51
- if self.max_time != pd.Timestamp.min:
46
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
47
+ if len(self.min_max_time_dict) > 0:
48
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
49
+ max_time = max(t for _, t in self.min_max_time_dict.values())
52
50
  logger.log(f"Identified temporal graph from "
53
- f"{self.min_time.date()} to {self.max_time.date()}")
51
+ f"{min_time.date()} to {max_time.date()}")
54
52
  else:
55
53
  logger.log("Identified static graph without timestamps")
56
54
 
@@ -60,14 +58,6 @@ class LocalGraphStore:
60
58
  logger.log(f"Created graph with {num_nodes:,} nodes and "
61
59
  f"{num_edges:,} edges")
62
60
 
63
- @property
64
- def node_types(self) -> List[str]:
65
- return list(self.df_dict.keys())
66
-
67
- @property
68
- def edge_types(self) -> List[Tuple[str, str, str]]:
69
- return list(self.row_dict.keys())
70
-
71
61
  def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
72
62
  r"""Returns the node ID given primary keys.
73
63
 
@@ -103,7 +93,7 @@ class LocalGraphStore:
103
93
 
104
94
  def sanitize(
105
95
  self,
106
- graph: Graph,
96
+ graph: 'Graph',
107
97
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
108
98
  r"""Sanitizes raw data according to table schema definition:
109
99
 
@@ -151,34 +141,16 @@ class LocalGraphStore:
151
141
 
152
142
  return df_dict, mask_dict
153
143
 
154
- def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
155
- stype_dict: Dict[str, Dict[str, Stype]] = {}
156
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
157
- for table in graph.tables.values():
158
- stype_dict[table.name] = {}
159
- for column in table.columns:
160
- if column == table.primary_key:
161
- continue
162
- if (table.name, column.name) in foreign_keys:
163
- continue
164
- stype_dict[table.name][column.name] = column.stype
165
- return stype_dict
166
-
167
- def get_pkey_data(
144
+ def get_pkey_map_dict(
168
145
  self,
169
- graph: Graph,
170
- ) -> Tuple[
171
- Dict[str, str],
172
- Dict[str, pd.DataFrame],
173
- ]:
174
- pkey_name_dict: Dict[str, str] = {}
146
+ graph: 'Graph',
147
+ ) -> Dict[str, pd.DataFrame]:
175
148
  pkey_map_dict: Dict[str, pd.DataFrame] = {}
176
149
 
177
150
  for table in graph.tables.values():
178
151
  if table._primary_key is None:
179
152
  continue
180
153
 
181
- pkey_name_dict[table.name] = table._primary_key
182
154
  pkey = self.df_dict[table.name][table._primary_key]
183
155
  pkey_map = pd.DataFrame(
184
156
  dict(arange=range(len(pkey))),
@@ -200,52 +172,41 @@ class LocalGraphStore:
200
172
 
201
173
  pkey_map_dict[table.name] = pkey_map
202
174
 
203
- return pkey_name_dict, pkey_map_dict
175
+ return pkey_map_dict
204
176
 
205
177
  def get_time_data(
206
178
  self,
207
- graph: Graph,
179
+ graph: 'Graph',
208
180
  ) -> Tuple[
209
- Dict[str, str],
210
- Dict[str, str],
211
181
  Dict[str, np.ndarray],
212
- pd.Timestamp,
213
- pd.Timestamp,
182
+ Dict[str, Tuple[pd.Timestamp, pd.Timestamp]],
214
183
  ]:
215
- time_column_dict: Dict[str, str] = {}
216
- end_time_column_dict: Dict[str, str] = {}
217
184
  time_dict: Dict[str, np.ndarray] = {}
218
- min_time = pd.Timestamp.max
219
- max_time = pd.Timestamp.min
185
+ min_max_time_dict: Dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
220
186
  for table in graph.tables.values():
221
- if table._end_time_column is not None:
222
- end_time_column_dict[table.name] = table._end_time_column
223
-
224
187
  if table._time_column is None:
225
188
  continue
226
189
 
227
190
  time = self.df_dict[table.name][table._time_column]
228
- time_dict[table.name] = time.astype('datetime64[ns]').astype(
229
- int).to_numpy() // 1000**3
230
- time_column_dict[table.name] = table._time_column
191
+ if time.dtype != 'datetime64[ns]':
192
+ time = time.astype('datetime64[ns]')
193
+ time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
231
194
 
232
195
  if table.name in self.mask_dict.keys():
233
196
  time = time[self.mask_dict[table.name]]
234
197
  if len(time) > 0:
235
- min_time = min(min_time, time.min())
236
- max_time = max(max_time, time.max())
198
+ min_max_time_dict[table.name] = (time.min(), time.max())
199
+ else:
200
+ min_max_time_dict[table.name] = (
201
+ pd.Timestamp.max,
202
+ pd.Timestamp.min,
203
+ )
237
204
 
238
- return (
239
- time_column_dict,
240
- end_time_column_dict,
241
- time_dict,
242
- min_time,
243
- max_time,
244
- )
205
+ return time_dict, min_max_time_dict
245
206
 
246
207
  def get_csc(
247
208
  self,
248
- graph: Graph,
209
+ graph: 'Graph',
249
210
  ) -> Tuple[
250
211
  Dict[Tuple[str, str, str], np.ndarray],
251
212
  Dict[Tuple[str, str, str], np.ndarray],
@@ -0,0 +1,313 @@
1
+ from typing import TYPE_CHECKING, Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+
7
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
10
+ from kumoai.utils import ProgressLogger
11
+
12
+ if TYPE_CHECKING:
13
+ from kumoai.experimental.rfm import Graph
14
+
15
+
16
+ class LocalSampler(Sampler):
17
+ def __init__(
18
+ self,
19
+ graph: 'Graph',
20
+ verbose: bool | ProgressLogger = True,
21
+ ) -> None:
22
+ super().__init__(graph=graph)
23
+
24
+ import kumoai.kumolib as kumolib
25
+
26
+ self._graph_store = LocalGraphStore(graph, verbose)
27
+ self._graph_sampler = kumolib.NeighborSampler(
28
+ list(self.table_stype_dict.keys()),
29
+ self.edge_types,
30
+ {
31
+ '__'.join(edge_type): colptr
32
+ for edge_type, colptr in self._graph_store.colptr_dict.items()
33
+ },
34
+ {
35
+ '__'.join(edge_type): row
36
+ for edge_type, row in self._graph_store.row_dict.items()
37
+ },
38
+ self._graph_store.time_dict,
39
+ )
40
+
41
+ def _get_min_max_time_dict(
42
+ self,
43
+ table_names: list[str],
44
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
45
+ return {
46
+ key: value
47
+ for key, value in self._graph_store.min_max_time_dict.items()
48
+ if key in table_names
49
+ }
50
+
51
+ def _sample_subgraph(
52
+ self,
53
+ entity_table_name: str,
54
+ entity_pkey: pd.Series,
55
+ anchor_time: pd.Series | Literal['entity'],
56
+ columns_dict: dict[str, set[str]],
57
+ num_neighbors: list[int],
58
+ ) -> SamplerOutput:
59
+
60
+ index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
61
+
62
+ if isinstance(anchor_time, pd.Series):
63
+ time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
64
+ else:
65
+ assert anchor_time == 'entity'
66
+ time = self._graph_store.time_dict[entity_table_name][index]
67
+
68
+ (
69
+ row_dict,
70
+ col_dict,
71
+ node_dict,
72
+ batch_dict,
73
+ num_sampled_nodes_dict,
74
+ num_sampled_edges_dict,
75
+ ) = self._graph_sampler.sample(
76
+ {
77
+ '__'.join(edge_type): num_neighbors
78
+ for edge_type in self.edge_types
79
+ },
80
+ {},
81
+ entity_table_name,
82
+ index,
83
+ time,
84
+ )
85
+
86
+ df_dict: dict[str, pd.DataFrame] = {}
87
+ inverse_dict: dict[str, np.ndarray] = {}
88
+ for table_name, node in node_dict.items():
89
+ df = self._graph_store.df_dict[table_name]
90
+ columns = columns_dict[table_name]
91
+ if self.end_time_column_dict.get(table_name, None) in columns:
92
+ df = df.iloc[node]
93
+ elif len(columns) == 0:
94
+ df = df.iloc[node]
95
+ else:
96
+ # Only store unique rows in `df` above a certain threshold:
97
+ unique_node, inverse = np.unique(node, return_inverse=True)
98
+ if len(node) > 1.05 * len(unique_node):
99
+ df = df.iloc[unique_node]
100
+ inverse_dict[table_name] = inverse
101
+ else:
102
+ df = df.iloc[node]
103
+ df = df.reset_index(drop=True)
104
+ df = df[list(columns)]
105
+ df_dict[table_name] = df
106
+
107
+ num_sampled_nodes_dict = {
108
+ table_name: num_sampled_nodes.tolist()
109
+ for table_name, num_sampled_nodes in
110
+ num_sampled_nodes_dict.items()
111
+ }
112
+
113
+ row_dict = {
114
+ edge_type: row_dict['__'.join(edge_type)]
115
+ for edge_type in self.edge_types
116
+ }
117
+ col_dict = {
118
+ edge_type: col_dict['__'.join(edge_type)]
119
+ for edge_type in self.edge_types
120
+ }
121
+ num_sampled_edges_dict = {
122
+ edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
123
+ for edge_type in self.edge_types
124
+ }
125
+
126
+ return SamplerOutput(
127
+ anchor_time=time * 1000**3, # to nanoseconds
128
+ df_dict=df_dict,
129
+ inverse_dict=inverse_dict,
130
+ batch_dict=batch_dict,
131
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
132
+ row_dict=row_dict,
133
+ col_dict=col_dict,
134
+ num_sampled_edges_dict=num_sampled_edges_dict,
135
+ )
136
+
137
+ def _sample_entity_table(
138
+ self,
139
+ table_name: str,
140
+ columns: set[str],
141
+ num_rows: int,
142
+ random_seed: int | None = None,
143
+ ) -> pd.DataFrame:
144
+ pkey_map = self._graph_store.pkey_map_dict[table_name]
145
+ if len(pkey_map) > num_rows:
146
+ pkey_map = pkey_map.sample(
147
+ n=num_rows,
148
+ random_state=random_seed,
149
+ ignore_index=True,
150
+ )
151
+ df = self._graph_store.df_dict[table_name]
152
+ df = df.iloc[pkey_map['arange']][list(columns)]
153
+ return df
154
+
155
+ def _sample_target(
156
+ self,
157
+ query: ValidatedPredictiveQuery,
158
+ entity_df: pd.DataFrame,
159
+ train_index: np.ndarray,
160
+ train_time: pd.Series,
161
+ num_train_examples: int,
162
+ test_index: np.ndarray,
163
+ test_time: pd.Series,
164
+ num_test_examples: int,
165
+ columns_dict: dict[str, set[str]],
166
+ time_offset_dict: dict[
167
+ tuple[str, str, str],
168
+ tuple[pd.DateOffset | None, pd.DateOffset],
169
+ ],
170
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
171
+
172
+ train_y, train_mask = self._sample_target_set(
173
+ query=query,
174
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
175
+ index=train_index,
176
+ anchor_time=train_time,
177
+ num_examples=num_train_examples,
178
+ columns_dict=columns_dict,
179
+ time_offset_dict=time_offset_dict,
180
+ )
181
+
182
+ test_y, test_mask = self._sample_target_set(
183
+ query=query,
184
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
185
+ index=test_index,
186
+ anchor_time=test_time,
187
+ num_examples=num_test_examples,
188
+ columns_dict=columns_dict,
189
+ time_offset_dict=time_offset_dict,
190
+ )
191
+
192
+ return train_y, train_mask, test_y, test_mask
193
+
194
+ def _sample_target_set(
195
+ self,
196
+ query: ValidatedPredictiveQuery,
197
+ pkey: pd.Series,
198
+ index: np.ndarray,
199
+ anchor_time: pd.Series,
200
+ num_examples: int,
201
+ columns_dict: dict[str, set[str]],
202
+ time_offset_dict: dict[
203
+ tuple[str, str, str],
204
+ tuple[pd.DateOffset | None, pd.DateOffset],
205
+ ],
206
+ batch_size: int = 10_000,
207
+ ) -> tuple[pd.Series, np.ndarray]:
208
+
209
+ num_hops = 1 if len(time_offset_dict) > 0 else 0
210
+ num_neighbors_dict: dict[str, list[int]] = {}
211
+ unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
212
+ for edge_type, (start, end) in time_offset_dict.items():
213
+ unix_time_offset_dict['__'.join(edge_type)] = [[
214
+ date_offset_to_seconds(start) if start is not None else None,
215
+ date_offset_to_seconds(end),
216
+ ]]
217
+ for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
218
+ num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
219
+
220
+ if anchor_time.dtype != 'datetime64[ns]':
221
+ anchor_time = anchor_time.astype('datetime64')
222
+
223
+ count = 0
224
+ ys: list[pd.Series] = []
225
+ mask = np.full(len(index), False, dtype=bool)
226
+ for start in range(0, len(index), batch_size):
227
+ subset = pkey.iloc[index[start:start + batch_size]]
228
+ time = anchor_time.iloc[start:start + batch_size]
229
+
230
+ _, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
231
+ num_neighbors_dict,
232
+ unix_time_offset_dict,
233
+ query.entity_table,
234
+ self._graph_store.get_node_id(query.entity_table, subset),
235
+ time.astype(int).to_numpy() // 1000**3, # to seconds
236
+ )
237
+
238
+ feat_dict: dict[str, pd.DataFrame] = {}
239
+ time_dict: dict[str, pd.Series] = {}
240
+ for table_name, columns in columns_dict.items():
241
+ df = self._graph_store.df_dict[table_name]
242
+ df = df.iloc[node_dict[table_name]].reset_index(drop=True)
243
+ df = df[list(columns)]
244
+ feat_dict[table_name] = df
245
+
246
+ time_column = self.time_column_dict.get(table_name)
247
+ if time_column in columns:
248
+ time_dict[table_name] = df[time_column]
249
+
250
+ y, _mask = PQueryPandasExecutor().execute(
251
+ query=query,
252
+ feat_dict=feat_dict,
253
+ time_dict=time_dict,
254
+ batch_dict=batch_dict,
255
+ anchor_time=time,
256
+ num_forecasts=query.num_forecasts,
257
+ )
258
+ ys.append(y)
259
+ mask[start:start + batch_size] = _mask
260
+
261
+ count += len(y)
262
+ if count >= num_examples:
263
+ break
264
+
265
+ if len(ys) == 0:
266
+ y = pd.Series([], dtype=float)
267
+ elif len(ys) == 1:
268
+ y = ys[0]
269
+ else:
270
+ y = pd.concat(ys, axis=0, ignore_index=True)
271
+
272
+ return y, mask
273
+
274
+
275
+ # Helper Methods ##############################################################
276
+
277
+
278
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
279
+ r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
280
+
281
+ .. note::
282
+ We are conservative and take months and years as their maximum value.
283
+ Additional values are then dropped in label computation where we know
284
+ the actual dates.
285
+ """
286
+ MAX_DAYS_IN_MONTH = 31
287
+ MAX_DAYS_IN_YEAR = 366
288
+
289
+ SECONDS_IN_MINUTE = 60
290
+ SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
291
+ SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
292
+
293
+ total_sec = 0
294
+ multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
295
+
296
+ for attr, value in offset.__dict__.items():
297
+ if value is None or value == 0:
298
+ continue
299
+ scaled_value = value * multiplier
300
+ if attr == 'years':
301
+ total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
302
+ elif attr == 'months':
303
+ total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
304
+ elif attr == 'days':
305
+ total_sec += scaled_value * SECONDS_IN_DAY
306
+ elif attr == 'hours':
307
+ total_sec += scaled_value * SECONDS_IN_HOUR
308
+ elif attr == 'minutes':
309
+ total_sec += scaled_value * SECONDS_IN_MINUTE
310
+ elif attr == 'seconds':
311
+ total_sec += scaled_value
312
+
313
+ return total_sec
@@ -4,7 +4,7 @@ from typing import List, Optional, Sequence
4
4
  import pandas as pd
5
5
  from kumoapi.typing import Dtype
6
6
 
7
- from kumoai.experimental.rfm.backend.sqlite import Connection
7
+ from kumoai.experimental.rfm.backend.snow import Connection
8
8
  from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
9
9
 
10
10
 
@@ -1,10 +1,13 @@
1
1
  from .source import SourceColumn, SourceForeignKey
2
2
  from .column import Column
3
3
  from .table import Table
4
+ from .sampler import SamplerOutput, Sampler
4
5
 
5
6
  __all__ = [
6
7
  'SourceColumn',
7
8
  'SourceForeignKey',
8
9
  'Column',
9
10
  'Table',
11
+ 'SamplerOutput',
12
+ 'Sampler',
10
13
  ]