kumoai 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601181732__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +38 -30
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +134 -114
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +541 -307
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +95 -0
- kumoai/utils/progress_logger.py +205 -117
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601181732.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601181732.dist-info}/RECORD +40 -35
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601181732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601181732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601181732.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import sys
|
|
3
4
|
import threading
|
|
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
|
|
|
68
69
|
if self._url is None or (self._api_key is None
|
|
69
70
|
and self._spcs_token is None
|
|
70
71
|
and self._snowpark_session is None):
|
|
71
|
-
raise ValueError(
|
|
72
|
-
|
|
73
|
-
"your client before proceeding.")
|
|
72
|
+
raise ValueError("Client creation or authentication failed. "
|
|
73
|
+
"Please re-create your client before proceeding.")
|
|
74
74
|
|
|
75
75
|
if hasattr(self.thread_local, '_client'):
|
|
76
76
|
# Set the spcs token in the client to ensure it has the latest.
|
|
@@ -123,10 +123,9 @@ def init(
|
|
|
123
123
|
""" # noqa
|
|
124
124
|
# Avoid mutations to the global state after it is set:
|
|
125
125
|
if global_state.initialized:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
"session.")
|
|
126
|
+
warnings.warn("Kumo SDK already initialized. To re-initialize the "
|
|
127
|
+
"SDK, please start a new interpreter. No changes will "
|
|
128
|
+
"be made to the current session.")
|
|
130
129
|
return
|
|
131
130
|
|
|
132
131
|
set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
|
|
@@ -138,15 +137,15 @@ def init(
|
|
|
138
137
|
if snowflake_application:
|
|
139
138
|
if url is not None:
|
|
140
139
|
raise ValueError(
|
|
141
|
-
"
|
|
142
|
-
"are specified. If running from a
|
|
143
|
-
"only snowflake_application.")
|
|
140
|
+
"Kumo SDK initialization failed. Both 'snowflake_application' "
|
|
141
|
+
"and 'url' are specified. If running from a Snowflake "
|
|
142
|
+
"notebook, specify only 'snowflake_application'.")
|
|
144
143
|
snowpark_session = _get_active_session()
|
|
145
144
|
if not snowpark_session:
|
|
146
145
|
raise ValueError(
|
|
147
|
-
"
|
|
148
|
-
"without an active
|
|
149
|
-
"a
|
|
146
|
+
"Kumo SDK initialization failed. 'snowflake_application' is "
|
|
147
|
+
"specified without an active Snowpark session. If running "
|
|
148
|
+
"outside a Snowflake notebook, specify a URL and credentials.")
|
|
150
149
|
description = snowpark_session.sql(
|
|
151
150
|
f"DESCRIBE SERVICE {snowflake_application}."
|
|
152
151
|
"USER_SCHEMA.KUMO_SERVICE").collect()[0]
|
|
@@ -155,14 +154,14 @@ def init(
|
|
|
155
154
|
if api_key is None and not snowflake_application:
|
|
156
155
|
if snowflake_credentials is None:
|
|
157
156
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"credentials provided. Please either set the
|
|
160
|
-
"or explicitly call `kumoai.init(...)`.")
|
|
157
|
+
"Kumo SDK initialization failed. Neither an API key nor "
|
|
158
|
+
"Snowflake credentials provided. Please either set the "
|
|
159
|
+
"'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
|
|
161
160
|
if (set(snowflake_credentials.keys())
|
|
162
161
|
!= {'user', 'password', 'account'}):
|
|
163
162
|
raise ValueError(
|
|
164
|
-
f"Provided credentials should be a dictionary with
|
|
165
|
-
f"'user', 'password', and 'account'. Only "
|
|
163
|
+
f"Provided Snowflake credentials should be a dictionary with "
|
|
164
|
+
f"keys 'user', 'password', and 'account'. Only "
|
|
166
165
|
f"{set(snowflake_credentials.keys())} were provided.")
|
|
167
166
|
|
|
168
167
|
# Get or infer URL:
|
|
@@ -173,10 +172,10 @@ def init(
|
|
|
173
172
|
except KeyError:
|
|
174
173
|
pass
|
|
175
174
|
if url is None:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
175
|
+
raise ValueError("Kumo SDK initialization failed since no endpoint "
|
|
176
|
+
"URL was provided. Please either set the "
|
|
177
|
+
"'KUMO_API_ENDPOINT' environment variable or "
|
|
178
|
+
"explicitly call `kumoai.init(...)`.")
|
|
180
179
|
|
|
181
180
|
# Assign global state after verification that client can be created and
|
|
182
181
|
# authenticated successfully:
|
|
@@ -198,10 +197,8 @@ def init(
|
|
|
198
197
|
logger = logging.getLogger('kumoai')
|
|
199
198
|
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
200
199
|
|
|
201
|
-
logger.info(
|
|
202
|
-
|
|
203
|
-
f"against deployment {url}, with "
|
|
204
|
-
f"log level {log_level}.")
|
|
200
|
+
logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
|
|
201
|
+
f"'{url}'")
|
|
205
202
|
|
|
206
203
|
|
|
207
204
|
def set_log_level(level: str) -> None:
|
|
@@ -280,7 +277,7 @@ __all__ = [
|
|
|
280
277
|
]
|
|
281
278
|
|
|
282
279
|
|
|
283
|
-
def
|
|
280
|
+
def in_streamlit_notebook() -> bool:
|
|
284
281
|
try:
|
|
285
282
|
from snowflake.snowpark.context import get_active_session
|
|
286
283
|
import streamlit # noqa: F401
|
|
@@ -290,9 +287,7 @@ def in_snowflake_notebook() -> bool:
|
|
|
290
287
|
return False
|
|
291
288
|
|
|
292
289
|
|
|
293
|
-
def
|
|
294
|
-
if in_snowflake_notebook():
|
|
295
|
-
return True
|
|
290
|
+
def in_jupyter_notebook() -> bool:
|
|
296
291
|
try:
|
|
297
292
|
from IPython import get_ipython
|
|
298
293
|
shell = get_ipython()
|
|
@@ -301,3 +296,16 @@ def in_notebook() -> bool:
|
|
|
301
296
|
return shell.__class__.__name__ == 'ZMQInteractiveShell'
|
|
302
297
|
except Exception:
|
|
303
298
|
return False
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def in_vnext_notebook() -> bool:
|
|
302
|
+
try:
|
|
303
|
+
from snowflake.snowpark.context import get_active_session
|
|
304
|
+
get_active_session()
|
|
305
|
+
return in_jupyter_notebook()
|
|
306
|
+
except Exception:
|
|
307
|
+
return False
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def in_notebook() -> bool:
|
|
311
|
+
return in_streamlit_notebook() or in_jupyter_notebook()
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.15.0.dev202601181732'
|
kumoai/client/client.py
CHANGED
|
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
|
|
13
13
|
ArtifactExportJobAPI,
|
|
14
14
|
BaselineJobAPI,
|
|
15
15
|
BatchPredictionJobAPI,
|
|
16
|
+
DistillationJobAPI,
|
|
16
17
|
GeneratePredictionTableJobAPI,
|
|
17
18
|
GenerateTrainTableJobAPI,
|
|
18
19
|
LLMJobAPI,
|
|
@@ -132,6 +133,11 @@ class KumoClient:
|
|
|
132
133
|
from kumoai.client.jobs import TrainingJobAPI
|
|
133
134
|
return TrainingJobAPI(self)
|
|
134
135
|
|
|
136
|
+
@property
|
|
137
|
+
def distillation_job_api(self) -> 'DistillationJobAPI':
|
|
138
|
+
from kumoai.client.jobs import DistillationJobAPI
|
|
139
|
+
return DistillationJobAPI(self)
|
|
140
|
+
|
|
135
141
|
@property
|
|
136
142
|
def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
|
|
137
143
|
from kumoai.client.jobs import BatchPredictionJobAPI
|
kumoai/client/jobs.py
CHANGED
|
@@ -22,6 +22,8 @@ from kumoapi.jobs import (
|
|
|
22
22
|
BatchPredictionRequest,
|
|
23
23
|
CancelBatchPredictionJobResponse,
|
|
24
24
|
CancelTrainingJobResponse,
|
|
25
|
+
DistillationJobRequest,
|
|
26
|
+
DistillationJobResource,
|
|
25
27
|
ErrorDetails,
|
|
26
28
|
GeneratePredictionTableJobResource,
|
|
27
29
|
GeneratePredictionTableRequest,
|
|
@@ -171,6 +173,28 @@ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
|
|
|
171
173
|
return resource.config
|
|
172
174
|
|
|
173
175
|
|
|
176
|
+
class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
|
|
177
|
+
DistillationJobResource]):
|
|
178
|
+
r"""Typed API definition for the distillation job resource."""
|
|
179
|
+
def __init__(self, client: KumoClient) -> None:
|
|
180
|
+
super().__init__(client, '/training_jobs/distilled_training_job',
|
|
181
|
+
DistillationJobResource)
|
|
182
|
+
|
|
183
|
+
def get_config(self, job_id: str) -> DistillationJobRequest:
|
|
184
|
+
raise NotImplementedError(
|
|
185
|
+
"Getting the configuration for a distillation job is "
|
|
186
|
+
"not implemented yet.")
|
|
187
|
+
|
|
188
|
+
def get_progress(self, id: str) -> AutoTrainerProgress:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Getting the progress for a distillation job is not "
|
|
191
|
+
"implemented yet.")
|
|
192
|
+
|
|
193
|
+
def cancel(self, id: str) -> CancelTrainingJobResponse:
|
|
194
|
+
raise NotImplementedError(
|
|
195
|
+
"Cancelling a distillation job is not implemented yet.")
|
|
196
|
+
|
|
197
|
+
|
|
174
198
|
class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
|
|
175
199
|
BatchPredictionJobResource]):
|
|
176
200
|
r"""Typed API definition for the prediction job resource."""
|
|
@@ -320,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
|
|
|
320
344
|
id: str,
|
|
321
345
|
source_table_type: SourceTableType,
|
|
322
346
|
train_table_mod: TrainingTableSpec,
|
|
347
|
+
extensive_validation: bool,
|
|
323
348
|
) -> ValidationResponse:
|
|
324
349
|
response = self._client._post(
|
|
325
350
|
f'{self._base_endpoint}/{id}/validate_custom_train_table',
|
|
326
351
|
json=to_json_dict({
|
|
327
352
|
'custom_table': source_table_type,
|
|
328
353
|
'train_table_mod': train_table_mod,
|
|
354
|
+
'extensive_validation': extensive_validation,
|
|
329
355
|
}),
|
|
330
356
|
)
|
|
331
357
|
return parse_response(ValidationResponse, response)
|
kumoai/connector/utils.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import csv
|
|
3
|
-
import gc
|
|
4
3
|
import io
|
|
5
4
|
import math
|
|
6
5
|
import os
|
|
7
6
|
import re
|
|
7
|
+
import sys
|
|
8
8
|
import tempfile
|
|
9
9
|
import threading
|
|
10
10
|
import time
|
|
@@ -920,7 +920,10 @@ def _read_remote_file_with_progress(
|
|
|
920
920
|
if capture_first_line and not seen_nl:
|
|
921
921
|
header_line = bytes(header_acc)
|
|
922
922
|
|
|
923
|
-
|
|
923
|
+
if sys.version_info >= (3, 13):
|
|
924
|
+
mv = memoryview(buf.getvalue())
|
|
925
|
+
else:
|
|
926
|
+
mv = buf.getbuffer() # zero-copy view of BytesIO internal buffer
|
|
924
927
|
return buf, mv, header_line
|
|
925
928
|
|
|
926
929
|
|
|
@@ -999,7 +1002,10 @@ def _iter_mv_chunks(mv: memoryview,
|
|
|
999
1002
|
n = mv.nbytes
|
|
1000
1003
|
while pos < n:
|
|
1001
1004
|
nxt = min(n, pos + part_size)
|
|
1002
|
-
|
|
1005
|
+
if sys.version_info >= (3, 13):
|
|
1006
|
+
yield mv[pos:nxt].tobytes()
|
|
1007
|
+
else:
|
|
1008
|
+
yield mv[pos:nxt] # zero-copy slice
|
|
1003
1009
|
pos = nxt
|
|
1004
1010
|
|
|
1005
1011
|
|
|
@@ -1473,13 +1479,17 @@ def _remote_upload_file(name: str, fs: Filesystem, url: str, info: dict,
|
|
|
1473
1479
|
if renamed_cols_msg:
|
|
1474
1480
|
logger.info(renamed_cols_msg)
|
|
1475
1481
|
|
|
1482
|
+
try:
|
|
1483
|
+
if isinstance(data_mv, memoryview):
|
|
1484
|
+
data_mv.release()
|
|
1485
|
+
except Exception:
|
|
1486
|
+
pass
|
|
1487
|
+
|
|
1476
1488
|
try:
|
|
1477
1489
|
if buf:
|
|
1478
1490
|
buf.close()
|
|
1479
1491
|
except Exception:
|
|
1480
1492
|
pass
|
|
1481
|
-
del buf, data_mv, header_line
|
|
1482
|
-
gc.collect()
|
|
1483
1493
|
|
|
1484
1494
|
logger.info("Upload complete. Validated table %s.", name)
|
|
1485
1495
|
|
|
@@ -1719,13 +1729,17 @@ def _remote_upload_directory(
|
|
|
1719
1729
|
else:
|
|
1720
1730
|
break
|
|
1721
1731
|
|
|
1732
|
+
try:
|
|
1733
|
+
if isinstance(data_mv, memoryview):
|
|
1734
|
+
data_mv.release()
|
|
1735
|
+
except Exception:
|
|
1736
|
+
pass
|
|
1737
|
+
|
|
1722
1738
|
try:
|
|
1723
1739
|
if buf:
|
|
1724
1740
|
buf.close()
|
|
1725
1741
|
except Exception:
|
|
1726
1742
|
pass
|
|
1727
|
-
del buf, data_mv, header_line
|
|
1728
|
-
gc.collect()
|
|
1729
1743
|
|
|
1730
1744
|
_safe_bar_update(file_bar, 1)
|
|
1731
1745
|
_merge_status_update(fpath)
|
|
@@ -20,6 +20,7 @@ from .sagemaker import (
|
|
|
20
20
|
from .base import Table
|
|
21
21
|
from .backend.local import LocalTable
|
|
22
22
|
from .graph import Graph
|
|
23
|
+
from .task_table import TaskTable
|
|
23
24
|
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
24
25
|
|
|
25
26
|
logger = logging.getLogger('kumoai_rfm')
|
|
@@ -78,9 +79,9 @@ def _get_snowflake_url(snowflake_application: str) -> str:
|
|
|
78
79
|
snowpark_session = _get_active_session()
|
|
79
80
|
if not snowpark_session:
|
|
80
81
|
raise ValueError(
|
|
81
|
-
"
|
|
82
|
-
"without an active
|
|
83
|
-
"a
|
|
82
|
+
"KumoRFM initialization failed. 'snowflake_application' is "
|
|
83
|
+
"specified without an active Snowpark session. If running outside "
|
|
84
|
+
"a Snowflake notebook, specify a URL and credentials.")
|
|
84
85
|
with snowpark_session.connection.cursor() as cur:
|
|
85
86
|
cur.execute(
|
|
86
87
|
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
@@ -103,6 +104,9 @@ class RfmGlobalState:
|
|
|
103
104
|
|
|
104
105
|
@property
|
|
105
106
|
def client(self) -> KumoClient:
|
|
107
|
+
if self._backend == InferenceBackend.UNKNOWN:
|
|
108
|
+
raise RuntimeError("KumoRFM is not yet initialized")
|
|
109
|
+
|
|
106
110
|
if self._backend == InferenceBackend.REST:
|
|
107
111
|
return kumoai.global_state.client
|
|
108
112
|
|
|
@@ -146,18 +150,19 @@ def init(
|
|
|
146
150
|
with global_state._lock:
|
|
147
151
|
if global_state._initialized:
|
|
148
152
|
if url != global_state._url:
|
|
149
|
-
raise
|
|
150
|
-
"
|
|
151
|
-
"URL. Re-initialization with a different URL is not "
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"KumoRFM has already been initialized with a different "
|
|
155
|
+
"API URL. Re-initialization with a different URL is not "
|
|
152
156
|
"supported.")
|
|
153
157
|
return
|
|
154
158
|
|
|
155
159
|
if snowflake_application:
|
|
156
160
|
if url is not None:
|
|
157
161
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"url are specified. If
|
|
160
|
-
"specify only
|
|
162
|
+
"KumoRFM initialization failed. Both "
|
|
163
|
+
"'snowflake_application' and 'url' are specified. If "
|
|
164
|
+
"running from a Snowflake notebook, specify only "
|
|
165
|
+
"'snowflake_application'.")
|
|
161
166
|
url = _get_snowflake_url(snowflake_application)
|
|
162
167
|
api_key = "test:DISABLED"
|
|
163
168
|
|
|
@@ -166,32 +171,28 @@ def init(
|
|
|
166
171
|
|
|
167
172
|
backend, region, endpoint_name = _detect_backend(url)
|
|
168
173
|
if backend == InferenceBackend.REST:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
kumoai.init(url=url, api_key=api_key,
|
|
177
|
-
snowflake_credentials=snowflake_credentials,
|
|
178
|
-
snowflake_application=snowflake_application,
|
|
179
|
-
log_level=log_level)
|
|
174
|
+
kumoai.init(
|
|
175
|
+
url=url,
|
|
176
|
+
api_key=api_key,
|
|
177
|
+
snowflake_credentials=snowflake_credentials,
|
|
178
|
+
snowflake_application=snowflake_application,
|
|
179
|
+
log_level=log_level,
|
|
180
|
+
)
|
|
180
181
|
elif backend == InferenceBackend.AWS_SAGEMAKER:
|
|
181
182
|
assert region
|
|
182
183
|
assert endpoint_name
|
|
183
184
|
KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
|
|
185
|
+
logger.info("KumoRFM initialized in AWS SageMaker")
|
|
184
186
|
else:
|
|
185
187
|
assert backend == InferenceBackend.LOCAL_SAGEMAKER
|
|
186
188
|
KumoClient_SageMakerProxy_Local(url).authenticate()
|
|
189
|
+
logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
|
|
187
190
|
|
|
188
191
|
global_state._url = url
|
|
189
192
|
global_state._backend = backend
|
|
190
193
|
global_state._region = region
|
|
191
194
|
global_state._endpoint_name = endpoint_name
|
|
192
195
|
global_state._initialized = True
|
|
193
|
-
logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
|
|
194
|
-
url)
|
|
195
196
|
|
|
196
197
|
|
|
197
198
|
LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
|
|
@@ -202,6 +203,7 @@ __all__ = [
|
|
|
202
203
|
'Table',
|
|
203
204
|
'LocalTable',
|
|
204
205
|
'Graph',
|
|
206
|
+
'TaskTable',
|
|
205
207
|
'KumoRFM',
|
|
206
208
|
'ExplainConfig',
|
|
207
209
|
'Explanation',
|
|
@@ -1,12 +1,11 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from typing import TYPE_CHECKING
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from kumoapi.rfm.context import Subgraph
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
6
|
|
|
9
7
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
10
9
|
from kumoai.utils import ProgressLogger
|
|
11
10
|
|
|
12
11
|
try:
|
|
@@ -106,26 +105,20 @@ class LocalGraphStore:
|
|
|
106
105
|
df_dict: dict[str, pd.DataFrame] = {}
|
|
107
106
|
for table_name, table in graph.tables.items():
|
|
108
107
|
assert isinstance(table, LocalTable)
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
111
119
|
|
|
112
120
|
mask_dict: dict[str, np.ndarray] = {}
|
|
113
121
|
for table in graph.tables.values():
|
|
114
|
-
for col in table.columns:
|
|
115
|
-
if col.stype == Stype.timestamp:
|
|
116
|
-
ser = df_dict[table.name][col.name]
|
|
117
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
118
|
-
with warnings.catch_warnings():
|
|
119
|
-
warnings.filterwarnings(
|
|
120
|
-
'ignore',
|
|
121
|
-
message='Could not infer format',
|
|
122
|
-
)
|
|
123
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
124
|
-
df_dict[table.name][col.name] = ser
|
|
125
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
126
|
-
ser = ser.dt.tz_localize(None)
|
|
127
|
-
df_dict[table.name][col.name] = ser
|
|
128
|
-
|
|
129
122
|
mask: np.ndarray | None = None
|
|
130
123
|
if table._time_column is not None:
|
|
131
124
|
ser = df_dict[table.name][table._time_column]
|
|
@@ -188,8 +181,6 @@ class LocalGraphStore:
|
|
|
188
181
|
continue
|
|
189
182
|
|
|
190
183
|
time = self.df_dict[table.name][table._time_column]
|
|
191
|
-
if time.dtype != 'datetime64[ns]':
|
|
192
|
-
time = time.astype('datetime64[ns]')
|
|
193
184
|
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
194
185
|
|
|
195
186
|
if table.name in self.mask_dict.keys():
|
|
@@ -219,9 +219,6 @@ class LocalSampler(Sampler):
|
|
|
219
219
|
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
220
220
|
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
221
221
|
|
|
222
|
-
if anchor_time.dtype != 'datetime64[ns]':
|
|
223
|
-
anchor_time = anchor_time.astype('datetime64')
|
|
224
|
-
|
|
225
222
|
count = 0
|
|
226
223
|
ys: list[pd.Series] = []
|
|
227
224
|
mask = np.full(len(index), False, dtype=bool)
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import cast
|
|
1
|
+
from typing import Sequence, cast
|
|
3
2
|
|
|
4
3
|
import pandas as pd
|
|
5
4
|
from kumoapi.model_plan import MissingType
|
|
6
5
|
|
|
7
|
-
from kumoai.experimental.rfm.base import
|
|
8
|
-
|
|
6
|
+
from kumoai.experimental.rfm.base import (
|
|
7
|
+
ColumnSpec,
|
|
8
|
+
DataBackend,
|
|
9
|
+
SourceColumn,
|
|
10
|
+
SourceForeignKey,
|
|
11
|
+
Table,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class LocalTable(Table):
|
|
@@ -71,7 +75,6 @@ class LocalTable(Table):
|
|
|
71
75
|
|
|
72
76
|
super().__init__(
|
|
73
77
|
name=name,
|
|
74
|
-
columns=list(df.columns),
|
|
75
78
|
primary_key=primary_key,
|
|
76
79
|
time_column=time_column,
|
|
77
80
|
end_time_column=end_time_column,
|
|
@@ -82,33 +85,29 @@ class LocalTable(Table):
|
|
|
82
85
|
return cast(DataBackend, DataBackend.LOCAL)
|
|
83
86
|
|
|
84
87
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
dtype = infer_dtype(ser)
|
|
90
|
-
except Exception:
|
|
91
|
-
warnings.warn(f"Encountered unsupported data type "
|
|
92
|
-
f"'{ser.dtype}' for column '{column}' in table "
|
|
93
|
-
f"'{self.name}'. Please change the data type of "
|
|
94
|
-
f"the column in the `pandas.DataFrame` to use "
|
|
95
|
-
f"it within this table, or remove it to "
|
|
96
|
-
f"suppress this warning.")
|
|
97
|
-
continue
|
|
98
|
-
|
|
99
|
-
source_column = SourceColumn(
|
|
100
|
-
name=column,
|
|
101
|
-
dtype=dtype,
|
|
88
|
+
return [
|
|
89
|
+
SourceColumn(
|
|
90
|
+
name=column_name,
|
|
91
|
+
dtype=None,
|
|
102
92
|
is_primary_key=False,
|
|
103
93
|
is_unique_key=False,
|
|
104
94
|
is_nullable=True,
|
|
105
|
-
)
|
|
106
|
-
|
|
95
|
+
) for column_name in self._data.columns
|
|
96
|
+
]
|
|
107
97
|
|
|
108
|
-
|
|
98
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
99
|
+
return []
|
|
109
100
|
|
|
110
101
|
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
111
102
|
return self._data
|
|
112
103
|
|
|
104
|
+
def _get_expr_sample_df(
|
|
105
|
+
self,
|
|
106
|
+
columns: Sequence[ColumnSpec],
|
|
107
|
+
) -> pd.DataFrame:
|
|
108
|
+
raise RuntimeError(f"Column expressions are not supported in "
|
|
109
|
+
f"'{self.__class__.__name__}'. Please apply your "
|
|
110
|
+
f"expressions on the `pd.DataFrame` directly.")
|
|
111
|
+
|
|
113
112
|
def _get_num_rows(self) -> int | None:
|
|
114
113
|
return len(self._data)
|