kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +49 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +32 -14
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +186 -39
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +380 -185
- kumoai/experimental/rfm/graph.py +404 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +52 -60
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +283 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import json
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
6
|
|
|
@@ -48,8 +48,8 @@ class KumoClient_SageMakerAdapter(KumoClient):
|
|
|
48
48
|
|
|
49
49
|
# Recording buffers.
|
|
50
50
|
self._recording_active = False
|
|
51
|
-
self._recorded_reqs:
|
|
52
|
-
self._recorded_resps:
|
|
51
|
+
self._recorded_reqs: list[dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: list[dict[str, Any]] = []
|
|
53
53
|
|
|
54
54
|
def authenticate(self) -> None:
|
|
55
55
|
# TODO(siyang): call /ping to verify?
|
|
@@ -92,7 +92,7 @@ class KumoClient_SageMakerAdapter(KumoClient):
|
|
|
92
92
|
self._recorded_reqs.clear()
|
|
93
93
|
self._recorded_resps.clear()
|
|
94
94
|
|
|
95
|
-
def end_recording(self) ->
|
|
95
|
+
def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
|
|
96
96
|
"""Stop recording and return recorded requests/responses."""
|
|
97
97
|
assert self._recording_active
|
|
98
98
|
self._recording_active = False
|
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|
kumoai/testing/snow.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
5
|
+
from kumoai.experimental.rfm.backend.snow import connect as _connect
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def connect(
|
|
9
|
+
region: str,
|
|
10
|
+
id: str,
|
|
11
|
+
account: str,
|
|
12
|
+
user: str,
|
|
13
|
+
warehouse: str,
|
|
14
|
+
database: str | None = None,
|
|
15
|
+
schema: str | None = None,
|
|
16
|
+
) -> Connection:
|
|
17
|
+
|
|
18
|
+
kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
|
|
19
|
+
if kwargs['password'] is None:
|
|
20
|
+
import boto3
|
|
21
|
+
from cryptography.hazmat.primitives import serialization
|
|
22
|
+
|
|
23
|
+
client = boto3.client(
|
|
24
|
+
service_name='secretsmanager',
|
|
25
|
+
region_name=region,
|
|
26
|
+
)
|
|
27
|
+
secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
|
|
28
|
+
f'{account}.snowflakecomputing.com')
|
|
29
|
+
response = client.get_secret_value(SecretId=secret_id)['SecretString']
|
|
30
|
+
secret = json.loads(response)
|
|
31
|
+
|
|
32
|
+
private_key = serialization.load_pem_private_key(
|
|
33
|
+
secret['kumo_user_secretkey'].encode(),
|
|
34
|
+
password=None,
|
|
35
|
+
)
|
|
36
|
+
kwargs['private_key'] = private_key.private_bytes(
|
|
37
|
+
encoding=serialization.Encoding.DER,
|
|
38
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
39
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return _connect(
|
|
43
|
+
account=account,
|
|
44
|
+
user=user,
|
|
45
|
+
warehouse='WH_XS',
|
|
46
|
+
database='KUMO',
|
|
47
|
+
schema=schema,
|
|
48
|
+
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal, Mapping, Optional, Union, overload
|
|
3
|
+
|
|
4
|
+
from kumoapi.distilled_model_plan import DistilledModelPlan
|
|
5
|
+
from kumoapi.jobs import DistillationJobRequest, DistillationJobResource
|
|
6
|
+
|
|
7
|
+
from kumoai import global_state
|
|
8
|
+
from kumoai.client.jobs import TrainingJobID
|
|
9
|
+
from kumoai.graph import Graph
|
|
10
|
+
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
|
|
11
|
+
from kumoai.trainer.job import TrainingJob, TrainingJobResult
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DistillationTrainer:
|
|
17
|
+
r"""A trainer supports creating a Kumo machine learning model
|
|
18
|
+
for use in an online serving endpoint. The distllation process involes
|
|
19
|
+
training a shallow model on a :class:`~kumoai.pquery.PredictiveQuery` using
|
|
20
|
+
the embeddings generated by a base model :args:`base_training_job_id`.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model_plan: The distilled model plan to use for the distillation process.
|
|
24
|
+
base_training_job_id: The ID of the base training job to use for the distillation process.
|
|
25
|
+
""" # noqa: E501
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model_plan: DistilledModelPlan,
|
|
30
|
+
base_training_job_id: TrainingJobID,
|
|
31
|
+
) -> None:
|
|
32
|
+
self.model_plan: DistilledModelPlan = model_plan
|
|
33
|
+
self.base_training_job_id: TrainingJobID = base_training_job_id
|
|
34
|
+
|
|
35
|
+
# Cached from backend:
|
|
36
|
+
self._training_job_id: Optional[TrainingJobID] = None
|
|
37
|
+
|
|
38
|
+
# Metadata ################################################################
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def is_trained(self) -> bool:
|
|
42
|
+
r"""Returns ``True`` if this trainer instance has successfully been
|
|
43
|
+
trained (and is therefore ready for prediction); ``False`` otherwise.
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError(
|
|
46
|
+
"Checking if a distilled trainer is trained is not "
|
|
47
|
+
"implemented yet.")
|
|
48
|
+
|
|
49
|
+
@overload
|
|
50
|
+
def fit(
|
|
51
|
+
self,
|
|
52
|
+
graph: Graph,
|
|
53
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
54
|
+
) -> TrainingJobResult:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@overload
|
|
58
|
+
def fit(
|
|
59
|
+
self,
|
|
60
|
+
graph: Graph,
|
|
61
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
62
|
+
*,
|
|
63
|
+
non_blocking: Literal[False],
|
|
64
|
+
) -> TrainingJobResult:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@overload
|
|
68
|
+
def fit(
|
|
69
|
+
self,
|
|
70
|
+
graph: Graph,
|
|
71
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
72
|
+
*,
|
|
73
|
+
non_blocking: Literal[True],
|
|
74
|
+
) -> TrainingJob:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@overload
|
|
78
|
+
def fit(
|
|
79
|
+
self,
|
|
80
|
+
graph: Graph,
|
|
81
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
82
|
+
*,
|
|
83
|
+
non_blocking: bool,
|
|
84
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def fit(
|
|
88
|
+
self,
|
|
89
|
+
graph: Graph,
|
|
90
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
91
|
+
*,
|
|
92
|
+
non_blocking: bool = False,
|
|
93
|
+
custom_tags: Mapping[str, str] = {},
|
|
94
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
95
|
+
r"""Fits a model to the specified graph and training table, with the
|
|
96
|
+
strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
graph: The :class:`~kumoai.graph.Graph` object that represents the
|
|
100
|
+
tables and relationships that Kumo will learn from.
|
|
101
|
+
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
|
|
102
|
+
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
|
|
103
|
+
represents the training data produced by a
|
|
104
|
+
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
|
|
105
|
+
non_blocking: Whether this operation should return immediately
|
|
106
|
+
after launching the training job, or await completion of the
|
|
107
|
+
training job.
|
|
108
|
+
custom_tags: Additional, customer defined k-v tags to be associated
|
|
109
|
+
with the job to be launched. Job tags are useful for grouping
|
|
110
|
+
and searching jobs.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Union[TrainingJobResult, TrainingJob]:
|
|
114
|
+
If ``non_blocking=False``, returns a training job object. If
|
|
115
|
+
``non_blocking=True``, returns a training job future object.
|
|
116
|
+
"""
|
|
117
|
+
# TODO(manan, siyang): remove soon:
|
|
118
|
+
job_id = train_table.job_id
|
|
119
|
+
assert job_id is not None
|
|
120
|
+
|
|
121
|
+
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
122
|
+
pq_id = train_table_job_api.get(job_id).config.pquery_id
|
|
123
|
+
assert pq_id is not None
|
|
124
|
+
|
|
125
|
+
custom_table = None
|
|
126
|
+
if isinstance(train_table, TrainingTable):
|
|
127
|
+
custom_table = train_table._custom_train_table
|
|
128
|
+
|
|
129
|
+
# NOTE the backend implementation currently handles sequentialization
|
|
130
|
+
# between a training table future and a training job; that is, if the
|
|
131
|
+
# training table future is still executing, the backend will wait on
|
|
132
|
+
# the job ID completion before executing a training job. This preserves
|
|
133
|
+
# semantics for both futures, ensures that Kumo works as expected if
|
|
134
|
+
# used only via REST API, and allows us to avoid chaining calllbacks
|
|
135
|
+
# in an ugly way here:
|
|
136
|
+
api = global_state.client.distillation_job_api
|
|
137
|
+
self._training_job_id = api.create(
|
|
138
|
+
DistillationJobRequest(
|
|
139
|
+
dict(custom_tags),
|
|
140
|
+
pquery_id=pq_id,
|
|
141
|
+
base_training_job_id=self.base_training_job_id,
|
|
142
|
+
distilled_model_plan=self.model_plan,
|
|
143
|
+
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
|
|
144
|
+
train_table_job_id=job_id,
|
|
145
|
+
custom_train_table=custom_table,
|
|
146
|
+
))
|
|
147
|
+
|
|
148
|
+
out = TrainingJob(job_id=self._training_job_id)
|
|
149
|
+
if non_blocking:
|
|
150
|
+
return out
|
|
151
|
+
return out.attach()
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def _load_from_job(
|
|
155
|
+
cls,
|
|
156
|
+
job: DistillationJobResource,
|
|
157
|
+
) -> 'DistillationTrainer':
|
|
158
|
+
trainer = cls(job.config.distilled_model_plan,
|
|
159
|
+
job.config.base_training_job_id)
|
|
160
|
+
trainer._training_job_id = job.job_id
|
|
161
|
+
return trainer
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer':
|
|
165
|
+
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
|
|
166
|
+
job ID.
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError(
|
|
169
|
+
"Loading a distilled trainer from a job ID is not implemented yet."
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer':
|
|
174
|
+
raise NotImplementedError(
|
|
175
|
+
"Loading a distilled trainer from tags is not implemented yet.")
|
kumoai/utils/__init__.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .sql import quote_ident
|
|
2
|
+
from .progress_logger import ProgressLogger
|
|
2
3
|
from .forecasting import ForecastVisualizer
|
|
3
4
|
from .datasets import from_relbench
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
7
|
+
'quote_ident',
|
|
6
8
|
'ProgressLogger',
|
|
7
|
-
'InteractiveProgressLogger',
|
|
8
9
|
'ForecastVisualizer',
|
|
9
10
|
'from_relbench',
|
|
10
11
|
]
|
kumoai/utils/display.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def message(msg: str) -> None:
|
|
9
|
+
msg = msg.replace("`", "'") if not in_notebook() else msg
|
|
10
|
+
|
|
11
|
+
if in_snowflake_notebook():
|
|
12
|
+
import streamlit as st
|
|
13
|
+
st.markdown(msg)
|
|
14
|
+
elif in_notebook():
|
|
15
|
+
from IPython.display import Markdown, display
|
|
16
|
+
display(Markdown(msg))
|
|
17
|
+
else:
|
|
18
|
+
print(msg)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def title(msg: str) -> None:
|
|
22
|
+
message(f"### {msg}" if in_notebook() else f"{msg}:")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def italic(msg: str) -> None:
|
|
26
|
+
message(f"*{msg}*" if in_notebook() else msg)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unordered_list(items: Sequence[str]) -> None:
|
|
30
|
+
if in_notebook():
|
|
31
|
+
msg = '\n'.join([f"- {item}" for item in items])
|
|
32
|
+
else:
|
|
33
|
+
msg = '\n'.join([f"• {item.replace('`', '')}" for item in items])
|
|
34
|
+
message(msg)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def dataframe(df: pd.DataFrame) -> None:
|
|
38
|
+
if in_snowflake_notebook():
|
|
39
|
+
import streamlit as st
|
|
40
|
+
st.dataframe(df, hide_index=True)
|
|
41
|
+
elif in_notebook():
|
|
42
|
+
from IPython.display import display
|
|
43
|
+
try:
|
|
44
|
+
if hasattr(df.style, 'hide'):
|
|
45
|
+
display(df.style.hide(axis='index')) # pandas=2
|
|
46
|
+
else:
|
|
47
|
+
display(df.style.hide_index()) # pandas<1.3
|
|
48
|
+
except ImportError:
|
|
49
|
+
print(df.to_string(index=False)) # missing jinja2
|
|
50
|
+
else:
|
|
51
|
+
print(df.to_string(index=False))
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
import re
|
|
1
2
|
import sys
|
|
2
3
|
import time
|
|
3
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
from rich.console import Console, ConsoleOptions, RenderResult
|
|
6
7
|
from rich.live import Live
|
|
@@ -20,12 +21,22 @@ from typing_extensions import Self
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ProgressLogger:
|
|
23
|
-
def __init__(self, msg: str) -> None:
|
|
24
|
+
def __init__(self, msg: str, verbose: bool = True) -> None:
|
|
24
25
|
self.msg = msg
|
|
25
|
-
self.
|
|
26
|
+
self.verbose = verbose
|
|
27
|
+
|
|
28
|
+
self.logs: list[str] = []
|
|
29
|
+
|
|
30
|
+
self.start_time: float | None = None
|
|
31
|
+
self.end_time: float | None = None
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def default(cls, msg: str, verbose: bool = True) -> 'ProgressLogger':
|
|
35
|
+
from kumoai import in_snowflake_notebook
|
|
26
36
|
|
|
27
|
-
|
|
28
|
-
|
|
37
|
+
if in_snowflake_notebook():
|
|
38
|
+
return StreamlitProgressLogger(msg, verbose)
|
|
39
|
+
return RichProgressLogger(msg, verbose)
|
|
29
40
|
|
|
30
41
|
@property
|
|
31
42
|
def duration(self) -> float:
|
|
@@ -37,6 +48,12 @@ class ProgressLogger:
|
|
|
37
48
|
def log(self, msg: str) -> None:
|
|
38
49
|
self.logs.append(msg)
|
|
39
50
|
|
|
51
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def step(self) -> None:
|
|
55
|
+
pass
|
|
56
|
+
|
|
40
57
|
def __enter__(self) -> Self:
|
|
41
58
|
self.start_time = time.perf_counter()
|
|
42
59
|
return self
|
|
@@ -66,22 +83,21 @@ class ColoredTimeRemainingColumn(TimeRemainingColumn):
|
|
|
66
83
|
return Text(str(super().render(task)), style=self.style)
|
|
67
84
|
|
|
68
85
|
|
|
69
|
-
class
|
|
86
|
+
class RichProgressLogger(ProgressLogger):
|
|
70
87
|
def __init__(
|
|
71
88
|
self,
|
|
72
89
|
msg: str,
|
|
73
90
|
verbose: bool = True,
|
|
74
91
|
refresh_per_second: int = 10,
|
|
75
92
|
) -> None:
|
|
76
|
-
super().__init__(msg=msg)
|
|
93
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
77
94
|
|
|
78
|
-
self.verbose = verbose
|
|
79
95
|
self.refresh_per_second = refresh_per_second
|
|
80
96
|
|
|
81
|
-
self._progress:
|
|
82
|
-
self._task:
|
|
97
|
+
self._progress: Progress | None = None
|
|
98
|
+
self._task: int | None = None
|
|
83
99
|
|
|
84
|
-
self._live:
|
|
100
|
+
self._live: Live | None = None
|
|
85
101
|
self._exception: bool = False
|
|
86
102
|
|
|
87
103
|
def init_progress(self, total: int, description: str) -> None:
|
|
@@ -151,7 +167,7 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
151
167
|
|
|
152
168
|
table = Table.grid(padding=(0, 1))
|
|
153
169
|
|
|
154
|
-
icon:
|
|
170
|
+
icon: Text | Padding
|
|
155
171
|
if self._exception:
|
|
156
172
|
style = 'red'
|
|
157
173
|
icon = Text('❌', style=style)
|
|
@@ -175,3 +191,153 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
175
191
|
|
|
176
192
|
if self.verbose and self._progress is not None:
|
|
177
193
|
yield self._progress.get_renderable()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class StreamlitProgressLogger(ProgressLogger):
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
msg: str,
|
|
200
|
+
verbose: bool = True,
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__(msg=msg, verbose=verbose)
|
|
203
|
+
|
|
204
|
+
self._status: Any = None
|
|
205
|
+
|
|
206
|
+
self._total = 0
|
|
207
|
+
self._current = 0
|
|
208
|
+
self._description: str = ''
|
|
209
|
+
self._progress: Any = None
|
|
210
|
+
|
|
211
|
+
def __enter__(self) -> Self:
|
|
212
|
+
super().__enter__()
|
|
213
|
+
|
|
214
|
+
import streamlit as st
|
|
215
|
+
|
|
216
|
+
# Adjust layout for prettier output:
|
|
217
|
+
st.markdown(STREAMLIT_CSS, unsafe_allow_html=True)
|
|
218
|
+
|
|
219
|
+
if self.verbose:
|
|
220
|
+
self._status = st.status(
|
|
221
|
+
f':blue[{self._sanitize_text(self.msg)}]',
|
|
222
|
+
expanded=True,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return self
|
|
226
|
+
|
|
227
|
+
def log(self, msg: str) -> None:
|
|
228
|
+
super().log(msg)
|
|
229
|
+
if self.verbose and self._status is not None:
|
|
230
|
+
self._status.write(self._sanitize_text(msg))
|
|
231
|
+
|
|
232
|
+
def init_progress(self, total: int, description: str) -> None:
|
|
233
|
+
if self.verbose and self._status is not None:
|
|
234
|
+
self._total = total
|
|
235
|
+
self._current = 0
|
|
236
|
+
self._description = self._sanitize_text(description)
|
|
237
|
+
percent = min(self._current / self._total, 1.0)
|
|
238
|
+
self._progress = self._status.progress(
|
|
239
|
+
value=percent,
|
|
240
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def step(self) -> None:
|
|
244
|
+
self._current += 1
|
|
245
|
+
|
|
246
|
+
if self.verbose and self._progress is not None:
|
|
247
|
+
percent = min(self._current / self._total, 1.0)
|
|
248
|
+
self._progress.progress(
|
|
249
|
+
value=percent,
|
|
250
|
+
text=f'{self._description} [{self._current}/{self._total}]',
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
254
|
+
super().__exit__(exc_type, exc_val, exc_tb)
|
|
255
|
+
|
|
256
|
+
if not self.verbose or self._status is None:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
label = f'{self._sanitize_text(self.msg)} ({self.duration:.2f}s)'
|
|
260
|
+
|
|
261
|
+
if exc_type is not None:
|
|
262
|
+
self._status.update(
|
|
263
|
+
label=f':red[{label}]',
|
|
264
|
+
state='error',
|
|
265
|
+
expanded=True,
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
self._status.update(
|
|
269
|
+
label=f':green[{label}]',
|
|
270
|
+
state='complete',
|
|
271
|
+
expanded=True,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
@staticmethod
|
|
275
|
+
def _sanitize_text(msg: str) -> str:
|
|
276
|
+
return re.sub(r'\[/?bold\]', '**', msg)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
STREAMLIT_CSS = """
|
|
280
|
+
<style>
|
|
281
|
+
/* Fix horizontal scrollbar */
|
|
282
|
+
.stExpander summary {
|
|
283
|
+
width: auto;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
/* Fix paddings/margins */
|
|
287
|
+
.stExpander summary {
|
|
288
|
+
padding: 0.75rem 1rem 0.5rem;
|
|
289
|
+
}
|
|
290
|
+
.stExpander p {
|
|
291
|
+
margin: 0px 0px 0.2rem;
|
|
292
|
+
}
|
|
293
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
294
|
+
padding-bottom: 1.45rem;
|
|
295
|
+
}
|
|
296
|
+
.stExpander .stProgress div:first-child {
|
|
297
|
+
padding-bottom: 4px;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/* Fix expand icon position */
|
|
301
|
+
.stExpander summary svg {
|
|
302
|
+
height: 1.5rem;
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
/* Fix summary icons */
|
|
306
|
+
.stExpander summary [data-testid="stExpanderIconCheck"] {
|
|
307
|
+
font-size: 1.8rem;
|
|
308
|
+
margin-top: -3px;
|
|
309
|
+
color: rgb(21, 130, 55);
|
|
310
|
+
}
|
|
311
|
+
.stExpander summary [data-testid="stExpanderIconError"] {
|
|
312
|
+
font-size: 1.8rem;
|
|
313
|
+
margin-top: -3px;
|
|
314
|
+
color: rgb(255, 43, 43);
|
|
315
|
+
}
|
|
316
|
+
.stExpander summary span:first-child span:first-child {
|
|
317
|
+
width: 1.6rem;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
/* Add border between title and content */
|
|
321
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
322
|
+
border-top: 1px solid rgba(30, 37, 47, 0.2);
|
|
323
|
+
padding-top: 0.5rem;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
/* Fix title font size */
|
|
327
|
+
.stExpander summary p {
|
|
328
|
+
font-size: 1rem;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
/* Gray out content */
|
|
332
|
+
.stExpander [data-testid="stExpanderDetails"] {
|
|
333
|
+
color: rgba(30, 37, 47, 0.5);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
/* Fix progress bar font size */
|
|
337
|
+
.stExpander .stProgress p {
|
|
338
|
+
line-height: 1.6;
|
|
339
|
+
font-size: 1rem;
|
|
340
|
+
color: rgba(30, 37, 47, 0.5);
|
|
341
|
+
}
|
|
342
|
+
</style>
|
|
343
|
+
"""
|
kumoai/utils/sql.py
ADDED
{kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202512301731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -41,7 +41,9 @@ Requires-Dist: requests-mock; extra == "test"
|
|
|
41
41
|
Provides-Extra: sqlite
|
|
42
42
|
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
43
|
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: numpy<2.0; extra == "snowflake"
|
|
44
45
|
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
46
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
45
47
|
Provides-Extra: sagemaker
|
|
46
48
|
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
47
49
|
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|