kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
kumoai/jobs.py
CHANGED
|
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from typing import Generic, Mapping, Optional, TypeVar
|
|
3
3
|
|
|
4
4
|
from kumoapi.jobs import JobStatusReport
|
|
5
|
+
from typing_extensions import Self
|
|
5
6
|
|
|
6
7
|
from kumoai.client.jobs import CommonJobAPI, JobRequestType, JobResourceType
|
|
7
8
|
|
|
@@ -9,12 +10,37 @@ IDType = TypeVar('IDType', bound=str)
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class JobInterface(ABC, Generic[IDType, JobRequestType, JobResourceType]):
|
|
12
|
-
r"""Defines a standard interface for job objects"""
|
|
13
|
+
r"""Defines a standard interface for job objects."""
|
|
13
14
|
@staticmethod
|
|
14
15
|
@abstractmethod
|
|
15
16
|
def _api() -> CommonJobAPI[JobRequestType, JobResourceType]:
|
|
16
17
|
pass
|
|
17
18
|
|
|
19
|
+
@classmethod
|
|
20
|
+
def search_by_tags(cls, tags: Mapping[str, str],
|
|
21
|
+
limit: int = 10) -> list[Self]:
|
|
22
|
+
r"""Returns a list of job instances from a set of job tags.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tags (Mapping[str, str]): Tags by which to search.
|
|
26
|
+
limit (int): Max number of jobs to list, default 10.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> # doctest: +SKIP
|
|
30
|
+
>>> tags = {'pquery_name': 'my_pquery_name'}
|
|
31
|
+
>>> jobs = BatchPredictionJob.search_by_tags(tags)
|
|
32
|
+
Search limited to 10 results based on the `limit` parameter.
|
|
33
|
+
Found 2 jobs.
|
|
34
|
+
"""
|
|
35
|
+
print(f"Search limited to {limit} results based on the `limit` "
|
|
36
|
+
"parameter.")
|
|
37
|
+
|
|
38
|
+
jobs = cls._api().list(limit=limit, additional_tags=tags)
|
|
39
|
+
|
|
40
|
+
print(f"Found {len(jobs)} jobs.")
|
|
41
|
+
|
|
42
|
+
return [cls(j.job_id) for j in jobs] # type: ignore
|
|
43
|
+
|
|
18
44
|
@property
|
|
19
45
|
@abstractmethod
|
|
20
46
|
def id(self) -> IDType:
|
|
Binary file
|
|
@@ -4,6 +4,7 @@ import asyncio
|
|
|
4
4
|
import logging
|
|
5
5
|
from concurrent.futures import Future
|
|
6
6
|
from datetime import datetime
|
|
7
|
+
from functools import cached_property
|
|
7
8
|
from typing import List, Optional, Union
|
|
8
9
|
|
|
9
10
|
import pandas as pd
|
|
@@ -217,9 +218,10 @@ class PredictionTableJob(JobInterface[GeneratePredictionTableJobID,
|
|
|
217
218
|
) -> None:
|
|
218
219
|
self.job_id = job_id
|
|
219
220
|
self.job: Optional[GeneratePredictionTableJobResource] = None
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
221
|
+
|
|
222
|
+
@cached_property
|
|
223
|
+
def _fut(self) -> Future:
|
|
224
|
+
return create_future(self._poll())
|
|
223
225
|
|
|
224
226
|
@override
|
|
225
227
|
@property
|
kumoai/pquery/training_table.py
CHANGED
|
@@ -5,6 +5,7 @@ import logging
|
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
7
|
from concurrent.futures import Future
|
|
8
|
+
from functools import cached_property
|
|
8
9
|
from typing import List, Optional, Tuple, Union
|
|
9
10
|
|
|
10
11
|
import pandas as pd
|
|
@@ -308,9 +309,10 @@ class TrainingTableJob(JobInterface[GenerateTrainTableJobID,
|
|
|
308
309
|
job_id: GenerateTrainTableJobID,
|
|
309
310
|
) -> None:
|
|
310
311
|
self.job_id = job_id
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
312
|
+
|
|
313
|
+
@cached_property
|
|
314
|
+
def _fut(self) -> Future[TrainingTable]:
|
|
315
|
+
return create_future(_poll(self.job_id))
|
|
314
316
|
|
|
315
317
|
@override
|
|
316
318
|
@property
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
kumoai/trainer/job.py
CHANGED
|
@@ -4,7 +4,7 @@ import concurrent.futures
|
|
|
4
4
|
import time
|
|
5
5
|
from datetime import datetime, timezone
|
|
6
6
|
from functools import cached_property
|
|
7
|
-
from typing import TYPE_CHECKING, Dict, List,
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
8
8
|
from urllib.parse import urlparse, urlunparse
|
|
9
9
|
|
|
10
10
|
import pandas as pd
|
|
@@ -600,8 +600,10 @@ class TrainingJob(JobInterface[TrainingJobID, TrainingJobRequest,
|
|
|
600
600
|
|
|
601
601
|
def __init__(self, job_id: TrainingJobID) -> None:
|
|
602
602
|
self.job_id = job_id
|
|
603
|
-
|
|
604
|
-
|
|
603
|
+
|
|
604
|
+
@cached_property
|
|
605
|
+
def _fut(self) -> concurrent.futures.Future:
|
|
606
|
+
return create_future(_poll_training(self.job_id))
|
|
605
607
|
|
|
606
608
|
@override
|
|
607
609
|
@property
|
|
@@ -1002,31 +1004,6 @@ class BatchPredictionJob(JobInterface[BatchPredictionJobID,
|
|
|
1002
1004
|
"""
|
|
1003
1005
|
return self._api().get_config(self.job_id)
|
|
1004
1006
|
|
|
1005
|
-
@classmethod
|
|
1006
|
-
def search_by_tags(cls, tags: Mapping[str, str],
|
|
1007
|
-
limit: int = 10) -> list['BatchPredictionJob']:
|
|
1008
|
-
r"""Returns a list of :class:`~kumoai.trainer.job.BatchPredictionJob`
|
|
1009
|
-
instances from a set of job tags.
|
|
1010
|
-
|
|
1011
|
-
Args:
|
|
1012
|
-
tags (Mapping[str, str]): Tags by which to search.
|
|
1013
|
-
limit (int): Max number of jobs to list, default 10.
|
|
1014
|
-
|
|
1015
|
-
Example:
|
|
1016
|
-
>>> tags = {'pquery_name': 'my_pquery_name'}
|
|
1017
|
-
>>> jobs = BatchPredictionJob.search_by_tags(tags)
|
|
1018
|
-
Search limited to 10 results based on the `limit` parameter.
|
|
1019
|
-
Found 2 jobs.
|
|
1020
|
-
"""
|
|
1021
|
-
print(f"Search limited to {limit} results based on the `limit` "
|
|
1022
|
-
"parameter.")
|
|
1023
|
-
|
|
1024
|
-
jobs = cls._api().list(limit=limit, additional_tags=tags)
|
|
1025
|
-
|
|
1026
|
-
print(f"Found {len(jobs)} jobs.")
|
|
1027
|
-
|
|
1028
|
-
return [cls(j.job_id) for j in jobs]
|
|
1029
|
-
|
|
1030
1007
|
|
|
1031
1008
|
def _get_batch_prediction_job(job_id: str) -> BatchPredictionJobResource:
|
|
1032
1009
|
api = global_state.client.batch_prediction_job_api
|
|
@@ -1097,8 +1074,10 @@ class BaselineJob(JobInterface[BaselineJobID, BaselineJobRequest,
|
|
|
1097
1074
|
|
|
1098
1075
|
def __init__(self, job_id: BaselineJobID) -> None:
|
|
1099
1076
|
self.job_id = job_id
|
|
1100
|
-
|
|
1101
|
-
|
|
1077
|
+
|
|
1078
|
+
@cached_property
|
|
1079
|
+
def _fut(self) -> concurrent.futures.Future:
|
|
1080
|
+
return create_future(_poll_baseline(self.job_id))
|
|
1102
1081
|
|
|
1103
1082
|
@override
|
|
1104
1083
|
@property
|
kumoai/trainer/trainer.py
CHANGED
|
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
|
|
|
20
20
|
TrainingJobResource,
|
|
21
21
|
)
|
|
22
22
|
from kumoapi.model_plan import ModelPlan
|
|
23
|
-
from kumoapi.task import TaskType
|
|
24
23
|
|
|
25
24
|
from kumoai import global_state
|
|
26
25
|
from kumoai.artifact_export.config import OutputConfig
|
|
@@ -190,6 +189,7 @@ class Trainer:
|
|
|
190
189
|
*,
|
|
191
190
|
non_blocking: bool = False,
|
|
192
191
|
custom_tags: Mapping[str, str] = {},
|
|
192
|
+
warm_start_job_id: Optional[TrainingJobID] = None,
|
|
193
193
|
) -> Union[TrainingJob, TrainingJobResult]:
|
|
194
194
|
r"""Fits a model to the specified graph and training table, with the
|
|
195
195
|
strategy defined by this :class:`Trainer`'s :obj:`model_plan`.
|
|
@@ -207,6 +207,11 @@ class Trainer:
|
|
|
207
207
|
custom_tags: Additional, customer defined k-v tags to be associated
|
|
208
208
|
with the job to be launched. Job tags are useful for grouping
|
|
209
209
|
and searching jobs.
|
|
210
|
+
warm_start_job_id: Optional job ID of a completed training job to
|
|
211
|
+
warm start from. Initializes the new model with the best
|
|
212
|
+
weights from the specified job, using its model
|
|
213
|
+
architecture, column processing, and neighbor sampling
|
|
214
|
+
configurations.
|
|
210
215
|
|
|
211
216
|
Returns:
|
|
212
217
|
Union[TrainingJobResult, TrainingJob]:
|
|
@@ -241,6 +246,7 @@ class Trainer:
|
|
|
241
246
|
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
|
|
242
247
|
train_table_job_id=job_id,
|
|
243
248
|
custom_train_table=custom_table,
|
|
249
|
+
warm_start_job_id=warm_start_job_id,
|
|
244
250
|
))
|
|
245
251
|
|
|
246
252
|
out = TrainingJob(job_id=self._training_job_id)
|
|
@@ -353,6 +359,9 @@ class Trainer:
|
|
|
353
359
|
'deprecated. Please use output_config to specify these '
|
|
354
360
|
'parameters.')
|
|
355
361
|
assert output_config is not None
|
|
362
|
+
# Be able to pass output_config as a dictionary
|
|
363
|
+
if isinstance(output_config, dict):
|
|
364
|
+
output_config = OutputConfig(**output_config)
|
|
356
365
|
output_table_name = to_db_table_name(output_config.output_table_name)
|
|
357
366
|
validate_output_arguments(
|
|
358
367
|
output_config.output_types,
|
|
@@ -395,15 +404,15 @@ class Trainer:
|
|
|
395
404
|
pred_table_data_path = prediction_table.table_data_uri
|
|
396
405
|
|
|
397
406
|
api = global_state.client.batch_prediction_job_api
|
|
398
|
-
|
|
399
|
-
from kumoai.pquery.predictive_query import PredictiveQuery
|
|
400
|
-
pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
401
|
-
if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
+
# Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
|
|
408
|
+
# from kumoai.pquery.predictive_query import PredictiveQuery
|
|
409
|
+
# pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
410
|
+
# if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
411
|
+
# if binary_classification_threshold is None:
|
|
412
|
+
# logger.warning(
|
|
413
|
+
# "No binary classification threshold provided. "
|
|
414
|
+
# "Using default threshold of 0.5.")
|
|
415
|
+
# binary_classification_threshold = 0.5
|
|
407
416
|
job_id, response = api.maybe_create(
|
|
408
417
|
BatchPredictionRequest(
|
|
409
418
|
dict(custom_tags),
|
kumoai/utils/__init__.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from .progress_logger import ProgressLogger
|
|
1
|
+
from .progress_logger import ProgressLogger, InteractiveProgressLogger
|
|
2
2
|
from .forecasting import ForecastVisualizer
|
|
3
3
|
from .datasets import from_relbench
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
6
6
|
'ProgressLogger',
|
|
7
|
+
'InteractiveProgressLogger',
|
|
7
8
|
'ForecastVisualizer',
|
|
8
9
|
'from_relbench',
|
|
9
10
|
]
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
import time
|
|
2
3
|
from typing import Any, List, Optional, Union
|
|
3
4
|
|
|
4
5
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
5
6
|
from rich.live import Live
|
|
6
7
|
from rich.padding import Padding
|
|
8
|
+
from rich.progress import (
|
|
9
|
+
BarColumn,
|
|
10
|
+
MofNCompleteColumn,
|
|
11
|
+
Progress,
|
|
12
|
+
Task,
|
|
13
|
+
TextColumn,
|
|
14
|
+
TimeRemainingColumn,
|
|
15
|
+
)
|
|
7
16
|
from rich.spinner import Spinner
|
|
8
17
|
from rich.table import Table
|
|
9
18
|
from rich.text import Text
|
|
@@ -11,27 +20,13 @@ from typing_extensions import Self
|
|
|
11
20
|
|
|
12
21
|
|
|
13
22
|
class ProgressLogger:
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
msg: str,
|
|
17
|
-
verbose: bool = True,
|
|
18
|
-
refresh_per_second: int = 10,
|
|
19
|
-
) -> None:
|
|
20
|
-
|
|
23
|
+
def __init__(self, msg: str) -> None:
|
|
21
24
|
self.msg = msg
|
|
22
|
-
self.verbose = verbose
|
|
23
|
-
self.refresh_per_second = refresh_per_second
|
|
24
25
|
self.logs: List[str] = []
|
|
25
26
|
|
|
26
27
|
self.start_time: Optional[float] = None
|
|
27
28
|
self.end_time: Optional[float] = None
|
|
28
29
|
|
|
29
|
-
self._live: Optional[Live] = None
|
|
30
|
-
self._exception: bool = False
|
|
31
|
-
|
|
32
|
-
def __repr__(self) -> str:
|
|
33
|
-
return f'{self.__class__.__name__}({self.msg})'
|
|
34
|
-
|
|
35
30
|
@property
|
|
36
31
|
def duration(self) -> float:
|
|
37
32
|
assert self.start_time is not None
|
|
@@ -44,6 +39,77 @@ class ProgressLogger:
|
|
|
44
39
|
|
|
45
40
|
def __enter__(self) -> Self:
|
|
46
41
|
self.start_time = time.perf_counter()
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
45
|
+
self.end_time = time.perf_counter()
|
|
46
|
+
|
|
47
|
+
def __repr__(self) -> str:
|
|
48
|
+
return f'{self.__class__.__name__}({self.msg})'
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ColoredMofNCompleteColumn(MofNCompleteColumn):
|
|
52
|
+
def __init__(self, style: str = 'green') -> None:
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.style = style
|
|
55
|
+
|
|
56
|
+
def render(self, task: Task) -> Text:
|
|
57
|
+
return Text(str(super().render(task)), style=self.style)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ColoredTimeRemainingColumn(TimeRemainingColumn):
|
|
61
|
+
def __init__(self, style: str = 'cyan') -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.style = style
|
|
64
|
+
|
|
65
|
+
def render(self, task: Task) -> Text:
|
|
66
|
+
return Text(str(super().render(task)), style=self.style)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class InteractiveProgressLogger(ProgressLogger):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
msg: str,
|
|
73
|
+
verbose: bool = True,
|
|
74
|
+
refresh_per_second: int = 10,
|
|
75
|
+
) -> None:
|
|
76
|
+
super().__init__(msg=msg)
|
|
77
|
+
|
|
78
|
+
self.verbose = verbose
|
|
79
|
+
self.refresh_per_second = refresh_per_second
|
|
80
|
+
|
|
81
|
+
self._progress: Optional[Progress] = None
|
|
82
|
+
self._task: Optional[int] = None
|
|
83
|
+
|
|
84
|
+
self._live: Optional[Live] = None
|
|
85
|
+
self._exception: bool = False
|
|
86
|
+
|
|
87
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
88
|
+
assert self._progress is None
|
|
89
|
+
if self.verbose:
|
|
90
|
+
self._progress = Progress(
|
|
91
|
+
TextColumn(f' ↳ {description}', style='dim'),
|
|
92
|
+
BarColumn(bar_width=None),
|
|
93
|
+
ColoredMofNCompleteColumn(style='dim'),
|
|
94
|
+
TextColumn('•', style='dim'),
|
|
95
|
+
ColoredTimeRemainingColumn(style='dim'),
|
|
96
|
+
)
|
|
97
|
+
self._task = self._progress.add_task("Progress", total=total)
|
|
98
|
+
|
|
99
|
+
def step(self) -> None:
|
|
100
|
+
if self.verbose:
|
|
101
|
+
assert self._progress is not None
|
|
102
|
+
assert self._task is not None
|
|
103
|
+
self._progress.update(self._task, advance=1) # type: ignore
|
|
104
|
+
|
|
105
|
+
def __enter__(self) -> Self:
|
|
106
|
+
from kumoai import in_notebook
|
|
107
|
+
|
|
108
|
+
super().__enter__()
|
|
109
|
+
|
|
110
|
+
if not in_notebook(): # Render progress bar in TUI.
|
|
111
|
+
sys.stdout.write("\x1b]9;4;3\x07")
|
|
112
|
+
sys.stdout.flush()
|
|
47
113
|
|
|
48
114
|
if self.verbose:
|
|
49
115
|
self._live = Live(
|
|
@@ -56,16 +122,27 @@ class ProgressLogger:
|
|
|
56
122
|
return self
|
|
57
123
|
|
|
58
124
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
59
|
-
|
|
125
|
+
from kumoai import in_notebook
|
|
126
|
+
|
|
127
|
+
super().__exit__(exc_type, exc_val, exc_tb)
|
|
60
128
|
|
|
61
129
|
if exc_type is not None:
|
|
62
130
|
self._exception = True
|
|
63
131
|
|
|
132
|
+
if self._progress is not None:
|
|
133
|
+
self._progress.stop()
|
|
134
|
+
self._progress = None
|
|
135
|
+
self._task = None
|
|
136
|
+
|
|
64
137
|
if self._live is not None:
|
|
65
138
|
self._live.update(self, refresh=True)
|
|
66
139
|
self._live.stop()
|
|
67
140
|
self._live = None
|
|
68
141
|
|
|
142
|
+
if not in_notebook():
|
|
143
|
+
sys.stdout.write("\x1b]9;4;0\x07")
|
|
144
|
+
sys.stdout.flush()
|
|
145
|
+
|
|
69
146
|
def __rich_console__(
|
|
70
147
|
self,
|
|
71
148
|
console: Console,
|
|
@@ -95,3 +172,6 @@ class ProgressLogger:
|
|
|
95
172
|
table.add_row('', Text(f'↳ {log}', style='dim'))
|
|
96
173
|
|
|
97
174
|
yield table
|
|
175
|
+
|
|
176
|
+
if self.verbose and self._progress is not None:
|
|
177
|
+
yield self._progress.get_renderable()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.0.dev202512041141
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -9,13 +9,12 @@ Project-URL: documentation, https://kumo.ai/docs
|
|
|
9
9
|
Keywords: deep-learning,graph-neural-networks,cloud-data-warehouse
|
|
10
10
|
Classifier: Development Status :: 5 - Production/Stable
|
|
11
11
|
Classifier: Programming Language :: Python
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
17
16
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
18
|
-
Requires-Python: >=3.
|
|
17
|
+
Requires-Python: >=3.10
|
|
19
18
|
Description-Content-Type: text/markdown
|
|
20
19
|
License-File: LICENSE
|
|
21
20
|
Requires-Dist: pandas
|
|
@@ -24,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
24
23
|
Requires-Dist: urllib3
|
|
25
24
|
Requires-Dist: plotly
|
|
26
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
27
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.48.0
|
|
28
27
|
Requires-Dist: tqdm>=4.66.0
|
|
29
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
30
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -39,6 +38,16 @@ Provides-Extra: test
|
|
|
39
38
|
Requires-Dist: pytest; extra == "test"
|
|
40
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
41
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
46
|
+
Provides-Extra: sagemaker
|
|
47
|
+
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
48
|
+
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
49
|
+
Provides-Extra: test-sagemaker
|
|
50
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
42
51
|
Dynamic: license-file
|
|
43
52
|
Dynamic: requires-dist
|
|
44
53
|
|
|
@@ -54,7 +63,7 @@ interact with the Kumo machine learning platform
|
|
|
54
63
|
|
|
55
64
|
## Installation
|
|
56
65
|
|
|
57
|
-
The Kumo SDK is available for Python 3.
|
|
66
|
+
The Kumo SDK is available for Python 3.10 to Python 3.13. To install, simply run
|
|
58
67
|
|
|
59
68
|
```
|
|
60
69
|
pip install kumoai
|