kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +138 -28
- kumoai/experimental/rfm/backend/snow/table.py +16 -13
- kumoai/experimental/rfm/backend/sqlite/sampler.py +73 -15
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +23 -1
- kumoai/experimental/rfm/base/sql_sampler.py +252 -11
- kumoai/experimental/rfm/base/table.py +15 -29
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +9 -9
- kumoai/experimental/rfm/infer/dtype.py +3 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/rfm.py +195 -114
- kumoai/experimental/rfm/task_table.py +2 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/utils/display.py +44 -8
- kumoai/utils/progress_logger.py +2 -1
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +25 -23
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.15.0.dev202601141731'
|
kumoai/client/jobs.py
CHANGED
|
@@ -344,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
|
|
|
344
344
|
id: str,
|
|
345
345
|
source_table_type: SourceTableType,
|
|
346
346
|
train_table_mod: TrainingTableSpec,
|
|
347
|
+
extensive_validation: bool,
|
|
347
348
|
) -> ValidationResponse:
|
|
348
349
|
response = self._client._post(
|
|
349
350
|
f'{self._base_endpoint}/{id}/validate_custom_train_table',
|
|
350
351
|
json=to_json_dict({
|
|
351
352
|
'custom_table': source_table_type,
|
|
352
353
|
'train_table_mod': train_table_mod,
|
|
354
|
+
'extensive_validation': extensive_validation,
|
|
353
355
|
}),
|
|
354
356
|
)
|
|
355
357
|
return parse_response(ValidationResponse, response)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
from collections.abc import Iterator
|
|
3
4
|
from contextlib import contextmanager
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pandas as pd
|
|
@@ -11,7 +12,7 @@ from kumoapi.pquery import ValidatedPredictiveQuery
|
|
|
11
12
|
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
12
13
|
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
13
14
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
14
|
-
from kumoai.utils import ProgressLogger
|
|
15
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from kumoai.experimental.rfm import Graph
|
|
@@ -37,6 +38,15 @@ class SnowSampler(SQLSampler):
|
|
|
37
38
|
assert isinstance(table, SnowTable)
|
|
38
39
|
self._connection = table._connection
|
|
39
40
|
|
|
41
|
+
self._num_rows_dict: dict[str, int] = {
|
|
42
|
+
table.name: cast(int, table._num_rows)
|
|
43
|
+
for table in graph.tables.values()
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_rows_dict(self) -> dict[str, int]:
|
|
48
|
+
return self._num_rows_dict
|
|
49
|
+
|
|
40
50
|
def _get_min_max_time_dict(
|
|
41
51
|
self,
|
|
42
52
|
table_names: list[str],
|
|
@@ -45,8 +55,9 @@ class SnowSampler(SQLSampler):
|
|
|
45
55
|
for table_name in table_names:
|
|
46
56
|
column = self.time_column_dict[table_name]
|
|
47
57
|
column_ref = self.table_column_ref_dict[table_name][column]
|
|
58
|
+
ident = quote_ident(table_name, char="'")
|
|
48
59
|
select = (f"SELECT\n"
|
|
49
|
-
f"
|
|
60
|
+
f" {ident} as table_name,\n"
|
|
50
61
|
f" MIN({column_ref}) as min_date,\n"
|
|
51
62
|
f" MAX({column_ref}) as max_date\n"
|
|
52
63
|
f"FROM {self.source_name_dict[table_name]}")
|
|
@@ -54,14 +65,13 @@ class SnowSampler(SQLSampler):
|
|
|
54
65
|
sql = "\nUNION ALL\n".join(selects)
|
|
55
66
|
|
|
56
67
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
57
|
-
with
|
|
58
|
-
cursor.execute(sql
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
)
|
|
68
|
+
with self._connection.cursor() as cursor:
|
|
69
|
+
cursor.execute(sql)
|
|
70
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
71
|
+
out_dict[table_name] = (
|
|
72
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
73
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
74
|
+
)
|
|
65
75
|
|
|
66
76
|
return out_dict
|
|
67
77
|
|
|
@@ -144,11 +154,11 @@ class SnowSampler(SQLSampler):
|
|
|
144
154
|
query.entity_table: np.arange(len(entity_df)),
|
|
145
155
|
}
|
|
146
156
|
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
147
|
-
table_name,
|
|
157
|
+
table_name, foreign_key, _ = edge_type
|
|
148
158
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
149
159
|
table_name=table_name,
|
|
150
|
-
|
|
151
|
-
|
|
160
|
+
foreign_key=foreign_key,
|
|
161
|
+
index=entity_df[self.primary_key_dict[query.entity_table]],
|
|
152
162
|
anchor_time=time,
|
|
153
163
|
min_offset=min_offset,
|
|
154
164
|
max_offset=max_offset,
|
|
@@ -179,7 +189,7 @@ class SnowSampler(SQLSampler):
|
|
|
179
189
|
def _by_pkey(
|
|
180
190
|
self,
|
|
181
191
|
table_name: str,
|
|
182
|
-
|
|
192
|
+
index: pd.Series,
|
|
183
193
|
columns: set[str],
|
|
184
194
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
185
195
|
key = self.primary_key_dict[table_name]
|
|
@@ -189,7 +199,7 @@ class SnowSampler(SQLSampler):
|
|
|
189
199
|
for column in columns
|
|
190
200
|
]
|
|
191
201
|
|
|
192
|
-
payload = json.dumps(list(
|
|
202
|
+
payload = json.dumps(list(index))
|
|
193
203
|
|
|
194
204
|
sql = ("WITH TMP as (\n"
|
|
195
205
|
" SELECT\n"
|
|
@@ -206,7 +216,7 @@ class SnowSampler(SQLSampler):
|
|
|
206
216
|
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
207
217
|
f"{', '.join(projections)}\n"
|
|
208
218
|
f"FROM TMP\n"
|
|
209
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
219
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
210
220
|
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
211
221
|
|
|
212
222
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
@@ -228,13 +238,108 @@ class SnowSampler(SQLSampler):
|
|
|
228
238
|
stype_dict=self.table_stype_dict[table_name],
|
|
229
239
|
), batch
|
|
230
240
|
|
|
241
|
+
def _by_fkey(
|
|
242
|
+
self,
|
|
243
|
+
table_name: str,
|
|
244
|
+
foreign_key: str,
|
|
245
|
+
index: pd.Series,
|
|
246
|
+
num_neighbors: int,
|
|
247
|
+
anchor_time: pd.Series | None,
|
|
248
|
+
columns: set[str],
|
|
249
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
250
|
+
time_column = self.time_column_dict.get(table_name)
|
|
251
|
+
|
|
252
|
+
end_time: pd.Series | None = None
|
|
253
|
+
start_time: pd.Series | None = None
|
|
254
|
+
if time_column is not None and anchor_time is not None:
|
|
255
|
+
# In order to avoid a full table scan, we limit foreign key
|
|
256
|
+
# sampling to a certain time range, approximated by the number of
|
|
257
|
+
# rows, timestamp ranges and `num_neighbors` value.
|
|
258
|
+
# Downstream, this helps Snowflake to apply partition pruning:
|
|
259
|
+
dst_table_name = [
|
|
260
|
+
dst_table
|
|
261
|
+
for key, dst_table in self.foreign_key_dict[table_name]
|
|
262
|
+
if key == foreign_key
|
|
263
|
+
][0]
|
|
264
|
+
num_facts = self.num_rows_dict[table_name]
|
|
265
|
+
num_entities = self.num_rows_dict[dst_table_name]
|
|
266
|
+
min_time = self.get_min_time([table_name])
|
|
267
|
+
max_time = self.get_max_time([table_name])
|
|
268
|
+
freq = num_facts / num_entities
|
|
269
|
+
freq = freq / max((max_time - min_time).total_seconds(), 1)
|
|
270
|
+
offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
|
|
271
|
+
|
|
272
|
+
end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
273
|
+
start_time = anchor_time - offset
|
|
274
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
275
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
276
|
+
else:
|
|
277
|
+
payload = json.dumps(list(zip(index)))
|
|
278
|
+
|
|
279
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
280
|
+
projections = [
|
|
281
|
+
self.table_column_proj_dict[table_name][column]
|
|
282
|
+
for column in columns
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
sql = ("WITH TMP as (\n"
|
|
286
|
+
" SELECT\n"
|
|
287
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
288
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
289
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__"
|
|
290
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
291
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
292
|
+
else:
|
|
293
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
294
|
+
if end_time is not None and start_time is not None:
|
|
295
|
+
sql += (",\n"
|
|
296
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
|
|
297
|
+
" f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
|
|
298
|
+
sql += (f"\n"
|
|
299
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
300
|
+
f")\n"
|
|
301
|
+
f"SELECT "
|
|
302
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
303
|
+
f"{', '.join(projections)}\n"
|
|
304
|
+
f"FROM TMP\n"
|
|
305
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
306
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
307
|
+
if end_time is not None and start_time is not None:
|
|
308
|
+
assert time_column is not None
|
|
309
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
310
|
+
sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
|
|
311
|
+
f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
312
|
+
f"WHERE {time_ref} <= '{end_time.max()}'\n"
|
|
313
|
+
f" AND {time_ref} > '{start_time.min()}'\n")
|
|
314
|
+
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
315
|
+
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
316
|
+
if time_column is not None:
|
|
317
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
318
|
+
else:
|
|
319
|
+
sql += f" ORDER BY {key_ref}\n"
|
|
320
|
+
sql += f") <= {num_neighbors}"
|
|
321
|
+
|
|
322
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
323
|
+
cursor.execute(sql, (payload, ))
|
|
324
|
+
table = cursor.fetch_arrow_all()
|
|
325
|
+
|
|
326
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
327
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
328
|
+
table = table.remove_column(batch_index)
|
|
329
|
+
|
|
330
|
+
return Table._sanitize(
|
|
331
|
+
df=table.to_pandas(),
|
|
332
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
333
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
334
|
+
), batch
|
|
335
|
+
|
|
231
336
|
# Helper Methods ##########################################################
|
|
232
337
|
|
|
233
338
|
def _by_time(
|
|
234
339
|
self,
|
|
235
340
|
table_name: str,
|
|
236
|
-
|
|
237
|
-
|
|
341
|
+
foreign_key: str,
|
|
342
|
+
index: pd.Series,
|
|
238
343
|
anchor_time: pd.Series,
|
|
239
344
|
min_offset: pd.DateOffset | None,
|
|
240
345
|
max_offset: pd.DateOffset,
|
|
@@ -244,14 +349,15 @@ class SnowSampler(SQLSampler):
|
|
|
244
349
|
|
|
245
350
|
end_time = anchor_time + max_offset
|
|
246
351
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
352
|
+
start_time: pd.Series | None = None
|
|
247
353
|
if min_offset is not None:
|
|
248
354
|
start_time = anchor_time + min_offset
|
|
249
355
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
250
|
-
payload = json.dumps(list(zip(
|
|
356
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
251
357
|
else:
|
|
252
|
-
payload = json.dumps(list(zip(
|
|
358
|
+
payload = json.dumps(list(zip(index, end_time)))
|
|
253
359
|
|
|
254
|
-
key_ref = self.table_column_ref_dict[table_name][
|
|
360
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
255
361
|
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
256
362
|
projections = [
|
|
257
363
|
self.table_column_proj_dict[table_name][column]
|
|
@@ -260,9 +366,9 @@ class SnowSampler(SQLSampler):
|
|
|
260
366
|
sql = ("WITH TMP as (\n"
|
|
261
367
|
" SELECT\n"
|
|
262
368
|
" f.index as __KUMO_BATCH__,\n")
|
|
263
|
-
if self.table_dtype_dict[table_name][
|
|
369
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
264
370
|
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
265
|
-
elif self.table_dtype_dict[table_name][
|
|
371
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
266
372
|
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
267
373
|
else:
|
|
268
374
|
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
@@ -276,11 +382,15 @@ class SnowSampler(SQLSampler):
|
|
|
276
382
|
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
277
383
|
f"{', '.join(projections)}\n"
|
|
278
384
|
f"FROM TMP\n"
|
|
279
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
385
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
280
386
|
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
281
|
-
f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
|
|
282
|
-
if
|
|
283
|
-
sql += f"
|
|
387
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
|
|
388
|
+
if start_time is not None:
|
|
389
|
+
sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
390
|
+
# Add global time bounds to enable partition pruning:
|
|
391
|
+
sql += f"WHERE {time_ref} <= '{end_time.max()}'"
|
|
392
|
+
if start_time is not None:
|
|
393
|
+
sql += f"\nAND {time_ref} > '{start_time.min()}'"
|
|
284
394
|
|
|
285
395
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
286
396
|
cursor.execute(sql, (payload, ))
|
|
@@ -76,21 +76,13 @@ class SnowTable(Table):
|
|
|
76
76
|
|
|
77
77
|
@property
|
|
78
78
|
def source_name(self) -> str:
|
|
79
|
-
names
|
|
80
|
-
|
|
81
|
-
names.append(self._database)
|
|
82
|
-
if self._schema is not None:
|
|
83
|
-
names.append(self._schema)
|
|
84
|
-
return '.'.join(names + [self._source_name])
|
|
79
|
+
names = [self._database, self._schema, self._source_name]
|
|
80
|
+
return '.'.join(names)
|
|
85
81
|
|
|
86
82
|
@property
|
|
87
83
|
def _quoted_source_name(self) -> str:
|
|
88
|
-
names
|
|
89
|
-
|
|
90
|
-
names.append(quote_ident(self._database))
|
|
91
|
-
if self._schema is not None:
|
|
92
|
-
names.append(quote_ident(self._schema))
|
|
93
|
-
return '.'.join(names + [quote_ident(self._source_name)])
|
|
84
|
+
names = [self._database, self._schema, self._source_name]
|
|
85
|
+
return '.'.join([quote_ident(name) for name in names])
|
|
94
86
|
|
|
95
87
|
@property
|
|
96
88
|
def backend(self) -> DataBackend:
|
|
@@ -159,7 +151,18 @@ class SnowTable(Table):
|
|
|
159
151
|
)
|
|
160
152
|
|
|
161
153
|
def _get_num_rows(self) -> int | None:
|
|
162
|
-
|
|
154
|
+
with self._connection.cursor() as cursor:
|
|
155
|
+
quoted_source_name = quote_ident(self._source_name, char="'")
|
|
156
|
+
sql = (f"SHOW TABLES LIKE {quoted_source_name} "
|
|
157
|
+
f"IN SCHEMA {quote_ident(self._database)}."
|
|
158
|
+
f"{quote_ident(self._schema)}")
|
|
159
|
+
cursor.execute(sql)
|
|
160
|
+
num_rows = cursor.fetchone()[7]
|
|
161
|
+
|
|
162
|
+
if num_rows == 0:
|
|
163
|
+
raise RuntimeError("Table '{self.source_name}' is empty")
|
|
164
|
+
|
|
165
|
+
return num_rows
|
|
163
166
|
|
|
164
167
|
def _get_expr_sample_df(
|
|
165
168
|
self,
|
|
@@ -121,8 +121,9 @@ class SQLiteSampler(SQLSampler):
|
|
|
121
121
|
for table_name in table_names:
|
|
122
122
|
column = self.time_column_dict[table_name]
|
|
123
123
|
column_ref = self.table_column_ref_dict[table_name][column]
|
|
124
|
+
ident = quote_ident(table_name, char="'")
|
|
124
125
|
select = (f"SELECT\n"
|
|
125
|
-
f"
|
|
126
|
+
f" {ident} as table_name,\n"
|
|
126
127
|
f" MIN({column_ref}) as min_date,\n"
|
|
127
128
|
f" MAX({column_ref}) as max_date\n"
|
|
128
129
|
f"FROM {self.source_name_dict[table_name]}")
|
|
@@ -131,12 +132,13 @@ class SQLiteSampler(SQLSampler):
|
|
|
131
132
|
|
|
132
133
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
133
134
|
with self._connection.cursor() as cursor:
|
|
134
|
-
cursor.execute(sql
|
|
135
|
+
cursor.execute(sql)
|
|
135
136
|
for table_name, _min, _max in cursor.fetchall():
|
|
136
137
|
out_dict[table_name] = (
|
|
137
138
|
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
138
139
|
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
139
140
|
)
|
|
141
|
+
|
|
140
142
|
return out_dict
|
|
141
143
|
|
|
142
144
|
def _sample_entity_table(
|
|
@@ -226,7 +228,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
226
228
|
def _by_pkey(
|
|
227
229
|
self,
|
|
228
230
|
table_name: str,
|
|
229
|
-
|
|
231
|
+
index: pd.Series,
|
|
230
232
|
columns: set[str],
|
|
231
233
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
232
234
|
source_table = self.source_table_dict[table_name]
|
|
@@ -237,7 +239,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
237
239
|
for column in columns
|
|
238
240
|
]
|
|
239
241
|
|
|
240
|
-
tmp = pa.table([pa.array(
|
|
242
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
241
243
|
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
244
|
|
|
243
245
|
sql = (f"SELECT "
|
|
@@ -245,7 +247,6 @@ class SQLiteSampler(SQLSampler):
|
|
|
245
247
|
f"{', '.join(projections)}\n"
|
|
246
248
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
249
|
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
248
|
-
|
|
249
250
|
if key in source_table and source_table[key].is_unique_key:
|
|
250
251
|
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
251
252
|
else:
|
|
@@ -271,13 +272,70 @@ class SQLiteSampler(SQLSampler):
|
|
|
271
272
|
stype_dict=self.table_stype_dict[table_name],
|
|
272
273
|
), batch
|
|
273
274
|
|
|
275
|
+
def _by_fkey(
|
|
276
|
+
self,
|
|
277
|
+
table_name: str,
|
|
278
|
+
foreign_key: str,
|
|
279
|
+
index: pd.Series,
|
|
280
|
+
num_neighbors: int,
|
|
281
|
+
anchor_time: pd.Series | None,
|
|
282
|
+
columns: set[str],
|
|
283
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
284
|
+
time_column = self.time_column_dict.get(table_name)
|
|
285
|
+
|
|
286
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
287
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
288
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
289
|
+
if time_column is not None and anchor_time is not None:
|
|
290
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
291
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
292
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
293
|
+
|
|
294
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
295
|
+
projections = [
|
|
296
|
+
self.table_column_proj_dict[table_name][column]
|
|
297
|
+
for column in columns
|
|
298
|
+
]
|
|
299
|
+
sql = (f"SELECT "
|
|
300
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
301
|
+
f"{', '.join(projections)}\n"
|
|
302
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
303
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
304
|
+
f"ON fact.rowid IN (\n"
|
|
305
|
+
f" SELECT rowid\n"
|
|
306
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
307
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
308
|
+
if time_column is not None and anchor_time is not None:
|
|
309
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
310
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
311
|
+
if time_column is not None:
|
|
312
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
313
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
314
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
315
|
+
f")")
|
|
316
|
+
|
|
317
|
+
with self._connection.cursor() as cursor:
|
|
318
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
319
|
+
cursor.execute(sql)
|
|
320
|
+
table = cursor.fetch_arrow_table()
|
|
321
|
+
|
|
322
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
323
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
324
|
+
table = table.remove_column(batch_index)
|
|
325
|
+
|
|
326
|
+
return Table._sanitize(
|
|
327
|
+
df=table.to_pandas(),
|
|
328
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
329
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
330
|
+
), batch
|
|
331
|
+
|
|
274
332
|
# Helper Methods ##########################################################
|
|
275
333
|
|
|
276
334
|
def _by_time(
|
|
277
335
|
self,
|
|
278
336
|
table_name: str,
|
|
279
|
-
|
|
280
|
-
|
|
337
|
+
foreign_key: str,
|
|
338
|
+
index: pd.Series,
|
|
281
339
|
anchor_time: pd.Series,
|
|
282
340
|
min_offset: pd.DateOffset | None,
|
|
283
341
|
max_offset: pd.DateOffset,
|
|
@@ -287,7 +345,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
287
345
|
|
|
288
346
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
289
347
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
290
|
-
tmp = pa.table([pa.array(
|
|
348
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
291
349
|
end_time = anchor_time + max_offset
|
|
292
350
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
293
351
|
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
@@ -295,9 +353,9 @@ class SQLiteSampler(SQLSampler):
|
|
|
295
353
|
start_time = anchor_time + min_offset
|
|
296
354
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
297
355
|
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
298
|
-
tmp_name = f'tmp_{table_name}_{
|
|
356
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
299
357
|
|
|
300
|
-
key_ref = self.table_column_ref_dict[table_name][
|
|
358
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
301
359
|
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
302
360
|
projections = [
|
|
303
361
|
self.table_column_proj_dict[table_name][column]
|
|
@@ -307,7 +365,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
307
365
|
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
308
366
|
f"{', '.join(projections)}\n"
|
|
309
367
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
310
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
368
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
311
369
|
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
312
370
|
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
313
371
|
if min_offset is not None:
|
|
@@ -359,11 +417,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
359
417
|
query.entity_table: np.arange(len(df)),
|
|
360
418
|
}
|
|
361
419
|
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
362
|
-
table_name,
|
|
420
|
+
table_name, foreign_key, _ = edge_type
|
|
363
421
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
364
422
|
table_name=table_name,
|
|
365
|
-
|
|
366
|
-
|
|
423
|
+
foreign_key=foreign_key,
|
|
424
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
367
425
|
anchor_time=time,
|
|
368
426
|
min_offset=_min,
|
|
369
427
|
max_offset=_max,
|
|
@@ -378,7 +436,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
378
436
|
feat_dict=feat_dict,
|
|
379
437
|
time_dict=time_dict,
|
|
380
438
|
batch_dict=batch_dict,
|
|
381
|
-
anchor_time=
|
|
439
|
+
anchor_time=time,
|
|
382
440
|
num_forecasts=query.num_forecasts,
|
|
383
441
|
)
|
|
384
442
|
ys.append(y)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Mapper:
|
|
6
|
+
r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
num_examples: The maximum number of examples to add/retrieve.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, num_examples: int):
|
|
12
|
+
self._pkey_dtype: pd.CategoricalDtype | None = None
|
|
13
|
+
self._indices: list[np.ndarray] = []
|
|
14
|
+
self._index_dtype: pd.CategoricalDtype | None = None
|
|
15
|
+
self._num_examples = num_examples
|
|
16
|
+
|
|
17
|
+
def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
|
|
18
|
+
r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
pkey: The primary keys.
|
|
22
|
+
batch: The batch vector.
|
|
23
|
+
"""
|
|
24
|
+
if self._pkey_dtype is not None:
|
|
25
|
+
category = np.concatenate([
|
|
26
|
+
self._pkey_dtype.categories.values,
|
|
27
|
+
pkey,
|
|
28
|
+
], axis=0)
|
|
29
|
+
category = pd.unique(category)
|
|
30
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
31
|
+
elif pd.api.types.is_string_dtype(pkey):
|
|
32
|
+
category = pd.unique(pkey)
|
|
33
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
34
|
+
|
|
35
|
+
if self._pkey_dtype is not None:
|
|
36
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
37
|
+
index = index.astype('int64')
|
|
38
|
+
else:
|
|
39
|
+
index = pkey.to_numpy()
|
|
40
|
+
index = self._num_examples * index + batch
|
|
41
|
+
self._indices.append(index)
|
|
42
|
+
self._index_dtype = None
|
|
43
|
+
|
|
44
|
+
def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
|
|
45
|
+
r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
|
|
46
|
+
|
|
47
|
+
Returns ``-1`` for any pair not registered in the mapping.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
pkey: The primary keys.
|
|
51
|
+
batch: The batch vector.
|
|
52
|
+
"""
|
|
53
|
+
if len(self._indices) == 0:
|
|
54
|
+
return np.full(len(pkey), -1, dtype=np.int64)
|
|
55
|
+
|
|
56
|
+
if self._index_dtype is None: # Lazy build index:
|
|
57
|
+
category = pd.unique(np.concatenate(self._indices))
|
|
58
|
+
self._index_dtype = pd.CategoricalDtype(category)
|
|
59
|
+
|
|
60
|
+
if self._pkey_dtype is not None:
|
|
61
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
62
|
+
index = index.astype('int64')
|
|
63
|
+
else:
|
|
64
|
+
index = pkey.to_numpy()
|
|
65
|
+
index = self._num_examples * index + batch
|
|
66
|
+
|
|
67
|
+
out = pd.Categorical(index, dtype=self._index_dtype).codes
|
|
68
|
+
out = out.astype('int64')
|
|
69
|
+
return out
|
|
@@ -59,6 +59,17 @@ class Sampler(ABC):
|
|
|
59
59
|
self._edge_types.append(edge_type)
|
|
60
60
|
self._edge_types.append(Subgraph.rev_edge_type(edge_type))
|
|
61
61
|
|
|
62
|
+
# Source Table -> [(Foreign Key, Destination Table)]
|
|
63
|
+
self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
64
|
+
# Destination Table -> [(Source Table, Foreign Key)]
|
|
65
|
+
self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
66
|
+
for table in graph.tables.values():
|
|
67
|
+
self._foreign_key_dict[table.name] = []
|
|
68
|
+
self._rev_foreign_key_dict[table.name] = []
|
|
69
|
+
for src_table, fkey, dst_table in graph.edges:
|
|
70
|
+
self._foreign_key_dict[src_table].append((fkey, dst_table))
|
|
71
|
+
self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
|
|
72
|
+
|
|
62
73
|
self._primary_key_dict: dict[str, str] = {
|
|
63
74
|
table.name: table._primary_key
|
|
64
75
|
for table in graph.tables.values()
|
|
@@ -98,6 +109,16 @@ class Sampler(ABC):
|
|
|
98
109
|
r"""All available edge types in the graph."""
|
|
99
110
|
return self._edge_types
|
|
100
111
|
|
|
112
|
+
@property
|
|
113
|
+
def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
114
|
+
r"""The foreign keys for all tables in the graph."""
|
|
115
|
+
return self._foreign_key_dict
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
119
|
+
r"""The foreign key back references for all tables in the graph."""
|
|
120
|
+
return self._rev_foreign_key_dict
|
|
121
|
+
|
|
101
122
|
@property
|
|
102
123
|
def primary_key_dict(self) -> dict[str, str]:
|
|
103
124
|
r"""All available primary keys in the graph."""
|
|
@@ -274,7 +295,8 @@ class Sampler(ABC):
|
|
|
274
295
|
|
|
275
296
|
# Store in compressed representation if more efficient:
|
|
276
297
|
num_cols = subgraph.table_dict[edge_type[2]].num_rows
|
|
277
|
-
if col is not None and len(col) > num_cols + 1
|
|
298
|
+
if (col is not None and len(col) > num_cols + 1
|
|
299
|
+
and ((col[1:] - col[:-1]) >= 0).all()):
|
|
278
300
|
layout = EdgeLayout.CSC
|
|
279
301
|
colcount = np.bincount(col, minlength=num_cols)
|
|
280
302
|
col = np.empty(num_cols + 1, dtype=col.dtype)
|