kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kumoai/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import os
2
3
  import sys
3
4
  import threading
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
68
69
  if self._url is None or (self._api_key is None
69
70
  and self._spcs_token is None
70
71
  and self._snowpark_session is None):
71
- raise ValueError(
72
- "Client creation or authentication failed; please re-create "
73
- "your client before proceeding.")
72
+ raise ValueError("Client creation or authentication failed. "
73
+ "Please re-create your client before proceeding.")
74
74
 
75
75
  if hasattr(self.thread_local, '_client'):
76
76
  # Set the spcs token in the client to ensure it has the latest.
@@ -123,10 +123,9 @@ def init(
123
123
  """ # noqa
124
124
  # Avoid mutations to the global state after it is set:
125
125
  if global_state.initialized:
126
- print(
127
- "Client has already been created. To re-initialize Kumo, please "
128
- "start a new interpreter. No changes will be made to the current "
129
- "session.")
126
+ warnings.warn("Kumo SDK already initialized. To re-initialize the "
127
+ "SDK, please start a new interpreter. No changes will "
128
+ "be made to the current session.")
130
129
  return
131
130
 
132
131
  set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
@@ -138,15 +137,15 @@ def init(
138
137
  if snowflake_application:
139
138
  if url is not None:
140
139
  raise ValueError(
141
- "Client creation failed: both snowflake_application and url "
142
- "are specified. If running from a snowflake notebook, specify"
143
- "only snowflake_application.")
140
+ "Kumo SDK initialization failed. Both 'snowflake_application' "
141
+ "and 'url' are specified. If running from a Snowflake "
142
+ "notebook, specify only 'snowflake_application'.")
144
143
  snowpark_session = _get_active_session()
145
144
  if not snowpark_session:
146
145
  raise ValueError(
147
- "Client creation failed: snowflake_application is specified "
148
- "without an active snowpark session. If running outside "
149
- "a snowflake notebook, specify a URL and credentials.")
146
+ "Kumo SDK initialization failed. 'snowflake_application' is "
147
+ "specified without an active Snowpark session. If running "
148
+ "outside a Snowflake notebook, specify a URL and credentials.")
150
149
  description = snowpark_session.sql(
151
150
  f"DESCRIBE SERVICE {snowflake_application}."
152
151
  "USER_SCHEMA.KUMO_SERVICE").collect()[0]
@@ -155,14 +154,14 @@ def init(
155
154
  if api_key is None and not snowflake_application:
156
155
  if snowflake_credentials is None:
157
156
  raise ValueError(
158
- "Client creation failed: Neither API key nor snowflake "
159
- "credentials provided. Please either set the 'KUMO_API_KEY' "
160
- "or explicitly call `kumoai.init(...)`.")
157
+ "Kumo SDK initialization failed. Neither an API key nor "
158
+ "Snowflake credentials provided. Please either set the "
159
+ "'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
161
160
  if (set(snowflake_credentials.keys())
162
161
  != {'user', 'password', 'account'}):
163
162
  raise ValueError(
164
- f"Provided credentials should be a dictionary with keys "
165
- f"'user', 'password', and 'account'. Only "
163
+ f"Provided Snowflake credentials should be a dictionary with "
164
+ f"keys 'user', 'password', and 'account'. Only "
166
165
  f"{set(snowflake_credentials.keys())} were provided.")
167
166
 
168
167
  # Get or infer URL:
@@ -173,10 +172,10 @@ def init(
173
172
  except KeyError:
174
173
  pass
175
174
  if url is None:
176
- raise ValueError(
177
- "Client creation failed: endpoint URL not provided. Please "
178
- "either set the 'KUMO_API_ENDPOINT' environment variable or "
179
- "explicitly call `kumoai.init(...)`.")
175
+ raise ValueError("Kumo SDK initialization failed since no endpoint "
176
+ "URL was provided. Please either set the "
177
+ "'KUMO_API_ENDPOINT' environment variable or "
178
+ "explicitly call `kumoai.init(...)`.")
180
179
 
181
180
  # Assign global state after verification that client can be created and
182
181
  # authenticated successfully:
@@ -198,10 +197,8 @@ def init(
198
197
  logger = logging.getLogger('kumoai')
199
198
  log_level = logging.getLevelName(logger.getEffectiveLevel())
200
199
 
201
- logger.info(
202
- f"Successfully initialized the Kumo SDK (version {__version__}) "
203
- f"against deployment {url}, with "
204
- f"log level {log_level}.")
200
+ logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
201
+ f"'{url}'")
205
202
 
206
203
 
207
204
  def set_log_level(level: str) -> None:
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.14.0.dev202512271732'
1
+ __version__ = '2.14.0rc2'
kumoai/client/jobs.py CHANGED
@@ -344,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
344
344
  id: str,
345
345
  source_table_type: SourceTableType,
346
346
  train_table_mod: TrainingTableSpec,
347
+ extensive_validation: bool,
347
348
  ) -> ValidationResponse:
348
349
  response = self._client._post(
349
350
  f'{self._base_endpoint}/{id}/validate_custom_train_table',
350
351
  json=to_json_dict({
351
352
  'custom_table': source_table_type,
352
353
  'train_table_mod': train_table_mod,
354
+ 'extensive_validation': extensive_validation,
353
355
  }),
354
356
  )
355
357
  return parse_response(ValidationResponse, response)
kumoai/connector/utils.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import asyncio
2
2
  import csv
3
- import gc
4
3
  import io
5
4
  import math
6
5
  import os
7
6
  import re
7
+ import sys
8
8
  import tempfile
9
9
  import threading
10
10
  import time
@@ -920,7 +920,10 @@ def _read_remote_file_with_progress(
920
920
  if capture_first_line and not seen_nl:
921
921
  header_line = bytes(header_acc)
922
922
 
923
- mv = buf.getbuffer() # zero-copy view of BytesIO internal buffer
923
+ if sys.version_info >= (3, 13):
924
+ mv = memoryview(buf.getvalue())
925
+ else:
926
+ mv = buf.getbuffer() # zero-copy view of BytesIO internal buffer
924
927
  return buf, mv, header_line
925
928
 
926
929
 
@@ -999,7 +1002,10 @@ def _iter_mv_chunks(mv: memoryview,
999
1002
  n = mv.nbytes
1000
1003
  while pos < n:
1001
1004
  nxt = min(n, pos + part_size)
1002
- yield mv[pos:nxt] # zero-copy slice
1005
+ if sys.version_info >= (3, 13):
1006
+ yield mv[pos:nxt].tobytes()
1007
+ else:
1008
+ yield mv[pos:nxt] # zero-copy slice
1003
1009
  pos = nxt
1004
1010
 
1005
1011
 
@@ -1473,13 +1479,17 @@ def _remote_upload_file(name: str, fs: Filesystem, url: str, info: dict,
1473
1479
  if renamed_cols_msg:
1474
1480
  logger.info(renamed_cols_msg)
1475
1481
 
1482
+ try:
1483
+ if isinstance(data_mv, memoryview):
1484
+ data_mv.release()
1485
+ except Exception:
1486
+ pass
1487
+
1476
1488
  try:
1477
1489
  if buf:
1478
1490
  buf.close()
1479
1491
  except Exception:
1480
1492
  pass
1481
- del buf, data_mv, header_line
1482
- gc.collect()
1483
1493
 
1484
1494
  logger.info("Upload complete. Validated table %s.", name)
1485
1495
 
@@ -1719,13 +1729,17 @@ def _remote_upload_directory(
1719
1729
  else:
1720
1730
  break
1721
1731
 
1732
+ try:
1733
+ if isinstance(data_mv, memoryview):
1734
+ data_mv.release()
1735
+ except Exception:
1736
+ pass
1737
+
1722
1738
  try:
1723
1739
  if buf:
1724
1740
  buf.close()
1725
1741
  except Exception:
1726
1742
  pass
1727
- del buf, data_mv, header_line
1728
- gc.collect()
1729
1743
 
1730
1744
  _safe_bar_update(file_bar, 1)
1731
1745
  _merge_status_update(fpath)
@@ -20,6 +20,7 @@ from .sagemaker import (
20
20
  from .base import Table
21
21
  from .backend.local import LocalTable
22
22
  from .graph import Graph
23
+ from .task_table import TaskTable
23
24
  from .rfm import ExplainConfig, Explanation, KumoRFM
24
25
 
25
26
  logger = logging.getLogger('kumoai_rfm')
@@ -78,9 +79,9 @@ def _get_snowflake_url(snowflake_application: str) -> str:
78
79
  snowpark_session = _get_active_session()
79
80
  if not snowpark_session:
80
81
  raise ValueError(
81
- "Client creation failed: snowflake_application is specified "
82
- "without an active snowpark session. If running outside "
83
- "a snowflake notebook, specify a URL and credentials.")
82
+ "KumoRFM initialization failed. 'snowflake_application' is "
83
+ "specified without an active Snowpark session. If running outside "
84
+ "a Snowflake notebook, specify a URL and credentials.")
84
85
  with snowpark_session.connection.cursor() as cur:
85
86
  cur.execute(
86
87
  f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
@@ -103,6 +104,9 @@ class RfmGlobalState:
103
104
 
104
105
  @property
105
106
  def client(self) -> KumoClient:
107
+ if self._backend == InferenceBackend.UNKNOWN:
108
+ raise RuntimeError("KumoRFM is not yet initialized")
109
+
106
110
  if self._backend == InferenceBackend.REST:
107
111
  return kumoai.global_state.client
108
112
 
@@ -146,18 +150,19 @@ def init(
146
150
  with global_state._lock:
147
151
  if global_state._initialized:
148
152
  if url != global_state._url:
149
- raise ValueError(
150
- "Kumo RFM has already been initialized with a different "
151
- "URL. Re-initialization with a different URL is not "
153
+ raise RuntimeError(
154
+ "KumoRFM has already been initialized with a different "
155
+ "API URL. Re-initialization with a different URL is not "
152
156
  "supported.")
153
157
  return
154
158
 
155
159
  if snowflake_application:
156
160
  if url is not None:
157
161
  raise ValueError(
158
- "Client creation failed: both snowflake_application and "
159
- "url are specified. If running from a snowflake notebook, "
160
- "specify only snowflake_application.")
162
+ "KumoRFM initialization failed. Both "
163
+ "'snowflake_application' and 'url' are specified. If "
164
+ "running from a Snowflake notebook, specify only "
165
+ "'snowflake_application'.")
161
166
  url = _get_snowflake_url(snowflake_application)
162
167
  api_key = "test:DISABLED"
163
168
 
@@ -166,32 +171,28 @@ def init(
166
171
 
167
172
  backend, region, endpoint_name = _detect_backend(url)
168
173
  if backend == InferenceBackend.REST:
169
- # Initialize kumoai.global_state
170
- if (kumoai.global_state.initialized
171
- and kumoai.global_state._url != url):
172
- raise ValueError(
173
- "Kumo AI SDK has already been initialized with different "
174
- "API URL. Please restart Python interpreter and "
175
- "initialize via kumoai.rfm.init()")
176
- kumoai.init(url=url, api_key=api_key,
177
- snowflake_credentials=snowflake_credentials,
178
- snowflake_application=snowflake_application,
179
- log_level=log_level)
174
+ kumoai.init(
175
+ url=url,
176
+ api_key=api_key,
177
+ snowflake_credentials=snowflake_credentials,
178
+ snowflake_application=snowflake_application,
179
+ log_level=log_level,
180
+ )
180
181
  elif backend == InferenceBackend.AWS_SAGEMAKER:
181
182
  assert region
182
183
  assert endpoint_name
183
184
  KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
185
+ logger.info("KumoRFM initialized in AWS SageMaker")
184
186
  else:
185
187
  assert backend == InferenceBackend.LOCAL_SAGEMAKER
186
188
  KumoClient_SageMakerProxy_Local(url).authenticate()
189
+ logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
187
190
 
188
191
  global_state._url = url
189
192
  global_state._backend = backend
190
193
  global_state._region = region
191
194
  global_state._endpoint_name = endpoint_name
192
195
  global_state._initialized = True
193
- logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
194
- url)
195
196
 
196
197
 
197
198
  LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
@@ -202,6 +203,7 @@ __all__ = [
202
203
  'Table',
203
204
  'LocalTable',
204
205
  'Graph',
206
+ 'TaskTable',
205
207
  'KumoRFM',
206
208
  'ExplainConfig',
207
209
  'Explanation',
@@ -144,11 +144,11 @@ class SnowSampler(SQLSampler):
144
144
  query.entity_table: np.arange(len(entity_df)),
145
145
  }
146
146
  for edge_type, (min_offset, max_offset) in time_offset_dict.items():
147
- table_name, fkey, _ = edge_type
147
+ table_name, foreign_key, _ = edge_type
148
148
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
149
149
  table_name=table_name,
150
- fkey=fkey,
151
- pkey=entity_df[self.primary_key_dict[query.entity_table]],
150
+ foreign_key=foreign_key,
151
+ index=entity_df[self.primary_key_dict[query.entity_table]],
152
152
  anchor_time=time,
153
153
  min_offset=min_offset,
154
154
  max_offset=max_offset,
@@ -179,7 +179,7 @@ class SnowSampler(SQLSampler):
179
179
  def _by_pkey(
180
180
  self,
181
181
  table_name: str,
182
- pkey: pd.Series,
182
+ index: pd.Series,
183
183
  columns: set[str],
184
184
  ) -> tuple[pd.DataFrame, np.ndarray]:
185
185
  key = self.primary_key_dict[table_name]
@@ -189,7 +189,7 @@ class SnowSampler(SQLSampler):
189
189
  for column in columns
190
190
  ]
191
191
 
192
- payload = json.dumps(list(pkey))
192
+ payload = json.dumps(list(index))
193
193
 
194
194
  sql = ("WITH TMP as (\n"
195
195
  " SELECT\n"
@@ -206,7 +206,7 @@ class SnowSampler(SQLSampler):
206
206
  f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
207
  f"{', '.join(projections)}\n"
208
208
  f"FROM TMP\n"
209
- f"JOIN {self.source_name_dict[table_name]} ENT\n"
209
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
210
  f" ON {key_ref} = TMP.__KUMO_ID__")
211
211
 
212
212
  with paramstyle(self._connection), self._connection.cursor() as cursor:
@@ -228,13 +228,82 @@ class SnowSampler(SQLSampler):
228
228
  stype_dict=self.table_stype_dict[table_name],
229
229
  ), batch
230
230
 
231
+ def _by_fkey(
232
+ self,
233
+ table_name: str,
234
+ foreign_key: str,
235
+ index: pd.Series,
236
+ num_neighbors: int,
237
+ anchor_time: pd.Series | None,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ time_column = self.time_column_dict.get(table_name)
241
+
242
+ if time_column is not None and anchor_time is not None:
243
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
244
+ payload = json.dumps(list(zip(index, anchor_time)))
245
+ else:
246
+ payload = json.dumps(list(zip(index)))
247
+
248
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
249
+ projections = [
250
+ self.table_column_proj_dict[table_name][column]
251
+ for column in columns
252
+ ]
253
+
254
+ sql = ("WITH TMP as (\n"
255
+ " SELECT\n"
256
+ " f.index as __KUMO_BATCH__,\n")
257
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
258
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
259
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
260
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
261
+ else:
262
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
263
+ if time_column is not None and anchor_time is not None:
264
+ sql += (",\n"
265
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
266
+ sql += (f"\n"
267
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
268
+ f")\n"
269
+ f"SELECT "
270
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
271
+ f"{', '.join(projections)}\n"
272
+ f"FROM TMP\n"
273
+ f"JOIN {self.source_name_dict[table_name]}\n"
274
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
275
+ if time_column is not None and anchor_time is not None:
276
+ time_ref = self.table_column_ref_dict[table_name][time_column]
277
+ sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
278
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
279
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
280
+ if time_column is not None:
281
+ sql += f" ORDER BY {time_ref} DESC\n"
282
+ else:
283
+ sql += f" ORDER BY {key_ref}\n"
284
+ sql += f") <= {num_neighbors}"
285
+
286
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
287
+ cursor.execute(sql, (payload, ))
288
+ table = cursor.fetch_arrow_all()
289
+
290
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
291
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
292
+ table = table.remove_column(batch_index)
293
+
294
+ return Table._sanitize(
295
+ df=table.to_pandas(),
296
+ dtype_dict=self.table_dtype_dict[table_name],
297
+ stype_dict=self.table_stype_dict[table_name],
298
+ ), batch
299
+
231
300
  # Helper Methods ##########################################################
232
301
 
233
302
  def _by_time(
234
303
  self,
235
304
  table_name: str,
236
- fkey: str,
237
- pkey: pd.Series,
305
+ foreign_key: str,
306
+ index: pd.Series,
238
307
  anchor_time: pd.Series,
239
308
  min_offset: pd.DateOffset | None,
240
309
  max_offset: pd.DateOffset,
@@ -247,11 +316,11 @@ class SnowSampler(SQLSampler):
247
316
  if min_offset is not None:
248
317
  start_time = anchor_time + min_offset
249
318
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
250
- payload = json.dumps(list(zip(pkey, end_time, start_time)))
319
+ payload = json.dumps(list(zip(index, end_time, start_time)))
251
320
  else:
252
- payload = json.dumps(list(zip(pkey, end_time)))
321
+ payload = json.dumps(list(zip(index, end_time)))
253
322
 
254
- key_ref = self.table_column_ref_dict[table_name][fkey]
323
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
255
324
  time_ref = self.table_column_ref_dict[table_name][time_column]
256
325
  projections = [
257
326
  self.table_column_proj_dict[table_name][column]
@@ -260,9 +329,9 @@ class SnowSampler(SQLSampler):
260
329
  sql = ("WITH TMP as (\n"
261
330
  " SELECT\n"
262
331
  " f.index as __KUMO_BATCH__,\n")
263
- if self.table_dtype_dict[table_name][fkey].is_int():
332
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
264
333
  sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
265
- elif self.table_dtype_dict[table_name][fkey].is_float():
334
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
266
335
  sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
267
336
  else:
268
337
  sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
@@ -276,7 +345,7 @@ class SnowSampler(SQLSampler):
276
345
  f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
277
346
  f"{', '.join(projections)}\n"
278
347
  f"FROM TMP\n"
279
- f"JOIN {self.source_name_dict[table_name]} FACT\n"
348
+ f"JOIN {self.source_name_dict[table_name]}\n"
280
349
  f" ON {key_ref} = TMP.__KUMO_ID__\n"
281
350
  f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
282
351
  if min_offset is not None:
@@ -226,7 +226,7 @@ class SQLiteSampler(SQLSampler):
226
226
  def _by_pkey(
227
227
  self,
228
228
  table_name: str,
229
- pkey: pd.Series,
229
+ index: pd.Series,
230
230
  columns: set[str],
231
231
  ) -> tuple[pd.DataFrame, np.ndarray]:
232
232
  source_table = self.source_table_dict[table_name]
@@ -237,7 +237,7 @@ class SQLiteSampler(SQLSampler):
237
237
  for column in columns
238
238
  ]
239
239
 
240
- tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
240
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
241
241
  tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
242
242
 
243
243
  sql = (f"SELECT "
@@ -245,7 +245,6 @@ class SQLiteSampler(SQLSampler):
245
245
  f"{', '.join(projections)}\n"
246
246
  f"FROM {quote_ident(tmp_name)} tmp\n"
247
247
  f"JOIN {self.source_name_dict[table_name]} ent\n")
248
-
249
248
  if key in source_table and source_table[key].is_unique_key:
250
249
  sql += (f" ON {key_ref} = tmp.__kumo_id__")
251
250
  else:
@@ -271,13 +270,70 @@ class SQLiteSampler(SQLSampler):
271
270
  stype_dict=self.table_stype_dict[table_name],
272
271
  ), batch
273
272
 
273
+ def _by_fkey(
274
+ self,
275
+ table_name: str,
276
+ foreign_key: str,
277
+ index: pd.Series,
278
+ num_neighbors: int,
279
+ anchor_time: pd.Series | None,
280
+ columns: set[str],
281
+ ) -> tuple[pd.DataFrame, np.ndarray]:
282
+ time_column = self.time_column_dict.get(table_name)
283
+
284
+ # NOTE SQLite does not have a native datetime format. Currently, we
285
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
286
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
287
+ if time_column is not None and anchor_time is not None:
288
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
289
+ tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
290
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
291
+
292
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
293
+ projections = [
294
+ self.table_column_proj_dict[table_name][column]
295
+ for column in columns
296
+ ]
297
+ sql = (f"SELECT "
298
+ f"tmp.rowid - 1 as __kumo_batch__, "
299
+ f"{', '.join(projections)}\n"
300
+ f"FROM {quote_ident(tmp_name)} tmp\n"
301
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
302
+ f"ON fact.rowid IN (\n"
303
+ f" SELECT rowid\n"
304
+ f" FROM {self.source_name_dict[table_name]}\n"
305
+ f" WHERE {key_ref} = tmp.__kumo_id__\n")
306
+ if time_column is not None and anchor_time is not None:
307
+ time_ref = self.table_column_ref_dict[table_name][time_column]
308
+ sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
309
+ if time_column is not None:
310
+ time_ref = self.table_column_ref_dict[table_name][time_column]
311
+ sql += f" ORDER BY {time_ref} DESC\n"
312
+ sql += (f" LIMIT {num_neighbors}\n"
313
+ f")")
314
+
315
+ with self._connection.cursor() as cursor:
316
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
317
+ cursor.execute(sql)
318
+ table = cursor.fetch_arrow_table()
319
+
320
+ batch = table['__kumo_batch__'].to_numpy()
321
+ batch_index = table.schema.get_field_index('__kumo_batch__')
322
+ table = table.remove_column(batch_index)
323
+
324
+ return Table._sanitize(
325
+ df=table.to_pandas(),
326
+ dtype_dict=self.table_dtype_dict[table_name],
327
+ stype_dict=self.table_stype_dict[table_name],
328
+ ), batch
329
+
274
330
  # Helper Methods ##########################################################
275
331
 
276
332
  def _by_time(
277
333
  self,
278
334
  table_name: str,
279
- fkey: str,
280
- pkey: pd.Series,
335
+ foreign_key: str,
336
+ index: pd.Series,
281
337
  anchor_time: pd.Series,
282
338
  min_offset: pd.DateOffset | None,
283
339
  max_offset: pd.DateOffset,
@@ -287,7 +343,7 @@ class SQLiteSampler(SQLSampler):
287
343
 
288
344
  # NOTE SQLite does not have a native datetime format. Currently, we
289
345
  # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
290
- tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
346
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
291
347
  end_time = anchor_time + max_offset
292
348
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
293
349
  tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
@@ -295,9 +351,9 @@ class SQLiteSampler(SQLSampler):
295
351
  start_time = anchor_time + min_offset
296
352
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
297
353
  tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
298
- tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
354
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
299
355
 
300
- key_ref = self.table_column_ref_dict[table_name][fkey]
356
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
301
357
  time_ref = self.table_column_ref_dict[table_name][time_column]
302
358
  projections = [
303
359
  self.table_column_proj_dict[table_name][column]
@@ -307,7 +363,7 @@ class SQLiteSampler(SQLSampler):
307
363
  f"tmp.rowid - 1 as __kumo_batch__, "
308
364
  f"{', '.join(projections)}\n"
309
365
  f"FROM {quote_ident(tmp_name)} tmp\n"
310
- f"JOIN {self.source_name_dict[table_name]} fact\n"
366
+ f"JOIN {self.source_name_dict[table_name]}\n"
311
367
  f" ON {key_ref} = tmp.__kumo_id__\n"
312
368
  f" AND {time_ref} <= tmp.__kumo_end__")
313
369
  if min_offset is not None:
@@ -359,11 +415,11 @@ class SQLiteSampler(SQLSampler):
359
415
  query.entity_table: np.arange(len(df)),
360
416
  }
361
417
  for edge_type, (_min, _max) in time_offset_dict.items():
362
- table_name, fkey, _ = edge_type
418
+ table_name, foreign_key, _ = edge_type
363
419
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
364
420
  table_name=table_name,
365
- fkey=fkey,
366
- pkey=df[self.primary_key_dict[query.entity_table]],
421
+ foreign_key=foreign_key,
422
+ index=df[self.primary_key_dict[query.entity_table]],
367
423
  anchor_time=time,
368
424
  min_offset=_min,
369
425
  max_offset=_max,
@@ -0,0 +1,67 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ class Mapper:
6
+ r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
7
+
8
+ Args:
9
+ num_examples: The maximum number of examples to add/retrieve.
10
+ """
11
+ def __init__(self, num_examples: int):
12
+ self._pkey_dtype: pd.CategoricalDtype | None = None
13
+ self._indices: list[np.ndarray] = []
14
+ self._index_dtype: pd.CategoricalDtype | None = None
15
+ self._num_examples = num_examples
16
+
17
+ def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
18
+ r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
19
+
20
+ Args:
21
+ pkey: The primary keys.
22
+ batch: The batch vector.
23
+ """
24
+ if self._pkey_dtype is not None:
25
+ category = np.concatenate([
26
+ self._pkey_dtype.categories.values,
27
+ pkey,
28
+ ], axis=0)
29
+ category = pd.unique(category)
30
+ self._pkey_dtype = pd.CategoricalDtype(category)
31
+ elif pd.api.types.is_string_dtype(pkey):
32
+ category = pd.unique(pkey)
33
+ self._pkey_dtype = pd.CategoricalDtype(category)
34
+
35
+ if self._pkey_dtype is not None:
36
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
37
+ else:
38
+ index = pkey.to_numpy()
39
+ index = self._num_examples * index + batch
40
+ self._indices.append(index)
41
+ self._index_dtype = None
42
+
43
+ def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
44
+ r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
45
+
46
+ Returns ``-1`` for any pair not registered in the mapping.
47
+
48
+ Args:
49
+ pkey: The primary keys.
50
+ batch: The batch vector.
51
+ """
52
+ if len(self._indices) == 0:
53
+ return np.full(len(pkey), -1, dtype=np.int64)
54
+
55
+ if self._index_dtype is None: # Lazy build index:
56
+ category = pd.unique(np.concatenate(self._indices))
57
+ self._index_dtype = pd.CategoricalDtype(category)
58
+
59
+ if self._pkey_dtype is not None:
60
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
61
+ else:
62
+ index = pkey.to_numpy()
63
+ index = self._num_examples * index + batch
64
+
65
+ out = pd.Categorical(index, dtype=self._index_dtype).codes
66
+ out = out.astype('int64')
67
+ return out