kumoai 2.12.0.dev202511071730__cp310-cp310-win_amd64.whl → 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +6 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/rfm.py +15 -7
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +244 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +124 -0
- kumoai/experimental/rfm/base/__init__.py +7 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +71 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +144 -57
- kumoai/experimental/rfm/infer/__init__.py +2 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/local_graph_sampler.py +0 -2
- kumoai/experimental/rfm/local_graph_store.py +12 -11
- kumoai/experimental/rfm/local_pquery_driver.py +2 -2
- kumoai/experimental/rfm/rfm.py +83 -28
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/utils.py +1 -120
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/METADATA +11 -2
- {kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/RECORD +33 -23
- {kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,7 +5,17 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
@@ -20,15 +30,16 @@ from kumoapi.rfm import (
|
|
|
20
30
|
)
|
|
21
31
|
from kumoapi.task import TaskType
|
|
22
32
|
|
|
23
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
24
34
|
from kumoai.exceptions import HTTPException
|
|
25
|
-
from kumoai.experimental.rfm import
|
|
35
|
+
from kumoai.experimental.rfm import Graph
|
|
26
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
27
37
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
28
38
|
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
29
39
|
LocalPQueryDriver,
|
|
30
40
|
date_offset_to_seconds,
|
|
31
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
32
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
33
44
|
|
|
34
45
|
_RANDOM_SEED = 42
|
|
@@ -59,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
59
70
|
"beyond this for your use-case.")
|
|
60
71
|
|
|
61
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
62
84
|
@dataclass(repr=False)
|
|
63
85
|
class Explanation:
|
|
64
86
|
prediction: pd.DataFrame
|
|
@@ -86,6 +108,12 @@ class Explanation:
|
|
|
86
108
|
def __repr__(self) -> str:
|
|
87
109
|
return str((self.prediction, self.summary))
|
|
88
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
89
117
|
|
|
90
118
|
class KumoRFM:
|
|
91
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -95,17 +123,17 @@ class KumoRFM:
|
|
|
95
123
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
96
124
|
relational dataset without training.
|
|
97
125
|
The model is pre-trained and the class provides an interface to query the
|
|
98
|
-
model from a :class:`
|
|
126
|
+
model from a :class:`Graph` object.
|
|
99
127
|
|
|
100
128
|
.. code-block:: python
|
|
101
129
|
|
|
102
|
-
from kumoai.experimental.rfm import
|
|
130
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
103
131
|
|
|
104
132
|
df_users = pd.DataFrame(...)
|
|
105
133
|
df_items = pd.DataFrame(...)
|
|
106
134
|
df_orders = pd.DataFrame(...)
|
|
107
135
|
|
|
108
|
-
graph =
|
|
136
|
+
graph = Graph.from_data({
|
|
109
137
|
'users': df_users,
|
|
110
138
|
'items': df_items,
|
|
111
139
|
'orders': df_orders,
|
|
@@ -113,9 +141,9 @@ class KumoRFM:
|
|
|
113
141
|
|
|
114
142
|
rfm = KumoRFM(graph)
|
|
115
143
|
|
|
116
|
-
query = ("PREDICT COUNT(
|
|
117
|
-
"FOR users.user_id=
|
|
118
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
119
147
|
|
|
120
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
121
149
|
# 1 0.85
|
|
@@ -135,7 +163,7 @@ class KumoRFM:
|
|
|
135
163
|
"""
|
|
136
164
|
def __init__(
|
|
137
165
|
self,
|
|
138
|
-
graph:
|
|
166
|
+
graph: Graph,
|
|
139
167
|
preprocess: bool = False,
|
|
140
168
|
verbose: Union[bool, ProgressLogger] = True,
|
|
141
169
|
) -> None:
|
|
@@ -144,9 +172,20 @@ class KumoRFM:
|
|
|
144
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
145
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
146
174
|
|
|
175
|
+
self._client: Optional[RFMAPI] = None
|
|
176
|
+
|
|
147
177
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
148
178
|
self.num_retries: int = 0
|
|
149
179
|
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
150
189
|
def __repr__(self) -> str:
|
|
151
190
|
return f'{self.__class__.__name__}()'
|
|
152
191
|
|
|
@@ -208,7 +247,7 @@ class KumoRFM:
|
|
|
208
247
|
query: str,
|
|
209
248
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
210
249
|
*,
|
|
211
|
-
explain: Literal[True],
|
|
250
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
212
251
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
213
252
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
214
253
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -226,7 +265,7 @@ class KumoRFM:
|
|
|
226
265
|
query: str,
|
|
227
266
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
228
267
|
*,
|
|
229
|
-
explain: bool = False,
|
|
268
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
230
269
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
231
270
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
232
271
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -246,9 +285,12 @@ class KumoRFM:
|
|
|
246
285
|
be generated for all indices, independent of whether they
|
|
247
286
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
248
287
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
249
|
-
explain:
|
|
250
|
-
|
|
251
|
-
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
252
294
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
253
295
|
``None``, will use the maximum timestamp in the data.
|
|
254
296
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -272,16 +314,25 @@ class KumoRFM:
|
|
|
272
314
|
|
|
273
315
|
Returns:
|
|
274
316
|
The predictions as a :class:`pandas.DataFrame`.
|
|
275
|
-
If ``explain
|
|
276
|
-
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
277
319
|
"""
|
|
320
|
+
explain_config: Optional[ExplainConfig] = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
278
326
|
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
279
328
|
|
|
280
329
|
if num_hops != 2 and num_neighbors is not None:
|
|
281
330
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
282
331
|
f"custom 'num_hops={num_hops}' option")
|
|
283
332
|
|
|
284
|
-
if
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
285
336
|
warnings.warn(f"Explainability is currently only supported for "
|
|
286
337
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
287
338
|
f"mode has been reset. Please lower the run mode to "
|
|
@@ -298,13 +349,13 @@ class KumoRFM:
|
|
|
298
349
|
if len(indices) == 0:
|
|
299
350
|
raise ValueError("At least one entity is required")
|
|
300
351
|
|
|
301
|
-
if
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
302
353
|
raise ValueError(
|
|
303
354
|
f"Cannot explain predictions for more than a single entity "
|
|
304
355
|
f"(got {len(indices)})")
|
|
305
356
|
|
|
306
357
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
307
|
-
if
|
|
358
|
+
if explain_config is not None:
|
|
308
359
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
309
360
|
else:
|
|
310
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -355,6 +406,7 @@ class KumoRFM:
|
|
|
355
406
|
request = RFMPredictRequest(
|
|
356
407
|
context=context,
|
|
357
408
|
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
358
410
|
use_prediction_time=use_prediction_time,
|
|
359
411
|
)
|
|
360
412
|
with warnings.catch_warnings():
|
|
@@ -378,12 +430,15 @@ class KumoRFM:
|
|
|
378
430
|
|
|
379
431
|
for attempt in range(self.num_retries + 1):
|
|
380
432
|
try:
|
|
381
|
-
if
|
|
382
|
-
resp =
|
|
433
|
+
if explain_config is not None:
|
|
434
|
+
resp = self._api_client.explain(
|
|
435
|
+
request=_bytes,
|
|
436
|
+
skip_summary=explain_config.skip_summary,
|
|
437
|
+
)
|
|
383
438
|
summary = resp.summary
|
|
384
439
|
details = resp.details
|
|
385
440
|
else:
|
|
386
|
-
resp =
|
|
441
|
+
resp = self._api_client.predict(_bytes)
|
|
387
442
|
df = pd.DataFrame(**resp.prediction)
|
|
388
443
|
|
|
389
444
|
# Cast 'ENTITY' to correct data type:
|
|
@@ -430,7 +485,7 @@ class KumoRFM:
|
|
|
430
485
|
else:
|
|
431
486
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
432
487
|
|
|
433
|
-
if
|
|
488
|
+
if explain_config is not None:
|
|
434
489
|
assert len(predictions) == 1
|
|
435
490
|
assert summary is not None
|
|
436
491
|
assert details is not None
|
|
@@ -586,10 +641,10 @@ class KumoRFM:
|
|
|
586
641
|
|
|
587
642
|
if len(request_bytes) > _MAX_SIZE:
|
|
588
643
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
589
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
644
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
590
645
|
|
|
591
646
|
try:
|
|
592
|
-
resp =
|
|
647
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
593
648
|
except HTTPException as e:
|
|
594
649
|
try:
|
|
595
650
|
msg = json.loads(e.detail)['detail']
|
|
@@ -687,7 +742,8 @@ class KumoRFM:
|
|
|
687
742
|
graph_definition=self._graph_def,
|
|
688
743
|
)
|
|
689
744
|
|
|
690
|
-
resp =
|
|
745
|
+
resp = self._api_client.parse_query(request)
|
|
746
|
+
|
|
691
747
|
# TODO Expose validation warnings.
|
|
692
748
|
|
|
693
749
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -991,7 +1047,6 @@ class KumoRFM:
|
|
|
991
1047
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
992
1048
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
993
1049
|
]),
|
|
994
|
-
run_mode=run_mode,
|
|
995
1050
|
num_neighbors=num_neighbors,
|
|
996
1051
|
exclude_cols_dict=exclude_cols_dict,
|
|
997
1052
|
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
kumoai/experimental/rfm/utils.py
CHANGED
|
@@ -1,127 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
5
|
import pandas as pd
|
|
7
|
-
import pyarrow as pa
|
|
8
|
-
from kumoapi.typing import Dtype, Stype
|
|
9
|
-
|
|
10
|
-
from kumoai.experimental.rfm.infer import (
|
|
11
|
-
contains_categorical,
|
|
12
|
-
contains_id,
|
|
13
|
-
contains_multicategorical,
|
|
14
|
-
contains_timestamp,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
# Mapping from pandas/numpy dtypes to Kumo Dtypes
|
|
18
|
-
PANDAS_TO_DTYPE: Dict[Any, Dtype] = {
|
|
19
|
-
np.dtype('bool'): Dtype.bool,
|
|
20
|
-
pd.BooleanDtype(): Dtype.bool,
|
|
21
|
-
pa.bool_(): Dtype.bool,
|
|
22
|
-
np.dtype('byte'): Dtype.int,
|
|
23
|
-
pd.UInt8Dtype(): Dtype.int,
|
|
24
|
-
np.dtype('int16'): Dtype.int,
|
|
25
|
-
pd.Int16Dtype(): Dtype.int,
|
|
26
|
-
np.dtype('int32'): Dtype.int,
|
|
27
|
-
pd.Int32Dtype(): Dtype.int,
|
|
28
|
-
np.dtype('int64'): Dtype.int,
|
|
29
|
-
pd.Int64Dtype(): Dtype.int,
|
|
30
|
-
np.dtype('float32'): Dtype.float,
|
|
31
|
-
pd.Float32Dtype(): Dtype.float,
|
|
32
|
-
np.dtype('float64'): Dtype.float,
|
|
33
|
-
pd.Float64Dtype(): Dtype.float,
|
|
34
|
-
np.dtype('object'): Dtype.string,
|
|
35
|
-
pd.StringDtype(storage='python'): Dtype.string,
|
|
36
|
-
pd.StringDtype(storage='pyarrow'): Dtype.string,
|
|
37
|
-
pa.string(): Dtype.string,
|
|
38
|
-
pa.binary(): Dtype.binary,
|
|
39
|
-
np.dtype('datetime64[ns]'): Dtype.date,
|
|
40
|
-
np.dtype('timedelta64[ns]'): Dtype.timedelta,
|
|
41
|
-
pa.list_(pa.float32()): Dtype.floatlist,
|
|
42
|
-
pa.list_(pa.int64()): Dtype.intlist,
|
|
43
|
-
pa.list_(pa.string()): Dtype.stringlist,
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def to_dtype(ser: pd.Series) -> Dtype:
|
|
48
|
-
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
ser: A :class:`pandas.Series` to analyze.
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
The data type.
|
|
55
|
-
"""
|
|
56
|
-
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
57
|
-
return Dtype.date
|
|
58
|
-
|
|
59
|
-
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
60
|
-
return Dtype.string
|
|
61
|
-
|
|
62
|
-
if pd.api.types.is_object_dtype(ser.dtype):
|
|
63
|
-
index = ser.iloc[:1000].first_valid_index()
|
|
64
|
-
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
65
|
-
pos = ser.index.get_loc(index)
|
|
66
|
-
assert isinstance(pos, int)
|
|
67
|
-
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
68
|
-
|
|
69
|
-
if not ser.map(pd.api.types.is_list_like).all():
|
|
70
|
-
raise ValueError("Data contains a mix of list-like and "
|
|
71
|
-
"non-list-like values")
|
|
72
|
-
|
|
73
|
-
ser = ser[ser.map(lambda x: not isinstance(x, list) or len(x) > 0)]
|
|
74
|
-
|
|
75
|
-
dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
|
|
76
|
-
np.array(x).dtype, Dtype.string)).unique().tolist()
|
|
77
|
-
|
|
78
|
-
invalid_dtypes = set(dtypes) - {
|
|
79
|
-
Dtype.string,
|
|
80
|
-
Dtype.int,
|
|
81
|
-
Dtype.float,
|
|
82
|
-
}
|
|
83
|
-
if len(invalid_dtypes) > 0:
|
|
84
|
-
raise ValueError(f"Data contains unsupported list data types: "
|
|
85
|
-
f"{list(invalid_dtypes)}")
|
|
86
|
-
|
|
87
|
-
if Dtype.string in dtypes:
|
|
88
|
-
return Dtype.stringlist
|
|
89
|
-
|
|
90
|
-
if dtypes == [Dtype.int]:
|
|
91
|
-
return Dtype.intlist
|
|
92
|
-
|
|
93
|
-
return Dtype.floatlist
|
|
94
|
-
|
|
95
|
-
if ser.dtype not in PANDAS_TO_DTYPE:
|
|
96
|
-
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
97
|
-
|
|
98
|
-
return PANDAS_TO_DTYPE[ser.dtype]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
|
|
102
|
-
r"""Infers the semantic type of a column.
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
ser: A :class:`pandas.Series` to analyze.
|
|
106
|
-
column_name: The name of the column (used for pattern matching).
|
|
107
|
-
dtype: The data type.
|
|
108
|
-
|
|
109
|
-
Returns:
|
|
110
|
-
The semantic type.
|
|
111
|
-
"""
|
|
112
|
-
if contains_id(ser, column_name, dtype):
|
|
113
|
-
return Stype.ID
|
|
114
|
-
|
|
115
|
-
if contains_timestamp(ser, column_name, dtype):
|
|
116
|
-
return Stype.timestamp
|
|
117
|
-
|
|
118
|
-
if contains_multicategorical(ser, column_name, dtype):
|
|
119
|
-
return Stype.multicategorical
|
|
120
|
-
|
|
121
|
-
if contains_categorical(ser, column_name, dtype):
|
|
122
|
-
return Stype.categorical
|
|
123
|
-
|
|
124
|
-
return dtype.default_stype
|
|
125
6
|
|
|
126
7
|
|
|
127
8
|
def detect_primary_key(
|
|
Binary file
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -103,10 +103,13 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
103
103
|
self._progress.update(self._task, advance=1) # type: ignore
|
|
104
104
|
|
|
105
105
|
def __enter__(self) -> Self:
|
|
106
|
+
from kumoai import in_notebook
|
|
107
|
+
|
|
106
108
|
super().__enter__()
|
|
107
109
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
+
if not in_notebook(): # Render progress bar in TUI.
|
|
111
|
+
sys.stdout.write("\x1b]9;4;3\x07")
|
|
112
|
+
sys.stdout.flush()
|
|
110
113
|
|
|
111
114
|
if self.verbose:
|
|
112
115
|
self._live = Live(
|
|
@@ -119,6 +122,8 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
119
122
|
return self
|
|
120
123
|
|
|
121
124
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
125
|
+
from kumoai import in_notebook
|
|
126
|
+
|
|
122
127
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
123
128
|
|
|
124
129
|
if exc_type is not None:
|
|
@@ -134,8 +139,9 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
134
139
|
self._live.stop()
|
|
135
140
|
self._live = None
|
|
136
141
|
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
if not in_notebook():
|
|
143
|
+
sys.stdout.write("\x1b]9;4;0\x07")
|
|
144
|
+
sys.stdout.flush()
|
|
139
145
|
|
|
140
146
|
def __rich_console__(
|
|
141
147
|
self,
|
{kumoai-2.12.0.dev202511071730.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.0.dev202512021731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.48.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -38,6 +38,15 @@ Provides-Extra: test
|
|
|
38
38
|
Requires-Dist: pytest; extra == "test"
|
|
39
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
40
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
|
+
Provides-Extra: sagemaker
|
|
46
|
+
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
47
|
+
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
48
|
+
Provides-Extra: test-sagemaker
|
|
49
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
41
50
|
Dynamic: license-file
|
|
42
51
|
Dynamic: requires-dist
|
|
43
52
|
|