kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +21 -16
  10. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +102 -48
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  16. kumoai/experimental/rfm/base/__init__.py +26 -3
  17. kumoai/experimental/rfm/base/column.py +14 -12
  18. kumoai/experimental/rfm/base/column_expression.py +50 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +1 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +229 -0
  23. kumoai/experimental/rfm/base/table.py +173 -138
  24. kumoai/experimental/rfm/graph.py +302 -108
  25. kumoai/experimental/rfm/infer/__init__.py +6 -4
  26. kumoai/experimental/rfm/infer/dtype.py +3 -3
  27. kumoai/experimental/rfm/infer/pkey.py +4 -2
  28. kumoai/experimental/rfm/infer/stype.py +35 -0
  29. kumoai/experimental/rfm/infer/time_col.py +1 -2
  30. kumoai/experimental/rfm/pquery/executor.py +27 -27
  31. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  32. kumoai/experimental/rfm/rfm.py +299 -230
  33. kumoai/experimental/rfm/sagemaker.py +4 -4
  34. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  35. kumoai/pquery/predictive_query.py +10 -6
  36. kumoai/testing/snow.py +50 -0
  37. kumoai/utils/__init__.py +3 -2
  38. kumoai/utils/progress_logger.py +178 -12
  39. kumoai/utils/sql.py +3 -0
  40. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
  41. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +44 -36
  42. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  43. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  44. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  45. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  46. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/__init__.py CHANGED
@@ -280,7 +280,19 @@ __all__ = [
280
280
  ]
281
281
 
282
282
 
283
+ def in_snowflake_notebook() -> bool:
284
+ try:
285
+ from snowflake.snowpark.context import get_active_session
286
+ import streamlit # noqa: F401
287
+ get_active_session()
288
+ return True
289
+ except Exception:
290
+ return False
291
+
292
+
283
293
  def in_notebook() -> bool:
294
+ if in_snowflake_notebook():
295
+ return True
284
296
  try:
285
297
  from IPython import get_ipython
286
298
  shell = get_ipython()
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.13.0.dev202512040649'
1
+ __version__ = '2.14.0.dev202512211732'
kumoai/client/pquery.py CHANGED
@@ -176,8 +176,12 @@ def filter_model_plan(
176
176
  # Undefined
177
177
  pass
178
178
 
179
- new_opt_fields.append((field.name, _type, default))
180
- new_opts.append(getattr(section, field.name))
179
+ # Forward compatibility - Remove any newly introduced arguments not
180
+ # returned yet by the backend:
181
+ value = getattr(section, field.name)
182
+ if value != MissingType.VALUE:
183
+ new_opt_fields.append((field.name, _type, default))
184
+ new_opts.append(value)
181
185
 
182
186
  Section = dataclass(
183
187
  config=dict(validate_assignment=True),
@@ -6,11 +6,11 @@ import socket
6
6
  import threading
7
7
  from dataclasses import dataclass
8
8
  from enum import Enum
9
- from typing import Dict, Optional, Tuple
10
9
  from urllib.parse import urlparse
11
10
 
12
11
  import kumoai
13
12
  from kumoai.client.client import KumoClient
13
+ from kumoai.spcs import _get_active_session
14
14
 
15
15
  from .authenticate import authenticate
16
16
  from .sagemaker import (
@@ -49,7 +49,8 @@ class InferenceBackend(str, Enum):
49
49
 
50
50
 
51
51
  def _detect_backend(
52
- url: str) -> Tuple[InferenceBackend, Optional[str], Optional[str]]:
52
+ url: str, #
53
+ ) -> tuple[InferenceBackend, str | None, str | None]:
53
54
  parsed = urlparse(url)
54
55
 
55
56
  # Remote SageMaker
@@ -73,12 +74,27 @@ def _detect_backend(
73
74
  return InferenceBackend.REST, None, None
74
75
 
75
76
 
77
+ def _get_snowflake_url(snowflake_application: str) -> str:
78
+ snowpark_session = _get_active_session()
79
+ if not snowpark_session:
80
+ raise ValueError(
81
+ "Client creation failed: snowflake_application is specified "
82
+ "without an active snowpark session. If running outside "
83
+ "a snowflake notebook, specify a URL and credentials.")
84
+ with snowpark_session.connection.cursor() as cur:
85
+ cur.execute(
86
+ f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
87
+ f" ->> SELECT \"dns_name\" from $1")
88
+ dns_name: str = cur.fetchone()[0]
89
+ return f"http://{dns_name}:8000/api"
90
+
91
+
76
92
  @dataclass
77
93
  class RfmGlobalState:
78
94
  _url: str = '__url_not_provided__'
79
95
  _backend: InferenceBackend = InferenceBackend.UNKNOWN
80
- _region: Optional[str] = None
81
- _endpoint_name: Optional[str] = None
96
+ _region: str | None = None
97
+ _endpoint_name: str | None = None
82
98
  _thread_local = threading.local()
83
99
 
84
100
  # Thread-safe init-once.
@@ -121,10 +137,10 @@ global_state = RfmGlobalState()
121
137
 
122
138
 
123
139
  def init(
124
- url: Optional[str] = None,
125
- api_key: Optional[str] = None,
126
- snowflake_credentials: Optional[Dict[str, str]] = None,
127
- snowflake_application: Optional[str] = None,
140
+ url: str | None = None,
141
+ api_key: str | None = None,
142
+ snowflake_credentials: dict[str, str] | None = None,
143
+ snowflake_application: str | None = None,
128
144
  log_level: str = "INFO",
129
145
  ) -> None:
130
146
  with global_state._lock:
@@ -136,6 +152,15 @@ def init(
136
152
  "supported.")
137
153
  return
138
154
 
155
+ if snowflake_application:
156
+ if url is not None:
157
+ raise ValueError(
158
+ "Client creation failed: both snowflake_application and "
159
+ "url are specified. If running from a snowflake notebook, "
160
+ "specify only snowflake_application.")
161
+ url = _get_snowflake_url(snowflake_application)
162
+ api_key = "test:DISABLED"
163
+
139
164
  if url is None:
140
165
  url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
141
166
 
@@ -2,12 +2,11 @@ import logging
2
2
  import os
3
3
  import platform
4
4
  from datetime import datetime
5
- from typing import Optional
6
5
 
7
6
  from kumoai import in_notebook
8
7
 
9
8
 
10
- def authenticate(api_url: Optional[str] = None) -> None:
9
+ def authenticate(api_url: str | None = None) -> None:
11
10
  """Authenticates the user and sets the Kumo API key for the SDK.
12
11
 
13
12
  This function detects the current environment and launches the appropriate
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
65
64
  import webbrowser
66
65
  from getpass import getpass
67
66
  from socketserver import TCPServer
68
- from typing import Any, Dict
67
+ from typing import Any
69
68
 
70
69
  logger = logging.getLogger('kumoai')
71
70
 
72
- token_status: Dict[str, Any] = {
71
+ token_status: dict[str, Any] = {
73
72
  'token': None,
74
73
  'token_name': None,
75
74
  'failed': False
@@ -32,7 +32,11 @@ Please create a feature request at 'https://github.com/kumo-ai/kumo-rfm'."""
32
32
  raise RuntimeError(_msg) from e
33
33
 
34
34
  from .table import LocalTable
35
+ from .graph_store import LocalGraphStore
36
+ from .sampler import LocalSampler
35
37
 
36
38
  __all__ = [
37
39
  'LocalTable',
40
+ 'LocalGraphStore',
41
+ 'LocalSampler',
38
42
  ]
@@ -1,13 +1,13 @@
1
1
  import warnings
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import Graph, LocalTable
10
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
9
+ from kumoai.experimental.rfm.backend.local import LocalTable
10
+ from kumoai.utils import ProgressLogger
11
11
 
12
12
  try:
13
13
  import torch
@@ -15,42 +15,40 @@ try:
15
15
  except ImportError:
16
16
  WITH_TORCH = False
17
17
 
18
+ if TYPE_CHECKING:
19
+ from kumoai.experimental.rfm import Graph
20
+
18
21
 
19
22
  class LocalGraphStore:
20
23
  def __init__(
21
24
  self,
22
- graph: Graph,
23
- verbose: Union[bool, ProgressLogger] = True,
25
+ graph: 'Graph',
26
+ verbose: bool | ProgressLogger = True,
24
27
  ) -> None:
25
28
 
26
29
  if not isinstance(verbose, ProgressLogger):
27
- verbose = InteractiveProgressLogger(
28
- "Materializing graph",
30
+ verbose = ProgressLogger.default(
31
+ msg="Materializing graph",
29
32
  verbose=verbose,
30
33
  )
31
34
 
32
35
  with verbose as logger:
33
36
  self.df_dict, self.mask_dict = self.sanitize(graph)
34
- self.stype_dict = self.get_stype_dict(graph)
35
37
  logger.log("Sanitized input data")
36
38
 
37
- self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
39
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
38
40
  num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
39
41
  if num_pkeys > 1:
40
42
  logger.log(f"Collected primary keys from {num_pkeys} tables")
41
43
  else:
42
44
  logger.log(f"Collected primary key from {num_pkeys} table")
43
45
 
44
- (
45
- self.time_column_dict,
46
- self.end_time_column_dict,
47
- self.time_dict,
48
- self.min_time,
49
- self.max_time,
50
- ) = self.get_time_data(graph)
51
- if self.max_time != pd.Timestamp.min:
46
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
47
+ if len(self.min_max_time_dict) > 0:
48
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
49
+ max_time = max(t for _, t in self.min_max_time_dict.values())
52
50
  logger.log(f"Identified temporal graph from "
53
- f"{self.min_time.date()} to {self.max_time.date()}")
51
+ f"{min_time.date()} to {max_time.date()}")
54
52
  else:
55
53
  logger.log("Identified static graph without timestamps")
56
54
 
@@ -60,14 +58,6 @@ class LocalGraphStore:
60
58
  logger.log(f"Created graph with {num_nodes:,} nodes and "
61
59
  f"{num_edges:,} edges")
62
60
 
63
- @property
64
- def node_types(self) -> List[str]:
65
- return list(self.df_dict.keys())
66
-
67
- @property
68
- def edge_types(self) -> List[Tuple[str, str, str]]:
69
- return list(self.row_dict.keys())
70
-
71
61
  def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
72
62
  r"""Returns the node ID given primary keys.
73
63
 
@@ -103,8 +93,8 @@ class LocalGraphStore:
103
93
 
104
94
  def sanitize(
105
95
  self,
106
- graph: Graph,
107
- ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
96
+ graph: 'Graph',
97
+ ) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
108
98
  r"""Sanitizes raw data according to table schema definition:
109
99
 
110
100
  In particular, it:
@@ -113,13 +103,13 @@ class LocalGraphStore:
113
103
  * drops duplicate primary keys
114
104
  * removes rows with missing primary keys or time values
115
105
  """
116
- df_dict: Dict[str, pd.DataFrame] = {}
106
+ df_dict: dict[str, pd.DataFrame] = {}
117
107
  for table_name, table in graph.tables.items():
118
108
  assert isinstance(table, LocalTable)
119
109
  df = table._data
120
110
  df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
121
111
 
122
- mask_dict: Dict[str, np.ndarray] = {}
112
+ mask_dict: dict[str, np.ndarray] = {}
123
113
  for table in graph.tables.values():
124
114
  for col in table.columns:
125
115
  if col.stype == Stype.timestamp:
@@ -136,7 +126,7 @@ class LocalGraphStore:
136
126
  ser = ser.dt.tz_localize(None)
137
127
  df_dict[table.name][col.name] = ser
138
128
 
139
- mask: Optional[np.ndarray] = None
129
+ mask: np.ndarray | None = None
140
130
  if table._time_column is not None:
141
131
  ser = df_dict[table.name][table._time_column]
142
132
  mask = ser.notna().to_numpy()
@@ -151,34 +141,16 @@ class LocalGraphStore:
151
141
 
152
142
  return df_dict, mask_dict
153
143
 
154
- def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
155
- stype_dict: Dict[str, Dict[str, Stype]] = {}
156
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
157
- for table in graph.tables.values():
158
- stype_dict[table.name] = {}
159
- for column in table.columns:
160
- if column == table.primary_key:
161
- continue
162
- if (table.name, column.name) in foreign_keys:
163
- continue
164
- stype_dict[table.name][column.name] = column.stype
165
- return stype_dict
166
-
167
- def get_pkey_data(
144
+ def get_pkey_map_dict(
168
145
  self,
169
- graph: Graph,
170
- ) -> Tuple[
171
- Dict[str, str],
172
- Dict[str, pd.DataFrame],
173
- ]:
174
- pkey_name_dict: Dict[str, str] = {}
175
- pkey_map_dict: Dict[str, pd.DataFrame] = {}
146
+ graph: 'Graph',
147
+ ) -> dict[str, pd.DataFrame]:
148
+ pkey_map_dict: dict[str, pd.DataFrame] = {}
176
149
 
177
150
  for table in graph.tables.values():
178
151
  if table._primary_key is None:
179
152
  continue
180
153
 
181
- pkey_name_dict[table.name] = table._primary_key
182
154
  pkey = self.df_dict[table.name][table._primary_key]
183
155
  pkey_map = pd.DataFrame(
184
156
  dict(arange=range(len(pkey))),
@@ -200,61 +172,50 @@ class LocalGraphStore:
200
172
 
201
173
  pkey_map_dict[table.name] = pkey_map
202
174
 
203
- return pkey_name_dict, pkey_map_dict
175
+ return pkey_map_dict
204
176
 
205
177
  def get_time_data(
206
178
  self,
207
- graph: Graph,
208
- ) -> Tuple[
209
- Dict[str, str],
210
- Dict[str, str],
211
- Dict[str, np.ndarray],
212
- pd.Timestamp,
213
- pd.Timestamp,
179
+ graph: 'Graph',
180
+ ) -> tuple[
181
+ dict[str, np.ndarray],
182
+ dict[str, tuple[pd.Timestamp, pd.Timestamp]],
214
183
  ]:
215
- time_column_dict: Dict[str, str] = {}
216
- end_time_column_dict: Dict[str, str] = {}
217
- time_dict: Dict[str, np.ndarray] = {}
218
- min_time = pd.Timestamp.max
219
- max_time = pd.Timestamp.min
184
+ time_dict: dict[str, np.ndarray] = {}
185
+ min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
220
186
  for table in graph.tables.values():
221
- if table._end_time_column is not None:
222
- end_time_column_dict[table.name] = table._end_time_column
223
-
224
187
  if table._time_column is None:
225
188
  continue
226
189
 
227
190
  time = self.df_dict[table.name][table._time_column]
228
- time_dict[table.name] = time.astype('datetime64[ns]').astype(
229
- int).to_numpy() // 1000**3
230
- time_column_dict[table.name] = table._time_column
191
+ if time.dtype != 'datetime64[ns]':
192
+ time = time.astype('datetime64[ns]')
193
+ time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
231
194
 
232
195
  if table.name in self.mask_dict.keys():
233
196
  time = time[self.mask_dict[table.name]]
234
197
  if len(time) > 0:
235
- min_time = min(min_time, time.min())
236
- max_time = max(max_time, time.max())
198
+ min_max_time_dict[table.name] = (time.min(), time.max())
199
+ else:
200
+ min_max_time_dict[table.name] = (
201
+ pd.Timestamp.max,
202
+ pd.Timestamp.min,
203
+ )
237
204
 
238
- return (
239
- time_column_dict,
240
- end_time_column_dict,
241
- time_dict,
242
- min_time,
243
- max_time,
244
- )
205
+ return time_dict, min_max_time_dict
245
206
 
246
207
  def get_csc(
247
208
  self,
248
- graph: Graph,
249
- ) -> Tuple[
250
- Dict[Tuple[str, str, str], np.ndarray],
251
- Dict[Tuple[str, str, str], np.ndarray],
209
+ graph: 'Graph',
210
+ ) -> tuple[
211
+ dict[tuple[str, str, str], np.ndarray],
212
+ dict[tuple[str, str, str], np.ndarray],
252
213
  ]:
253
214
  # A mapping from raw primary keys to node indices (0 to N-1):
254
- map_dict: Dict[str, pd.CategoricalDtype] = {}
215
+ map_dict: dict[str, pd.CategoricalDtype] = {}
255
216
  # A dictionary to manage offsets of node indices for invalid rows:
256
- offset_dict: Dict[str, np.ndarray] = {}
257
- for table_name in set(edge.dst_table for edge in graph.edges):
217
+ offset_dict: dict[str, np.ndarray] = {}
218
+ for table_name in {edge.dst_table for edge in graph.edges}:
258
219
  ser = self.df_dict[table_name][graph[table_name]._primary_key]
259
220
  if table_name in self.mask_dict.keys():
260
221
  mask = self.mask_dict[table_name]
@@ -263,8 +224,8 @@ class LocalGraphStore:
263
224
  map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
264
225
 
265
226
  # Build CSC graph representation:
266
- row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
267
- colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
227
+ row_dict: dict[tuple[str, str, str], np.ndarray] = {}
228
+ colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
268
229
  for src_table, fkey, dst_table in graph.edges:
269
230
  src_df = self.df_dict[src_table]
270
231
  dst_df = self.df_dict[dst_table]
@@ -326,7 +287,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
326
287
  return torch.from_numpy(input).argsort().numpy()
327
288
 
328
289
 
329
- def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
290
+ def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
330
291
  assert len(inputs) >= 1
331
292
 
332
293
  if not WITH_TORCH: