kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +4 -5
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
- kumoai/experimental/rfm/backend/snow/table.py +166 -56
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
- kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
- kumoai/experimental/rfm/base/__init__.py +8 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +10 -5
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +606 -361
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import sys
|
|
3
4
|
import threading
|
|
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
|
|
|
68
69
|
if self._url is None or (self._api_key is None
|
|
69
70
|
and self._spcs_token is None
|
|
70
71
|
and self._snowpark_session is None):
|
|
71
|
-
raise ValueError(
|
|
72
|
-
|
|
73
|
-
"your client before proceeding.")
|
|
72
|
+
raise ValueError("Client creation or authentication failed. "
|
|
73
|
+
"Please re-create your client before proceeding.")
|
|
74
74
|
|
|
75
75
|
if hasattr(self.thread_local, '_client'):
|
|
76
76
|
# Set the spcs token in the client to ensure it has the latest.
|
|
@@ -123,10 +123,9 @@ def init(
|
|
|
123
123
|
""" # noqa
|
|
124
124
|
# Avoid mutations to the global state after it is set:
|
|
125
125
|
if global_state.initialized:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
"session.")
|
|
126
|
+
warnings.warn("Kumo SDK already initialized. To re-initialize the "
|
|
127
|
+
"SDK, please start a new interpreter. No changes will "
|
|
128
|
+
"be made to the current session.")
|
|
130
129
|
return
|
|
131
130
|
|
|
132
131
|
set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
|
|
@@ -138,15 +137,15 @@ def init(
|
|
|
138
137
|
if snowflake_application:
|
|
139
138
|
if url is not None:
|
|
140
139
|
raise ValueError(
|
|
141
|
-
"
|
|
142
|
-
"are specified. If running from a
|
|
143
|
-
"only snowflake_application.")
|
|
140
|
+
"Kumo SDK initialization failed. Both 'snowflake_application' "
|
|
141
|
+
"and 'url' are specified. If running from a Snowflake "
|
|
142
|
+
"notebook, specify only 'snowflake_application'.")
|
|
144
143
|
snowpark_session = _get_active_session()
|
|
145
144
|
if not snowpark_session:
|
|
146
145
|
raise ValueError(
|
|
147
|
-
"
|
|
148
|
-
"without an active
|
|
149
|
-
"a
|
|
146
|
+
"Kumo SDK initialization failed. 'snowflake_application' is "
|
|
147
|
+
"specified without an active Snowpark session. If running "
|
|
148
|
+
"outside a Snowflake notebook, specify a URL and credentials.")
|
|
150
149
|
description = snowpark_session.sql(
|
|
151
150
|
f"DESCRIBE SERVICE {snowflake_application}."
|
|
152
151
|
"USER_SCHEMA.KUMO_SERVICE").collect()[0]
|
|
@@ -155,14 +154,14 @@ def init(
|
|
|
155
154
|
if api_key is None and not snowflake_application:
|
|
156
155
|
if snowflake_credentials is None:
|
|
157
156
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"credentials provided. Please either set the
|
|
160
|
-
"or explicitly call `kumoai.init(...)`.")
|
|
157
|
+
"Kumo SDK initialization failed. Neither an API key nor "
|
|
158
|
+
"Snowflake credentials provided. Please either set the "
|
|
159
|
+
"'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
|
|
161
160
|
if (set(snowflake_credentials.keys())
|
|
162
161
|
!= {'user', 'password', 'account'}):
|
|
163
162
|
raise ValueError(
|
|
164
|
-
f"Provided credentials should be a dictionary with
|
|
165
|
-
f"'user', 'password', and 'account'. Only "
|
|
163
|
+
f"Provided Snowflake credentials should be a dictionary with "
|
|
164
|
+
f"keys 'user', 'password', and 'account'. Only "
|
|
166
165
|
f"{set(snowflake_credentials.keys())} were provided.")
|
|
167
166
|
|
|
168
167
|
# Get or infer URL:
|
|
@@ -173,10 +172,10 @@ def init(
|
|
|
173
172
|
except KeyError:
|
|
174
173
|
pass
|
|
175
174
|
if url is None:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
175
|
+
raise ValueError("Kumo SDK initialization failed since no endpoint "
|
|
176
|
+
"URL was provided. Please either set the "
|
|
177
|
+
"'KUMO_API_ENDPOINT' environment variable or "
|
|
178
|
+
"explicitly call `kumoai.init(...)`.")
|
|
180
179
|
|
|
181
180
|
# Assign global state after verification that client can be created and
|
|
182
181
|
# authenticated successfully:
|
|
@@ -198,10 +197,8 @@ def init(
|
|
|
198
197
|
logger = logging.getLogger('kumoai')
|
|
199
198
|
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
200
199
|
|
|
201
|
-
logger.info(
|
|
202
|
-
|
|
203
|
-
f"against deployment {url}, with "
|
|
204
|
-
f"log level {log_level}.")
|
|
200
|
+
logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
|
|
201
|
+
f"'{url}'")
|
|
205
202
|
|
|
206
203
|
|
|
207
204
|
def set_log_level(level: str) -> None:
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.15.0.dev202601131732'
|
kumoai/client/client.py
CHANGED
|
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
|
|
13
13
|
ArtifactExportJobAPI,
|
|
14
14
|
BaselineJobAPI,
|
|
15
15
|
BatchPredictionJobAPI,
|
|
16
|
+
DistillationJobAPI,
|
|
16
17
|
GeneratePredictionTableJobAPI,
|
|
17
18
|
GenerateTrainTableJobAPI,
|
|
18
19
|
LLMJobAPI,
|
|
@@ -132,6 +133,11 @@ class KumoClient:
|
|
|
132
133
|
from kumoai.client.jobs import TrainingJobAPI
|
|
133
134
|
return TrainingJobAPI(self)
|
|
134
135
|
|
|
136
|
+
@property
|
|
137
|
+
def distillation_job_api(self) -> 'DistillationJobAPI':
|
|
138
|
+
from kumoai.client.jobs import DistillationJobAPI
|
|
139
|
+
return DistillationJobAPI(self)
|
|
140
|
+
|
|
135
141
|
@property
|
|
136
142
|
def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
|
|
137
143
|
from kumoai.client.jobs import BatchPredictionJobAPI
|
kumoai/client/jobs.py
CHANGED
|
@@ -22,6 +22,8 @@ from kumoapi.jobs import (
|
|
|
22
22
|
BatchPredictionRequest,
|
|
23
23
|
CancelBatchPredictionJobResponse,
|
|
24
24
|
CancelTrainingJobResponse,
|
|
25
|
+
DistillationJobRequest,
|
|
26
|
+
DistillationJobResource,
|
|
25
27
|
ErrorDetails,
|
|
26
28
|
GeneratePredictionTableJobResource,
|
|
27
29
|
GeneratePredictionTableRequest,
|
|
@@ -171,6 +173,28 @@ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
|
|
|
171
173
|
return resource.config
|
|
172
174
|
|
|
173
175
|
|
|
176
|
+
class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
|
|
177
|
+
DistillationJobResource]):
|
|
178
|
+
r"""Typed API definition for the distillation job resource."""
|
|
179
|
+
def __init__(self, client: KumoClient) -> None:
|
|
180
|
+
super().__init__(client, '/training_jobs/distilled_training_job',
|
|
181
|
+
DistillationJobResource)
|
|
182
|
+
|
|
183
|
+
def get_config(self, job_id: str) -> DistillationJobRequest:
|
|
184
|
+
raise NotImplementedError(
|
|
185
|
+
"Getting the configuration for a distillation job is "
|
|
186
|
+
"not implemented yet.")
|
|
187
|
+
|
|
188
|
+
def get_progress(self, id: str) -> AutoTrainerProgress:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Getting the progress for a distillation job is not "
|
|
191
|
+
"implemented yet.")
|
|
192
|
+
|
|
193
|
+
def cancel(self, id: str) -> CancelTrainingJobResponse:
|
|
194
|
+
raise NotImplementedError(
|
|
195
|
+
"Cancelling a distillation job is not implemented yet.")
|
|
196
|
+
|
|
197
|
+
|
|
174
198
|
class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
|
|
175
199
|
BatchPredictionJobResource]):
|
|
176
200
|
r"""Typed API definition for the prediction job resource."""
|
|
@@ -320,12 +344,14 @@ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
|
|
|
320
344
|
id: str,
|
|
321
345
|
source_table_type: SourceTableType,
|
|
322
346
|
train_table_mod: TrainingTableSpec,
|
|
347
|
+
extensive_validation: bool,
|
|
323
348
|
) -> ValidationResponse:
|
|
324
349
|
response = self._client._post(
|
|
325
350
|
f'{self._base_endpoint}/{id}/validate_custom_train_table',
|
|
326
351
|
json=to_json_dict({
|
|
327
352
|
'custom_table': source_table_type,
|
|
328
353
|
'train_table_mod': train_table_mod,
|
|
354
|
+
'extensive_validation': extensive_validation,
|
|
329
355
|
}),
|
|
330
356
|
)
|
|
331
357
|
return parse_response(ValidationResponse, response)
|
kumoai/connector/utils.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import csv
|
|
3
|
-
import gc
|
|
4
3
|
import io
|
|
5
4
|
import math
|
|
6
5
|
import os
|
|
7
6
|
import re
|
|
7
|
+
import sys
|
|
8
8
|
import tempfile
|
|
9
9
|
import threading
|
|
10
10
|
import time
|
|
@@ -920,7 +920,10 @@ def _read_remote_file_with_progress(
|
|
|
920
920
|
if capture_first_line and not seen_nl:
|
|
921
921
|
header_line = bytes(header_acc)
|
|
922
922
|
|
|
923
|
-
|
|
923
|
+
if sys.version_info >= (3, 13):
|
|
924
|
+
mv = memoryview(buf.getvalue())
|
|
925
|
+
else:
|
|
926
|
+
mv = buf.getbuffer() # zero-copy view of BytesIO internal buffer
|
|
924
927
|
return buf, mv, header_line
|
|
925
928
|
|
|
926
929
|
|
|
@@ -999,7 +1002,10 @@ def _iter_mv_chunks(mv: memoryview,
|
|
|
999
1002
|
n = mv.nbytes
|
|
1000
1003
|
while pos < n:
|
|
1001
1004
|
nxt = min(n, pos + part_size)
|
|
1002
|
-
|
|
1005
|
+
if sys.version_info >= (3, 13):
|
|
1006
|
+
yield mv[pos:nxt].tobytes()
|
|
1007
|
+
else:
|
|
1008
|
+
yield mv[pos:nxt] # zero-copy slice
|
|
1003
1009
|
pos = nxt
|
|
1004
1010
|
|
|
1005
1011
|
|
|
@@ -1473,13 +1479,17 @@ def _remote_upload_file(name: str, fs: Filesystem, url: str, info: dict,
|
|
|
1473
1479
|
if renamed_cols_msg:
|
|
1474
1480
|
logger.info(renamed_cols_msg)
|
|
1475
1481
|
|
|
1482
|
+
try:
|
|
1483
|
+
if isinstance(data_mv, memoryview):
|
|
1484
|
+
data_mv.release()
|
|
1485
|
+
except Exception:
|
|
1486
|
+
pass
|
|
1487
|
+
|
|
1476
1488
|
try:
|
|
1477
1489
|
if buf:
|
|
1478
1490
|
buf.close()
|
|
1479
1491
|
except Exception:
|
|
1480
1492
|
pass
|
|
1481
|
-
del buf, data_mv, header_line
|
|
1482
|
-
gc.collect()
|
|
1483
1493
|
|
|
1484
1494
|
logger.info("Upload complete. Validated table %s.", name)
|
|
1485
1495
|
|
|
@@ -1719,13 +1729,17 @@ def _remote_upload_directory(
|
|
|
1719
1729
|
else:
|
|
1720
1730
|
break
|
|
1721
1731
|
|
|
1732
|
+
try:
|
|
1733
|
+
if isinstance(data_mv, memoryview):
|
|
1734
|
+
data_mv.release()
|
|
1735
|
+
except Exception:
|
|
1736
|
+
pass
|
|
1737
|
+
|
|
1722
1738
|
try:
|
|
1723
1739
|
if buf:
|
|
1724
1740
|
buf.close()
|
|
1725
1741
|
except Exception:
|
|
1726
1742
|
pass
|
|
1727
|
-
del buf, data_mv, header_line
|
|
1728
|
-
gc.collect()
|
|
1729
1743
|
|
|
1730
1744
|
_safe_bar_update(file_bar, 1)
|
|
1731
1745
|
_merge_status_update(fpath)
|
|
@@ -6,11 +6,11 @@ import socket
|
|
|
6
6
|
import threading
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from enum import Enum
|
|
9
|
-
from typing import Dict, Optional, Tuple
|
|
10
9
|
from urllib.parse import urlparse
|
|
11
10
|
|
|
12
11
|
import kumoai
|
|
13
12
|
from kumoai.client.client import KumoClient
|
|
13
|
+
from kumoai.spcs import _get_active_session
|
|
14
14
|
|
|
15
15
|
from .authenticate import authenticate
|
|
16
16
|
from .sagemaker import (
|
|
@@ -20,6 +20,7 @@ from .sagemaker import (
|
|
|
20
20
|
from .base import Table
|
|
21
21
|
from .backend.local import LocalTable
|
|
22
22
|
from .graph import Graph
|
|
23
|
+
from .task_table import TaskTable
|
|
23
24
|
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
24
25
|
|
|
25
26
|
logger = logging.getLogger('kumoai_rfm')
|
|
@@ -49,7 +50,8 @@ class InferenceBackend(str, Enum):
|
|
|
49
50
|
|
|
50
51
|
|
|
51
52
|
def _detect_backend(
|
|
52
|
-
url: str
|
|
53
|
+
url: str, #
|
|
54
|
+
) -> tuple[InferenceBackend, str | None, str | None]:
|
|
53
55
|
parsed = urlparse(url)
|
|
54
56
|
|
|
55
57
|
# Remote SageMaker
|
|
@@ -73,12 +75,27 @@ def _detect_backend(
|
|
|
73
75
|
return InferenceBackend.REST, None, None
|
|
74
76
|
|
|
75
77
|
|
|
78
|
+
def _get_snowflake_url(snowflake_application: str) -> str:
|
|
79
|
+
snowpark_session = _get_active_session()
|
|
80
|
+
if not snowpark_session:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"KumoRFM initialization failed. 'snowflake_application' is "
|
|
83
|
+
"specified without an active Snowpark session. If running outside "
|
|
84
|
+
"a Snowflake notebook, specify a URL and credentials.")
|
|
85
|
+
with snowpark_session.connection.cursor() as cur:
|
|
86
|
+
cur.execute(
|
|
87
|
+
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
88
|
+
f" ->> SELECT \"dns_name\" from $1")
|
|
89
|
+
dns_name: str = cur.fetchone()[0]
|
|
90
|
+
return f"http://{dns_name}:8000/api"
|
|
91
|
+
|
|
92
|
+
|
|
76
93
|
@dataclass
|
|
77
94
|
class RfmGlobalState:
|
|
78
95
|
_url: str = '__url_not_provided__'
|
|
79
96
|
_backend: InferenceBackend = InferenceBackend.UNKNOWN
|
|
80
|
-
_region:
|
|
81
|
-
_endpoint_name:
|
|
97
|
+
_region: str | None = None
|
|
98
|
+
_endpoint_name: str | None = None
|
|
82
99
|
_thread_local = threading.local()
|
|
83
100
|
|
|
84
101
|
# Thread-safe init-once.
|
|
@@ -87,6 +104,9 @@ class RfmGlobalState:
|
|
|
87
104
|
|
|
88
105
|
@property
|
|
89
106
|
def client(self) -> KumoClient:
|
|
107
|
+
if self._backend == InferenceBackend.UNKNOWN:
|
|
108
|
+
raise RuntimeError("KumoRFM is not yet initialized")
|
|
109
|
+
|
|
90
110
|
if self._backend == InferenceBackend.REST:
|
|
91
111
|
return kumoai.global_state.client
|
|
92
112
|
|
|
@@ -121,52 +141,58 @@ global_state = RfmGlobalState()
|
|
|
121
141
|
|
|
122
142
|
|
|
123
143
|
def init(
|
|
124
|
-
url:
|
|
125
|
-
api_key:
|
|
126
|
-
snowflake_credentials:
|
|
127
|
-
snowflake_application:
|
|
144
|
+
url: str | None = None,
|
|
145
|
+
api_key: str | None = None,
|
|
146
|
+
snowflake_credentials: dict[str, str] | None = None,
|
|
147
|
+
snowflake_application: str | None = None,
|
|
128
148
|
log_level: str = "INFO",
|
|
129
149
|
) -> None:
|
|
130
150
|
with global_state._lock:
|
|
131
151
|
if global_state._initialized:
|
|
132
152
|
if url != global_state._url:
|
|
133
|
-
raise
|
|
134
|
-
"
|
|
135
|
-
"URL. Re-initialization with a different URL is not "
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"KumoRFM has already been initialized with a different "
|
|
155
|
+
"API URL. Re-initialization with a different URL is not "
|
|
136
156
|
"supported.")
|
|
137
157
|
return
|
|
138
158
|
|
|
159
|
+
if snowflake_application:
|
|
160
|
+
if url is not None:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"KumoRFM initialization failed. Both "
|
|
163
|
+
"'snowflake_application' and 'url' are specified. If "
|
|
164
|
+
"running from a Snowflake notebook, specify only "
|
|
165
|
+
"'snowflake_application'.")
|
|
166
|
+
url = _get_snowflake_url(snowflake_application)
|
|
167
|
+
api_key = "test:DISABLED"
|
|
168
|
+
|
|
139
169
|
if url is None:
|
|
140
170
|
url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
|
|
141
171
|
|
|
142
172
|
backend, region, endpoint_name = _detect_backend(url)
|
|
143
173
|
if backend == InferenceBackend.REST:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
kumoai.init(url=url, api_key=api_key,
|
|
152
|
-
snowflake_credentials=snowflake_credentials,
|
|
153
|
-
snowflake_application=snowflake_application,
|
|
154
|
-
log_level=log_level)
|
|
174
|
+
kumoai.init(
|
|
175
|
+
url=url,
|
|
176
|
+
api_key=api_key,
|
|
177
|
+
snowflake_credentials=snowflake_credentials,
|
|
178
|
+
snowflake_application=snowflake_application,
|
|
179
|
+
log_level=log_level,
|
|
180
|
+
)
|
|
155
181
|
elif backend == InferenceBackend.AWS_SAGEMAKER:
|
|
156
182
|
assert region
|
|
157
183
|
assert endpoint_name
|
|
158
184
|
KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
|
|
185
|
+
logger.info("KumoRFM initialized in AWS SageMaker")
|
|
159
186
|
else:
|
|
160
187
|
assert backend == InferenceBackend.LOCAL_SAGEMAKER
|
|
161
188
|
KumoClient_SageMakerProxy_Local(url).authenticate()
|
|
189
|
+
logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
|
|
162
190
|
|
|
163
191
|
global_state._url = url
|
|
164
192
|
global_state._backend = backend
|
|
165
193
|
global_state._region = region
|
|
166
194
|
global_state._endpoint_name = endpoint_name
|
|
167
195
|
global_state._initialized = True
|
|
168
|
-
logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
|
|
169
|
-
url)
|
|
170
196
|
|
|
171
197
|
|
|
172
198
|
LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
|
|
@@ -177,6 +203,7 @@ __all__ = [
|
|
|
177
203
|
'Table',
|
|
178
204
|
'LocalTable',
|
|
179
205
|
'Graph',
|
|
206
|
+
'TaskTable',
|
|
180
207
|
'KumoRFM',
|
|
181
208
|
'ExplainConfig',
|
|
182
209
|
'Explanation',
|
|
@@ -2,12 +2,11 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
import platform
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
from kumoai import in_notebook
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
def authenticate(api_url:
|
|
9
|
+
def authenticate(api_url: str | None = None) -> None:
|
|
11
10
|
"""Authenticates the user and sets the Kumo API key for the SDK.
|
|
12
11
|
|
|
13
12
|
This function detects the current environment and launches the appropriate
|
|
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
|
|
|
65
64
|
import webbrowser
|
|
66
65
|
from getpass import getpass
|
|
67
66
|
from socketserver import TCPServer
|
|
68
|
-
from typing import Any
|
|
67
|
+
from typing import Any
|
|
69
68
|
|
|
70
69
|
logger = logging.getLogger('kumoai')
|
|
71
70
|
|
|
72
|
-
token_status:
|
|
71
|
+
token_status: dict[str, Any] = {
|
|
73
72
|
'token': None,
|
|
74
73
|
'token_name': None,
|
|
75
74
|
'failed': False
|
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from kumoapi.rfm.context import Subgraph
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
6
|
|
|
9
7
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
10
|
-
from kumoai.
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
9
|
+
from kumoai.utils import ProgressLogger
|
|
11
10
|
|
|
12
11
|
try:
|
|
13
12
|
import torch
|
|
@@ -23,12 +22,12 @@ class LocalGraphStore:
|
|
|
23
22
|
def __init__(
|
|
24
23
|
self,
|
|
25
24
|
graph: 'Graph',
|
|
26
|
-
verbose:
|
|
25
|
+
verbose: bool | ProgressLogger = True,
|
|
27
26
|
) -> None:
|
|
28
27
|
|
|
29
28
|
if not isinstance(verbose, ProgressLogger):
|
|
30
|
-
verbose =
|
|
31
|
-
"Materializing graph",
|
|
29
|
+
verbose = ProgressLogger.default(
|
|
30
|
+
msg="Materializing graph",
|
|
32
31
|
verbose=verbose,
|
|
33
32
|
)
|
|
34
33
|
|
|
@@ -94,7 +93,7 @@ class LocalGraphStore:
|
|
|
94
93
|
def sanitize(
|
|
95
94
|
self,
|
|
96
95
|
graph: 'Graph',
|
|
97
|
-
) ->
|
|
96
|
+
) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
|
|
98
97
|
r"""Sanitizes raw data according to table schema definition:
|
|
99
98
|
|
|
100
99
|
In particular, it:
|
|
@@ -103,30 +102,24 @@ class LocalGraphStore:
|
|
|
103
102
|
* drops duplicate primary keys
|
|
104
103
|
* removes rows with missing primary keys or time values
|
|
105
104
|
"""
|
|
106
|
-
df_dict:
|
|
105
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
107
106
|
for table_name, table in graph.tables.items():
|
|
108
107
|
assert isinstance(table, LocalTable)
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
111
119
|
|
|
112
|
-
mask_dict:
|
|
120
|
+
mask_dict: dict[str, np.ndarray] = {}
|
|
113
121
|
for table in graph.tables.values():
|
|
114
|
-
|
|
115
|
-
if col.stype == Stype.timestamp:
|
|
116
|
-
ser = df_dict[table.name][col.name]
|
|
117
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
118
|
-
with warnings.catch_warnings():
|
|
119
|
-
warnings.filterwarnings(
|
|
120
|
-
'ignore',
|
|
121
|
-
message='Could not infer format',
|
|
122
|
-
)
|
|
123
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
124
|
-
df_dict[table.name][col.name] = ser
|
|
125
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
126
|
-
ser = ser.dt.tz_localize(None)
|
|
127
|
-
df_dict[table.name][col.name] = ser
|
|
128
|
-
|
|
129
|
-
mask: Optional[np.ndarray] = None
|
|
122
|
+
mask: np.ndarray | None = None
|
|
130
123
|
if table._time_column is not None:
|
|
131
124
|
ser = df_dict[table.name][table._time_column]
|
|
132
125
|
mask = ser.notna().to_numpy()
|
|
@@ -144,8 +137,8 @@ class LocalGraphStore:
|
|
|
144
137
|
def get_pkey_map_dict(
|
|
145
138
|
self,
|
|
146
139
|
graph: 'Graph',
|
|
147
|
-
) ->
|
|
148
|
-
pkey_map_dict:
|
|
140
|
+
) -> dict[str, pd.DataFrame]:
|
|
141
|
+
pkey_map_dict: dict[str, pd.DataFrame] = {}
|
|
149
142
|
|
|
150
143
|
for table in graph.tables.values():
|
|
151
144
|
if table._primary_key is None:
|
|
@@ -177,19 +170,17 @@ class LocalGraphStore:
|
|
|
177
170
|
def get_time_data(
|
|
178
171
|
self,
|
|
179
172
|
graph: 'Graph',
|
|
180
|
-
) ->
|
|
181
|
-
|
|
182
|
-
|
|
173
|
+
) -> tuple[
|
|
174
|
+
dict[str, np.ndarray],
|
|
175
|
+
dict[str, tuple[pd.Timestamp, pd.Timestamp]],
|
|
183
176
|
]:
|
|
184
|
-
time_dict:
|
|
185
|
-
min_max_time_dict:
|
|
177
|
+
time_dict: dict[str, np.ndarray] = {}
|
|
178
|
+
min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
186
179
|
for table in graph.tables.values():
|
|
187
180
|
if table._time_column is None:
|
|
188
181
|
continue
|
|
189
182
|
|
|
190
183
|
time = self.df_dict[table.name][table._time_column]
|
|
191
|
-
if time.dtype != 'datetime64[ns]':
|
|
192
|
-
time = time.astype('datetime64[ns]')
|
|
193
184
|
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
194
185
|
|
|
195
186
|
if table.name in self.mask_dict.keys():
|
|
@@ -207,15 +198,15 @@ class LocalGraphStore:
|
|
|
207
198
|
def get_csc(
|
|
208
199
|
self,
|
|
209
200
|
graph: 'Graph',
|
|
210
|
-
) ->
|
|
211
|
-
|
|
212
|
-
|
|
201
|
+
) -> tuple[
|
|
202
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
203
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
213
204
|
]:
|
|
214
205
|
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
215
|
-
map_dict:
|
|
206
|
+
map_dict: dict[str, pd.CategoricalDtype] = {}
|
|
216
207
|
# A dictionary to manage offsets of node indices for invalid rows:
|
|
217
|
-
offset_dict:
|
|
218
|
-
for table_name in
|
|
208
|
+
offset_dict: dict[str, np.ndarray] = {}
|
|
209
|
+
for table_name in {edge.dst_table for edge in graph.edges}:
|
|
219
210
|
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
220
211
|
if table_name in self.mask_dict.keys():
|
|
221
212
|
mask = self.mask_dict[table_name]
|
|
@@ -224,8 +215,8 @@ class LocalGraphStore:
|
|
|
224
215
|
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
225
216
|
|
|
226
217
|
# Build CSC graph representation:
|
|
227
|
-
row_dict:
|
|
228
|
-
colptr_dict:
|
|
218
|
+
row_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
219
|
+
colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
229
220
|
for src_table, fkey, dst_table in graph.edges:
|
|
230
221
|
src_df = self.df_dict[src_table]
|
|
231
222
|
dst_df = self.df_dict[dst_table]
|
|
@@ -287,7 +278,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
|
|
|
287
278
|
return torch.from_numpy(input).argsort().numpy()
|
|
288
279
|
|
|
289
280
|
|
|
290
|
-
def _lexsort(inputs:
|
|
281
|
+
def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
|
|
291
282
|
assert len(inputs) >= 1
|
|
292
283
|
|
|
293
284
|
if not WITH_TORCH:
|
|
@@ -19,7 +19,7 @@ class LocalSampler(Sampler):
|
|
|
19
19
|
graph: 'Graph',
|
|
20
20
|
verbose: bool | ProgressLogger = True,
|
|
21
21
|
) -> None:
|
|
22
|
-
super().__init__(graph=graph)
|
|
22
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
23
23
|
|
|
24
24
|
import kumoai.kumolib as kumolib
|
|
25
25
|
|
|
@@ -191,6 +191,8 @@ class LocalSampler(Sampler):
|
|
|
191
191
|
|
|
192
192
|
return train_y, train_mask, test_y, test_mask
|
|
193
193
|
|
|
194
|
+
# Helper Methods ##########################################################
|
|
195
|
+
|
|
194
196
|
def _sample_target_set(
|
|
195
197
|
self,
|
|
196
198
|
query: ValidatedPredictiveQuery,
|
|
@@ -217,9 +219,6 @@ class LocalSampler(Sampler):
|
|
|
217
219
|
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
218
220
|
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
219
221
|
|
|
220
|
-
if anchor_time.dtype != 'datetime64[ns]':
|
|
221
|
-
anchor_time = anchor_time.astype('datetime64')
|
|
222
|
-
|
|
223
222
|
count = 0
|
|
224
223
|
ys: list[pd.Series] = []
|
|
225
224
|
mask = np.full(len(index), False, dtype=bool)
|
|
@@ -272,7 +271,7 @@ class LocalSampler(Sampler):
|
|
|
272
271
|
return y, mask
|
|
273
272
|
|
|
274
273
|
|
|
275
|
-
# Helper
|
|
274
|
+
# Helper Functions ############################################################
|
|
276
275
|
|
|
277
276
|
|
|
278
277
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|