kumoai 2.13.0.dev202511131731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +191 -50
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/sampler.py +761 -0
  24. kumoai/experimental/rfm/base/source.py +19 -0
  25. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  26. kumoai/experimental/rfm/base/table.py +753 -0
  27. kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
  28. kumoai/experimental/rfm/infer/__init__.py +8 -0
  29. kumoai/experimental/rfm/infer/dtype.py +81 -0
  30. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  31. kumoai/experimental/rfm/infer/pkey.py +128 -0
  32. kumoai/experimental/rfm/infer/stype.py +35 -0
  33. kumoai/experimental/rfm/infer/time_col.py +61 -0
  34. kumoai/experimental/rfm/pquery/executor.py +27 -27
  35. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  36. kumoai/experimental/rfm/rfm.py +322 -252
  37. kumoai/experimental/rfm/sagemaker.py +138 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/distilled_trainer.py +175 -0
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +178 -12
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +13 -2
  47. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +50 -29
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  50. kumoai/experimental/rfm/local_table.py +0 -545
  51. kumoai/experimental/rfm/utils.py +0 -344
  52. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ import base64
2
+ import json
3
+ from typing import Any
4
+
5
+ import requests
6
+
7
+ from kumoai.client import KumoClient
8
+ from kumoai.client.endpoints import Endpoint, HTTPMethod
9
+ from kumoai.exceptions import HTTPException
10
+
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
21
+
22
+ class SageMakerResponseAdapter(requests.Response):
23
+ def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
24
+ super().__init__()
25
+ # Read the body bytes
26
+ self._content = sm_response['Body'].read()
27
+ self.status_code = 200
28
+ self.headers['Content-Type'] = sm_response.get('ContentType',
29
+ 'application/json')
30
+ # Optionally, you can store original sm_response for debugging
31
+ self.sm_response = sm_response
32
+
33
+ @property
34
+ def text(self) -> str:
35
+ assert isinstance(self._content, bytes)
36
+ return self._content.decode('utf-8')
37
+
38
+ def json(self, **kwargs) -> dict[str, Any]: # type: ignore
39
+ return json.loads(self.text, **kwargs)
40
+
41
+
42
+ class KumoClient_SageMakerAdapter(KumoClient):
43
+ def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
45
+ self._client: SageMakerRuntimeClient = boto3.client(
46
+ service_name="sagemaker-runtime", region_name=region)
47
+ self._endpoint_name = endpoint_name
48
+
49
+ # Recording buffers.
50
+ self._recording_active = False
51
+ self._recorded_reqs: list[dict[str, Any]] = []
52
+ self._recorded_resps: list[dict[str, Any]] = []
53
+
54
+ def authenticate(self) -> None:
55
+ # TODO(siyang): call /ping to verify?
56
+ pass
57
+
58
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
59
+ assert endpoint.method == HTTPMethod.POST
60
+ if 'json' in kwargs:
61
+ payload = json.dumps(kwargs.pop('json'))
62
+ elif 'data' in kwargs:
63
+ raw_payload = kwargs.pop('data')
64
+ assert isinstance(raw_payload, bytes)
65
+ payload = base64.b64encode(raw_payload).decode()
66
+ else:
67
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
68
+
69
+ request = {
70
+ 'method': endpoint.get_path().rsplit('/')[-1],
71
+ 'payload': payload,
72
+ }
73
+ response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
74
+ EndpointName=self._endpoint_name,
75
+ ContentType="application/json",
76
+ Body=json.dumps(request),
77
+ )
78
+
79
+ adapted_response = SageMakerResponseAdapter(response)
80
+
81
+ # If validation is active, store input/output
82
+ if self._recording_active:
83
+ self._recorded_reqs.append(request)
84
+ self._recorded_resps.append(adapted_response.json())
85
+
86
+ return adapted_response
87
+
88
+ def start_recording(self) -> None:
89
+ """Start recording requests/responses to/from sagemaker endpoint."""
90
+ assert not self._recording_active
91
+ self._recording_active = True
92
+ self._recorded_reqs.clear()
93
+ self._recorded_resps.clear()
94
+
95
+ def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
96
+ """Stop recording and return recorded requests/responses."""
97
+ assert self._recording_active
98
+ self._recording_active = False
99
+ recorded = list(zip(self._recorded_reqs, self._recorded_resps))
100
+ self._recorded_reqs.clear()
101
+ self._recorded_resps.clear()
102
+ return recorded
103
+
104
+
105
+ class KumoClient_SageMakerProxy_Local(KumoClient):
106
+ def __init__(self, url: str):
107
+ self._client = KumoClient(url, api_key=None)
108
+ self._client._api_url = self._client._url
109
+ self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
110
+
111
+ def authenticate(self) -> None:
112
+ try:
113
+ self._client._session.get(
114
+ self._url + '/ping',
115
+ verify=self._verify_ssl).raise_for_status()
116
+ except Exception:
117
+ raise ValueError(
118
+ "Client authentication failed. Please check if you "
119
+ "have a valid API key/credentials.")
120
+
121
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
122
+ assert endpoint.method == HTTPMethod.POST
123
+ if 'json' in kwargs:
124
+ payload = json.dumps(kwargs.pop('json'))
125
+ elif 'data' in kwargs:
126
+ raw_payload = kwargs.pop('data')
127
+ assert isinstance(raw_payload, bytes)
128
+ payload = base64.b64encode(raw_payload).decode()
129
+ else:
130
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
131
+ return self._client._request(
132
+ self._endpoint,
133
+ json={
134
+ 'method': endpoint.get_path().rsplit('/')[-1],
135
+ 'payload': payload,
136
+ },
137
+ **kwargs,
138
+ )
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking:
kumoai/spcs.py CHANGED
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
54
54
  api_key=global_state._api_key,
55
55
  spcs_token=spcs_token,
56
56
  )
57
- if not client.authenticate():
58
- raise ValueError("Client authentication failed. Please check if you "
59
- "have a valid API key.")
57
+ client.authenticate()
60
58
 
61
59
  # Update state:
62
60
  global_state.set_spcs_token(spcs_token)
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
25
25
  def has_package(package: str) -> bool:
26
26
  r"""Returns ``True`` in case ``package`` is installed."""
27
27
  req = Requirement(package)
28
- if importlib.util.find_spec(req.name) is None:
28
+ if importlib.util.find_spec(req.name) is None: # type: ignore
29
29
  return False
30
30
 
31
31
  try:
kumoai/testing/snow.py ADDED
@@ -0,0 +1,50 @@
1
+ import json
2
+ import os
3
+
4
+ from kumoai.experimental.rfm.backend.snow import Connection
5
+ from kumoai.experimental.rfm.backend.snow import connect as _connect
6
+
7
+
8
+ def connect(
9
+ region: str,
10
+ id: str,
11
+ account: str,
12
+ user: str,
13
+ warehouse: str,
14
+ database: str | None = None,
15
+ schema: str | None = None,
16
+ ) -> Connection:
17
+
18
+ kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
19
+ if kwargs['password'] is None:
20
+ import boto3
21
+ from cryptography.hazmat.primitives import serialization
22
+
23
+ client = boto3.client(
24
+ service_name='secretsmanager',
25
+ region_name=region,
26
+ )
27
+ secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
28
+ f'{account}.snowflakecomputing.com')
29
+ response = client.get_secret_value(SecretId=secret_id)['SecretString']
30
+ secret = json.loads(response)
31
+
32
+ private_key = serialization.load_pem_private_key(
33
+ secret['kumo_user_secretkey'].encode(),
34
+ password=None,
35
+ )
36
+ kwargs['private_key'] = private_key.private_bytes(
37
+ encoding=serialization.Encoding.DER,
38
+ format=serialization.PrivateFormat.PKCS8,
39
+ encryption_algorithm=serialization.NoEncryption(),
40
+ )
41
+
42
+ return _connect(
43
+ account=account,
44
+ user=user,
45
+ warehouse='WH_XS',
46
+ database='KUMO',
47
+ schema=schema,
48
+ session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
49
+ **kwargs,
50
+ )
@@ -0,0 +1,175 @@
1
+ import logging
2
+ from typing import Literal, Mapping, Optional, Union, overload
3
+
4
+ from kumoapi.distilled_model_plan import DistilledModelPlan
5
+ from kumoapi.jobs import DistillationJobRequest, DistillationJobResource
6
+
7
+ from kumoai import global_state
8
+ from kumoai.client.jobs import TrainingJobID
9
+ from kumoai.graph import Graph
10
+ from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
11
+ from kumoai.trainer.job import TrainingJob, TrainingJobResult
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DistillationTrainer:
17
+ r"""A trainer supports creating a Kumo machine learning model
18
+ for use in an online serving endpoint. The distllation process involes
19
+ training a shallow model on a :class:`~kumoai.pquery.PredictiveQuery` using
20
+ the embeddings generated by a base model :args:`base_training_job_id`.
21
+
22
+ Args:
23
+ model_plan: The distilled model plan to use for the distillation process.
24
+ base_training_job_id: The ID of the base training job to use for the distillation process.
25
+ """ # noqa: E501
26
+
27
+ def __init__(
28
+ self,
29
+ model_plan: DistilledModelPlan,
30
+ base_training_job_id: TrainingJobID,
31
+ ) -> None:
32
+ self.model_plan: DistilledModelPlan = model_plan
33
+ self.base_training_job_id: TrainingJobID = base_training_job_id
34
+
35
+ # Cached from backend:
36
+ self._training_job_id: Optional[TrainingJobID] = None
37
+
38
+ # Metadata ################################################################
39
+
40
+ @property
41
+ def is_trained(self) -> bool:
42
+ r"""Returns ``True`` if this trainer instance has successfully been
43
+ trained (and is therefore ready for prediction); ``False`` otherwise.
44
+ """
45
+ raise NotImplementedError(
46
+ "Checking if a distilled trainer is trained is not "
47
+ "implemented yet.")
48
+
49
+ @overload
50
+ def fit(
51
+ self,
52
+ graph: Graph,
53
+ train_table: Union[TrainingTable, TrainingTableJob],
54
+ ) -> TrainingJobResult:
55
+ pass
56
+
57
+ @overload
58
+ def fit(
59
+ self,
60
+ graph: Graph,
61
+ train_table: Union[TrainingTable, TrainingTableJob],
62
+ *,
63
+ non_blocking: Literal[False],
64
+ ) -> TrainingJobResult:
65
+ pass
66
+
67
+ @overload
68
+ def fit(
69
+ self,
70
+ graph: Graph,
71
+ train_table: Union[TrainingTable, TrainingTableJob],
72
+ *,
73
+ non_blocking: Literal[True],
74
+ ) -> TrainingJob:
75
+ pass
76
+
77
+ @overload
78
+ def fit(
79
+ self,
80
+ graph: Graph,
81
+ train_table: Union[TrainingTable, TrainingTableJob],
82
+ *,
83
+ non_blocking: bool,
84
+ ) -> Union[TrainingJob, TrainingJobResult]:
85
+ pass
86
+
87
+ def fit(
88
+ self,
89
+ graph: Graph,
90
+ train_table: Union[TrainingTable, TrainingTableJob],
91
+ *,
92
+ non_blocking: bool = False,
93
+ custom_tags: Mapping[str, str] = {},
94
+ ) -> Union[TrainingJob, TrainingJobResult]:
95
+ r"""Fits a model to the specified graph and training table, with the
96
+ strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`.
97
+
98
+ Args:
99
+ graph: The :class:`~kumoai.graph.Graph` object that represents the
100
+ tables and relationships that Kumo will learn from.
101
+ train_table: The :class:`~kumoai.pquery.TrainingTable`, or
102
+ in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
103
+ represents the training data produced by a
104
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
105
+ non_blocking: Whether this operation should return immediately
106
+ after launching the training job, or await completion of the
107
+ training job.
108
+ custom_tags: Additional, customer defined k-v tags to be associated
109
+ with the job to be launched. Job tags are useful for grouping
110
+ and searching jobs.
111
+
112
+ Returns:
113
+ Union[TrainingJobResult, TrainingJob]:
114
+ If ``non_blocking=False``, returns a training job object. If
115
+ ``non_blocking=True``, returns a training job future object.
116
+ """
117
+ # TODO(manan, siyang): remove soon:
118
+ job_id = train_table.job_id
119
+ assert job_id is not None
120
+
121
+ train_table_job_api = global_state.client.generate_train_table_job_api
122
+ pq_id = train_table_job_api.get(job_id).config.pquery_id
123
+ assert pq_id is not None
124
+
125
+ custom_table = None
126
+ if isinstance(train_table, TrainingTable):
127
+ custom_table = train_table._custom_train_table
128
+
129
+ # NOTE the backend implementation currently handles sequentialization
130
+ # between a training table future and a training job; that is, if the
131
+ # training table future is still executing, the backend will wait on
132
+ # the job ID completion before executing a training job. This preserves
133
+ # semantics for both futures, ensures that Kumo works as expected if
134
+ # used only via REST API, and allows us to avoid chaining calllbacks
135
+ # in an ugly way here:
136
+ api = global_state.client.distillation_job_api
137
+ self._training_job_id = api.create(
138
+ DistillationJobRequest(
139
+ dict(custom_tags),
140
+ pquery_id=pq_id,
141
+ base_training_job_id=self.base_training_job_id,
142
+ distilled_model_plan=self.model_plan,
143
+ graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
144
+ train_table_job_id=job_id,
145
+ custom_train_table=custom_table,
146
+ ))
147
+
148
+ out = TrainingJob(job_id=self._training_job_id)
149
+ if non_blocking:
150
+ return out
151
+ return out.attach()
152
+
153
+ @classmethod
154
+ def _load_from_job(
155
+ cls,
156
+ job: DistillationJobResource,
157
+ ) -> 'DistillationTrainer':
158
+ trainer = cls(job.config.distilled_model_plan,
159
+ job.config.base_training_job_id)
160
+ trainer._training_job_id = job.job_id
161
+ return trainer
162
+
163
+ @classmethod
164
+ def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer':
165
+ r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
166
+ job ID.
167
+ """
168
+ raise NotImplementedError(
169
+ "Loading a distilled trainer from a job ID is not implemented yet."
170
+ )
171
+
172
+ @classmethod
173
+ def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer':
174
+ raise NotImplementedError(
175
+ "Loading a distilled trainer from tags is not implemented yet.")
kumoai/utils/__init__.py CHANGED
@@ -1,10 +1,11 @@
1
- from .progress_logger import ProgressLogger, InteractiveProgressLogger
1
+ from .sql import quote_ident
2
+ from .progress_logger import ProgressLogger
2
3
  from .forecasting import ForecastVisualizer
3
4
  from .datasets import from_relbench
4
5
 
5
6
  __all__ = [
7
+ 'quote_ident',
6
8
  'ProgressLogger',
7
- 'InteractiveProgressLogger',
8
9
  'ForecastVisualizer',
9
10
  'from_relbench',
10
11
  ]
@@ -1,6 +1,7 @@
1
+ import re
1
2
  import sys
2
3
  import time
3
- from typing import Any, List, Optional, Union
4
+ from typing import Any
4
5
 
5
6
  from rich.console import Console, ConsoleOptions, RenderResult
6
7
  from rich.live import Live
@@ -20,12 +21,22 @@ from typing_extensions import Self
20
21
 
21
22
 
22
23
  class ProgressLogger:
23
- def __init__(self, msg: str) -> None:
24
+ def __init__(self, msg: str, verbose: bool = True) -> None:
24
25
  self.msg = msg
25
- self.logs: List[str] = []
26
+ self.verbose = verbose
27
+
28
+ self.logs: list[str] = []
29
+
30
+ self.start_time: float | None = None
31
+ self.end_time: float | None = None
32
+
33
+ @classmethod
34
+ def default(cls, msg: str, verbose: bool = True) -> 'ProgressLogger':
35
+ from kumoai import in_snowflake_notebook
26
36
 
27
- self.start_time: Optional[float] = None
28
- self.end_time: Optional[float] = None
37
+ if in_snowflake_notebook():
38
+ return StreamlitProgressLogger(msg, verbose)
39
+ return RichProgressLogger(msg, verbose)
29
40
 
30
41
  @property
31
42
  def duration(self) -> float:
@@ -37,6 +48,12 @@ class ProgressLogger:
37
48
  def log(self, msg: str) -> None:
38
49
  self.logs.append(msg)
39
50
 
51
+ def init_progress(self, total: int, description: str) -> None:
52
+ pass
53
+
54
+ def step(self) -> None:
55
+ pass
56
+
40
57
  def __enter__(self) -> Self:
41
58
  self.start_time = time.perf_counter()
42
59
  return self
@@ -66,22 +83,21 @@ class ColoredTimeRemainingColumn(TimeRemainingColumn):
66
83
  return Text(str(super().render(task)), style=self.style)
67
84
 
68
85
 
69
- class InteractiveProgressLogger(ProgressLogger):
86
+ class RichProgressLogger(ProgressLogger):
70
87
  def __init__(
71
88
  self,
72
89
  msg: str,
73
90
  verbose: bool = True,
74
91
  refresh_per_second: int = 10,
75
92
  ) -> None:
76
- super().__init__(msg=msg)
93
+ super().__init__(msg=msg, verbose=verbose)
77
94
 
78
- self.verbose = verbose
79
95
  self.refresh_per_second = refresh_per_second
80
96
 
81
- self._progress: Optional[Progress] = None
82
- self._task: Optional[int] = None
97
+ self._progress: Progress | None = None
98
+ self._task: int | None = None
83
99
 
84
- self._live: Optional[Live] = None
100
+ self._live: Live | None = None
85
101
  self._exception: bool = False
86
102
 
87
103
  def init_progress(self, total: int, description: str) -> None:
@@ -151,7 +167,7 @@ class InteractiveProgressLogger(ProgressLogger):
151
167
 
152
168
  table = Table.grid(padding=(0, 1))
153
169
 
154
- icon: Union[Text, Padding]
170
+ icon: Text | Padding
155
171
  if self._exception:
156
172
  style = 'red'
157
173
  icon = Text('❌', style=style)
@@ -175,3 +191,153 @@ class InteractiveProgressLogger(ProgressLogger):
175
191
 
176
192
  if self.verbose and self._progress is not None:
177
193
  yield self._progress.get_renderable()
194
+
195
+
196
+ class StreamlitProgressLogger(ProgressLogger):
197
+ def __init__(
198
+ self,
199
+ msg: str,
200
+ verbose: bool = True,
201
+ ) -> None:
202
+ super().__init__(msg=msg, verbose=verbose)
203
+
204
+ self._status: Any = None
205
+
206
+ self._total = 0
207
+ self._current = 0
208
+ self._description: str = ''
209
+ self._progress: Any = None
210
+
211
+ def __enter__(self) -> Self:
212
+ super().__enter__()
213
+
214
+ import streamlit as st
215
+
216
+ # Adjust layout for prettier output:
217
+ st.markdown(STREAMLIT_CSS, unsafe_allow_html=True)
218
+
219
+ if self.verbose:
220
+ self._status = st.status(
221
+ f':blue[{self._sanitize_text(self.msg)}]',
222
+ expanded=True,
223
+ )
224
+
225
+ return self
226
+
227
+ def log(self, msg: str) -> None:
228
+ super().log(msg)
229
+ if self.verbose and self._status is not None:
230
+ self._status.write(self._sanitize_text(msg))
231
+
232
+ def init_progress(self, total: int, description: str) -> None:
233
+ if self.verbose and self._status is not None:
234
+ self._total = total
235
+ self._current = 0
236
+ self._description = self._sanitize_text(description)
237
+ percent = min(self._current / self._total, 1.0)
238
+ self._progress = self._status.progress(
239
+ value=percent,
240
+ text=f'{self._description} [{self._current}/{self._total}]',
241
+ )
242
+
243
+ def step(self) -> None:
244
+ self._current += 1
245
+
246
+ if self.verbose and self._progress is not None:
247
+ percent = min(self._current / self._total, 1.0)
248
+ self._progress.progress(
249
+ value=percent,
250
+ text=f'{self._description} [{self._current}/{self._total}]',
251
+ )
252
+
253
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
254
+ super().__exit__(exc_type, exc_val, exc_tb)
255
+
256
+ if not self.verbose or self._status is None:
257
+ return
258
+
259
+ label = f'{self._sanitize_text(self.msg)} ({self.duration:.2f}s)'
260
+
261
+ if exc_type is not None:
262
+ self._status.update(
263
+ label=f':red[{label}]',
264
+ state='error',
265
+ expanded=True,
266
+ )
267
+ else:
268
+ self._status.update(
269
+ label=f':green[{label}]',
270
+ state='complete',
271
+ expanded=True,
272
+ )
273
+
274
+ @staticmethod
275
+ def _sanitize_text(msg: str) -> str:
276
+ return re.sub(r'\[/?bold\]', '**', msg)
277
+
278
+
279
+ STREAMLIT_CSS = """
280
+ <style>
281
+ /* Fix horizontal scrollbar */
282
+ .stExpander summary {
283
+ width: auto;
284
+ }
285
+
286
+ /* Fix paddings/margins */
287
+ .stExpander summary {
288
+ padding: 0.75rem 1rem 0.5rem;
289
+ }
290
+ .stExpander p {
291
+ margin: 0px 0px 0.2rem;
292
+ }
293
+ .stExpander [data-testid="stExpanderDetails"] {
294
+ padding-bottom: 1.45rem;
295
+ }
296
+ .stExpander .stProgress div:first-child {
297
+ padding-bottom: 4px;
298
+ }
299
+
300
+ /* Fix expand icon position */
301
+ .stExpander summary svg {
302
+ height: 1.5rem;
303
+ }
304
+
305
+ /* Fix summary icons */
306
+ .stExpander summary [data-testid="stExpanderIconCheck"] {
307
+ font-size: 1.8rem;
308
+ margin-top: -3px;
309
+ color: rgb(21, 130, 55);
310
+ }
311
+ .stExpander summary [data-testid="stExpanderIconError"] {
312
+ font-size: 1.8rem;
313
+ margin-top: -3px;
314
+ color: rgb(255, 43, 43);
315
+ }
316
+ .stExpander summary span:first-child span:first-child {
317
+ width: 1.6rem;
318
+ }
319
+
320
+ /* Add border between title and content */
321
+ .stExpander [data-testid="stExpanderDetails"] {
322
+ border-top: 1px solid rgba(30, 37, 47, 0.2);
323
+ padding-top: 0.5rem;
324
+ }
325
+
326
+ /* Fix title font size */
327
+ .stExpander summary p {
328
+ font-size: 1rem;
329
+ }
330
+
331
+ /* Gray out content */
332
+ .stExpander [data-testid="stExpanderDetails"] {
333
+ color: rgba(30, 37, 47, 0.5);
334
+ }
335
+
336
+ /* Fix progress bar font size */
337
+ .stExpander .stProgress p {
338
+ line-height: 1.6;
339
+ font-size: 1rem;
340
+ color: rgba(30, 37, 47, 0.5);
341
+ }
342
+ </style>
343
+ """
kumoai/utils/sql.py ADDED
@@ -0,0 +1,3 @@
1
+ def quote_ident(name: str) -> str:
2
+ r"""Quotes a SQL identifier."""
3
+ return '"' + name.replace('"', '""') + '"'