kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +162 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
- kumoai/experimental/rfm/base/__init__.py +23 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
- kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +233 -174
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
kumoai/testing/snow.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
5
|
+
from kumoai.experimental.rfm.backend.snow import connect as _connect
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def connect(
|
|
9
|
+
region: str,
|
|
10
|
+
id: str,
|
|
11
|
+
account: str,
|
|
12
|
+
user: str,
|
|
13
|
+
warehouse: str,
|
|
14
|
+
database: str | None = None,
|
|
15
|
+
schema: str | None = None,
|
|
16
|
+
) -> Connection:
|
|
17
|
+
|
|
18
|
+
kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
|
|
19
|
+
if kwargs['password'] is None:
|
|
20
|
+
import boto3
|
|
21
|
+
from cryptography.hazmat.primitives import serialization
|
|
22
|
+
|
|
23
|
+
client = boto3.client(
|
|
24
|
+
service_name='secretsmanager',
|
|
25
|
+
region_name=region,
|
|
26
|
+
)
|
|
27
|
+
secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
|
|
28
|
+
f'{account}.snowflakecomputing.com')
|
|
29
|
+
response = client.get_secret_value(SecretId=secret_id)['SecretString']
|
|
30
|
+
secret = json.loads(response)
|
|
31
|
+
|
|
32
|
+
private_key = serialization.load_pem_private_key(
|
|
33
|
+
secret['kumo_user_secretkey'].encode(),
|
|
34
|
+
password=None,
|
|
35
|
+
)
|
|
36
|
+
kwargs['private_key'] = private_key.private_bytes(
|
|
37
|
+
encoding=serialization.Encoding.DER,
|
|
38
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
39
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return _connect(
|
|
43
|
+
account=account,
|
|
44
|
+
user=user,
|
|
45
|
+
warehouse='WH_XS',
|
|
46
|
+
database='KUMO',
|
|
47
|
+
schema=schema,
|
|
48
|
+
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|
kumoai/utils/__init__.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
from .sql import quote_ident
|
|
1
2
|
from .progress_logger import ProgressLogger, InteractiveProgressLogger
|
|
2
3
|
from .forecasting import ForecastVisualizer
|
|
3
4
|
from .datasets import from_relbench
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
7
|
+
'quote_ident',
|
|
6
8
|
'ProgressLogger',
|
|
7
9
|
'InteractiveProgressLogger',
|
|
8
10
|
'ForecastVisualizer',
|
kumoai/utils/sql.py
ADDED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202512141732
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -38,6 +38,16 @@ Provides-Extra: test
|
|
|
38
38
|
Requires-Dist: pytest; extra == "test"
|
|
39
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
40
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
46
|
+
Provides-Extra: sagemaker
|
|
47
|
+
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
48
|
+
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
49
|
+
Provides-Extra: test-sagemaker
|
|
50
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
41
51
|
Dynamic: license-file
|
|
42
52
|
Dynamic: requires-dist
|
|
43
53
|
|
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
kumoai/__init__.py,sha256=
|
|
1
|
+
kumoai/__init__.py,sha256=Nn9YH_x9kAeEFn8RWbP95slZow0qFnakPZZ1WADe1hY,10843
|
|
2
2
|
kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
|
|
3
3
|
kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
|
|
4
|
-
kumoai/_version.py,sha256=
|
|
4
|
+
kumoai/_version.py,sha256=jDuUtpQbekehTWtqQXeU4rYieU0OA-7URATbFO1iyvo,39
|
|
5
5
|
kumoai/databricks.py,sha256=e6E4lOFvZHXFwh4CO1kXU1zzDU3AapLQYMxjiHPC-HQ,476
|
|
6
6
|
kumoai/exceptions.py,sha256=b-_sdbAKOg50uaJZ65GmBLdTo4HANdjl8_R0sJpwaN0,833
|
|
7
7
|
kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
|
|
8
8
|
kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
|
|
9
9
|
kumoai/jobs.py,sha256=NrdLEFNo7oeCYSy-kj2nAvCFrz9BZ_xrhkqHFHk5ksY,2496
|
|
10
10
|
kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
|
|
11
|
-
kumoai/spcs.py,sha256=
|
|
11
|
+
kumoai/spcs.py,sha256=N31d7rLa-bgYh8e2J4YzX1ScxGLqiVXrqJnCl1y4Mts,4139
|
|
12
12
|
kumoai/artifact_export/__init__.py,sha256=BsfDrc3mCHpO9-BqvqKm8qrXDIwfdaoH5UIoG4eQkc4,238
|
|
13
13
|
kumoai/artifact_export/config.py,sha256=jOPDduduxv0uuB-7xVlDiZglfpmFF5lzQhhH1SMkGvw,8024
|
|
14
14
|
kumoai/artifact_export/job.py,sha256=GEisSwvcjK_35RgOfsLXGgxMTXIWm765B_BW_Kgs-V0,3275
|
|
15
15
|
kumoai/client/__init__.py,sha256=MkyOuMaHQ2c8GPxjBDQSVFhfRE2d2_6CXQ6rxj4ps4w,64
|
|
16
|
-
kumoai/client/client.py,sha256=
|
|
16
|
+
kumoai/client/client.py,sha256=Jda8V9yiu3LbhxlcgRWPeYi7eF6jzCKcq8-B_vEd1ik,8514
|
|
17
17
|
kumoai/client/connector.py,sha256=x3i2aBTJTEMZvYRcWkY-UfWVOANZjqAso4GBbcshFjw,3920
|
|
18
18
|
kumoai/client/endpoints.py,sha256=iF2ZD25AJCIVbmBJ8tTZ8y1Ch0m6nTp18ydN7h4WiTk,5382
|
|
19
19
|
kumoai/client/graph.py,sha256=zvLEDExLT_RVbUMHqVl0m6tO6s2gXmYSoWmPF6YMlnA,3831
|
|
20
20
|
kumoai/client/jobs.py,sha256=iu_Wrta6BQMlV6ZtzSnmhjwNPKDMQDXOsqVVIyWodqw,17074
|
|
21
21
|
kumoai/client/online.py,sha256=pkBBh_DEC3GAnPcNw6bopNRlGe7EUbIFe7_seQqZRaw,2720
|
|
22
|
-
kumoai/client/pquery.py,sha256=
|
|
22
|
+
kumoai/client/pquery.py,sha256=IQ8As-OOJOkuMoMosphOsA5hxQYLCbzOQJO7RezK8uY,7091
|
|
23
23
|
kumoai/client/rfm.py,sha256=NxKk8mH2A-B58rSXhDWaph4KeiSyJYDq-RO-vAHh7es,3726
|
|
24
24
|
kumoai/client/source_table.py,sha256=VCsCcM7KYcnjGP7HLTb-AOSEGEVsJTWjk8bMg1JdgPU,2101
|
|
25
25
|
kumoai/client/table.py,sha256=cQG-RPm-e91idEgse1IPJDvBmzddIDGDkuyrR1rq4wU,3235
|
|
@@ -49,26 +49,41 @@ kumoai/connector/glue_connector.py,sha256=HivT0QYQ8-XeB4QLgWvghiqXuq7jyBK9G2R1py
|
|
|
49
49
|
kumoai/connector/s3_connector.py,sha256=3kbv-h7DwD8O260Q0h1GPm5wwQpLt-Tb3d_CBSaie44,10155
|
|
50
50
|
kumoai/connector/snowflake_connector.py,sha256=K0s-H9tW3rve8g2x1PbyxvzSpkROfGQZz-Qa4PoT4UE,9022
|
|
51
51
|
kumoai/connector/source_table.py,sha256=QLT8bEYaxeMwy-b168url0VfnkTrs5K6VKLbxTI4hEY,17539
|
|
52
|
-
kumoai/connector/utils.py,sha256=
|
|
52
|
+
kumoai/connector/utils.py,sha256=wlqQxMmPvnFNoCcczGkKYjSu05h8OhWh4fhTzQm_2bQ,64694
|
|
53
53
|
kumoai/encoder/__init__.py,sha256=VPGs4miBC_WfwWeOXeHhFomOUocERFavhKf5fqITcds,182
|
|
54
54
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
-
kumoai/experimental/rfm/__init__.py,sha256=
|
|
55
|
+
kumoai/experimental/rfm/__init__.py,sha256=slliYcrh80xPtQQ_nnsp3ny9IbmHCyirmdZUfKTdME4,6064
|
|
56
56
|
kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
|
|
57
|
-
kumoai/experimental/rfm/
|
|
58
|
-
kumoai/experimental/rfm/
|
|
59
|
-
kumoai/experimental/rfm/
|
|
60
|
-
kumoai/experimental/rfm/
|
|
61
|
-
kumoai/experimental/rfm/
|
|
62
|
-
kumoai/experimental/rfm/
|
|
63
|
-
kumoai/experimental/rfm/
|
|
64
|
-
kumoai/experimental/rfm/
|
|
57
|
+
kumoai/experimental/rfm/graph.py,sha256=awVJSk4cWRMacS5CJvJtR8TR56FEbrJPcQCukNydQOc,40392
|
|
58
|
+
kumoai/experimental/rfm/rfm.py,sha256=j4LfCbmHPs9RrwRNjTEE-sPUC6RcHNcymKBtG_Rt-M4,49670
|
|
59
|
+
kumoai/experimental/rfm/sagemaker.py,sha256=_hTrFg4qfXe7uzwqSEG_wze-IFkwn7qde9OpUodCpbc,4982
|
|
60
|
+
kumoai/experimental/rfm/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
61
|
+
kumoai/experimental/rfm/backend/local/__init__.py,sha256=2s9sSA-E-8pfkkzCH4XPuaSxSznEURMfMgwEIfYYPsg,1014
|
|
62
|
+
kumoai/experimental/rfm/backend/local/graph_store.py,sha256=5cHuExHljU_Z56KV3s-PwzeiLuPKgh2mCxcjTMmPZ8E,11928
|
|
63
|
+
kumoai/experimental/rfm/backend/local/sampler.py,sha256=wJXQdiOOCIjqpNkDoXgPV0cnsKjeJ2Zi77-pJ8zAEM4,10738
|
|
64
|
+
kumoai/experimental/rfm/backend/local/table.py,sha256=-R_9nncosByAfSMfUt6HgCUNoW_MLGJW3F5SnAd4Ru0,3744
|
|
65
|
+
kumoai/experimental/rfm/backend/snow/__init__.py,sha256=BYfsiuJ4Ee30GjG9EuUtitMHXnRfvVKi85zNlIwldV4,993
|
|
66
|
+
kumoai/experimental/rfm/backend/snow/sampler.py,sha256=d5hyKyiNxnjmINGxM7K5nCVnN7ZaknM1QdHLt8EBt5c,4287
|
|
67
|
+
kumoai/experimental/rfm/backend/snow/table.py,sha256=bnBGgkXmgLcSZlr0t4pYo9aSk7DO5PuSDM06HBgA5BE,4841
|
|
68
|
+
kumoai/experimental/rfm/backend/sqlite/__init__.py,sha256=cA-PZL1oTaLxthZbfLSudexImtF6jRsGkdjSp-66dCM,914
|
|
69
|
+
kumoai/experimental/rfm/backend/sqlite/sampler.py,sha256=n3etgKH7viiIgMXHFLM3KQvKvCpylXRTw0WBNoSKd7s,4057
|
|
70
|
+
kumoai/experimental/rfm/backend/sqlite/table.py,sha256=9F2_6iqgM3h9jkIcecHnEwQq-UG3RBmR9Y9XPcFLR04,4161
|
|
71
|
+
kumoai/experimental/rfm/base/__init__.py,sha256=IdQwFFEeXZnTBkj1I65qQDL6QlcbmjVoTP6-quCtmFc,481
|
|
72
|
+
kumoai/experimental/rfm/base/column.py,sha256=izCJmufJcd1RSi-ptFMfrue-JYag38MJxizka7ya0-A,2319
|
|
73
|
+
kumoai/experimental/rfm/base/sampler.py,sha256=aCD98t0CUhAvGXEFv24Vq2g4otuclpKkkyL1rMR_mFg,31449
|
|
74
|
+
kumoai/experimental/rfm/base/source.py,sha256=RqlI_kBoRV0ADb8KdEKn15RNHMdFUzEVzb57lIoyBM4,294
|
|
75
|
+
kumoai/experimental/rfm/base/table.py,sha256=neGldEZaweoJ8VRgnEnaSpAISSkSTkgXxItuuywBM4E,20010
|
|
76
|
+
kumoai/experimental/rfm/infer/__init__.py,sha256=krdMFN8iKZlSFOl-M5MW1KuSviQV3H1E18jj2uB8g6Q,469
|
|
65
77
|
kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
|
|
78
|
+
kumoai/experimental/rfm/infer/dtype.py,sha256=ZZ6ztqJnTR1CaC2z5Uhf0o0rSdNThnss5tem5JNQkck,2607
|
|
66
79
|
kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
|
|
67
80
|
kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
|
|
81
|
+
kumoai/experimental/rfm/infer/pkey.py,sha256=ubNqW1LIjLKiXbjXELAY3g6n2f3u2Eis_uC2DEiXFiU,4393
|
|
82
|
+
kumoai/experimental/rfm/infer/time_col.py,sha256=7R5Itl8RRBOr61qLpRTanIqrUVZFZcAXzDA9lCw4nx4,1820
|
|
68
83
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=vM9--7eStzaGG13Y-oLYlpNJyhL6f9dp17HDXwtl_DM,1094
|
|
69
84
|
kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
|
|
70
85
|
kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
|
|
71
|
-
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=
|
|
86
|
+
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=wYI9a3smClR2pQGwsYRdmpOm0PlUsbtyW9wpAVpCEe4,18492
|
|
72
87
|
kumoai/graph/__init__.py,sha256=n8X4X8luox4hPBHTRC9R-3JzvYYMoR8n7lF1H4w4Hzc,228
|
|
73
88
|
kumoai/graph/column.py,sha256=t7wBmcx0VYKXjIoESU9Nq-AisiJOdlqd80t8zby1R8Y,4189
|
|
74
89
|
kumoai/graph/graph.py,sha256=iyp4klPIMn2ttuEqMJvsrxKb_tmz_DTnvziIhCegduM,38291
|
|
@@ -78,7 +93,8 @@ kumoai/pquery/prediction_table.py,sha256=QPDH22X1UB0NIufY7qGuV2XW7brG3Pv--FbjNez
|
|
|
78
93
|
kumoai/pquery/predictive_query.py,sha256=UXn1s8ztubYZMNGl4ijaeidMiGlFveb1TGw9qI5-TAo,24901
|
|
79
94
|
kumoai/pquery/training_table.py,sha256=elmPDZx11kPiC_dkOhJcBUGtHKgL32GCBvZ9k6U0pMg,15809
|
|
80
95
|
kumoai/testing/__init__.py,sha256=goHIIo3JE7uHV7njo4_aTd89mVVR74BEAZ2uyBaOR0w,170
|
|
81
|
-
kumoai/testing/decorators.py,sha256=
|
|
96
|
+
kumoai/testing/decorators.py,sha256=83tMifuPTpUqX7zHxMttkj1TDdB62EBtAP-Fjj72Zdo,1607
|
|
97
|
+
kumoai/testing/snow.py,sha256=ubx3yJP0UHxsNiar1-jNdv8ZfszKc8Js3_Gg70uf008,1487
|
|
82
98
|
kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
|
|
83
99
|
kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06ogtEVc,4024
|
|
84
100
|
kumoai/trainer/config.py,sha256=-2RfK10AsVVThSyfWtlyfH4Fc4EwTdu0V3yrDRtIOjk,98
|
|
@@ -86,12 +102,13 @@ kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
|
|
|
86
102
|
kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
|
|
87
103
|
kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
|
|
88
104
|
kumoai/trainer/util.py,sha256=bDPGkMF9KOy4HgtA-OwhXP17z9cbrfMnZGtyGuUq_Eo,4062
|
|
89
|
-
kumoai/utils/__init__.py,sha256=
|
|
105
|
+
kumoai/utils/__init__.py,sha256=cF5ACzp1X61sqhlCHc6biQk6fc4gW_oyhGsBrjx-SoM,316
|
|
90
106
|
kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
|
|
91
107
|
kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
|
|
92
108
|
kumoai/utils/progress_logger.py,sha256=pngEGzMHkiOUKOa6fbzxCEc2xlA4SJKV4TDTVVoqObM,5062
|
|
93
|
-
kumoai
|
|
94
|
-
kumoai-2.
|
|
95
|
-
kumoai-2.
|
|
96
|
-
kumoai-2.
|
|
97
|
-
kumoai-2.
|
|
109
|
+
kumoai/utils/sql.py,sha256=f6lR6rBEW7Dtk0NdM26dOZXUHDizEHb1WPlBCJrwoq0,118
|
|
110
|
+
kumoai-2.14.0.dev202512141732.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
|
|
111
|
+
kumoai-2.14.0.dev202512141732.dist-info/METADATA,sha256=afHVHJGv5rxq96TO9d3VqyuN_Nhvf-dFGx3nDJo41lk,2510
|
|
112
|
+
kumoai-2.14.0.dev202512141732.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
113
|
+
kumoai-2.14.0.dev202512141732.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
114
|
+
kumoai-2.14.0.dev202512141732.dist-info/RECORD,,
|
|
@@ -1,184 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
|
-
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
|
-
|
|
9
|
-
import kumoai.kumolib as kumolib
|
|
10
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.utils import normalize_text
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class LocalGraphSampler:
|
|
15
|
-
def __init__(self, graph_store: LocalGraphStore) -> None:
|
|
16
|
-
self._graph_store = graph_store
|
|
17
|
-
self._sampler = kumolib.NeighborSampler(
|
|
18
|
-
self._graph_store.node_types,
|
|
19
|
-
self._graph_store.edge_types,
|
|
20
|
-
{
|
|
21
|
-
'__'.join(edge_type): colptr
|
|
22
|
-
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
23
|
-
},
|
|
24
|
-
{
|
|
25
|
-
'__'.join(edge_type): row
|
|
26
|
-
for edge_type, row in self._graph_store.row_dict.items()
|
|
27
|
-
},
|
|
28
|
-
self._graph_store.time_dict,
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
def __call__(
|
|
32
|
-
self,
|
|
33
|
-
entity_table_names: Tuple[str, ...],
|
|
34
|
-
node: np.ndarray,
|
|
35
|
-
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
|
-
num_neighbors: List[int],
|
|
38
|
-
exclude_cols_dict: Dict[str, List[str]],
|
|
39
|
-
) -> Subgraph:
|
|
40
|
-
|
|
41
|
-
(
|
|
42
|
-
row_dict,
|
|
43
|
-
col_dict,
|
|
44
|
-
node_dict,
|
|
45
|
-
batch_dict,
|
|
46
|
-
num_sampled_nodes_dict,
|
|
47
|
-
num_sampled_edges_dict,
|
|
48
|
-
) = self._sampler.sample(
|
|
49
|
-
{
|
|
50
|
-
'__'.join(edge_type): num_neighbors
|
|
51
|
-
for edge_type in self._graph_store.edge_types
|
|
52
|
-
},
|
|
53
|
-
{}, # time interval based sampling
|
|
54
|
-
entity_table_names[0],
|
|
55
|
-
node,
|
|
56
|
-
time // 1000**3, # nanoseconds to seconds
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
table_dict: Dict[str, Table] = {}
|
|
60
|
-
for table_name, node in node_dict.items():
|
|
61
|
-
batch = batch_dict[table_name]
|
|
62
|
-
|
|
63
|
-
if len(node) == 0:
|
|
64
|
-
continue
|
|
65
|
-
|
|
66
|
-
df = self._graph_store.df_dict[table_name]
|
|
67
|
-
|
|
68
|
-
num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
|
|
69
|
-
stype_dict = { # Exclude target columns:
|
|
70
|
-
column_name: stype
|
|
71
|
-
for column_name, stype in
|
|
72
|
-
self._graph_store.stype_dict[table_name].items()
|
|
73
|
-
if column_name not in exclude_cols_dict.get(table_name, [])
|
|
74
|
-
}
|
|
75
|
-
primary_key: Optional[str] = None
|
|
76
|
-
if table_name in entity_table_names:
|
|
77
|
-
primary_key = self._graph_store.pkey_name_dict.get(table_name)
|
|
78
|
-
|
|
79
|
-
columns: List[str] = []
|
|
80
|
-
if table_name in entity_table_names:
|
|
81
|
-
columns += [self._graph_store.pkey_name_dict[table_name]]
|
|
82
|
-
columns += list(stype_dict.keys())
|
|
83
|
-
|
|
84
|
-
if len(columns) == 0:
|
|
85
|
-
table_dict[table_name] = Table(
|
|
86
|
-
df=pd.DataFrame(index=range(len(node))),
|
|
87
|
-
row=None,
|
|
88
|
-
batch=batch,
|
|
89
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
90
|
-
stype_dict=stype_dict,
|
|
91
|
-
primary_key=primary_key,
|
|
92
|
-
)
|
|
93
|
-
continue
|
|
94
|
-
|
|
95
|
-
row: Optional[np.ndarray] = None
|
|
96
|
-
if table_name in self._graph_store.end_time_column_dict:
|
|
97
|
-
# Set end time to NaT for all values greater than anchor time:
|
|
98
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
99
|
-
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
100
|
-
ser = df[col_name]
|
|
101
|
-
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
102
|
-
mask = value > time[batch]
|
|
103
|
-
df.loc[mask, col_name] = pd.NaT
|
|
104
|
-
else:
|
|
105
|
-
# Only store unique rows in `df` above a certain threshold:
|
|
106
|
-
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
107
|
-
if len(node) > 1.05 * len(unique_node):
|
|
108
|
-
df = df.iloc[unique_node].reset_index(drop=True)
|
|
109
|
-
row = inverse
|
|
110
|
-
else:
|
|
111
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
112
|
-
|
|
113
|
-
# Filter data frame to minimal set of columns:
|
|
114
|
-
df = df[columns]
|
|
115
|
-
|
|
116
|
-
# Normalize text (if not already pre-processed):
|
|
117
|
-
for column_name, stype in stype_dict.items():
|
|
118
|
-
if stype == Stype.text:
|
|
119
|
-
df[column_name] = normalize_text(df[column_name])
|
|
120
|
-
|
|
121
|
-
table_dict[table_name] = Table(
|
|
122
|
-
df=df,
|
|
123
|
-
row=row,
|
|
124
|
-
batch=batch,
|
|
125
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
126
|
-
stype_dict=stype_dict,
|
|
127
|
-
primary_key=primary_key,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
link_dict: Dict[Tuple[str, str, str], Link] = {}
|
|
131
|
-
for edge_type in self._graph_store.edge_types:
|
|
132
|
-
edge_type_str = '__'.join(edge_type)
|
|
133
|
-
|
|
134
|
-
row = row_dict[edge_type_str]
|
|
135
|
-
col = col_dict[edge_type_str]
|
|
136
|
-
|
|
137
|
-
if len(row) == 0:
|
|
138
|
-
continue
|
|
139
|
-
|
|
140
|
-
# Do not store reverse edge type if it is a replica:
|
|
141
|
-
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
142
|
-
rev_edge_type_str = '__'.join(rev_edge_type)
|
|
143
|
-
if (rev_edge_type in link_dict
|
|
144
|
-
and np.array_equal(row, col_dict[rev_edge_type_str])
|
|
145
|
-
and np.array_equal(col, row_dict[rev_edge_type_str])):
|
|
146
|
-
link = Link(
|
|
147
|
-
layout=EdgeLayout.REV,
|
|
148
|
-
row=None,
|
|
149
|
-
col=None,
|
|
150
|
-
num_sampled_edges=(
|
|
151
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
152
|
-
)
|
|
153
|
-
link_dict[edge_type] = link
|
|
154
|
-
continue
|
|
155
|
-
|
|
156
|
-
layout = EdgeLayout.COO
|
|
157
|
-
if np.array_equal(row, np.arange(len(row))):
|
|
158
|
-
row = None
|
|
159
|
-
if np.array_equal(col, np.arange(len(col))):
|
|
160
|
-
col = None
|
|
161
|
-
|
|
162
|
-
# Store in compressed representation if more efficient:
|
|
163
|
-
num_cols = table_dict[edge_type[2]].num_rows
|
|
164
|
-
if col is not None and len(col) > num_cols + 1:
|
|
165
|
-
layout = EdgeLayout.CSC
|
|
166
|
-
colcount = np.bincount(col, minlength=num_cols)
|
|
167
|
-
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
168
|
-
col[0] = 0
|
|
169
|
-
np.cumsum(colcount, out=col[1:])
|
|
170
|
-
|
|
171
|
-
link = Link(
|
|
172
|
-
layout=layout,
|
|
173
|
-
row=row,
|
|
174
|
-
col=col,
|
|
175
|
-
num_sampled_edges=(
|
|
176
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
177
|
-
)
|
|
178
|
-
link_dict[edge_type] = link
|
|
179
|
-
|
|
180
|
-
return Subgraph(
|
|
181
|
-
anchor_time=time,
|
|
182
|
-
table_dict=table_dict,
|
|
183
|
-
link_dict=link_dict,
|
|
184
|
-
)
|