kumoai 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/backend/snow/sampler.py +61 -20
- kumoai/experimental/rfm/backend/snow/table.py +16 -13
- kumoai/experimental/rfm/backend/sqlite/sampler.py +5 -3
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +182 -44
- kumoai/experimental/rfm/base/table.py +3 -22
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +15 -3
- kumoai/experimental/rfm/infer/dtype.py +3 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/rfm.py +10 -2
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/testing/snow.py +3 -3
- kumoai/utils/progress_logger.py +2 -1
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +22 -20
- {kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.15.0.dev202601151732'
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
from collections.abc import Iterator
|
|
3
4
|
from contextlib import contextmanager
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pandas as pd
|
|
@@ -11,7 +12,7 @@ from kumoapi.pquery import ValidatedPredictiveQuery
|
|
|
11
12
|
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
12
13
|
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
13
14
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
14
|
-
from kumoai.utils import ProgressLogger
|
|
15
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
16
|
|
|
16
17
|
if TYPE_CHECKING:
|
|
17
18
|
from kumoai.experimental.rfm import Graph
|
|
@@ -37,6 +38,15 @@ class SnowSampler(SQLSampler):
|
|
|
37
38
|
assert isinstance(table, SnowTable)
|
|
38
39
|
self._connection = table._connection
|
|
39
40
|
|
|
41
|
+
self._num_rows_dict: dict[str, int] = {
|
|
42
|
+
table.name: cast(int, table._num_rows)
|
|
43
|
+
for table in graph.tables.values()
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_rows_dict(self) -> dict[str, int]:
|
|
48
|
+
return self._num_rows_dict
|
|
49
|
+
|
|
40
50
|
def _get_min_max_time_dict(
|
|
41
51
|
self,
|
|
42
52
|
table_names: list[str],
|
|
@@ -45,8 +55,9 @@ class SnowSampler(SQLSampler):
|
|
|
45
55
|
for table_name in table_names:
|
|
46
56
|
column = self.time_column_dict[table_name]
|
|
47
57
|
column_ref = self.table_column_ref_dict[table_name][column]
|
|
58
|
+
ident = quote_ident(table_name, char="'")
|
|
48
59
|
select = (f"SELECT\n"
|
|
49
|
-
f"
|
|
60
|
+
f" {ident} as table_name,\n"
|
|
50
61
|
f" MIN({column_ref}) as min_date,\n"
|
|
51
62
|
f" MAX({column_ref}) as max_date\n"
|
|
52
63
|
f"FROM {self.source_name_dict[table_name]}")
|
|
@@ -54,14 +65,13 @@ class SnowSampler(SQLSampler):
|
|
|
54
65
|
sql = "\nUNION ALL\n".join(selects)
|
|
55
66
|
|
|
56
67
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
57
|
-
with
|
|
58
|
-
cursor.execute(sql
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
)
|
|
68
|
+
with self._connection.cursor() as cursor:
|
|
69
|
+
cursor.execute(sql)
|
|
70
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
71
|
+
out_dict[table_name] = (
|
|
72
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
73
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
74
|
+
)
|
|
65
75
|
|
|
66
76
|
return out_dict
|
|
67
77
|
|
|
@@ -239,9 +249,30 @@ class SnowSampler(SQLSampler):
|
|
|
239
249
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
250
|
time_column = self.time_column_dict.get(table_name)
|
|
241
251
|
|
|
252
|
+
end_time: pd.Series | None = None
|
|
253
|
+
start_time: pd.Series | None = None
|
|
242
254
|
if time_column is not None and anchor_time is not None:
|
|
243
|
-
|
|
244
|
-
|
|
255
|
+
# In order to avoid a full table scan, we limit foreign key
|
|
256
|
+
# sampling to a certain time range, approximated by the number of
|
|
257
|
+
# rows, timestamp ranges and `num_neighbors` value.
|
|
258
|
+
# Downstream, this helps Snowflake to apply partition pruning:
|
|
259
|
+
dst_table_name = [
|
|
260
|
+
dst_table
|
|
261
|
+
for key, dst_table in self.foreign_key_dict[table_name]
|
|
262
|
+
if key == foreign_key
|
|
263
|
+
][0]
|
|
264
|
+
num_facts = self.num_rows_dict[table_name]
|
|
265
|
+
num_entities = self.num_rows_dict[dst_table_name]
|
|
266
|
+
min_time = self.get_min_time([table_name])
|
|
267
|
+
max_time = self.get_max_time([table_name])
|
|
268
|
+
freq = num_facts / num_entities
|
|
269
|
+
freq = freq / max((max_time - min_time).total_seconds(), 1)
|
|
270
|
+
offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
|
|
271
|
+
|
|
272
|
+
end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
273
|
+
start_time = anchor_time - offset
|
|
274
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
275
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
245
276
|
else:
|
|
246
277
|
payload = json.dumps(list(zip(index)))
|
|
247
278
|
|
|
@@ -260,9 +291,10 @@ class SnowSampler(SQLSampler):
|
|
|
260
291
|
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
261
292
|
else:
|
|
262
293
|
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
263
|
-
if
|
|
294
|
+
if end_time is not None and start_time is not None:
|
|
264
295
|
sql += (",\n"
|
|
265
|
-
" f.value[1]::TIMESTAMP_NTZ as
|
|
296
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
|
|
297
|
+
" f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
|
|
266
298
|
sql += (f"\n"
|
|
267
299
|
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
268
300
|
f")\n"
|
|
@@ -272,9 +304,13 @@ class SnowSampler(SQLSampler):
|
|
|
272
304
|
f"FROM TMP\n"
|
|
273
305
|
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
274
306
|
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
275
|
-
if
|
|
307
|
+
if end_time is not None and start_time is not None:
|
|
308
|
+
assert time_column is not None
|
|
276
309
|
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
277
|
-
sql += f" AND {time_ref} <= TMP.
|
|
310
|
+
sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
|
|
311
|
+
f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
312
|
+
f"WHERE {time_ref} <= '{end_time.max()}'\n"
|
|
313
|
+
f" AND {time_ref} > '{start_time.min()}'\n")
|
|
278
314
|
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
279
315
|
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
280
316
|
if time_column is not None:
|
|
@@ -313,6 +349,7 @@ class SnowSampler(SQLSampler):
|
|
|
313
349
|
|
|
314
350
|
end_time = anchor_time + max_offset
|
|
315
351
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
352
|
+
start_time: pd.Series | None = None
|
|
316
353
|
if min_offset is not None:
|
|
317
354
|
start_time = anchor_time + min_offset
|
|
318
355
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
@@ -347,9 +384,13 @@ class SnowSampler(SQLSampler):
|
|
|
347
384
|
f"FROM TMP\n"
|
|
348
385
|
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
349
386
|
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
350
|
-
f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
|
|
351
|
-
if
|
|
352
|
-
sql += f"
|
|
387
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
|
|
388
|
+
if start_time is not None:
|
|
389
|
+
sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
390
|
+
# Add global time bounds to enable partition pruning:
|
|
391
|
+
sql += f"WHERE {time_ref} <= '{end_time.max()}'"
|
|
392
|
+
if start_time is not None:
|
|
393
|
+
sql += f"\nAND {time_ref} > '{start_time.min()}'"
|
|
353
394
|
|
|
354
395
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
355
396
|
cursor.execute(sql, (payload, ))
|
|
@@ -76,21 +76,13 @@ class SnowTable(Table):
|
|
|
76
76
|
|
|
77
77
|
@property
|
|
78
78
|
def source_name(self) -> str:
|
|
79
|
-
names
|
|
80
|
-
|
|
81
|
-
names.append(self._database)
|
|
82
|
-
if self._schema is not None:
|
|
83
|
-
names.append(self._schema)
|
|
84
|
-
return '.'.join(names + [self._source_name])
|
|
79
|
+
names = [self._database, self._schema, self._source_name]
|
|
80
|
+
return '.'.join(names)
|
|
85
81
|
|
|
86
82
|
@property
|
|
87
83
|
def _quoted_source_name(self) -> str:
|
|
88
|
-
names
|
|
89
|
-
|
|
90
|
-
names.append(quote_ident(self._database))
|
|
91
|
-
if self._schema is not None:
|
|
92
|
-
names.append(quote_ident(self._schema))
|
|
93
|
-
return '.'.join(names + [quote_ident(self._source_name)])
|
|
84
|
+
names = [self._database, self._schema, self._source_name]
|
|
85
|
+
return '.'.join([quote_ident(name) for name in names])
|
|
94
86
|
|
|
95
87
|
@property
|
|
96
88
|
def backend(self) -> DataBackend:
|
|
@@ -159,7 +151,18 @@ class SnowTable(Table):
|
|
|
159
151
|
)
|
|
160
152
|
|
|
161
153
|
def _get_num_rows(self) -> int | None:
|
|
162
|
-
|
|
154
|
+
with self._connection.cursor() as cursor:
|
|
155
|
+
quoted_source_name = quote_ident(self._source_name, char="'")
|
|
156
|
+
sql = (f"SHOW TABLES LIKE {quoted_source_name} "
|
|
157
|
+
f"IN SCHEMA {quote_ident(self._database)}."
|
|
158
|
+
f"{quote_ident(self._schema)}")
|
|
159
|
+
cursor.execute(sql)
|
|
160
|
+
num_rows = cursor.fetchone()[7]
|
|
161
|
+
|
|
162
|
+
if num_rows == 0:
|
|
163
|
+
raise RuntimeError("Table '{self.source_name}' is empty")
|
|
164
|
+
|
|
165
|
+
return num_rows
|
|
163
166
|
|
|
164
167
|
def _get_expr_sample_df(
|
|
165
168
|
self,
|
|
@@ -121,8 +121,9 @@ class SQLiteSampler(SQLSampler):
|
|
|
121
121
|
for table_name in table_names:
|
|
122
122
|
column = self.time_column_dict[table_name]
|
|
123
123
|
column_ref = self.table_column_ref_dict[table_name][column]
|
|
124
|
+
ident = quote_ident(table_name, char="'")
|
|
124
125
|
select = (f"SELECT\n"
|
|
125
|
-
f"
|
|
126
|
+
f" {ident} as table_name,\n"
|
|
126
127
|
f" MIN({column_ref}) as min_date,\n"
|
|
127
128
|
f" MAX({column_ref}) as max_date\n"
|
|
128
129
|
f"FROM {self.source_name_dict[table_name]}")
|
|
@@ -131,12 +132,13 @@ class SQLiteSampler(SQLSampler):
|
|
|
131
132
|
|
|
132
133
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
133
134
|
with self._connection.cursor() as cursor:
|
|
134
|
-
cursor.execute(sql
|
|
135
|
+
cursor.execute(sql)
|
|
135
136
|
for table_name, _min, _max in cursor.fetchall():
|
|
136
137
|
out_dict[table_name] = (
|
|
137
138
|
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
138
139
|
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
139
140
|
)
|
|
141
|
+
|
|
140
142
|
return out_dict
|
|
141
143
|
|
|
142
144
|
def _sample_entity_table(
|
|
@@ -434,7 +436,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
434
436
|
feat_dict=feat_dict,
|
|
435
437
|
time_dict=time_dict,
|
|
436
438
|
batch_dict=batch_dict,
|
|
437
|
-
anchor_time=
|
|
439
|
+
anchor_time=time,
|
|
438
440
|
num_forecasts=query.num_forecasts,
|
|
439
441
|
)
|
|
440
442
|
ys.append(y)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Mapper:
|
|
6
|
+
r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
num_examples: The maximum number of examples to add/retrieve.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, num_examples: int):
|
|
12
|
+
self._pkey_dtype: pd.CategoricalDtype | None = None
|
|
13
|
+
self._indices: list[np.ndarray] = []
|
|
14
|
+
self._index_dtype: pd.CategoricalDtype | None = None
|
|
15
|
+
self._num_examples = num_examples
|
|
16
|
+
|
|
17
|
+
def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
|
|
18
|
+
r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
pkey: The primary keys.
|
|
22
|
+
batch: The batch vector.
|
|
23
|
+
"""
|
|
24
|
+
if self._pkey_dtype is not None:
|
|
25
|
+
category = np.concatenate([
|
|
26
|
+
self._pkey_dtype.categories.values,
|
|
27
|
+
pkey,
|
|
28
|
+
], axis=0)
|
|
29
|
+
category = pd.unique(category)
|
|
30
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
31
|
+
elif pd.api.types.is_string_dtype(pkey):
|
|
32
|
+
category = pd.unique(pkey)
|
|
33
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
34
|
+
|
|
35
|
+
if self._pkey_dtype is not None:
|
|
36
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
37
|
+
index = index.astype('int64')
|
|
38
|
+
else:
|
|
39
|
+
index = pkey.to_numpy()
|
|
40
|
+
index = self._num_examples * index + batch
|
|
41
|
+
self._indices.append(index)
|
|
42
|
+
self._index_dtype = None
|
|
43
|
+
|
|
44
|
+
def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
|
|
45
|
+
r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
|
|
46
|
+
|
|
47
|
+
Returns ``-1`` for any pair not registered in the mapping.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
pkey: The primary keys.
|
|
51
|
+
batch: The batch vector.
|
|
52
|
+
"""
|
|
53
|
+
if len(self._indices) == 0:
|
|
54
|
+
return np.full(len(pkey), -1, dtype=np.int64)
|
|
55
|
+
|
|
56
|
+
if self._index_dtype is None: # Lazy build index:
|
|
57
|
+
category = pd.unique(np.concatenate(self._indices))
|
|
58
|
+
self._index_dtype = pd.CategoricalDtype(category)
|
|
59
|
+
|
|
60
|
+
if self._pkey_dtype is not None:
|
|
61
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
62
|
+
index = index.astype('int64')
|
|
63
|
+
else:
|
|
64
|
+
index = pkey.to_numpy()
|
|
65
|
+
index = self._num_examples * index + batch
|
|
66
|
+
|
|
67
|
+
out = pd.Categorical(index, dtype=self._index_dtype).codes
|
|
68
|
+
out = out.astype('int64')
|
|
69
|
+
return out
|
|
@@ -295,7 +295,8 @@ class Sampler(ABC):
|
|
|
295
295
|
|
|
296
296
|
# Store in compressed representation if more efficient:
|
|
297
297
|
num_cols = subgraph.table_dict[edge_type[2]].num_rows
|
|
298
|
-
if col is not None and len(col) > num_cols + 1
|
|
298
|
+
if (col is not None and len(col) > num_cols + 1
|
|
299
|
+
and ((col[1:] - col[:-1]) >= 0).all()):
|
|
299
300
|
layout = EdgeLayout.CSC
|
|
300
301
|
colcount = np.bincount(col, minlength=num_cols)
|
|
301
302
|
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.rfm.context import Subgraph
|
|
7
8
|
from kumoapi.typing import Dtype
|
|
8
9
|
|
|
9
10
|
from kumoai.experimental.rfm.base import (
|
|
@@ -12,11 +13,14 @@ from kumoai.experimental.rfm.base import (
|
|
|
12
13
|
SamplerOutput,
|
|
13
14
|
SourceColumn,
|
|
14
15
|
)
|
|
16
|
+
from kumoai.experimental.rfm.base.mapper import Mapper
|
|
15
17
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
16
18
|
|
|
17
19
|
if TYPE_CHECKING:
|
|
18
20
|
from kumoai.experimental.rfm import Graph
|
|
19
21
|
|
|
22
|
+
EdgeType = tuple[str, str, str]
|
|
23
|
+
|
|
20
24
|
|
|
21
25
|
class SQLSampler(Sampler):
|
|
22
26
|
def __init__(
|
|
@@ -101,7 +105,8 @@ class SQLSampler(Sampler):
|
|
|
101
105
|
num_neighbors: list[int],
|
|
102
106
|
) -> SamplerOutput:
|
|
103
107
|
|
|
104
|
-
# Make sure to include primary key, foreign key and time columns
|
|
108
|
+
# Make sure to always include primary key, foreign key and time columns
|
|
109
|
+
# during data fetching since these are needed for graph traversal:
|
|
105
110
|
sample_columns_dict: dict[str, set[str]] = {}
|
|
106
111
|
for table, columns in columns_dict.items():
|
|
107
112
|
sample_columns = columns | {
|
|
@@ -110,9 +115,11 @@ class SQLSampler(Sampler):
|
|
|
110
115
|
}
|
|
111
116
|
if primary_key := self.primary_key_dict.get(table):
|
|
112
117
|
sample_columns |= {primary_key}
|
|
113
|
-
if time_column := self.time_column_dict.get(table):
|
|
114
|
-
sample_columns |= {time_column}
|
|
115
118
|
sample_columns_dict[table] = sample_columns
|
|
119
|
+
if not isinstance(anchor_time, pd.Series):
|
|
120
|
+
sample_columns_dict[entity_table_name] |= {
|
|
121
|
+
self.time_column_dict[entity_table_name]
|
|
122
|
+
}
|
|
116
123
|
|
|
117
124
|
# Sample Entity Table #################################################
|
|
118
125
|
|
|
@@ -139,88 +146,219 @@ class SQLSampler(Sampler):
|
|
|
139
146
|
anchor_time = df[time_column]
|
|
140
147
|
assert isinstance(anchor_time, pd.Series)
|
|
141
148
|
|
|
142
|
-
df_hop_dict: dict[tuple[str, int], pd.DataFrame] = {
|
|
143
|
-
(entity_table_name, 0): df,
|
|
144
|
-
}
|
|
145
|
-
batch_hop_dict: dict[tuple[str, int], np.ndarray] = {
|
|
146
|
-
(entity_table_name, 0): batch,
|
|
147
|
-
}
|
|
148
|
-
|
|
149
149
|
# Recursive Neighbor Sampling #########################################
|
|
150
150
|
|
|
151
|
+
mapper_dict: dict[str, Mapper] = defaultdict(
|
|
152
|
+
lambda: Mapper(num_examples=len(entity_pkey)))
|
|
153
|
+
mapper_dict[entity_table_name].add(
|
|
154
|
+
pkey=df[self.primary_key_dict[entity_table_name]],
|
|
155
|
+
batch=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
159
|
+
dfs_dict[entity_table_name].append(df)
|
|
160
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
161
|
+
batches_dict[entity_table_name].append(batch)
|
|
162
|
+
num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
|
|
163
|
+
lambda: [0] * (len(num_neighbors) + 1))
|
|
164
|
+
num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
|
|
165
|
+
|
|
166
|
+
rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
167
|
+
cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
168
|
+
num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
|
|
169
|
+
lambda: [0] * len(num_neighbors))
|
|
170
|
+
|
|
171
|
+
# The start index of data frame slices of the previous hop:
|
|
172
|
+
offset_dict: dict[str, int] = defaultdict(int)
|
|
173
|
+
|
|
151
174
|
for hop, neighbors in enumerate(num_neighbors):
|
|
152
175
|
if neighbors == 0:
|
|
153
176
|
break # Abort early.
|
|
154
177
|
|
|
155
|
-
|
|
156
|
-
|
|
178
|
+
for table in list(num_sampled_nodes_dict.keys()):
|
|
179
|
+
# Only sample from tables that have been visited in the
|
|
180
|
+
# previous hop:
|
|
181
|
+
if num_sampled_nodes_dict[table][hop] == 0:
|
|
182
|
+
continue
|
|
157
183
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
184
|
+
# Collect the slices of data sampled in the previous hop
|
|
185
|
+
# (but maintain only required key information):
|
|
186
|
+
cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
|
|
187
|
+
if table in self.primary_key_dict:
|
|
188
|
+
cols.append(self.primary_key_dict[table])
|
|
189
|
+
dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
|
|
190
|
+
df = pd.concat(
|
|
191
|
+
dfs,
|
|
192
|
+
axis=0,
|
|
193
|
+
ignore_index=True,
|
|
194
|
+
) if len(dfs) > 1 else dfs[0]
|
|
195
|
+
batches = batches_dict[table][offset_dict[table]:]
|
|
196
|
+
batch = (np.concatenate(batches)
|
|
197
|
+
if len(batches) > 1 else batches[0])
|
|
198
|
+
offset_dict[table] = len(batches_dict[table]) # Increase.
|
|
199
|
+
|
|
200
|
+
pkey: pd.Series | None = None
|
|
201
|
+
index: pd.ndarray | None = None
|
|
202
|
+
if table in self.primary_key_dict:
|
|
203
|
+
pkey = df[self.primary_key_dict[table]]
|
|
204
|
+
index = mapper_dict[table].get(pkey, batch)
|
|
162
205
|
|
|
163
206
|
# Iterate over foreign keys in the current table:
|
|
164
207
|
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
165
|
-
|
|
208
|
+
row = mapper_dict[dst_table].get(df[fkey], batch)
|
|
209
|
+
mask = row == -1
|
|
210
|
+
if mask.any():
|
|
211
|
+
key_df = pd.DataFrame({
|
|
212
|
+
'fkey': df[fkey],
|
|
213
|
+
'batch': batch,
|
|
214
|
+
}).iloc[mask]
|
|
215
|
+
# Only maintain unique keys per example:
|
|
216
|
+
unique_key_df = key_df.drop_duplicates()
|
|
217
|
+
# Fully de-duplicate keys across examples:
|
|
218
|
+
code, fkey_index = pd.factorize(unique_key_df['fkey'])
|
|
219
|
+
|
|
220
|
+
_df, _batch = self._by_pkey(
|
|
221
|
+
table_name=dst_table,
|
|
222
|
+
index=fkey_index,
|
|
223
|
+
columns=sample_columns_dict[dst_table],
|
|
224
|
+
) # Ensure result is sorted according to input order:
|
|
225
|
+
_df = _df.iloc[_batch.argsort()]
|
|
226
|
+
|
|
227
|
+
# Compute valid entries (without dangling foreign keys)
|
|
228
|
+
# in `unique_fkey_df`:
|
|
229
|
+
_mask = np.full(len(fkey_index), fill_value=False)
|
|
230
|
+
_mask[_batch] = True
|
|
231
|
+
_mask = _mask[code]
|
|
232
|
+
|
|
233
|
+
# Recontruct unique (key, batch) pairs:
|
|
234
|
+
code, _ = pd.factorize(unique_key_df['fkey'][_mask])
|
|
235
|
+
_df = _df.iloc[code].reset_index(drop=True)
|
|
236
|
+
_batch = unique_key_df['batch'].to_numpy()[_mask]
|
|
237
|
+
|
|
238
|
+
# Register node IDs:
|
|
239
|
+
mapper_dict[dst_table].add(
|
|
240
|
+
pkey=_df[self.primary_key_dict[dst_table]],
|
|
241
|
+
batch=_batch,
|
|
242
|
+
)
|
|
243
|
+
row[mask] = mapper_dict[dst_table].get(
|
|
244
|
+
pkey=key_df['fkey'],
|
|
245
|
+
batch=key_df['batch'].to_numpy(),
|
|
246
|
+
) # NOTE `row` may still hold `-1` for dangling fkeys.
|
|
247
|
+
|
|
248
|
+
dfs_dict[dst_table].append(_df)
|
|
249
|
+
batches_dict[dst_table].append(_batch)
|
|
250
|
+
num_sampled_nodes_dict[dst_table][hop + 1] += ( #
|
|
251
|
+
len(_batch))
|
|
252
|
+
|
|
253
|
+
mask = row != -1
|
|
254
|
+
|
|
255
|
+
col = index
|
|
256
|
+
if col is None:
|
|
257
|
+
start = sum(num_sampled_nodes_dict[table][:hop])
|
|
258
|
+
end = sum(num_sampled_nodes_dict[table][:hop + 1])
|
|
259
|
+
col = np.arange(start, end)
|
|
260
|
+
|
|
261
|
+
row = row[mask]
|
|
262
|
+
col = col[mask]
|
|
263
|
+
|
|
264
|
+
edge_type = (table, fkey, dst_table)
|
|
265
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
266
|
+
rows_dict[edge_type].append(row)
|
|
267
|
+
cols_dict[edge_type].append(col)
|
|
268
|
+
num_sampled_edges_dict[edge_type][hop] = len(col)
|
|
166
269
|
|
|
167
270
|
# Iterate over foreign keys that reference the current table:
|
|
168
271
|
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
272
|
+
assert pkey is not None and index is not None
|
|
169
273
|
_df, _batch = self._by_fkey(
|
|
170
274
|
table_name=src_table,
|
|
171
275
|
foreign_key=fkey,
|
|
172
|
-
index=
|
|
276
|
+
index=pkey,
|
|
173
277
|
num_neighbors=neighbors,
|
|
174
278
|
anchor_time=anchor_time.iloc[batch],
|
|
175
279
|
columns=sample_columns_dict[src_table],
|
|
176
280
|
)
|
|
177
|
-
_batch = batch[_batch]
|
|
178
281
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
282
|
+
edge_type = (src_table, fkey, table)
|
|
283
|
+
cols_dict[edge_type].append(index[_batch])
|
|
284
|
+
num_sampled_edges_dict[edge_type][hop] = len(_batch)
|
|
182
285
|
|
|
183
|
-
|
|
286
|
+
_batch = batch[_batch]
|
|
287
|
+
num_nodes = sum(num_sampled_nodes_dict[src_table])
|
|
288
|
+
if src_table in self.primary_key_dict:
|
|
289
|
+
_pkey = _df[self.primary_key_dict]
|
|
290
|
+
mapper_dict[src_table].add(_pkey, _batch)
|
|
291
|
+
row = mapper_dict[src_table].get(_pkey, _batch)
|
|
292
|
+
|
|
293
|
+
# Only preserve unknown rows:
|
|
294
|
+
mask = row >= num_nodes # type: ignore
|
|
295
|
+
mask[pd.duplicated(row)] = False
|
|
296
|
+
_df = _df.iloc[mask]
|
|
297
|
+
_batch = _batch[mask]
|
|
298
|
+
else:
|
|
299
|
+
row = np.arange(num_nodes, num_nodes + len(_batch))
|
|
300
|
+
|
|
301
|
+
rows_dict[edge_type].append(row)
|
|
302
|
+
num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
|
|
303
|
+
|
|
304
|
+
dfs_dict[src_table].append(_df)
|
|
305
|
+
batches_dict[src_table].append(_batch)
|
|
184
306
|
|
|
185
307
|
# Post-Processing #####################################################
|
|
186
308
|
|
|
187
|
-
|
|
188
|
-
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
189
|
-
num_hops = max(hop for _, hop in df_hop_dict.keys()) # TODO
|
|
190
|
-
num_sampled_nodes_dict: dict[str, list[int]] = {
|
|
191
|
-
table: [0] * (num_hops + 1)
|
|
192
|
-
for table in [table for table, _ in df_hop_dict.keys()]
|
|
193
|
-
}
|
|
194
|
-
for (table, hop), df in df_hop_dict.items():
|
|
195
|
-
dfs_dict[table].append(df)
|
|
196
|
-
batches_dict[table].append(batch_hop_dict[(table, hop)])
|
|
197
|
-
num_sampled_nodes_dict[table][hop] = len(df)
|
|
198
|
-
|
|
199
|
-
df_dict = { # Concatenate data frames across hops:
|
|
309
|
+
df_dict = {
|
|
200
310
|
table:
|
|
201
311
|
pd.concat(dfs, axis=0, ignore_index=True)
|
|
202
312
|
if len(dfs) > 1 else dfs[0]
|
|
203
313
|
for table, dfs in dfs_dict.items()
|
|
204
314
|
}
|
|
315
|
+
|
|
316
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
317
|
+
inverse_dict: dict[str, np.ndarray] = {}
|
|
318
|
+
for table, df in df_dict.items():
|
|
319
|
+
if table not in self.primary_key_dict:
|
|
320
|
+
continue
|
|
321
|
+
unique, index, inverse = np.unique(
|
|
322
|
+
df_dict[table][self.primary_key_dict[table]],
|
|
323
|
+
return_index=True,
|
|
324
|
+
return_inverse=True,
|
|
325
|
+
)
|
|
326
|
+
if len(df) > 1.05 * len(unique):
|
|
327
|
+
df_dict[table] = df.iloc[index].reset_index(drop=True)
|
|
328
|
+
inverse_dict[table] = inverse
|
|
329
|
+
|
|
205
330
|
df_dict = { # Post-filter column set:
|
|
206
331
|
table: df[list(columns_dict[table])]
|
|
207
|
-
for
|
|
332
|
+
for table, df in df_dict.items()
|
|
208
333
|
}
|
|
209
|
-
batch_dict = {
|
|
210
|
-
table:
|
|
211
|
-
np.concatenate(batches, axis=0) if len(batches) > 1 else batches[0]
|
|
334
|
+
batch_dict = {
|
|
335
|
+
table: np.concatenate(batches) if len(batches) > 1 else batches[0]
|
|
212
336
|
for table, batches in batches_dict.items()
|
|
213
337
|
}
|
|
338
|
+
row_dict = {
|
|
339
|
+
edge_type: np.concatenate(rows)
|
|
340
|
+
for edge_type, rows in rows_dict.items()
|
|
341
|
+
}
|
|
342
|
+
col_dict = {
|
|
343
|
+
edge_type: np.concatenate(cols)
|
|
344
|
+
for edge_type, cols in cols_dict.items()
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
if len(num_sampled_edges_dict) == 0: # Single table:
|
|
348
|
+
num_sampled_nodes_dict = {
|
|
349
|
+
key: value[:1]
|
|
350
|
+
for key, value in num_sampled_nodes_dict.items()
|
|
351
|
+
}
|
|
214
352
|
|
|
215
353
|
return SamplerOutput(
|
|
216
354
|
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
217
355
|
df_dict=df_dict,
|
|
218
|
-
inverse_dict=
|
|
356
|
+
inverse_dict=inverse_dict,
|
|
219
357
|
batch_dict=batch_dict,
|
|
220
358
|
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
221
|
-
row_dict=
|
|
222
|
-
col_dict=
|
|
223
|
-
num_sampled_edges_dict=
|
|
359
|
+
row_dict=row_dict,
|
|
360
|
+
col_dict=col_dict,
|
|
361
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
224
362
|
)
|
|
225
363
|
|
|
226
364
|
# Abstract Methods ########################################################
|
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from abc import ABC, abstractmethod
|
|
3
2
|
from collections.abc import Sequence
|
|
4
3
|
from functools import cached_property
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import pandas as pd
|
|
8
|
-
import pyarrow as pa
|
|
9
7
|
from kumoapi.model_plan import MissingType
|
|
10
8
|
from kumoapi.source_table import UnavailableSourceTable
|
|
11
9
|
from kumoapi.table import Column as ColumnDefinition
|
|
@@ -21,6 +19,7 @@ from kumoai.experimental.rfm.base import (
|
|
|
21
19
|
SourceColumn,
|
|
22
20
|
SourceForeignKey,
|
|
23
21
|
)
|
|
22
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
24
23
|
from kumoai.experimental.rfm.infer import (
|
|
25
24
|
infer_dtype,
|
|
26
25
|
infer_primary_key,
|
|
@@ -624,24 +623,6 @@ class Table(ABC):
|
|
|
624
623
|
r"""Sanitzes a :class:`pandas.DataFrame` in-place such that its data
|
|
625
624
|
types match table data and semantic type specification.
|
|
626
625
|
"""
|
|
627
|
-
def _to_datetime(ser: pd.Series) -> pd.Series:
|
|
628
|
-
if (not pd.api.types.is_datetime64_any_dtype(ser)
|
|
629
|
-
and not (isinstance(ser.dtype, pd.ArrowDtype) and
|
|
630
|
-
pa.types.is_timestamp(ser.dtype.pyarrow_dtype))):
|
|
631
|
-
with warnings.catch_warnings():
|
|
632
|
-
warnings.filterwarnings(
|
|
633
|
-
'ignore',
|
|
634
|
-
message='Could not infer format',
|
|
635
|
-
)
|
|
636
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
637
|
-
if (isinstance(ser.dtype, pd.DatetimeTZDtype)
|
|
638
|
-
or (isinstance(ser.dtype, pd.ArrowDtype)
|
|
639
|
-
and ser.dtype.pyarrow_dtype.tz is not None)):
|
|
640
|
-
ser = ser.dt.tz_localize(None)
|
|
641
|
-
if ser.dtype != 'datetime64[ns]':
|
|
642
|
-
ser = ser.astype('datetime64[ns]')
|
|
643
|
-
return ser
|
|
644
|
-
|
|
645
626
|
def _to_list(ser: pd.Series, dtype: Dtype | None) -> pd.Series:
|
|
646
627
|
if (pd.api.types.is_string_dtype(ser)
|
|
647
628
|
and dtype in {Dtype.intlist, Dtype.floatlist}):
|
|
@@ -672,9 +653,9 @@ class Table(ABC):
|
|
|
672
653
|
stype = (stype_dict or {}).get(column_name)
|
|
673
654
|
|
|
674
655
|
if dtype == Dtype.time:
|
|
675
|
-
df[column_name] =
|
|
656
|
+
df[column_name] = to_datetime(df[column_name])
|
|
676
657
|
elif stype == Stype.timestamp:
|
|
677
|
-
df[column_name] =
|
|
658
|
+
df[column_name] = to_datetime(df[column_name])
|
|
678
659
|
elif dtype is not None and dtype.is_list():
|
|
679
660
|
df[column_name] = _to_list(df[column_name], dtype)
|
|
680
661
|
elif stype == Stype.sequence:
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def is_datetime(ser: pd.Series) -> bool:
|
|
8
|
+
r"""Check whether a :class:`pandas.Series` holds datetime values."""
|
|
9
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
10
|
+
dtype = ser.dtype.pyarrow_dtype
|
|
11
|
+
return (pa.types.is_timestamp(dtype) or pa.types.is_date(dtype)
|
|
12
|
+
or pa.types.is_time(dtype))
|
|
13
|
+
|
|
14
|
+
return pd.api.types.is_datetime64_any_dtype(ser)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_datetime(ser: pd.Series) -> pd.Series:
|
|
18
|
+
"""Converts a :class:`pandas.Series` to ``datetime64[ns]`` format."""
|
|
19
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
20
|
+
ser = pd.Series(ser.to_numpy(), index=ser.index, name=ser.name)
|
|
21
|
+
|
|
22
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
23
|
+
with warnings.catch_warnings():
|
|
24
|
+
warnings.filterwarnings(
|
|
25
|
+
'ignore',
|
|
26
|
+
message='Could not infer format',
|
|
27
|
+
)
|
|
28
|
+
ser = pd.to_datetime(ser, errors='coerce')
|
|
29
|
+
|
|
30
|
+
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
31
|
+
ser = ser.dt.tz_localize(None)
|
|
32
|
+
|
|
33
|
+
if ser.dtype != 'datetime64[ns]':
|
|
34
|
+
ser = ser.astype('datetime64[ns]')
|
|
35
|
+
|
|
36
|
+
return ser
|
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -19,6 +19,7 @@ from typing_extensions import Self
|
|
|
19
19
|
|
|
20
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
21
21
|
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
22
|
+
from kumoai.experimental.rfm.infer import infer_time_column
|
|
22
23
|
from kumoai.graph import Edge
|
|
23
24
|
from kumoai.mixin import CastMixin
|
|
24
25
|
from kumoai.utils import display
|
|
@@ -415,8 +416,9 @@ class Graph:
|
|
|
415
416
|
assert isinstance(connection, Connection)
|
|
416
417
|
|
|
417
418
|
with connection.cursor() as cursor:
|
|
418
|
-
|
|
419
|
-
|
|
419
|
+
sql = (f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
420
|
+
f"'{semantic_view_name}')")
|
|
421
|
+
cursor.execute(sql)
|
|
420
422
|
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
421
423
|
|
|
422
424
|
graph = cls(tables=[])
|
|
@@ -492,7 +494,17 @@ class Graph:
|
|
|
492
494
|
)
|
|
493
495
|
|
|
494
496
|
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
-
|
|
497
|
+
time_candidates = [
|
|
498
|
+
column_cfg['name']
|
|
499
|
+
for column_cfg in table_cfg.get('time_dimensions', [])
|
|
500
|
+
if table.has_column(column_cfg['name'])
|
|
501
|
+
and table[column_cfg['name']].stype == Stype.timestamp
|
|
502
|
+
]
|
|
503
|
+
if time_column := infer_time_column(
|
|
504
|
+
df=table._get_sample_df(),
|
|
505
|
+
candidates=time_candidates,
|
|
506
|
+
):
|
|
507
|
+
table.time_column = time_column
|
|
496
508
|
|
|
497
509
|
graph.add_table(table)
|
|
498
510
|
|
|
@@ -3,6 +3,8 @@ import pandas as pd
|
|
|
3
3
|
import pyarrow as pa
|
|
4
4
|
from kumoapi.typing import Dtype
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import is_datetime
|
|
7
|
+
|
|
6
8
|
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
7
9
|
'bool': Dtype.bool,
|
|
8
10
|
'boolean': Dtype.bool,
|
|
@@ -34,7 +36,7 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
34
36
|
Returns:
|
|
35
37
|
The data type.
|
|
36
38
|
"""
|
|
37
|
-
if
|
|
39
|
+
if is_datetime(ser):
|
|
38
40
|
return Dtype.date
|
|
39
41
|
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
40
42
|
return Dtype.timedelta
|
|
@@ -3,6 +3,8 @@ import warnings
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def infer_time_column(
|
|
8
10
|
df: pd.DataFrame,
|
|
@@ -43,11 +45,11 @@ def infer_time_column(
|
|
|
43
45
|
with warnings.catch_warnings():
|
|
44
46
|
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
45
47
|
min_timestamp_dict = {
|
|
46
|
-
key:
|
|
48
|
+
key: to_datetime(df[key].iloc[:10_000])
|
|
47
49
|
for key in candidates
|
|
48
50
|
}
|
|
49
51
|
min_timestamp_dict = {
|
|
50
|
-
key: value.min()
|
|
52
|
+
key: value.min()
|
|
51
53
|
for key, value in min_timestamp_dict.items()
|
|
52
54
|
}
|
|
53
55
|
min_timestamp_dict = {
|
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1044,8 +1044,16 @@ class KumoRFM:
|
|
|
1044
1044
|
if len(self._sampler.time_column_dict) == 0:
|
|
1045
1045
|
return # Graph without timestamps
|
|
1046
1046
|
|
|
1047
|
-
|
|
1048
|
-
|
|
1047
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1048
|
+
aggr_table_names = [
|
|
1049
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1050
|
+
for aggr in query.get_all_target_aggregations()
|
|
1051
|
+
]
|
|
1052
|
+
min_time = self._sampler.get_min_time(aggr_table_names)
|
|
1053
|
+
max_time = self._sampler.get_max_time(aggr_table_names)
|
|
1054
|
+
else:
|
|
1055
|
+
min_time = self._sampler.get_min_time()
|
|
1056
|
+
max_time = self._sampler.get_max_time()
|
|
1049
1057
|
|
|
1050
1058
|
if anchor_time < min_time:
|
|
1051
1059
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
Binary file
|
kumoai/testing/snow.py
CHANGED
|
@@ -10,7 +10,7 @@ def connect(
|
|
|
10
10
|
id: str,
|
|
11
11
|
account: str,
|
|
12
12
|
user: str,
|
|
13
|
-
warehouse: str,
|
|
13
|
+
warehouse: str | None = None,
|
|
14
14
|
database: str | None = None,
|
|
15
15
|
schema: str | None = None,
|
|
16
16
|
) -> Connection:
|
|
@@ -42,8 +42,8 @@ def connect(
|
|
|
42
42
|
return _connect(
|
|
43
43
|
account=account,
|
|
44
44
|
user=user,
|
|
45
|
-
warehouse='WH_XS',
|
|
46
|
-
database='KUMO',
|
|
45
|
+
warehouse=warehouse or 'WH_XS',
|
|
46
|
+
database=database or 'KUMO',
|
|
47
47
|
schema=schema,
|
|
48
48
|
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
49
|
**kwargs,
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -57,7 +57,8 @@ class ProgressLogger:
|
|
|
57
57
|
|
|
58
58
|
def __enter__(self) -> Self:
|
|
59
59
|
self.depth += 1
|
|
60
|
-
self.
|
|
60
|
+
if self.depth == 1:
|
|
61
|
+
self.start_time = time.perf_counter()
|
|
61
62
|
return self
|
|
62
63
|
|
|
63
64
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
kumoai/utils/sql.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
def quote_ident(
|
|
1
|
+
def quote_ident(ident: str, char: str = '"') -> str:
|
|
2
2
|
r"""Quotes a SQL identifier."""
|
|
3
|
-
return
|
|
3
|
+
return char + ident.replace(char, char + char) + char
|
{kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.15.0.dev202601151732
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api
|
|
26
|
+
Requires-Dist: kumo-api<1.0.0,>=0.53.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
kumoai/__init__.py,sha256=cKL7QeT-b5OHi75jtvFzbIKGjeJV5Tago7jKLX0nuYE,11207
|
|
2
2
|
kumoai/_logging.py,sha256=qL4JbMQwKXri2f-SEJoFB8TY5ALG12S-nobGTNWxW-A,915
|
|
3
3
|
kumoai/_singleton.py,sha256=i2BHWKpccNh5SJGDyU0IXsnYzJAYr8Xb0wz4c6LRbpo,861
|
|
4
|
-
kumoai/_version.py,sha256=
|
|
4
|
+
kumoai/_version.py,sha256=IPXsgCkME-eTqhUBmuva76Ngg8kOKbQ3VUDkC6t6dI4,39
|
|
5
5
|
kumoai/databricks.py,sha256=ahwJz6DWLXMkndT0XwEDBxF-hoqhidFR8wBUQ4TLZ68,490
|
|
6
6
|
kumoai/exceptions.py,sha256=7TMs0SC8xrU009_Pgd4QXtSF9lxJq8MtRbeX9pcQUy4,859
|
|
7
7
|
kumoai/formatting.py,sha256=o3uCnLwXPhe1KI5WV9sBgRrcU7ed4rgu_pf89GL9Nc0,983
|
|
8
8
|
kumoai/futures.py,sha256=J8rtZMEYFzdn5xF_x-LAiKJz3KGL6PT02f6rq_2bOJk,3836
|
|
9
9
|
kumoai/jobs.py,sha256=dCi7BAdfm2tCnonYlGU4WJokJWbh3RzFfaOX2EYCIHU,2576
|
|
10
|
-
kumoai/kumolib.cp313-win_amd64.pyd,sha256=
|
|
10
|
+
kumoai/kumolib.cp313-win_amd64.pyd,sha256=KuxCQKoXH9eksQws8WB2LImapu-jOY0d42huAGFInoQ,198144
|
|
11
11
|
kumoai/mixin.py,sha256=IaiB8SAI0VqOoMVzzIaUlqMt53-QPUK6OB0HikG-V9E,840
|
|
12
12
|
kumoai/spcs.py,sha256=KWfENrwSLruprlD-QPh63uU0N6npiNrwkeKfBk3EUyQ,4260
|
|
13
13
|
kumoai/artifact_export/__init__.py,sha256=UXAQI5q92ChBzWAk8o3J6pElzYHudAzFZssQXd4o7i8,247
|
|
@@ -55,9 +55,9 @@ kumoai/encoder/__init__.py,sha256=8FeP6mUyCeXxr1b8kUIi5dxe5vEXQRft9tPoaV1CBqg,18
|
|
|
55
55
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
56
|
kumoai/experimental/rfm/__init__.py,sha256=dibc0t7g-PYanT90TncRlceD0ZqxtKStVdzzG1_cXC8,7226
|
|
57
57
|
kumoai/experimental/rfm/authenticate.py,sha256=odKaqOAEkdC_wB340cs_ozjSvQLTce45WLiJSEzQaL8,19283
|
|
58
|
-
kumoai/experimental/rfm/graph.py,sha256=
|
|
58
|
+
kumoai/experimental/rfm/graph.py,sha256=DiJZaEXiwNB7DzujRc9Fo__8u19VAsz7VagjmSKScVQ,48106
|
|
59
59
|
kumoai/experimental/rfm/relbench.py,sha256=30O7QAKYcMgr6C9Qpgev7gxSMAtWXop25p7DtmzrBlE,2352
|
|
60
|
-
kumoai/experimental/rfm/rfm.py,sha256=
|
|
60
|
+
kumoai/experimental/rfm/rfm.py,sha256=l31iaWoDujjmPilzTbh8BL_Ajlvpg4TSTYdnkbelsIg,61436
|
|
61
61
|
kumoai/experimental/rfm/sagemaker.py,sha256=7Yk4um0gBBn7u-Bz8JRv53z0__FcD0uESoiImJhxsBw,5101
|
|
62
62
|
kumoai/experimental/rfm/task_table.py,sha256=4sx9z6JhHQVQaPAlbyfDwbyOBApOUs6SEXHHcfsdxl0,10139
|
|
63
63
|
kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -66,26 +66,28 @@ kumoai/experimental/rfm/backend/local/graph_store.py,sha256=fmBOdXK6a7hHqfB5Nqpc
|
|
|
66
66
|
kumoai/experimental/rfm/backend/local/sampler.py,sha256=tD3l5xfcxjsWDaC45V-xOAI_-Jyyk_au-E7wyrMqCx4,11038
|
|
67
67
|
kumoai/experimental/rfm/backend/local/table.py,sha256=86lztrVxdpya25X4r8mR2c_t-tI8gAEyahz-mNmk9tA,3602
|
|
68
68
|
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=lsF0sJXZ0Pc3NvBTBXJHudp-iZJXdidrhyqFQKEU5_Q,1030
|
|
69
|
-
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=
|
|
70
|
-
kumoai/experimental/rfm/backend/snow/table.py,sha256=
|
|
69
|
+
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=qmjhO_Nz9cCiqmMCesw6PCGwFfY6705EkOdireHI0KM,16729
|
|
70
|
+
kumoai/experimental/rfm/backend/snow/table.py,sha256=5F_E3E4pGelFwbGe0zhXH31BZa5qZnDec0Uxtn38d2M,9323
|
|
71
71
|
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=wkSr2D_E5VCH4RGW8FCN2iJp-6wb_RTCMO8R3p5lkiw,934
|
|
72
|
-
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=
|
|
72
|
+
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=3vt4i2sv1cOjHq3_4JHya_NJEjMcy5TijopDwTY8F0Q,19155
|
|
73
73
|
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=nH3S3lBVfG6aWp0DtCUVJRBZhlQV4ieskbz-5D0AlG0,6867
|
|
74
74
|
kumoai/experimental/rfm/base/__init__.py,sha256=is8HTLng28h5AtpledQ-hdIheGM052JdBhjv8HtKhDw,754
|
|
75
75
|
kumoai/experimental/rfm/base/column.py,sha256=JeDKSZnTChFHMaIC3TcEgdPG9Rr2PATTAMIMhjnvXrs,5117
|
|
76
76
|
kumoai/experimental/rfm/base/expression.py,sha256=04NgmrrvjM1yFXnOMDZtb5V1-oFufqCamv2KTETOHik,1296
|
|
77
|
-
kumoai/experimental/rfm/base/
|
|
77
|
+
kumoai/experimental/rfm/base/mapper.py,sha256=WWPwpIOYa4Ppw8UOh4zf2D9fY2t3FtYKdEM3VCHJNiE,2489
|
|
78
|
+
kumoai/experimental/rfm/base/sampler.py,sha256=2hvFfvgwjqbQljZKAKtZCIcQCsmWGj_CMpSEWpSx3uk,32734
|
|
78
79
|
kumoai/experimental/rfm/base/source.py,sha256=67rpePejkZli4B_eDWzDrn_8Q5Msyo2XZ9F8IGB0ImI,320
|
|
79
|
-
kumoai/experimental/rfm/base/sql_sampler.py,sha256=
|
|
80
|
-
kumoai/experimental/rfm/base/table.py,sha256=
|
|
80
|
+
kumoai/experimental/rfm/base/sql_sampler.py,sha256=bsHQ1UTIjodkkU4c3oY-_6DRZHS8m6RIqLqSwPUfLZ4,16067
|
|
81
|
+
kumoai/experimental/rfm/base/table.py,sha256=N47ym2p7jJ5RO8jjt_KQWfATTW__MlJnGi4zV6EVUIk,26737
|
|
82
|
+
kumoai/experimental/rfm/base/utils.py,sha256=VEMeOehsQCKVatqPzMTzXnVFsLz-NkyQQ67pgivtuCE,1169
|
|
81
83
|
kumoai/experimental/rfm/infer/__init__.py,sha256=Uf4Od7B2G80U61mkkxsnxHPGu1Hh2RqOazTkOYtNLvA,538
|
|
82
84
|
kumoai/experimental/rfm/infer/categorical.py,sha256=bqmfrE5ZCBTcb35lA4SyAkCu3MgttAn29VBJYMBNhVg,893
|
|
83
|
-
kumoai/experimental/rfm/infer/dtype.py,sha256=
|
|
85
|
+
kumoai/experimental/rfm/infer/dtype.py,sha256=SDZR9ULx6Z35Ij29v6t79y-VuTvikEfrHDQLOIL_xI4,2895
|
|
84
86
|
kumoai/experimental/rfm/infer/id.py,sha256=xaJBETLZa8ttzZCsDwFSwfyCi3VYsLc_kDWT_t_6Ih4,954
|
|
85
87
|
kumoai/experimental/rfm/infer/multicategorical.py,sha256=mMuRCbfs0zsfOoPB_eCs6nlt4WgNPvklmYPRq7w85L4,1167
|
|
86
88
|
kumoai/experimental/rfm/infer/pkey.py,sha256=GCAUN8Hz5-leVv2-H8soP3k-DsXJ1O_uQU25-CsSWN0,4540
|
|
87
89
|
kumoai/experimental/rfm/infer/stype.py,sha256=lOgiGJ_rsaeiFWyVUw0IMwn_7hGOqL8mvy2rGzXfi3Q,929
|
|
88
|
-
kumoai/experimental/rfm/infer/time_col.py,sha256
|
|
90
|
+
kumoai/experimental/rfm/infer/time_col.py,sha256=G2zMtcy7gEPgz7O4ljXBws5LgZ1qpQpoFUk3t5q5eqA,1881
|
|
89
91
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=L2VxjtYTSyUBYAo4M-L08xSQlPpqnHMAVF5_vxjh3Y0,1135
|
|
90
92
|
kumoai/experimental/rfm/pquery/__init__.py,sha256=RkTn0I74uXOUuOiBpa6S-_QEYctMutkUnBEfF9ztQzI,159
|
|
91
93
|
kumoai/experimental/rfm/pquery/executor.py,sha256=mz5mqhHbgZM0f5oNFLyThWGM4UePx_kd1O4zyJ_8ToQ,2830
|
|
@@ -100,7 +102,7 @@ kumoai/pquery/predictive_query.py,sha256=I5Ntc7YO1qEGxKrLuhAzZO3SySr8Wnjhde8eDbb
|
|
|
100
102
|
kumoai/pquery/training_table.py,sha256=ex5FpA4_rY5OSIl2koisQENFoPbTz2PmG-DR3rvnysg,17004
|
|
101
103
|
kumoai/testing/__init__.py,sha256=XBQ_Sa3WnOYlpXZ3gUn8w6nVfZt-nfPhytfIBeiPt4w,178
|
|
102
104
|
kumoai/testing/decorators.py,sha256=p79ZCQqPY_MHWy0_l7-xQ6wUIqFTn4AbrGWTHLvpbQY,1664
|
|
103
|
-
kumoai/testing/snow.py,sha256=
|
|
105
|
+
kumoai/testing/snow.py,sha256=QItmVyelgPRW7dRcG1IQGAUdXFuWNULtz5Jo7GrxDtM,1576
|
|
104
106
|
kumoai/trainer/__init__.py,sha256=uCFXy9bw_byn_wYd3M-BTZCHTVvv4XXr8qRlh-QOvag,981
|
|
105
107
|
kumoai/trainer/baseline_trainer.py,sha256=oXweh8j1sar6KhQfr3A7gmQxcDq7SG0Bx3jIenbtyC4,4117
|
|
106
108
|
kumoai/trainer/config.py,sha256=7_Jv1w1mqaokCQwQdJkqCSgVpmh8GqE3fL1Ky_vvttI,100
|
|
@@ -113,10 +115,10 @@ kumoai/utils/__init__.py,sha256=lazi9gAl5YBg1Nk121zSDg-BIKTVETjFTZwTFUlGngo,267
|
|
|
113
115
|
kumoai/utils/datasets.py,sha256=UyAII-oAn7x3ombuvpbSQ41aVF9SYKBjQthTD-vcT2A,3011
|
|
114
116
|
kumoai/utils/display.py,sha256=oPNcXLUUnSKo0m2Hxc330QFPPtnV-wjJMjKoBseB1HY,2519
|
|
115
117
|
kumoai/utils/forecasting.py,sha256=ZgKeUCbWLOot0giAkoigwU5du8LkrwAicFOi5hVn6wg,7624
|
|
116
|
-
kumoai/utils/progress_logger.py,sha256=
|
|
117
|
-
kumoai/utils/sql.py,sha256=
|
|
118
|
-
kumoai-2.
|
|
119
|
-
kumoai-2.
|
|
120
|
-
kumoai-2.
|
|
121
|
-
kumoai-2.
|
|
122
|
-
kumoai-2.
|
|
118
|
+
kumoai/utils/progress_logger.py,sha256=z1eZwxMLcSymhS3r9_GQ35AgoRl1Hz5BfxAyUJkmifg,9893
|
|
119
|
+
kumoai/utils/sql.py,sha256=e4dMLBxIdxqOLgwdgsFshX1JQq4gpA5UlStI-XiuUBw,150
|
|
120
|
+
kumoai-2.15.0.dev202601151732.dist-info/licenses/LICENSE,sha256=ZUilBDp--4vbhsEr6f_Upw9rnIx09zQ3K9fXQ0rfd6w,1111
|
|
121
|
+
kumoai-2.15.0.dev202601151732.dist-info/METADATA,sha256=YzS0_Lc5sPPg0ArySpiRvrYGdMmeO7W6n8jAUD0Y8jA,2635
|
|
122
|
+
kumoai-2.15.0.dev202601151732.dist-info/WHEEL,sha256=qV0EIPljj1XC_vuSatRWjn02nZIz3N1t8jsZz7HBr2U,101
|
|
123
|
+
kumoai-2.15.0.dev202601151732.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
124
|
+
kumoai-2.15.0.dev202601151732.dist-info/RECORD,,
|
|
File without changes
|
{kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{kumoai-2.14.0.dev202601081732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt
RENAMED
|
File without changes
|