kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +37 -8
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
- kumoai/experimental/rfm/base/__init__.py +25 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +60 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
- kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
- kumoai/experimental/rfm/rfm.py +669 -246
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +1 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/trainer.py +12 -10
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +239 -4
- kumoai/utils/sql.py +3 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
- kumoai/experimental/rfm/local_graph_sampler.py +0 -176
- kumoai/experimental/rfm/local_pquery_driver.py +0 -404
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
kumoai/jobs.py
CHANGED
|
@@ -26,6 +26,7 @@ class JobInterface(ABC, Generic[IDType, JobRequestType, JobResourceType]):
|
|
|
26
26
|
limit (int): Max number of jobs to list, default 10.
|
|
27
27
|
|
|
28
28
|
Example:
|
|
29
|
+
>>> # doctest: +SKIP
|
|
29
30
|
>>> tags = {'pquery_name': 'my_pquery_name'}
|
|
30
31
|
>>> jobs = BatchPredictionJob.search_by_tags(tags)
|
|
31
32
|
Search limited to 10 results based on the `limit` parameter.
|
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
kumoai/testing/snow.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
5
|
+
from kumoai.experimental.rfm.backend.snow import connect as _connect
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def connect(
|
|
9
|
+
region: str,
|
|
10
|
+
id: str,
|
|
11
|
+
account: str,
|
|
12
|
+
user: str,
|
|
13
|
+
warehouse: str,
|
|
14
|
+
database: str | None = None,
|
|
15
|
+
schema: str | None = None,
|
|
16
|
+
) -> Connection:
|
|
17
|
+
|
|
18
|
+
kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
|
|
19
|
+
if kwargs['password'] is None:
|
|
20
|
+
import boto3
|
|
21
|
+
from cryptography.hazmat.primitives import serialization
|
|
22
|
+
|
|
23
|
+
client = boto3.client(
|
|
24
|
+
service_name='secretsmanager',
|
|
25
|
+
region_name=region,
|
|
26
|
+
)
|
|
27
|
+
secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
|
|
28
|
+
f'{account}.snowflakecomputing.com')
|
|
29
|
+
response = client.get_secret_value(SecretId=secret_id)['SecretString']
|
|
30
|
+
secret = json.loads(response)
|
|
31
|
+
|
|
32
|
+
private_key = serialization.load_pem_private_key(
|
|
33
|
+
secret['kumo_user_secretkey'].encode(),
|
|
34
|
+
password=None,
|
|
35
|
+
)
|
|
36
|
+
kwargs['private_key'] = private_key.private_bytes(
|
|
37
|
+
encoding=serialization.Encoding.DER,
|
|
38
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
39
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return _connect(
|
|
43
|
+
account=account,
|
|
44
|
+
user=user,
|
|
45
|
+
warehouse='WH_XS',
|
|
46
|
+
database='KUMO',
|
|
47
|
+
schema=schema,
|
|
48
|
+
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|
kumoai/trainer/trainer.py
CHANGED
|
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
|
|
|
20
20
|
TrainingJobResource,
|
|
21
21
|
)
|
|
22
22
|
from kumoapi.model_plan import ModelPlan
|
|
23
|
-
from kumoapi.task import TaskType
|
|
24
23
|
|
|
25
24
|
from kumoai import global_state
|
|
26
25
|
from kumoai.artifact_export.config import OutputConfig
|
|
@@ -360,6 +359,9 @@ class Trainer:
|
|
|
360
359
|
'deprecated. Please use output_config to specify these '
|
|
361
360
|
'parameters.')
|
|
362
361
|
assert output_config is not None
|
|
362
|
+
# Be able to pass output_config as a dictionary
|
|
363
|
+
if isinstance(output_config, dict):
|
|
364
|
+
output_config = OutputConfig(**output_config)
|
|
363
365
|
output_table_name = to_db_table_name(output_config.output_table_name)
|
|
364
366
|
validate_output_arguments(
|
|
365
367
|
output_config.output_types,
|
|
@@ -402,15 +404,15 @@ class Trainer:
|
|
|
402
404
|
pred_table_data_path = prediction_table.table_data_uri
|
|
403
405
|
|
|
404
406
|
api = global_state.client.batch_prediction_job_api
|
|
405
|
-
|
|
406
|
-
from kumoai.pquery.predictive_query import PredictiveQuery
|
|
407
|
-
pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
408
|
-
if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
407
|
+
# Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
|
|
408
|
+
# from kumoai.pquery.predictive_query import PredictiveQuery
|
|
409
|
+
# pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
410
|
+
# if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
411
|
+
# if binary_classification_threshold is None:
|
|
412
|
+
# logger.warning(
|
|
413
|
+
# "No binary classification threshold provided. "
|
|
414
|
+
# "Using default threshold of 0.5.")
|
|
415
|
+
# binary_classification_threshold = 0.5
|
|
414
416
|
job_id, response = api.maybe_create(
|
|
415
417
|
BatchPredictionRequest(
|
|
416
418
|
dict(custom_tags),
|
kumoai/utils/__init__.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .sql import quote_ident
|
|
2
|
+
from .progress_logger import ProgressLogger
|
|
2
3
|
from .forecasting import ForecastVisualizer
|
|
3
4
|
from .datasets import from_relbench
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
7
|
+
'quote_ident',
|
|
6
8
|
'ProgressLogger',
|
|
7
|
-
'InteractiveProgressLogger',
|
|
8
9
|
'ForecastVisualizer',
|
|
9
10
|
'from_relbench',
|
|
10
11
|
]
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import sys
|
|
1
3
|
import time
|
|
2
4
|
from typing import Any, List, Optional, Union
|
|
3
5
|
|
|
4
6
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
5
7
|
from rich.live import Live
|
|
6
8
|
from rich.padding import Padding
|
|
9
|
+
from rich.progress import (
|
|
10
|
+
BarColumn,
|
|
11
|
+
MofNCompleteColumn,
|
|
12
|
+
Progress,
|
|
13
|
+
Task,
|
|
14
|
+
TextColumn,
|
|
15
|
+
TimeRemainingColumn,
|
|
16
|
+
)
|
|
7
17
|
from rich.spinner import Spinner
|
|
8
18
|
from rich.table import Table
|
|
9
19
|
from rich.text import Text
|
|
@@ -11,13 +21,23 @@ from typing_extensions import Self
|
|
|
11
21
|
|
|
12
22
|
|
|
13
23
|
class ProgressLogger:
|
|
14
|
-
def __init__(self, msg: str) -> None:
|
|
24
|
+
def __init__(self, msg: str, verbose: bool = True) -> None:
|
|
15
25
|
self.msg = msg
|
|
26
|
+
self.verbose = verbose
|
|
27
|
+
|
|
16
28
|
self.logs: List[str] = []
|
|
17
29
|
|
|
18
30
|
self.start_time: Optional[float] = None
|
|
19
31
|
self.end_time: Optional[float] = None
|
|
20
32
|
|
|
33
|
+
@classmethod
|
|
34
|
+
def default(cls, msg: str, verbose: bool = True) -> 'ProgressLogger':
|
|
35
|
+
from kumoai import in_snowflake_notebook
|
|
36
|
+
|
|
37
|
+
if in_snowflake_notebook():
|
|
38
|
+
return StreamlitProgressLogger(msg, verbose)
|
|
39
|
+
return RichProgressLogger(msg, verbose)
|
|
40
|
+
|
|
21
41
|
@property
|
|
22
42
|
def duration(self) -> float:
|
|
23
43
|
assert self.start_time is not None
|
|
@@ -28,6 +48,12 @@ class ProgressLogger:
|
|
|
28
48
|
def log(self, msg: str) -> None:
|
|
29
49
|
self.logs.append(msg)
|
|
30
50
|
|
|
51
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def step(self) -> None:
|
|
55
|
+
pass
|
|
56
|
+
|
|
31
57
|
def __enter__(self) -> Self:
|
|
32
58
|
self.start_time = time.perf_counter()
|
|
33
59
|
return self
|
|
@@ -39,24 +65,68 @@ class ProgressLogger:
|
|
|
39
65
|
return f'{self.__class__.__name__}({self.msg})'
|
|
40
66
|
|
|
41
67
|
|
|
42
|
-
class
|
|
68
|
+
class ColoredMofNCompleteColumn(MofNCompleteColumn):
|
|
69
|
+
def __init__(self, style: str = 'green') -> None:
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.style = style
|
|
72
|
+
|
|
73
|
+
def render(self, task: Task) -> Text:
|
|
74
|
+
return Text(str(super().render(task)), style=self.style)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ColoredTimeRemainingColumn(TimeRemainingColumn):
|
|
78
|
+
def __init__(self, style: str = 'cyan') -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.style = style
|
|
81
|
+
|
|
82
|
+
def render(self, task: Task) -> Text:
|
|
83
|
+
return Text(str(super().render(task)), style=self.style)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RichProgressLogger(ProgressLogger):
|
|
43
87
|
def __init__(
|
|
44
88
|
self,
|
|
45
89
|
msg: str,
|
|
46
90
|
verbose: bool = True,
|
|
47
91
|
refresh_per_second: int = 10,
|
|
48
92
|
) -> None:
|
|
49
|
-
super().__init__(msg=msg)
|
|
93
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
50
94
|
|
|
51
|
-
self.verbose = verbose
|
|
52
95
|
self.refresh_per_second = refresh_per_second
|
|
53
96
|
|
|
97
|
+
self._progress: Optional[Progress] = None
|
|
98
|
+
self._task: Optional[int] = None
|
|
99
|
+
|
|
54
100
|
self._live: Optional[Live] = None
|
|
55
101
|
self._exception: bool = False
|
|
56
102
|
|
|
103
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
104
|
+
assert self._progress is None
|
|
105
|
+
if self.verbose:
|
|
106
|
+
self._progress = Progress(
|
|
107
|
+
TextColumn(f' ↳ {description}', style='dim'),
|
|
108
|
+
BarColumn(bar_width=None),
|
|
109
|
+
ColoredMofNCompleteColumn(style='dim'),
|
|
110
|
+
TextColumn('•', style='dim'),
|
|
111
|
+
ColoredTimeRemainingColumn(style='dim'),
|
|
112
|
+
)
|
|
113
|
+
self._task = self._progress.add_task("Progress", total=total)
|
|
114
|
+
|
|
115
|
+
def step(self) -> None:
|
|
116
|
+
if self.verbose:
|
|
117
|
+
assert self._progress is not None
|
|
118
|
+
assert self._task is not None
|
|
119
|
+
self._progress.update(self._task, advance=1) # type: ignore
|
|
120
|
+
|
|
57
121
|
def __enter__(self) -> Self:
|
|
122
|
+
from kumoai import in_notebook
|
|
123
|
+
|
|
58
124
|
super().__enter__()
|
|
59
125
|
|
|
126
|
+
if not in_notebook(): # Render progress bar in TUI.
|
|
127
|
+
sys.stdout.write("\x1b]9;4;3\x07")
|
|
128
|
+
sys.stdout.flush()
|
|
129
|
+
|
|
60
130
|
if self.verbose:
|
|
61
131
|
self._live = Live(
|
|
62
132
|
self,
|
|
@@ -68,16 +138,27 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
68
138
|
return self
|
|
69
139
|
|
|
70
140
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
141
|
+
from kumoai import in_notebook
|
|
142
|
+
|
|
71
143
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
72
144
|
|
|
73
145
|
if exc_type is not None:
|
|
74
146
|
self._exception = True
|
|
75
147
|
|
|
148
|
+
if self._progress is not None:
|
|
149
|
+
self._progress.stop()
|
|
150
|
+
self._progress = None
|
|
151
|
+
self._task = None
|
|
152
|
+
|
|
76
153
|
if self._live is not None:
|
|
77
154
|
self._live.update(self, refresh=True)
|
|
78
155
|
self._live.stop()
|
|
79
156
|
self._live = None
|
|
80
157
|
|
|
158
|
+
if not in_notebook():
|
|
159
|
+
sys.stdout.write("\x1b]9;4;0\x07")
|
|
160
|
+
sys.stdout.flush()
|
|
161
|
+
|
|
81
162
|
def __rich_console__(
|
|
82
163
|
self,
|
|
83
164
|
console: Console,
|
|
@@ -107,3 +188,157 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
107
188
|
table.add_row('', Text(f'↳ {log}', style='dim'))
|
|
108
189
|
|
|
109
190
|
yield table
|
|
191
|
+
|
|
192
|
+
if self.verbose and self._progress is not None:
|
|
193
|
+
yield self._progress.get_renderable()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class StreamlitProgressLogger(ProgressLogger):
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
msg: str,
|
|
200
|
+
verbose: bool = True,
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
203
|
+
|
|
204
|
+
self._status: Any = None
|
|
205
|
+
|
|
206
|
+
self._total = 0
|
|
207
|
+
self._current = 0
|
|
208
|
+
self._description: str = ''
|
|
209
|
+
self._progress: Any = None
|
|
210
|
+
|
|
211
|
+
def __enter__(self) -> Self:
|
|
212
|
+
super().__enter__()
|
|
213
|
+
|
|
214
|
+
import streamlit as st
|
|
215
|
+
|
|
216
|
+
# Adjust layout for prettier output:
|
|
217
|
+
st.markdown(STREAMLIT_CSS, unsafe_allow_html=True)
|
|
218
|
+
|
|
219
|
+
if self.verbose:
|
|
220
|
+
self._status = st.status(
|
|
221
|
+
f':blue[{self._sanitize_text(self.msg)}]',
|
|
222
|
+
expanded=True,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return self
|
|
226
|
+
|
|
227
|
+
def log(self, msg: str) -> None:
|
|
228
|
+
super().log(msg)
|
|
229
|
+
if self.verbose and self._status is not None:
|
|
230
|
+
self._status.write(self._sanitize_text(msg))
|
|
231
|
+
|
|
232
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
233
|
+
if self.verbose and self._status is not None:
|
|
234
|
+
self._total = total
|
|
235
|
+
self._current = 0
|
|
236
|
+
self._description = self._sanitize_text(description)
|
|
237
|
+
percent = min(self._current / self._total, 1.0)
|
|
238
|
+
self._progress = self._status.progress(
|
|
239
|
+
value=percent,
|
|
240
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
241
|
+
)
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
def step(self) -> None:
|
|
245
|
+
self._current += 1
|
|
246
|
+
|
|
247
|
+
if self.verbose and self._progress is not None:
|
|
248
|
+
percent = min(self._current / self._total, 1.0)
|
|
249
|
+
self._progress.progress(
|
|
250
|
+
value=percent,
|
|
251
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
255
|
+
super().__exit__(exc_type, exc_val, exc_tb)
|
|
256
|
+
|
|
257
|
+
if not self.verbose or self._status is None:
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
label = f'{self._sanitize_text(self.msg)} ({self.duration:.2f}s)'
|
|
261
|
+
|
|
262
|
+
if exc_type is not None:
|
|
263
|
+
self._status.update(
|
|
264
|
+
label=f':red[{label}]',
|
|
265
|
+
state='error',
|
|
266
|
+
expanded=True,
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
self._status.update(
|
|
270
|
+
label=f':green[{label}]',
|
|
271
|
+
state='complete',
|
|
272
|
+
expanded=True,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def _sanitize_text(msg: str) -> str:
|
|
277
|
+
return re.sub(r'\[/?bold\]', '**', msg)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
STREAMLIT_CSS = """
|
|
281
|
+
<style>
|
|
282
|
+
/* Fix horizontal scrollbar */
|
|
283
|
+
.stExpander summary {
|
|
284
|
+
width: auto;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
/* Fix paddings/margins */
|
|
288
|
+
.stExpander summary {
|
|
289
|
+
padding: 0.75rem 1rem 0.5rem;
|
|
290
|
+
}
|
|
291
|
+
.stExpander p {
|
|
292
|
+
margin: 0px 0px 0.2rem;
|
|
293
|
+
}
|
|
294
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
295
|
+
padding-bottom: 1.45rem;
|
|
296
|
+
}
|
|
297
|
+
.stExpander .stProgress div:first-child {
|
|
298
|
+
padding-bottom: 4px;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
/* Fix expand icon position */
|
|
302
|
+
.stExpander summary svg {
|
|
303
|
+
height: 1.5rem;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
/* Fix summary icons */
|
|
307
|
+
.stExpander summary [data-testid="stExpanderIconCheck"] {
|
|
308
|
+
font-size: 1.8rem;
|
|
309
|
+
margin-top: -3px;
|
|
310
|
+
color: rgb(21, 130, 55);
|
|
311
|
+
}
|
|
312
|
+
.stExpander summary [data-testid="stExpanderIconError"] {
|
|
313
|
+
font-size: 1.8rem;
|
|
314
|
+
margin-top: -3px;
|
|
315
|
+
color: rgb(255, 43, 43);
|
|
316
|
+
}
|
|
317
|
+
.stExpander summary span:first-child span:first-child {
|
|
318
|
+
width: 1.6rem;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/* Add border between title and content */
|
|
322
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
323
|
+
border-top: 1px solid rgba(30, 37, 47, 0.2);
|
|
324
|
+
padding-top: 0.5rem;
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
/* Fix title font size */
|
|
328
|
+
.stExpander summary p {
|
|
329
|
+
font-size: 1rem;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
/* Gray out content */
|
|
333
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
334
|
+
color: rgba(30, 37, 47, 0.5);
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
/* Fix progress bar font size */
|
|
338
|
+
.stExpander .stProgress p {
|
|
339
|
+
line-height: 1.6;
|
|
340
|
+
font-size: 1rem;
|
|
341
|
+
color: rgba(30, 37, 47, 0.5);
|
|
342
|
+
}
|
|
343
|
+
</style>
|
|
344
|
+
"""
|
kumoai/utils/sql.py
ADDED
{kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202512161731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -9,13 +9,12 @@ Project-URL: documentation, https://kumo.ai/docs
|
|
|
9
9
|
Keywords: deep-learning,graph-neural-networks,cloud-data-warehouse
|
|
10
10
|
Classifier: Development Status :: 5 - Production/Stable
|
|
11
11
|
Classifier: Programming Language :: Python
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
17
16
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
18
|
-
Requires-Python: >=3.
|
|
17
|
+
Requires-Python: >=3.10
|
|
19
18
|
Description-Content-Type: text/markdown
|
|
20
19
|
License-File: LICENSE
|
|
21
20
|
Requires-Dist: pandas
|
|
@@ -24,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
24
23
|
Requires-Dist: urllib3
|
|
25
24
|
Requires-Dist: plotly
|
|
26
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
27
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
28
27
|
Requires-Dist: tqdm>=4.66.0
|
|
29
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
30
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -39,6 +38,17 @@ Provides-Extra: test
|
|
|
39
38
|
Requires-Dist: pytest; extra == "test"
|
|
40
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
41
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: numpy<2.0; extra == "snowflake"
|
|
45
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
46
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
47
|
+
Provides-Extra: sagemaker
|
|
48
|
+
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
49
|
+
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
50
|
+
Provides-Extra: test-sagemaker
|
|
51
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
42
52
|
Dynamic: license-file
|
|
43
53
|
Dynamic: requires-dist
|
|
44
54
|
|
|
@@ -54,7 +64,7 @@ interact with the Kumo machine learning platform
|
|
|
54
64
|
|
|
55
65
|
## Installation
|
|
56
66
|
|
|
57
|
-
The Kumo SDK is available for Python 3.
|
|
67
|
+
The Kumo SDK is available for Python 3.10 to Python 3.13. To install, simply run
|
|
58
68
|
|
|
59
69
|
```
|
|
60
70
|
pip install kumoai
|