kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/snow/sampler.py +83 -14
- kumoai/experimental/rfm/backend/sqlite/sampler.py +68 -12
- kumoai/experimental/rfm/base/mapper.py +67 -0
- kumoai/experimental/rfm/base/sampler.py +21 -0
- kumoai/experimental/rfm/base/sql_sampler.py +233 -10
- kumoai/experimental/rfm/base/table.py +41 -53
- kumoai/experimental/rfm/graph.py +57 -60
- kumoai/experimental/rfm/infer/dtype.py +2 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +529 -303
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +24 -20
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import sys
|
|
3
4
|
import threading
|
|
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
|
|
|
68
69
|
if self._url is None or (self._api_key is None
|
|
69
70
|
and self._spcs_token is None
|
|
70
71
|
and self._snowpark_session is None):
|
|
71
|
-
raise ValueError(
|
|
72
|
-
|
|
73
|
-
"your client before proceeding.")
|
|
72
|
+
raise ValueError("Client creation or authentication failed. "
|
|
73
|
+
"Please re-create your client before proceeding.")
|
|
74
74
|
|
|
75
75
|
if hasattr(self.thread_local, '_client'):
|
|
76
76
|
# Set the spcs token in the client to ensure it has the latest.
|
|
@@ -123,10 +123,9 @@ def init(
|
|
|
123
123
|
""" # noqa
|
|
124
124
|
# Avoid mutations to the global state after it is set:
|
|
125
125
|
if global_state.initialized:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
"session.")
|
|
126
|
+
warnings.warn("Kumo SDK already initialized. To re-initialize the "
|
|
127
|
+
"SDK, please start a new interpreter. No changes will "
|
|
128
|
+
"be made to the current session.")
|
|
130
129
|
return
|
|
131
130
|
|
|
132
131
|
set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
|
|
@@ -138,15 +137,15 @@ def init(
|
|
|
138
137
|
if snowflake_application:
|
|
139
138
|
if url is not None:
|
|
140
139
|
raise ValueError(
|
|
141
|
-
"
|
|
142
|
-
"are specified. If running from a
|
|
143
|
-
"only snowflake_application.")
|
|
140
|
+
"Kumo SDK initialization failed. Both 'snowflake_application' "
|
|
141
|
+
"and 'url' are specified. If running from a Snowflake "
|
|
142
|
+
"notebook, specify only 'snowflake_application'.")
|
|
144
143
|
snowpark_session = _get_active_session()
|
|
145
144
|
if not snowpark_session:
|
|
146
145
|
raise ValueError(
|
|
147
|
-
"
|
|
148
|
-
"without an active
|
|
149
|
-
"a
|
|
146
|
+
"Kumo SDK initialization failed. 'snowflake_application' is "
|
|
147
|
+
"specified without an active Snowpark session. If running "
|
|
148
|
+
"outside a Snowflake notebook, specify a URL and credentials.")
|
|
150
149
|
description = snowpark_session.sql(
|
|
151
150
|
f"DESCRIBE SERVICE {snowflake_application}."
|
|
152
151
|
"USER_SCHEMA.KUMO_SERVICE").collect()[0]
|
|
@@ -155,14 +154,14 @@ def init(
|
|
|
155
154
|
if api_key is None and not snowflake_application:
|
|
156
155
|
if snowflake_credentials is None:
|
|
157
156
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"credentials provided. Please either set the
|
|
160
|
-
"or explicitly call `kumoai.init(...)`.")
|
|
157
|
+
"Kumo SDK initialization failed. Neither an API key nor "
|
|
158
|
+
"Snowflake credentials provided. Please either set the "
|
|
159
|
+
"'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
|
|
161
160
|
if (set(snowflake_credentials.keys())
|
|
162
161
|
!= {'user', 'password', 'account'}):
|
|
163
162
|
raise ValueError(
|
|
164
|
-
f"Provided credentials should be a dictionary with
|
|
165
|
-
f"'user', 'password', and 'account'. Only "
|
|
163
|
+
f"Provided Snowflake credentials should be a dictionary with "
|
|
164
|
+
f"keys 'user', 'password', and 'account'. Only "
|
|
166
165
|
f"{set(snowflake_credentials.keys())} were provided.")
|
|
167
166
|
|
|
168
167
|
# Get or infer URL:
|
|
@@ -173,10 +172,10 @@ def init(
|
|
|
173
172
|
except KeyError:
|
|
174
173
|
pass
|
|
175
174
|
if url is None:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
175
|
+
raise ValueError("Kumo SDK initialization failed since no endpoint "
|
|
176
|
+
"URL was provided. Please either set the "
|
|
177
|
+
"'KUMO_API_ENDPOINT' environment variable or "
|
|
178
|
+
"explicitly call `kumoai.init(...)`.")
|
|
180
179
|
|
|
181
180
|
# Assign global state after verification that client can be created and
|
|
182
181
|
# authenticated successfully:
|
|
@@ -198,10 +197,8 @@ def init(
|
|
|
198
197
|
logger = logging.getLogger('kumoai')
|
|
199
198
|
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
200
199
|
|
|
201
|
-
logger.info(
|
|
202
|
-
|
|
203
|
-
f"against deployment {url}, with "
|
|
204
|
-
f"log level {log_level}.")
|
|
200
|
+
logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
|
|
201
|
+
f"'{url}'")
|
|
205
202
|
|
|
206
203
|
|
|
207
204
|
def set_log_level(level: str) -> None:
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.14.
|
|
1
|
+
__version__ = '2.14.0rc2'
|
kumoai/client/jobs.py
CHANGED
|
@@ -344,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
|
|
|
344
344
|
id: str,
|
|
345
345
|
source_table_type: SourceTableType,
|
|
346
346
|
train_table_mod: TrainingTableSpec,
|
|
347
|
+
extensive_validation: bool,
|
|
347
348
|
) -> ValidationResponse:
|
|
348
349
|
response = self._client._post(
|
|
349
350
|
f'{self._base_endpoint}/{id}/validate_custom_train_table',
|
|
350
351
|
json=to_json_dict({
|
|
351
352
|
'custom_table': source_table_type,
|
|
352
353
|
'train_table_mod': train_table_mod,
|
|
354
|
+
'extensive_validation': extensive_validation,
|
|
353
355
|
}),
|
|
354
356
|
)
|
|
355
357
|
return parse_response(ValidationResponse, response)
|
kumoai/connector/utils.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import csv
|
|
3
|
-
import gc
|
|
4
3
|
import io
|
|
5
4
|
import math
|
|
6
5
|
import os
|
|
7
6
|
import re
|
|
7
|
+
import sys
|
|
8
8
|
import tempfile
|
|
9
9
|
import threading
|
|
10
10
|
import time
|
|
@@ -920,7 +920,10 @@ def _read_remote_file_with_progress(
|
|
|
920
920
|
if capture_first_line and not seen_nl:
|
|
921
921
|
header_line = bytes(header_acc)
|
|
922
922
|
|
|
923
|
-
|
|
923
|
+
if sys.version_info >= (3, 13):
|
|
924
|
+
mv = memoryview(buf.getvalue())
|
|
925
|
+
else:
|
|
926
|
+
mv = buf.getbuffer() # zero-copy view of BytesIO internal buffer
|
|
924
927
|
return buf, mv, header_line
|
|
925
928
|
|
|
926
929
|
|
|
@@ -999,7 +1002,10 @@ def _iter_mv_chunks(mv: memoryview,
|
|
|
999
1002
|
n = mv.nbytes
|
|
1000
1003
|
while pos < n:
|
|
1001
1004
|
nxt = min(n, pos + part_size)
|
|
1002
|
-
|
|
1005
|
+
if sys.version_info >= (3, 13):
|
|
1006
|
+
yield mv[pos:nxt].tobytes()
|
|
1007
|
+
else:
|
|
1008
|
+
yield mv[pos:nxt] # zero-copy slice
|
|
1003
1009
|
pos = nxt
|
|
1004
1010
|
|
|
1005
1011
|
|
|
@@ -1473,13 +1479,17 @@ def _remote_upload_file(name: str, fs: Filesystem, url: str, info: dict,
|
|
|
1473
1479
|
if renamed_cols_msg:
|
|
1474
1480
|
logger.info(renamed_cols_msg)
|
|
1475
1481
|
|
|
1482
|
+
try:
|
|
1483
|
+
if isinstance(data_mv, memoryview):
|
|
1484
|
+
data_mv.release()
|
|
1485
|
+
except Exception:
|
|
1486
|
+
pass
|
|
1487
|
+
|
|
1476
1488
|
try:
|
|
1477
1489
|
if buf:
|
|
1478
1490
|
buf.close()
|
|
1479
1491
|
except Exception:
|
|
1480
1492
|
pass
|
|
1481
|
-
del buf, data_mv, header_line
|
|
1482
|
-
gc.collect()
|
|
1483
1493
|
|
|
1484
1494
|
logger.info("Upload complete. Validated table %s.", name)
|
|
1485
1495
|
|
|
@@ -1719,13 +1729,17 @@ def _remote_upload_directory(
|
|
|
1719
1729
|
else:
|
|
1720
1730
|
break
|
|
1721
1731
|
|
|
1732
|
+
try:
|
|
1733
|
+
if isinstance(data_mv, memoryview):
|
|
1734
|
+
data_mv.release()
|
|
1735
|
+
except Exception:
|
|
1736
|
+
pass
|
|
1737
|
+
|
|
1722
1738
|
try:
|
|
1723
1739
|
if buf:
|
|
1724
1740
|
buf.close()
|
|
1725
1741
|
except Exception:
|
|
1726
1742
|
pass
|
|
1727
|
-
del buf, data_mv, header_line
|
|
1728
|
-
gc.collect()
|
|
1729
1743
|
|
|
1730
1744
|
_safe_bar_update(file_bar, 1)
|
|
1731
1745
|
_merge_status_update(fpath)
|
|
@@ -20,6 +20,7 @@ from .sagemaker import (
|
|
|
20
20
|
from .base import Table
|
|
21
21
|
from .backend.local import LocalTable
|
|
22
22
|
from .graph import Graph
|
|
23
|
+
from .task_table import TaskTable
|
|
23
24
|
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
24
25
|
|
|
25
26
|
logger = logging.getLogger('kumoai_rfm')
|
|
@@ -78,9 +79,9 @@ def _get_snowflake_url(snowflake_application: str) -> str:
|
|
|
78
79
|
snowpark_session = _get_active_session()
|
|
79
80
|
if not snowpark_session:
|
|
80
81
|
raise ValueError(
|
|
81
|
-
"
|
|
82
|
-
"without an active
|
|
83
|
-
"a
|
|
82
|
+
"KumoRFM initialization failed. 'snowflake_application' is "
|
|
83
|
+
"specified without an active Snowpark session. If running outside "
|
|
84
|
+
"a Snowflake notebook, specify a URL and credentials.")
|
|
84
85
|
with snowpark_session.connection.cursor() as cur:
|
|
85
86
|
cur.execute(
|
|
86
87
|
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
@@ -103,6 +104,9 @@ class RfmGlobalState:
|
|
|
103
104
|
|
|
104
105
|
@property
|
|
105
106
|
def client(self) -> KumoClient:
|
|
107
|
+
if self._backend == InferenceBackend.UNKNOWN:
|
|
108
|
+
raise RuntimeError("KumoRFM is not yet initialized")
|
|
109
|
+
|
|
106
110
|
if self._backend == InferenceBackend.REST:
|
|
107
111
|
return kumoai.global_state.client
|
|
108
112
|
|
|
@@ -146,18 +150,19 @@ def init(
|
|
|
146
150
|
with global_state._lock:
|
|
147
151
|
if global_state._initialized:
|
|
148
152
|
if url != global_state._url:
|
|
149
|
-
raise
|
|
150
|
-
"
|
|
151
|
-
"URL. Re-initialization with a different URL is not "
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"KumoRFM has already been initialized with a different "
|
|
155
|
+
"API URL. Re-initialization with a different URL is not "
|
|
152
156
|
"supported.")
|
|
153
157
|
return
|
|
154
158
|
|
|
155
159
|
if snowflake_application:
|
|
156
160
|
if url is not None:
|
|
157
161
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"url are specified. If
|
|
160
|
-
"specify only
|
|
162
|
+
"KumoRFM initialization failed. Both "
|
|
163
|
+
"'snowflake_application' and 'url' are specified. If "
|
|
164
|
+
"running from a Snowflake notebook, specify only "
|
|
165
|
+
"'snowflake_application'.")
|
|
161
166
|
url = _get_snowflake_url(snowflake_application)
|
|
162
167
|
api_key = "test:DISABLED"
|
|
163
168
|
|
|
@@ -166,32 +171,28 @@ def init(
|
|
|
166
171
|
|
|
167
172
|
backend, region, endpoint_name = _detect_backend(url)
|
|
168
173
|
if backend == InferenceBackend.REST:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
kumoai.init(url=url, api_key=api_key,
|
|
177
|
-
snowflake_credentials=snowflake_credentials,
|
|
178
|
-
snowflake_application=snowflake_application,
|
|
179
|
-
log_level=log_level)
|
|
174
|
+
kumoai.init(
|
|
175
|
+
url=url,
|
|
176
|
+
api_key=api_key,
|
|
177
|
+
snowflake_credentials=snowflake_credentials,
|
|
178
|
+
snowflake_application=snowflake_application,
|
|
179
|
+
log_level=log_level,
|
|
180
|
+
)
|
|
180
181
|
elif backend == InferenceBackend.AWS_SAGEMAKER:
|
|
181
182
|
assert region
|
|
182
183
|
assert endpoint_name
|
|
183
184
|
KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
|
|
185
|
+
logger.info("KumoRFM initialized in AWS SageMaker")
|
|
184
186
|
else:
|
|
185
187
|
assert backend == InferenceBackend.LOCAL_SAGEMAKER
|
|
186
188
|
KumoClient_SageMakerProxy_Local(url).authenticate()
|
|
189
|
+
logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
|
|
187
190
|
|
|
188
191
|
global_state._url = url
|
|
189
192
|
global_state._backend = backend
|
|
190
193
|
global_state._region = region
|
|
191
194
|
global_state._endpoint_name = endpoint_name
|
|
192
195
|
global_state._initialized = True
|
|
193
|
-
logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
|
|
194
|
-
url)
|
|
195
196
|
|
|
196
197
|
|
|
197
198
|
LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
|
|
@@ -202,6 +203,7 @@ __all__ = [
|
|
|
202
203
|
'Table',
|
|
203
204
|
'LocalTable',
|
|
204
205
|
'Graph',
|
|
206
|
+
'TaskTable',
|
|
205
207
|
'KumoRFM',
|
|
206
208
|
'ExplainConfig',
|
|
207
209
|
'Explanation',
|
|
@@ -144,11 +144,11 @@ class SnowSampler(SQLSampler):
|
|
|
144
144
|
query.entity_table: np.arange(len(entity_df)),
|
|
145
145
|
}
|
|
146
146
|
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
147
|
-
table_name,
|
|
147
|
+
table_name, foreign_key, _ = edge_type
|
|
148
148
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
149
149
|
table_name=table_name,
|
|
150
|
-
|
|
151
|
-
|
|
150
|
+
foreign_key=foreign_key,
|
|
151
|
+
index=entity_df[self.primary_key_dict[query.entity_table]],
|
|
152
152
|
anchor_time=time,
|
|
153
153
|
min_offset=min_offset,
|
|
154
154
|
max_offset=max_offset,
|
|
@@ -179,7 +179,7 @@ class SnowSampler(SQLSampler):
|
|
|
179
179
|
def _by_pkey(
|
|
180
180
|
self,
|
|
181
181
|
table_name: str,
|
|
182
|
-
|
|
182
|
+
index: pd.Series,
|
|
183
183
|
columns: set[str],
|
|
184
184
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
185
185
|
key = self.primary_key_dict[table_name]
|
|
@@ -189,7 +189,7 @@ class SnowSampler(SQLSampler):
|
|
|
189
189
|
for column in columns
|
|
190
190
|
]
|
|
191
191
|
|
|
192
|
-
payload = json.dumps(list(
|
|
192
|
+
payload = json.dumps(list(index))
|
|
193
193
|
|
|
194
194
|
sql = ("WITH TMP as (\n"
|
|
195
195
|
" SELECT\n"
|
|
@@ -206,7 +206,7 @@ class SnowSampler(SQLSampler):
|
|
|
206
206
|
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
207
207
|
f"{', '.join(projections)}\n"
|
|
208
208
|
f"FROM TMP\n"
|
|
209
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
209
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
210
210
|
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
211
211
|
|
|
212
212
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
@@ -228,13 +228,82 @@ class SnowSampler(SQLSampler):
|
|
|
228
228
|
stype_dict=self.table_stype_dict[table_name],
|
|
229
229
|
), batch
|
|
230
230
|
|
|
231
|
+
def _by_fkey(
|
|
232
|
+
self,
|
|
233
|
+
table_name: str,
|
|
234
|
+
foreign_key: str,
|
|
235
|
+
index: pd.Series,
|
|
236
|
+
num_neighbors: int,
|
|
237
|
+
anchor_time: pd.Series | None,
|
|
238
|
+
columns: set[str],
|
|
239
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
|
+
time_column = self.time_column_dict.get(table_name)
|
|
241
|
+
|
|
242
|
+
if time_column is not None and anchor_time is not None:
|
|
243
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
244
|
+
payload = json.dumps(list(zip(index, anchor_time)))
|
|
245
|
+
else:
|
|
246
|
+
payload = json.dumps(list(zip(index)))
|
|
247
|
+
|
|
248
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
249
|
+
projections = [
|
|
250
|
+
self.table_column_proj_dict[table_name][column]
|
|
251
|
+
for column in columns
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
sql = ("WITH TMP as (\n"
|
|
255
|
+
" SELECT\n"
|
|
256
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
257
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
258
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__"
|
|
259
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
260
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
261
|
+
else:
|
|
262
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
263
|
+
if time_column is not None and anchor_time is not None:
|
|
264
|
+
sql += (",\n"
|
|
265
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
|
|
266
|
+
sql += (f"\n"
|
|
267
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
268
|
+
f")\n"
|
|
269
|
+
f"SELECT "
|
|
270
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
271
|
+
f"{', '.join(projections)}\n"
|
|
272
|
+
f"FROM TMP\n"
|
|
273
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
274
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
275
|
+
if time_column is not None and anchor_time is not None:
|
|
276
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
277
|
+
sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
|
|
278
|
+
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
279
|
+
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
280
|
+
if time_column is not None:
|
|
281
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
282
|
+
else:
|
|
283
|
+
sql += f" ORDER BY {key_ref}\n"
|
|
284
|
+
sql += f") <= {num_neighbors}"
|
|
285
|
+
|
|
286
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
287
|
+
cursor.execute(sql, (payload, ))
|
|
288
|
+
table = cursor.fetch_arrow_all()
|
|
289
|
+
|
|
290
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
291
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
292
|
+
table = table.remove_column(batch_index)
|
|
293
|
+
|
|
294
|
+
return Table._sanitize(
|
|
295
|
+
df=table.to_pandas(),
|
|
296
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
297
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
298
|
+
), batch
|
|
299
|
+
|
|
231
300
|
# Helper Methods ##########################################################
|
|
232
301
|
|
|
233
302
|
def _by_time(
|
|
234
303
|
self,
|
|
235
304
|
table_name: str,
|
|
236
|
-
|
|
237
|
-
|
|
305
|
+
foreign_key: str,
|
|
306
|
+
index: pd.Series,
|
|
238
307
|
anchor_time: pd.Series,
|
|
239
308
|
min_offset: pd.DateOffset | None,
|
|
240
309
|
max_offset: pd.DateOffset,
|
|
@@ -247,11 +316,11 @@ class SnowSampler(SQLSampler):
|
|
|
247
316
|
if min_offset is not None:
|
|
248
317
|
start_time = anchor_time + min_offset
|
|
249
318
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
250
|
-
payload = json.dumps(list(zip(
|
|
319
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
251
320
|
else:
|
|
252
|
-
payload = json.dumps(list(zip(
|
|
321
|
+
payload = json.dumps(list(zip(index, end_time)))
|
|
253
322
|
|
|
254
|
-
key_ref = self.table_column_ref_dict[table_name][
|
|
323
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
255
324
|
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
256
325
|
projections = [
|
|
257
326
|
self.table_column_proj_dict[table_name][column]
|
|
@@ -260,9 +329,9 @@ class SnowSampler(SQLSampler):
|
|
|
260
329
|
sql = ("WITH TMP as (\n"
|
|
261
330
|
" SELECT\n"
|
|
262
331
|
" f.index as __KUMO_BATCH__,\n")
|
|
263
|
-
if self.table_dtype_dict[table_name][
|
|
332
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
264
333
|
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
265
|
-
elif self.table_dtype_dict[table_name][
|
|
334
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
266
335
|
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
267
336
|
else:
|
|
268
337
|
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
@@ -276,7 +345,7 @@ class SnowSampler(SQLSampler):
|
|
|
276
345
|
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
277
346
|
f"{', '.join(projections)}\n"
|
|
278
347
|
f"FROM TMP\n"
|
|
279
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
348
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
280
349
|
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
281
350
|
f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
|
|
282
351
|
if min_offset is not None:
|
|
@@ -226,7 +226,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
226
226
|
def _by_pkey(
|
|
227
227
|
self,
|
|
228
228
|
table_name: str,
|
|
229
|
-
|
|
229
|
+
index: pd.Series,
|
|
230
230
|
columns: set[str],
|
|
231
231
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
232
232
|
source_table = self.source_table_dict[table_name]
|
|
@@ -237,7 +237,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
237
237
|
for column in columns
|
|
238
238
|
]
|
|
239
239
|
|
|
240
|
-
tmp = pa.table([pa.array(
|
|
240
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
241
241
|
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
242
|
|
|
243
243
|
sql = (f"SELECT "
|
|
@@ -245,7 +245,6 @@ class SQLiteSampler(SQLSampler):
|
|
|
245
245
|
f"{', '.join(projections)}\n"
|
|
246
246
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
247
|
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
248
|
-
|
|
249
248
|
if key in source_table and source_table[key].is_unique_key:
|
|
250
249
|
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
251
250
|
else:
|
|
@@ -271,13 +270,70 @@ class SQLiteSampler(SQLSampler):
|
|
|
271
270
|
stype_dict=self.table_stype_dict[table_name],
|
|
272
271
|
), batch
|
|
273
272
|
|
|
273
|
+
def _by_fkey(
|
|
274
|
+
self,
|
|
275
|
+
table_name: str,
|
|
276
|
+
foreign_key: str,
|
|
277
|
+
index: pd.Series,
|
|
278
|
+
num_neighbors: int,
|
|
279
|
+
anchor_time: pd.Series | None,
|
|
280
|
+
columns: set[str],
|
|
281
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
282
|
+
time_column = self.time_column_dict.get(table_name)
|
|
283
|
+
|
|
284
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
285
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
286
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
287
|
+
if time_column is not None and anchor_time is not None:
|
|
288
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
289
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
290
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
291
|
+
|
|
292
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
293
|
+
projections = [
|
|
294
|
+
self.table_column_proj_dict[table_name][column]
|
|
295
|
+
for column in columns
|
|
296
|
+
]
|
|
297
|
+
sql = (f"SELECT "
|
|
298
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
299
|
+
f"{', '.join(projections)}\n"
|
|
300
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
301
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
302
|
+
f"ON fact.rowid IN (\n"
|
|
303
|
+
f" SELECT rowid\n"
|
|
304
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
305
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
306
|
+
if time_column is not None and anchor_time is not None:
|
|
307
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
308
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
309
|
+
if time_column is not None:
|
|
310
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
311
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
312
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
313
|
+
f")")
|
|
314
|
+
|
|
315
|
+
with self._connection.cursor() as cursor:
|
|
316
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
317
|
+
cursor.execute(sql)
|
|
318
|
+
table = cursor.fetch_arrow_table()
|
|
319
|
+
|
|
320
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
321
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
322
|
+
table = table.remove_column(batch_index)
|
|
323
|
+
|
|
324
|
+
return Table._sanitize(
|
|
325
|
+
df=table.to_pandas(),
|
|
326
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
327
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
328
|
+
), batch
|
|
329
|
+
|
|
274
330
|
# Helper Methods ##########################################################
|
|
275
331
|
|
|
276
332
|
def _by_time(
|
|
277
333
|
self,
|
|
278
334
|
table_name: str,
|
|
279
|
-
|
|
280
|
-
|
|
335
|
+
foreign_key: str,
|
|
336
|
+
index: pd.Series,
|
|
281
337
|
anchor_time: pd.Series,
|
|
282
338
|
min_offset: pd.DateOffset | None,
|
|
283
339
|
max_offset: pd.DateOffset,
|
|
@@ -287,7 +343,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
287
343
|
|
|
288
344
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
289
345
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
290
|
-
tmp = pa.table([pa.array(
|
|
346
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
291
347
|
end_time = anchor_time + max_offset
|
|
292
348
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
293
349
|
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
@@ -295,9 +351,9 @@ class SQLiteSampler(SQLSampler):
|
|
|
295
351
|
start_time = anchor_time + min_offset
|
|
296
352
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
297
353
|
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
298
|
-
tmp_name = f'tmp_{table_name}_{
|
|
354
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
299
355
|
|
|
300
|
-
key_ref = self.table_column_ref_dict[table_name][
|
|
356
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
301
357
|
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
302
358
|
projections = [
|
|
303
359
|
self.table_column_proj_dict[table_name][column]
|
|
@@ -307,7 +363,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
307
363
|
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
308
364
|
f"{', '.join(projections)}\n"
|
|
309
365
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
310
|
-
f"JOIN {self.source_name_dict[table_name]}
|
|
366
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
311
367
|
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
312
368
|
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
313
369
|
if min_offset is not None:
|
|
@@ -359,11 +415,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
359
415
|
query.entity_table: np.arange(len(df)),
|
|
360
416
|
}
|
|
361
417
|
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
362
|
-
table_name,
|
|
418
|
+
table_name, foreign_key, _ = edge_type
|
|
363
419
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
364
420
|
table_name=table_name,
|
|
365
|
-
|
|
366
|
-
|
|
421
|
+
foreign_key=foreign_key,
|
|
422
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
367
423
|
anchor_time=time,
|
|
368
424
|
min_offset=_min,
|
|
369
425
|
max_offset=_max,
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Mapper:
|
|
6
|
+
r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
num_examples: The maximum number of examples to add/retrieve.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, num_examples: int):
|
|
12
|
+
self._pkey_dtype: pd.CategoricalDtype | None = None
|
|
13
|
+
self._indices: list[np.ndarray] = []
|
|
14
|
+
self._index_dtype: pd.CategoricalDtype | None = None
|
|
15
|
+
self._num_examples = num_examples
|
|
16
|
+
|
|
17
|
+
def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
|
|
18
|
+
r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
pkey: The primary keys.
|
|
22
|
+
batch: The batch vector.
|
|
23
|
+
"""
|
|
24
|
+
if self._pkey_dtype is not None:
|
|
25
|
+
category = np.concatenate([
|
|
26
|
+
self._pkey_dtype.categories.values,
|
|
27
|
+
pkey,
|
|
28
|
+
], axis=0)
|
|
29
|
+
category = pd.unique(category)
|
|
30
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
31
|
+
elif pd.api.types.is_string_dtype(pkey):
|
|
32
|
+
category = pd.unique(pkey)
|
|
33
|
+
self._pkey_dtype = pd.CategoricalDtype(category)
|
|
34
|
+
|
|
35
|
+
if self._pkey_dtype is not None:
|
|
36
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
37
|
+
else:
|
|
38
|
+
index = pkey.to_numpy()
|
|
39
|
+
index = self._num_examples * index + batch
|
|
40
|
+
self._indices.append(index)
|
|
41
|
+
self._index_dtype = None
|
|
42
|
+
|
|
43
|
+
def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
|
|
44
|
+
r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
|
|
45
|
+
|
|
46
|
+
Returns ``-1`` for any pair not registered in the mapping.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
pkey: The primary keys.
|
|
50
|
+
batch: The batch vector.
|
|
51
|
+
"""
|
|
52
|
+
if len(self._indices) == 0:
|
|
53
|
+
return np.full(len(pkey), -1, dtype=np.int64)
|
|
54
|
+
|
|
55
|
+
if self._index_dtype is None: # Lazy build index:
|
|
56
|
+
category = pd.unique(np.concatenate(self._indices))
|
|
57
|
+
self._index_dtype = pd.CategoricalDtype(category)
|
|
58
|
+
|
|
59
|
+
if self._pkey_dtype is not None:
|
|
60
|
+
index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
|
|
61
|
+
else:
|
|
62
|
+
index = pkey.to_numpy()
|
|
63
|
+
index = self._num_examples * index + batch
|
|
64
|
+
|
|
65
|
+
out = pd.Categorical(index, dtype=self._index_dtype).codes
|
|
66
|
+
out = out.astype('int64')
|
|
67
|
+
return out
|