kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.14.0.dev202601051732'
1
+ __version__ = '2.15.0.dev202601141731'
kumoai/client/jobs.py CHANGED
@@ -344,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
344
344
  id: str,
345
345
  source_table_type: SourceTableType,
346
346
  train_table_mod: TrainingTableSpec,
347
+ extensive_validation: bool,
347
348
  ) -> ValidationResponse:
348
349
  response = self._client._post(
349
350
  f'{self._base_endpoint}/{id}/validate_custom_train_table',
350
351
  json=to_json_dict({
351
352
  'custom_table': source_table_type,
352
353
  'train_table_mod': train_table_mod,
354
+ 'extensive_validation': extensive_validation,
353
355
  }),
354
356
  )
355
357
  return parse_response(ValidationResponse, response)
@@ -1,7 +1,8 @@
1
1
  import json
2
+ import math
2
3
  from collections.abc import Iterator
3
4
  from contextlib import contextmanager
4
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, cast
5
6
 
6
7
  import numpy as np
7
8
  import pandas as pd
@@ -11,7 +12,7 @@ from kumoapi.pquery import ValidatedPredictiveQuery
11
12
  from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
13
  from kumoai.experimental.rfm.base import SQLSampler, Table
13
14
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
- from kumoai.utils import ProgressLogger
15
+ from kumoai.utils import ProgressLogger, quote_ident
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from kumoai.experimental.rfm import Graph
@@ -37,6 +38,15 @@ class SnowSampler(SQLSampler):
37
38
  assert isinstance(table, SnowTable)
38
39
  self._connection = table._connection
39
40
 
41
+ self._num_rows_dict: dict[str, int] = {
42
+ table.name: cast(int, table._num_rows)
43
+ for table in graph.tables.values()
44
+ }
45
+
46
+ @property
47
+ def num_rows_dict(self) -> dict[str, int]:
48
+ return self._num_rows_dict
49
+
40
50
  def _get_min_max_time_dict(
41
51
  self,
42
52
  table_names: list[str],
@@ -45,8 +55,9 @@ class SnowSampler(SQLSampler):
45
55
  for table_name in table_names:
46
56
  column = self.time_column_dict[table_name]
47
57
  column_ref = self.table_column_ref_dict[table_name][column]
58
+ ident = quote_ident(table_name, char="'")
48
59
  select = (f"SELECT\n"
49
- f" ? as table_name,\n"
60
+ f" {ident} as table_name,\n"
50
61
  f" MIN({column_ref}) as min_date,\n"
51
62
  f" MAX({column_ref}) as max_date\n"
52
63
  f"FROM {self.source_name_dict[table_name]}")
@@ -54,14 +65,13 @@ class SnowSampler(SQLSampler):
54
65
  sql = "\nUNION ALL\n".join(selects)
55
66
 
56
67
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
57
- with paramstyle(self._connection), self._connection.cursor() as cursor:
58
- cursor.execute(sql, table_names)
59
- rows = cursor.fetchall()
60
- for table_name, _min, _max in rows:
61
- out_dict[table_name] = (
62
- pd.Timestamp.max if _min is None else pd.Timestamp(_min),
63
- pd.Timestamp.min if _max is None else pd.Timestamp(_max),
64
- )
68
+ with self._connection.cursor() as cursor:
69
+ cursor.execute(sql)
70
+ for table_name, _min, _max in cursor.fetchall():
71
+ out_dict[table_name] = (
72
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
73
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
74
+ )
65
75
 
66
76
  return out_dict
67
77
 
@@ -144,11 +154,11 @@ class SnowSampler(SQLSampler):
144
154
  query.entity_table: np.arange(len(entity_df)),
145
155
  }
146
156
  for edge_type, (min_offset, max_offset) in time_offset_dict.items():
147
- table_name, fkey, _ = edge_type
157
+ table_name, foreign_key, _ = edge_type
148
158
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
149
159
  table_name=table_name,
150
- fkey=fkey,
151
- pkey=entity_df[self.primary_key_dict[query.entity_table]],
160
+ foreign_key=foreign_key,
161
+ index=entity_df[self.primary_key_dict[query.entity_table]],
152
162
  anchor_time=time,
153
163
  min_offset=min_offset,
154
164
  max_offset=max_offset,
@@ -179,7 +189,7 @@ class SnowSampler(SQLSampler):
179
189
  def _by_pkey(
180
190
  self,
181
191
  table_name: str,
182
- pkey: pd.Series,
192
+ index: pd.Series,
183
193
  columns: set[str],
184
194
  ) -> tuple[pd.DataFrame, np.ndarray]:
185
195
  key = self.primary_key_dict[table_name]
@@ -189,7 +199,7 @@ class SnowSampler(SQLSampler):
189
199
  for column in columns
190
200
  ]
191
201
 
192
- payload = json.dumps(list(pkey))
202
+ payload = json.dumps(list(index))
193
203
 
194
204
  sql = ("WITH TMP as (\n"
195
205
  " SELECT\n"
@@ -206,7 +216,7 @@ class SnowSampler(SQLSampler):
206
216
  f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
217
  f"{', '.join(projections)}\n"
208
218
  f"FROM TMP\n"
209
- f"JOIN {self.source_name_dict[table_name]} ENT\n"
219
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
220
  f" ON {key_ref} = TMP.__KUMO_ID__")
211
221
 
212
222
  with paramstyle(self._connection), self._connection.cursor() as cursor:
@@ -228,13 +238,108 @@ class SnowSampler(SQLSampler):
228
238
  stype_dict=self.table_stype_dict[table_name],
229
239
  ), batch
230
240
 
241
+ def _by_fkey(
242
+ self,
243
+ table_name: str,
244
+ foreign_key: str,
245
+ index: pd.Series,
246
+ num_neighbors: int,
247
+ anchor_time: pd.Series | None,
248
+ columns: set[str],
249
+ ) -> tuple[pd.DataFrame, np.ndarray]:
250
+ time_column = self.time_column_dict.get(table_name)
251
+
252
+ end_time: pd.Series | None = None
253
+ start_time: pd.Series | None = None
254
+ if time_column is not None and anchor_time is not None:
255
+ # In order to avoid a full table scan, we limit foreign key
256
+ # sampling to a certain time range, approximated by the number of
257
+ # rows, timestamp ranges and `num_neighbors` value.
258
+ # Downstream, this helps Snowflake to apply partition pruning:
259
+ dst_table_name = [
260
+ dst_table
261
+ for key, dst_table in self.foreign_key_dict[table_name]
262
+ if key == foreign_key
263
+ ][0]
264
+ num_facts = self.num_rows_dict[table_name]
265
+ num_entities = self.num_rows_dict[dst_table_name]
266
+ min_time = self.get_min_time([table_name])
267
+ max_time = self.get_max_time([table_name])
268
+ freq = num_facts / num_entities
269
+ freq = freq / max((max_time - min_time).total_seconds(), 1)
270
+ offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
271
+
272
+ end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
273
+ start_time = anchor_time - offset
274
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
275
+ payload = json.dumps(list(zip(index, end_time, start_time)))
276
+ else:
277
+ payload = json.dumps(list(zip(index)))
278
+
279
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
280
+ projections = [
281
+ self.table_column_proj_dict[table_name][column]
282
+ for column in columns
283
+ ]
284
+
285
+ sql = ("WITH TMP as (\n"
286
+ " SELECT\n"
287
+ " f.index as __KUMO_BATCH__,\n")
288
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
289
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
290
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
291
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
292
+ else:
293
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
294
+ if end_time is not None and start_time is not None:
295
+ sql += (",\n"
296
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
297
+ " f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
298
+ sql += (f"\n"
299
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
300
+ f")\n"
301
+ f"SELECT "
302
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
303
+ f"{', '.join(projections)}\n"
304
+ f"FROM TMP\n"
305
+ f"JOIN {self.source_name_dict[table_name]}\n"
306
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
307
+ if end_time is not None and start_time is not None:
308
+ assert time_column is not None
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
311
+ f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
312
+ f"WHERE {time_ref} <= '{end_time.max()}'\n"
313
+ f" AND {time_ref} > '{start_time.min()}'\n")
314
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
315
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
316
+ if time_column is not None:
317
+ sql += f" ORDER BY {time_ref} DESC\n"
318
+ else:
319
+ sql += f" ORDER BY {key_ref}\n"
320
+ sql += f") <= {num_neighbors}"
321
+
322
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
323
+ cursor.execute(sql, (payload, ))
324
+ table = cursor.fetch_arrow_all()
325
+
326
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
327
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
328
+ table = table.remove_column(batch_index)
329
+
330
+ return Table._sanitize(
331
+ df=table.to_pandas(),
332
+ dtype_dict=self.table_dtype_dict[table_name],
333
+ stype_dict=self.table_stype_dict[table_name],
334
+ ), batch
335
+
231
336
  # Helper Methods ##########################################################
232
337
 
233
338
  def _by_time(
234
339
  self,
235
340
  table_name: str,
236
- fkey: str,
237
- pkey: pd.Series,
341
+ foreign_key: str,
342
+ index: pd.Series,
238
343
  anchor_time: pd.Series,
239
344
  min_offset: pd.DateOffset | None,
240
345
  max_offset: pd.DateOffset,
@@ -244,14 +349,15 @@ class SnowSampler(SQLSampler):
244
349
 
245
350
  end_time = anchor_time + max_offset
246
351
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
352
+ start_time: pd.Series | None = None
247
353
  if min_offset is not None:
248
354
  start_time = anchor_time + min_offset
249
355
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
250
- payload = json.dumps(list(zip(pkey, end_time, start_time)))
356
+ payload = json.dumps(list(zip(index, end_time, start_time)))
251
357
  else:
252
- payload = json.dumps(list(zip(pkey, end_time)))
358
+ payload = json.dumps(list(zip(index, end_time)))
253
359
 
254
- key_ref = self.table_column_ref_dict[table_name][fkey]
360
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
255
361
  time_ref = self.table_column_ref_dict[table_name][time_column]
256
362
  projections = [
257
363
  self.table_column_proj_dict[table_name][column]
@@ -260,9 +366,9 @@ class SnowSampler(SQLSampler):
260
366
  sql = ("WITH TMP as (\n"
261
367
  " SELECT\n"
262
368
  " f.index as __KUMO_BATCH__,\n")
263
- if self.table_dtype_dict[table_name][fkey].is_int():
369
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
264
370
  sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
265
- elif self.table_dtype_dict[table_name][fkey].is_float():
371
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
266
372
  sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
267
373
  else:
268
374
  sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
@@ -276,11 +382,15 @@ class SnowSampler(SQLSampler):
276
382
  f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
277
383
  f"{', '.join(projections)}\n"
278
384
  f"FROM TMP\n"
279
- f"JOIN {self.source_name_dict[table_name]} FACT\n"
385
+ f"JOIN {self.source_name_dict[table_name]}\n"
280
386
  f" ON {key_ref} = TMP.__KUMO_ID__\n"
281
- f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
282
- if min_offset is not None:
283
- sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
387
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
388
+ if start_time is not None:
389
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
390
+ # Add global time bounds to enable partition pruning:
391
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
392
+ if start_time is not None:
393
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
284
394
 
285
395
  with paramstyle(self._connection), self._connection.cursor() as cursor:
286
396
  cursor.execute(sql, (payload, ))
@@ -76,21 +76,13 @@ class SnowTable(Table):
76
76
 
77
77
  @property
78
78
  def source_name(self) -> str:
79
- names: list[str] = []
80
- if self._database is not None:
81
- names.append(self._database)
82
- if self._schema is not None:
83
- names.append(self._schema)
84
- return '.'.join(names + [self._source_name])
79
+ names = [self._database, self._schema, self._source_name]
80
+ return '.'.join(names)
85
81
 
86
82
  @property
87
83
  def _quoted_source_name(self) -> str:
88
- names: list[str] = []
89
- if self._database is not None:
90
- names.append(quote_ident(self._database))
91
- if self._schema is not None:
92
- names.append(quote_ident(self._schema))
93
- return '.'.join(names + [quote_ident(self._source_name)])
84
+ names = [self._database, self._schema, self._source_name]
85
+ return '.'.join([quote_ident(name) for name in names])
94
86
 
95
87
  @property
96
88
  def backend(self) -> DataBackend:
@@ -159,7 +151,18 @@ class SnowTable(Table):
159
151
  )
160
152
 
161
153
  def _get_num_rows(self) -> int | None:
162
- return None
154
+ with self._connection.cursor() as cursor:
155
+ quoted_source_name = quote_ident(self._source_name, char="'")
156
+ sql = (f"SHOW TABLES LIKE {quoted_source_name} "
157
+ f"IN SCHEMA {quote_ident(self._database)}."
158
+ f"{quote_ident(self._schema)}")
159
+ cursor.execute(sql)
160
+ num_rows = cursor.fetchone()[7]
161
+
162
+ if num_rows == 0:
163
+ raise RuntimeError("Table '{self.source_name}' is empty")
164
+
165
+ return num_rows
163
166
 
164
167
  def _get_expr_sample_df(
165
168
  self,
@@ -121,8 +121,9 @@ class SQLiteSampler(SQLSampler):
121
121
  for table_name in table_names:
122
122
  column = self.time_column_dict[table_name]
123
123
  column_ref = self.table_column_ref_dict[table_name][column]
124
+ ident = quote_ident(table_name, char="'")
124
125
  select = (f"SELECT\n"
125
- f" ? as table_name,\n"
126
+ f" {ident} as table_name,\n"
126
127
  f" MIN({column_ref}) as min_date,\n"
127
128
  f" MAX({column_ref}) as max_date\n"
128
129
  f"FROM {self.source_name_dict[table_name]}")
@@ -131,12 +132,13 @@ class SQLiteSampler(SQLSampler):
131
132
 
132
133
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
133
134
  with self._connection.cursor() as cursor:
134
- cursor.execute(sql, table_names)
135
+ cursor.execute(sql)
135
136
  for table_name, _min, _max in cursor.fetchall():
136
137
  out_dict[table_name] = (
137
138
  pd.Timestamp.max if _min is None else pd.Timestamp(_min),
138
139
  pd.Timestamp.min if _max is None else pd.Timestamp(_max),
139
140
  )
141
+
140
142
  return out_dict
141
143
 
142
144
  def _sample_entity_table(
@@ -226,7 +228,7 @@ class SQLiteSampler(SQLSampler):
226
228
  def _by_pkey(
227
229
  self,
228
230
  table_name: str,
229
- pkey: pd.Series,
231
+ index: pd.Series,
230
232
  columns: set[str],
231
233
  ) -> tuple[pd.DataFrame, np.ndarray]:
232
234
  source_table = self.source_table_dict[table_name]
@@ -237,7 +239,7 @@ class SQLiteSampler(SQLSampler):
237
239
  for column in columns
238
240
  ]
239
241
 
240
- tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
242
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
241
243
  tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
242
244
 
243
245
  sql = (f"SELECT "
@@ -245,7 +247,6 @@ class SQLiteSampler(SQLSampler):
245
247
  f"{', '.join(projections)}\n"
246
248
  f"FROM {quote_ident(tmp_name)} tmp\n"
247
249
  f"JOIN {self.source_name_dict[table_name]} ent\n")
248
-
249
250
  if key in source_table and source_table[key].is_unique_key:
250
251
  sql += (f" ON {key_ref} = tmp.__kumo_id__")
251
252
  else:
@@ -271,13 +272,70 @@ class SQLiteSampler(SQLSampler):
271
272
  stype_dict=self.table_stype_dict[table_name],
272
273
  ), batch
273
274
 
275
+ def _by_fkey(
276
+ self,
277
+ table_name: str,
278
+ foreign_key: str,
279
+ index: pd.Series,
280
+ num_neighbors: int,
281
+ anchor_time: pd.Series | None,
282
+ columns: set[str],
283
+ ) -> tuple[pd.DataFrame, np.ndarray]:
284
+ time_column = self.time_column_dict.get(table_name)
285
+
286
+ # NOTE SQLite does not have a native datetime format. Currently, we
287
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
288
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
289
+ if time_column is not None and anchor_time is not None:
290
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
291
+ tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
292
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
293
+
294
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
295
+ projections = [
296
+ self.table_column_proj_dict[table_name][column]
297
+ for column in columns
298
+ ]
299
+ sql = (f"SELECT "
300
+ f"tmp.rowid - 1 as __kumo_batch__, "
301
+ f"{', '.join(projections)}\n"
302
+ f"FROM {quote_ident(tmp_name)} tmp\n"
303
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
304
+ f"ON fact.rowid IN (\n"
305
+ f" SELECT rowid\n"
306
+ f" FROM {self.source_name_dict[table_name]}\n"
307
+ f" WHERE {key_ref} = tmp.__kumo_id__\n")
308
+ if time_column is not None and anchor_time is not None:
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
311
+ if time_column is not None:
312
+ time_ref = self.table_column_ref_dict[table_name][time_column]
313
+ sql += f" ORDER BY {time_ref} DESC\n"
314
+ sql += (f" LIMIT {num_neighbors}\n"
315
+ f")")
316
+
317
+ with self._connection.cursor() as cursor:
318
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
319
+ cursor.execute(sql)
320
+ table = cursor.fetch_arrow_table()
321
+
322
+ batch = table['__kumo_batch__'].to_numpy()
323
+ batch_index = table.schema.get_field_index('__kumo_batch__')
324
+ table = table.remove_column(batch_index)
325
+
326
+ return Table._sanitize(
327
+ df=table.to_pandas(),
328
+ dtype_dict=self.table_dtype_dict[table_name],
329
+ stype_dict=self.table_stype_dict[table_name],
330
+ ), batch
331
+
274
332
  # Helper Methods ##########################################################
275
333
 
276
334
  def _by_time(
277
335
  self,
278
336
  table_name: str,
279
- fkey: str,
280
- pkey: pd.Series,
337
+ foreign_key: str,
338
+ index: pd.Series,
281
339
  anchor_time: pd.Series,
282
340
  min_offset: pd.DateOffset | None,
283
341
  max_offset: pd.DateOffset,
@@ -287,7 +345,7 @@ class SQLiteSampler(SQLSampler):
287
345
 
288
346
  # NOTE SQLite does not have a native datetime format. Currently, we
289
347
  # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
290
- tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
348
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
291
349
  end_time = anchor_time + max_offset
292
350
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
293
351
  tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
@@ -295,9 +353,9 @@ class SQLiteSampler(SQLSampler):
295
353
  start_time = anchor_time + min_offset
296
354
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
297
355
  tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
298
- tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
356
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
299
357
 
300
- key_ref = self.table_column_ref_dict[table_name][fkey]
358
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
301
359
  time_ref = self.table_column_ref_dict[table_name][time_column]
302
360
  projections = [
303
361
  self.table_column_proj_dict[table_name][column]
@@ -307,7 +365,7 @@ class SQLiteSampler(SQLSampler):
307
365
  f"tmp.rowid - 1 as __kumo_batch__, "
308
366
  f"{', '.join(projections)}\n"
309
367
  f"FROM {quote_ident(tmp_name)} tmp\n"
310
- f"JOIN {self.source_name_dict[table_name]} fact\n"
368
+ f"JOIN {self.source_name_dict[table_name]}\n"
311
369
  f" ON {key_ref} = tmp.__kumo_id__\n"
312
370
  f" AND {time_ref} <= tmp.__kumo_end__")
313
371
  if min_offset is not None:
@@ -359,11 +417,11 @@ class SQLiteSampler(SQLSampler):
359
417
  query.entity_table: np.arange(len(df)),
360
418
  }
361
419
  for edge_type, (_min, _max) in time_offset_dict.items():
362
- table_name, fkey, _ = edge_type
420
+ table_name, foreign_key, _ = edge_type
363
421
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
364
422
  table_name=table_name,
365
- fkey=fkey,
366
- pkey=df[self.primary_key_dict[query.entity_table]],
423
+ foreign_key=foreign_key,
424
+ index=df[self.primary_key_dict[query.entity_table]],
367
425
  anchor_time=time,
368
426
  min_offset=_min,
369
427
  max_offset=_max,
@@ -378,7 +436,7 @@ class SQLiteSampler(SQLSampler):
378
436
  feat_dict=feat_dict,
379
437
  time_dict=time_dict,
380
438
  batch_dict=batch_dict,
381
- anchor_time=anchor_time,
439
+ anchor_time=time,
382
440
  num_forecasts=query.num_forecasts,
383
441
  )
384
442
  ys.append(y)
@@ -0,0 +1,69 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ class Mapper:
6
+ r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
7
+
8
+ Args:
9
+ num_examples: The maximum number of examples to add/retrieve.
10
+ """
11
+ def __init__(self, num_examples: int):
12
+ self._pkey_dtype: pd.CategoricalDtype | None = None
13
+ self._indices: list[np.ndarray] = []
14
+ self._index_dtype: pd.CategoricalDtype | None = None
15
+ self._num_examples = num_examples
16
+
17
+ def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
18
+ r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
19
+
20
+ Args:
21
+ pkey: The primary keys.
22
+ batch: The batch vector.
23
+ """
24
+ if self._pkey_dtype is not None:
25
+ category = np.concatenate([
26
+ self._pkey_dtype.categories.values,
27
+ pkey,
28
+ ], axis=0)
29
+ category = pd.unique(category)
30
+ self._pkey_dtype = pd.CategoricalDtype(category)
31
+ elif pd.api.types.is_string_dtype(pkey):
32
+ category = pd.unique(pkey)
33
+ self._pkey_dtype = pd.CategoricalDtype(category)
34
+
35
+ if self._pkey_dtype is not None:
36
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
37
+ index = index.astype('int64')
38
+ else:
39
+ index = pkey.to_numpy()
40
+ index = self._num_examples * index + batch
41
+ self._indices.append(index)
42
+ self._index_dtype = None
43
+
44
+ def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
45
+ r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
46
+
47
+ Returns ``-1`` for any pair not registered in the mapping.
48
+
49
+ Args:
50
+ pkey: The primary keys.
51
+ batch: The batch vector.
52
+ """
53
+ if len(self._indices) == 0:
54
+ return np.full(len(pkey), -1, dtype=np.int64)
55
+
56
+ if self._index_dtype is None: # Lazy build index:
57
+ category = pd.unique(np.concatenate(self._indices))
58
+ self._index_dtype = pd.CategoricalDtype(category)
59
+
60
+ if self._pkey_dtype is not None:
61
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
62
+ index = index.astype('int64')
63
+ else:
64
+ index = pkey.to_numpy()
65
+ index = self._num_examples * index + batch
66
+
67
+ out = pd.Categorical(index, dtype=self._index_dtype).codes
68
+ out = out.astype('int64')
69
+ return out
@@ -59,6 +59,17 @@ class Sampler(ABC):
59
59
  self._edge_types.append(edge_type)
60
60
  self._edge_types.append(Subgraph.rev_edge_type(edge_type))
61
61
 
62
+ # Source Table -> [(Foreign Key, Destination Table)]
63
+ self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
64
+ # Destination Table -> [(Source Table, Foreign Key)]
65
+ self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
66
+ for table in graph.tables.values():
67
+ self._foreign_key_dict[table.name] = []
68
+ self._rev_foreign_key_dict[table.name] = []
69
+ for src_table, fkey, dst_table in graph.edges:
70
+ self._foreign_key_dict[src_table].append((fkey, dst_table))
71
+ self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
72
+
62
73
  self._primary_key_dict: dict[str, str] = {
63
74
  table.name: table._primary_key
64
75
  for table in graph.tables.values()
@@ -98,6 +109,16 @@ class Sampler(ABC):
98
109
  r"""All available edge types in the graph."""
99
110
  return self._edge_types
100
111
 
112
+ @property
113
+ def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
114
+ r"""The foreign keys for all tables in the graph."""
115
+ return self._foreign_key_dict
116
+
117
+ @property
118
+ def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
119
+ r"""The foreign key back references for all tables in the graph."""
120
+ return self._rev_foreign_key_dict
121
+
101
122
  @property
102
123
  def primary_key_dict(self) -> dict[str, str]:
103
124
  r"""All available primary keys in the graph."""
@@ -274,7 +295,8 @@ class Sampler(ABC):
274
295
 
275
296
  # Store in compressed representation if more efficient:
276
297
  num_cols = subgraph.table_dict[edge_type[2]].num_rows
277
- if col is not None and len(col) > num_cols + 1:
298
+ if (col is not None and len(col) > num_cols + 1
299
+ and ((col[1:] - col[:-1]) >= 0).all()):
278
300
  layout = EdgeLayout.CSC
279
301
  colcount = np.bincount(col, minlength=num_cols)
280
302
  col = np.empty(num_cols + 1, dtype=col.dtype)