kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/__init__.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import threading
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from kumoapi.typing import Dtype, Stype
|
|
10
|
+
|
|
11
|
+
from kumoai.client.client import KumoClient
|
|
12
|
+
from kumoai._logging import initialize_logging, _ENV_KUMO_LOG
|
|
13
|
+
from kumoai._singleton import Singleton
|
|
14
|
+
from kumoai.futures import create_future, initialize_event_loop
|
|
15
|
+
from kumoai.spcs import (
|
|
16
|
+
_get_active_session,
|
|
17
|
+
_get_spcs_token,
|
|
18
|
+
_run_refresh_spcs_token,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
initialize_logging()
|
|
22
|
+
initialize_event_loop()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class GlobalState(metaclass=Singleton):
|
|
27
|
+
r"""Global storage of the state needed to create a Kumo client object. A
|
|
28
|
+
singleton so its initialized state can be referenced elsewhere for free.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# NOTE fork semantics: CoW on Linux, and re-execed on Windows. So this will
|
|
32
|
+
# likely not work on Windows unless we have special handling for the shared
|
|
33
|
+
# state:
|
|
34
|
+
_url: Optional[str] = None
|
|
35
|
+
_api_key: Optional[str] = None
|
|
36
|
+
_snowflake_credentials: Optional[Dict[str, Any]] = None
|
|
37
|
+
_spcs_token: Optional[str] = None
|
|
38
|
+
_snowpark_session: Optional[Any] = None
|
|
39
|
+
|
|
40
|
+
thread_local: threading.local = threading.local()
|
|
41
|
+
|
|
42
|
+
def clear(self) -> None:
|
|
43
|
+
if hasattr(self.thread_local, '_client'):
|
|
44
|
+
del self.thread_local._client
|
|
45
|
+
self._url = None
|
|
46
|
+
self._api_key = None
|
|
47
|
+
self._snowflake_credentials = None
|
|
48
|
+
self._spcs_token = None
|
|
49
|
+
|
|
50
|
+
def set_spcs_token(self, spcs_token: str) -> None:
|
|
51
|
+
# Set the spcs token in the global state. This will be picked up the
|
|
52
|
+
# next time client() is accessed.
|
|
53
|
+
self._spcs_token = spcs_token
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def initialized(self) -> bool:
|
|
57
|
+
return self._url is not None and (
|
|
58
|
+
self._api_key is not None or self._snowflake_credentials
|
|
59
|
+
is not None or self._snowpark_session is not None)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def client(self) -> KumoClient:
|
|
63
|
+
r"""Accessor for the Kumo client. Note that clients are stored as
|
|
64
|
+
thread-local variables as the requests Session library is not
|
|
65
|
+
guaranteed to be thread-safe.
|
|
66
|
+
|
|
67
|
+
For more information, see https://github.com/psf/requests/issues/1871.
|
|
68
|
+
"""
|
|
69
|
+
if self._url is None or (self._api_key is None
|
|
70
|
+
and self._spcs_token is None
|
|
71
|
+
and self._snowpark_session is None):
|
|
72
|
+
raise ValueError("Client creation or authentication failed. "
|
|
73
|
+
"Please re-create your client before proceeding.")
|
|
74
|
+
|
|
75
|
+
if hasattr(self.thread_local, '_client'):
|
|
76
|
+
# Set the spcs token in the client to ensure it has the latest.
|
|
77
|
+
self.thread_local._client.set_spcs_token(self._spcs_token)
|
|
78
|
+
return self.thread_local._client
|
|
79
|
+
|
|
80
|
+
client = KumoClient(self._url, self._api_key, self._spcs_token)
|
|
81
|
+
self.thread_local._client = client
|
|
82
|
+
return client
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def is_spcs(self) -> bool:
|
|
86
|
+
return (self._api_key is None
|
|
87
|
+
and (self._snowflake_credentials is not None
|
|
88
|
+
or self._snowpark_session is not None))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
global_state: GlobalState = GlobalState()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def init(
|
|
95
|
+
url: Optional[str] = None,
|
|
96
|
+
api_key: Optional[str] = None,
|
|
97
|
+
snowflake_credentials: Optional[Dict[str, str]] = None,
|
|
98
|
+
snowflake_application: Optional[str] = None,
|
|
99
|
+
log_level: str = "INFO",
|
|
100
|
+
) -> None:
|
|
101
|
+
r"""Initializes and authenticates the API key against the Kumo service.
|
|
102
|
+
Successful authentication is required to use the SDK.
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
>>> import kumoai
|
|
106
|
+
>>> kumoai.init(url="<api_url>", api_key="<api_key>") # doctest: +SKIP
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
url: The Kumo API endpoint. Can also be provided via the
|
|
110
|
+
``KUMO_API_ENDPOINT`` envronment variable. Will be inferred from
|
|
111
|
+
the provided API key, if not provided.
|
|
112
|
+
api_key: The Kumo API key. Can also be provided via the
|
|
113
|
+
``KUMO_API_KEY`` environment variable.
|
|
114
|
+
snowflake_credentials: The Snowflake credentials to authenticate
|
|
115
|
+
against the Kumo service. The dictionary should contain the keys
|
|
116
|
+
``"user"``, ``"password"``, and ``"account"``. This should only be
|
|
117
|
+
provided for SPCS.
|
|
118
|
+
snowflake_application: The Snowflake application.
|
|
119
|
+
log_level: The logging level that Kumo operates under. Defaults to
|
|
120
|
+
INFO; for more information, please see
|
|
121
|
+
:class:`~kumoai.set_log_level`. Can also be set with the
|
|
122
|
+
environment variable ``KUMOAI_LOG``.
|
|
123
|
+
""" # noqa
|
|
124
|
+
# Avoid mutations to the global state after it is set:
|
|
125
|
+
if global_state.initialized:
|
|
126
|
+
warnings.warn("Kumo SDK already initialized. To re-initialize the "
|
|
127
|
+
"SDK, please start a new interpreter. No changes will "
|
|
128
|
+
"be made to the current session.")
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
|
|
132
|
+
|
|
133
|
+
# Get API key:
|
|
134
|
+
api_key = api_key or os.getenv("KUMO_API_KEY")
|
|
135
|
+
|
|
136
|
+
snowpark_session = None
|
|
137
|
+
if snowflake_application:
|
|
138
|
+
if url is not None:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Kumo SDK initialization failed. Both 'snowflake_application' "
|
|
141
|
+
"and 'url' are specified. If running from a Snowflake "
|
|
142
|
+
"notebook, specify only 'snowflake_application'.")
|
|
143
|
+
snowpark_session = _get_active_session()
|
|
144
|
+
if not snowpark_session:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"Kumo SDK initialization failed. 'snowflake_application' is "
|
|
147
|
+
"specified without an active Snowpark session. If running "
|
|
148
|
+
"outside a Snowflake notebook, specify a URL and credentials.")
|
|
149
|
+
description = snowpark_session.sql(
|
|
150
|
+
f"DESCRIBE SERVICE {snowflake_application}."
|
|
151
|
+
"USER_SCHEMA.KUMO_SERVICE").collect()[0]
|
|
152
|
+
url = f"http://{description.dns_name}:8888/public_api"
|
|
153
|
+
|
|
154
|
+
if api_key is None and not snowflake_application:
|
|
155
|
+
if snowflake_credentials is None:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Kumo SDK initialization failed. Neither an API key nor "
|
|
158
|
+
"Snowflake credentials provided. Please either set the "
|
|
159
|
+
"'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
|
|
160
|
+
if (set(snowflake_credentials.keys())
|
|
161
|
+
!= {'user', 'password', 'account'}):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"Provided Snowflake credentials should be a dictionary with "
|
|
164
|
+
f"keys 'user', 'password', and 'account'. Only "
|
|
165
|
+
f"{set(snowflake_credentials.keys())} were provided.")
|
|
166
|
+
|
|
167
|
+
# Get or infer URL:
|
|
168
|
+
url = url or os.getenv("KUMO_API_ENDPOINT")
|
|
169
|
+
try:
|
|
170
|
+
if api_key:
|
|
171
|
+
url = url or f"http://{api_key.split(':')[0]}.kumoai.cloud/api"
|
|
172
|
+
except KeyError:
|
|
173
|
+
pass
|
|
174
|
+
if url is None:
|
|
175
|
+
raise ValueError("Kumo SDK initialization failed since no endpoint "
|
|
176
|
+
"URL was provided. Please either set the "
|
|
177
|
+
"'KUMO_API_ENDPOINT' environment variable or "
|
|
178
|
+
"explicitly call `kumoai.init(...)`.")
|
|
179
|
+
|
|
180
|
+
# Assign global state after verification that client can be created and
|
|
181
|
+
# authenticated successfully:
|
|
182
|
+
spcs_token = _get_spcs_token(
|
|
183
|
+
snowflake_credentials
|
|
184
|
+
) if not api_key and snowflake_credentials else None
|
|
185
|
+
client = KumoClient(url=url, api_key=api_key, spcs_token=spcs_token)
|
|
186
|
+
client.authenticate()
|
|
187
|
+
global_state._url = client._url
|
|
188
|
+
global_state._api_key = client._api_key
|
|
189
|
+
global_state._snowflake_credentials = snowflake_credentials
|
|
190
|
+
global_state._spcs_token = client._spcs_token
|
|
191
|
+
global_state._snowpark_session = snowpark_session
|
|
192
|
+
|
|
193
|
+
if not api_key and snowflake_credentials:
|
|
194
|
+
# Refresh token every 10 minutes (expires in 1 hour):
|
|
195
|
+
create_future(_run_refresh_spcs_token(minutes=10))
|
|
196
|
+
|
|
197
|
+
logger = logging.getLogger('kumoai')
|
|
198
|
+
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
199
|
+
|
|
200
|
+
logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
|
|
201
|
+
f"'{url}'")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def set_log_level(level: str) -> None:
|
|
205
|
+
r"""Sets the Kumo logging level, which defines the amount of output that
|
|
206
|
+
methods produce.
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> import kumoai
|
|
210
|
+
>>> kumoai.set_log_level("INFO") # doctest: +SKIP
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
level: the logging level. Can be one of (in order of lowest to highest
|
|
214
|
+
log output) :obj:`DEBUG`, :obj:`INFO`, :obj:`WARNING`,
|
|
215
|
+
:obj:`ERROR`, :obj:`FATAL`, :obj:`CRITICAL`.
|
|
216
|
+
"""
|
|
217
|
+
# logging library will ensure `level` is a valid string, and raise a
|
|
218
|
+
# warning if not:
|
|
219
|
+
logging.getLogger('kumoai').setLevel(level)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# Try to initialize purely with environment variables:
|
|
223
|
+
if ("pytest" not in sys.modules and "KUMO_API_KEY" in os.environ
|
|
224
|
+
and "KUMO_API_ENDPOINT" in os.environ):
|
|
225
|
+
init()
|
|
226
|
+
|
|
227
|
+
import kumoai.connector # noqa
|
|
228
|
+
import kumoai.encoder # noqa
|
|
229
|
+
import kumoai.graph # noqa
|
|
230
|
+
import kumoai.pquery # noqa
|
|
231
|
+
import kumoai.trainer # noqa
|
|
232
|
+
import kumoai.utils # noqa
|
|
233
|
+
import kumoai.databricks # noqa
|
|
234
|
+
|
|
235
|
+
from kumoai.connector import ( # noqa
|
|
236
|
+
SourceTable, SourceTableFuture, S3Connector, SnowflakeConnector,
|
|
237
|
+
DatabricksConnector, BigQueryConnector, FileUploadConnector, GlueConnector)
|
|
238
|
+
from kumoai.graph import Column, Edge, Graph, Table # noqa
|
|
239
|
+
from kumoai.pquery import ( # noqa
|
|
240
|
+
PredictionTableGenerationPlan, PredictiveQuery,
|
|
241
|
+
TrainingTableGenerationPlan, TrainingTable, TrainingTableJob,
|
|
242
|
+
PredictionTable, PredictionTableJob)
|
|
243
|
+
from kumoai.trainer import ( # noqa
|
|
244
|
+
ModelPlan, Trainer, TrainingJobResult, TrainingJob,
|
|
245
|
+
BatchPredictionJobResult, BatchPredictionJob)
|
|
246
|
+
from kumoai._version import __version__ # noqa
|
|
247
|
+
|
|
248
|
+
__all__ = [
|
|
249
|
+
'Dtype',
|
|
250
|
+
'Stype',
|
|
251
|
+
'SourceTable',
|
|
252
|
+
'SourceTableFuture',
|
|
253
|
+
'S3Connector',
|
|
254
|
+
'SnowflakeConnector',
|
|
255
|
+
'DatabricksConnector',
|
|
256
|
+
'BigQueryConnector',
|
|
257
|
+
'FileUploadConnector',
|
|
258
|
+
'GlueConnector',
|
|
259
|
+
'Column',
|
|
260
|
+
'Table',
|
|
261
|
+
'Graph',
|
|
262
|
+
'Edge',
|
|
263
|
+
'PredictiveQuery',
|
|
264
|
+
'TrainingTable',
|
|
265
|
+
'TrainingTableJob',
|
|
266
|
+
'TrainingTableGenerationPlan',
|
|
267
|
+
'PredictionTable',
|
|
268
|
+
'PredictionTableJob',
|
|
269
|
+
'PredictionTableGenerationPlan',
|
|
270
|
+
'Trainer',
|
|
271
|
+
'TrainingJobResult',
|
|
272
|
+
'TrainingJob',
|
|
273
|
+
'BatchPredictionJobResult',
|
|
274
|
+
'BatchPredictionJob',
|
|
275
|
+
'ModelPlan',
|
|
276
|
+
'__version__',
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def in_snowflake_notebook() -> bool:
|
|
281
|
+
try:
|
|
282
|
+
from snowflake.snowpark.context import get_active_session
|
|
283
|
+
import streamlit # noqa: F401
|
|
284
|
+
get_active_session()
|
|
285
|
+
return True
|
|
286
|
+
except Exception:
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def in_notebook() -> bool:
|
|
291
|
+
if in_snowflake_notebook():
|
|
292
|
+
return True
|
|
293
|
+
try:
|
|
294
|
+
from IPython import get_ipython
|
|
295
|
+
shell = get_ipython()
|
|
296
|
+
if 'google.colab' in str(shell.__class__):
|
|
297
|
+
return True
|
|
298
|
+
return shell.__class__.__name__ == 'ZMQInteractiveShell'
|
|
299
|
+
except Exception:
|
|
300
|
+
return False
|
kumoai/_logging.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
_ENV_KUMO_LOG = "KUMO_LOG"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def initialize_logging() -> None:
|
|
8
|
+
r"""Initializes Kumo logging."""
|
|
9
|
+
logger: logging.Logger = logging.getLogger('kumoai')
|
|
10
|
+
|
|
11
|
+
# From openai-python/blob/main/src/openai/_utils/_logs.py#L4
|
|
12
|
+
logging.basicConfig(
|
|
13
|
+
format=(
|
|
14
|
+
"[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s"),
|
|
15
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
default_level = os.getenv(_ENV_KUMO_LOG, "INFO")
|
|
19
|
+
try:
|
|
20
|
+
logger.setLevel(default_level)
|
|
21
|
+
except (TypeError, ValueError):
|
|
22
|
+
logger.setLevel(logging.INFO)
|
|
23
|
+
logger.warning(
|
|
24
|
+
"Logging level %s could not be properly parsed. "
|
|
25
|
+
"Defaulting to INFO log level.", default_level)
|
|
26
|
+
|
|
27
|
+
for name in ["matplotlib", "urllib3", "snowflake"]:
|
|
28
|
+
# TODO(dm) required for spcs
|
|
29
|
+
logging.getLogger(name).setLevel(logging.ERROR)
|
kumoai/_singleton.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from abc import ABCMeta
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Singleton(ABCMeta):
|
|
6
|
+
r"""A per-process singleton definition."""
|
|
7
|
+
_instances: Dict[type, Any] = {}
|
|
8
|
+
|
|
9
|
+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
if cls not in cls._instances:
|
|
11
|
+
# Calls the `__init__` method of the subclass and returns a
|
|
12
|
+
# reference, which is stored to prevent multiple instantiations.
|
|
13
|
+
instance = super(Singleton, cls).__call__(*args, **kwargs)
|
|
14
|
+
cls._instances[cls] = instance
|
|
15
|
+
return instance
|
|
16
|
+
return cls._instances[cls]
|
|
17
|
+
|
|
18
|
+
def clear(cls) -> None:
|
|
19
|
+
r"""Clears the singleton class instance, so the next construction
|
|
20
|
+
will re-initialize the clas.
|
|
21
|
+
"""
|
|
22
|
+
try:
|
|
23
|
+
del Singleton._instances[cls]
|
|
24
|
+
except KeyError:
|
|
25
|
+
pass
|
kumoai/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '2.14.0.dev202601011731'
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from kumoapi.jobs import MetadataField, WriteMode
|
|
6
|
+
from kumoapi.typing import WITH_PYDANTIC_V2
|
|
7
|
+
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
|
8
|
+
|
|
9
|
+
if WITH_PYDANTIC_V2:
|
|
10
|
+
from pydantic import field_validator, model_validator # type: ignore
|
|
11
|
+
else:
|
|
12
|
+
from pydantic import root_validator, validator
|
|
13
|
+
|
|
14
|
+
from kumoai.connector.base import Connector
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compatible_field_validator(field_name: str): # type: ignore
|
|
18
|
+
"""Decorator factory that creates a field validator compatible with both
|
|
19
|
+
Pydantic v1 and v2.
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
@compatible_field_validator('field_name')
|
|
23
|
+
def validate_field(cls, v, values_or_info):
|
|
24
|
+
# Your validation logic here
|
|
25
|
+
return v
|
|
26
|
+
"""
|
|
27
|
+
def decorator(func): # type: ignore
|
|
28
|
+
if WITH_PYDANTIC_V2:
|
|
29
|
+
|
|
30
|
+
@field_validator(field_name)
|
|
31
|
+
@classmethod
|
|
32
|
+
@functools.wraps(func)
|
|
33
|
+
def wrapper(cls, v, info): # type: ignore
|
|
34
|
+
# Convert info to values dict for compatibility
|
|
35
|
+
values = info.data if hasattr(info, 'data') else {}
|
|
36
|
+
return func(cls, v, values)
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
else:
|
|
40
|
+
|
|
41
|
+
@validator(field_name)
|
|
42
|
+
@functools.wraps(func)
|
|
43
|
+
def wrapper(cls, v, values): # type: ignore
|
|
44
|
+
return func(cls, v, values)
|
|
45
|
+
|
|
46
|
+
return wrapper
|
|
47
|
+
|
|
48
|
+
return decorator
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# TODO: probably will need to be removed b/c using __post_init__ instead
|
|
52
|
+
def compatible_model_validator(mode='before'): # type: ignore
|
|
53
|
+
"""Decorator factory that creates a model validator compatible with both
|
|
54
|
+
Pydantic v1 and v2.
|
|
55
|
+
|
|
56
|
+
Usage:
|
|
57
|
+
@compatible_model_validator()
|
|
58
|
+
def validate_model(cls, values):
|
|
59
|
+
# Your validation logic here
|
|
60
|
+
return values
|
|
61
|
+
"""
|
|
62
|
+
def decorator(func): # type: ignore
|
|
63
|
+
if WITH_PYDANTIC_V2:
|
|
64
|
+
|
|
65
|
+
@model_validator(mode=mode)
|
|
66
|
+
@classmethod
|
|
67
|
+
@functools.wraps(func)
|
|
68
|
+
def wrapper(cls, values): # type: ignore
|
|
69
|
+
return func(cls, values)
|
|
70
|
+
|
|
71
|
+
return wrapper
|
|
72
|
+
else:
|
|
73
|
+
|
|
74
|
+
@root_validator
|
|
75
|
+
@functools.wraps(func)
|
|
76
|
+
def wrapper(cls, values): # type: ignore
|
|
77
|
+
return func(cls, values)
|
|
78
|
+
|
|
79
|
+
return wrapper
|
|
80
|
+
|
|
81
|
+
return decorator
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(frozen=True)
|
|
85
|
+
class QueryConnectorConfig:
|
|
86
|
+
# If using OVERWRITE, big query connector will first write to a staging
|
|
87
|
+
# table followed by overwriting to the destination table.
|
|
88
|
+
# When using APPEND, it is strongly recommended to use
|
|
89
|
+
# MetadataField.JOB_TIMESTAMP to indicate the timestamp of the job.
|
|
90
|
+
write_mode: WriteMode = WriteMode.APPEND
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass(frozen=True)
|
|
94
|
+
class BigQueryOutputConfig(QueryConnectorConfig):
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class SnowflakeConnectorConfig(QueryConnectorConfig):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
CONNECTOR_CONFIG_MAPPING = {
|
|
104
|
+
'BigQueryConnector': BigQueryOutputConfig,
|
|
105
|
+
'SnowflakeConnector': SnowflakeConnectorConfig,
|
|
106
|
+
# 'DatabricksConnector': DatabricksOutputConfig,
|
|
107
|
+
# 'S3Connector': S3OutputConfig,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@pydantic_dataclass(frozen=True, config={'arbitrary_types_allowed': True})
|
|
112
|
+
class OutputConfig:
|
|
113
|
+
"""Output configuration associated with a Batch Prediction Job.
|
|
114
|
+
Specifies the output types and optionally output data source
|
|
115
|
+
configuration.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
output_types(`Set[str]`): The types of outputs that should be produced
|
|
119
|
+
by the prediction job. Can include either ``'predictions'``,
|
|
120
|
+
``'embeddings'``, or both.
|
|
121
|
+
output_connector(`Connector` or None): The output data source that Kumo
|
|
122
|
+
should write batch predictions to, if it is None,
|
|
123
|
+
produce local download output only.
|
|
124
|
+
output_table_name(`str` or `Tuple[str, str]` or None): The name of the
|
|
125
|
+
table in the output data source
|
|
126
|
+
that Kumo should write batch predictions to. In the case of
|
|
127
|
+
a Databricks connector, this should be a tuple of two strings:
|
|
128
|
+
the schema name and the output prediction table name.
|
|
129
|
+
output_metadata_fields(`List[MetadataField]` or None): Any additional
|
|
130
|
+
metadata fields to include as new columns in the produced
|
|
131
|
+
``'predictions'`` output. Currently, allowed options are
|
|
132
|
+
``JOB_TIMESTAMP`` and ``ANCHOR_TIMESTAMP``.
|
|
133
|
+
connector_specific_config(`QueryConnectorConfig` or None): The custom
|
|
134
|
+
connector specific output config for predictions, for
|
|
135
|
+
example whether to append or overwrite existing table.
|
|
136
|
+
"""
|
|
137
|
+
output_types: Set[str]
|
|
138
|
+
output_connector: Optional[Connector] = None
|
|
139
|
+
output_table_name: Optional[Union[str, Tuple]] = None
|
|
140
|
+
output_metadata_fields: Optional[List[MetadataField]] = None
|
|
141
|
+
connector_specific_config: Optional[Union[
|
|
142
|
+
BigQueryOutputConfig,
|
|
143
|
+
SnowflakeConnectorConfig,
|
|
144
|
+
]] = None
|
|
145
|
+
|
|
146
|
+
@compatible_field_validator('connector_specific_config')
|
|
147
|
+
def validate_connector_config(cls, v: Any, values: Dict) -> Any:
|
|
148
|
+
"""Validate the connector specific output config. Raises ValueError if
|
|
149
|
+
there is a mismatch between the connector type and the config type.
|
|
150
|
+
"""
|
|
151
|
+
# Skip validation if no connector or no specific config
|
|
152
|
+
if values.get('output_connector') is None or v is None:
|
|
153
|
+
return v
|
|
154
|
+
|
|
155
|
+
connector_type = type(values['output_connector']).__name__
|
|
156
|
+
expected_config_type = CONNECTOR_CONFIG_MAPPING.get(connector_type)
|
|
157
|
+
|
|
158
|
+
# If we don't have a mapping for this connector type, it doesn't
|
|
159
|
+
# support specific configs yet
|
|
160
|
+
if expected_config_type is None:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Connector type '{connector_type}' does not support "
|
|
163
|
+
f"specific output configurations")
|
|
164
|
+
|
|
165
|
+
# Check if the provided config is of the correct type
|
|
166
|
+
if not isinstance(v, expected_config_type):
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Connector type '{connector_type}' requires output "
|
|
169
|
+
f"config of type '{expected_config_type.__name__}', but "
|
|
170
|
+
f"got '{type(v).__name__}'")
|
|
171
|
+
|
|
172
|
+
return v
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pydantic_dataclass(frozen=True, config={'arbitrary_types_allowed': True})
|
|
176
|
+
class TrainingTableExportConfig(OutputConfig):
|
|
177
|
+
"""Export configuration associated with a Training Table.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
output_types(`Set[str]`): The artifact to export from the training
|
|
181
|
+
table job. Currently only `'training_table'` is supported.
|
|
182
|
+
Which exports the full training table to the output connector.
|
|
183
|
+
output_connector(`Connector`): The output data source that Kumo should
|
|
184
|
+
write training table artifacts to.
|
|
185
|
+
output_table_name(str): The name of the table in the output data source
|
|
186
|
+
that Kumo should write batch predictions to. In the case of
|
|
187
|
+
a Databricks connector, this should be a tuple of two strings:
|
|
188
|
+
the schema name and the output prediction table name.
|
|
189
|
+
connector_specific_config(QueryConnectorConfig or None):
|
|
190
|
+
Defines custom connector specific output
|
|
191
|
+
for example whether to append or overwrite
|
|
192
|
+
existing table. This is currently only supported for BigQuery and
|
|
193
|
+
Snowflake.
|
|
194
|
+
"""
|
|
195
|
+
output_connector: Connector
|
|
196
|
+
output_table_name: str
|
|
197
|
+
|
|
198
|
+
def __post_init__(self) -> None:
|
|
199
|
+
if self.output_types != {'training_table'}:
|
|
200
|
+
raise ValueError("output_type must be set(['training_table'])"
|
|
201
|
+
f" (got {self.output_types})")
|
|
202
|
+
if self.output_connector is None:
|
|
203
|
+
raise ValueError("output_connector is required")
|
|
204
|
+
if self.output_table_name is None:
|
|
205
|
+
raise ValueError("output_table_name is required")
|
|
206
|
+
if self.output_metadata_fields is not None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"output_metadata_fields is not supported for training"
|
|
209
|
+
"table export")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from kumoapi.common import JobStatus
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from kumoai import global_state
|
|
10
|
+
from kumoai.futures import KumoProgressFuture, create_future
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ArtifactExportResult:
|
|
14
|
+
r"""Represents a completed artifact export job."""
|
|
15
|
+
def __init__(self, job_id: str) -> None:
|
|
16
|
+
self.job_id = job_id
|
|
17
|
+
|
|
18
|
+
def tracking_url(self) -> str:
|
|
19
|
+
r"""Returns a tracking URL pointing to the UI display of
|
|
20
|
+
this prediction export job.
|
|
21
|
+
"""
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
|
|
24
|
+
def __repr__(self) -> str:
|
|
25
|
+
return f"{self.__class__.__name__}(job_id={self.job_id})"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ArtifactExportJob(KumoProgressFuture[ArtifactExportResult]):
|
|
29
|
+
"""Represents an in-progress artifact export job."""
|
|
30
|
+
def __init__(self, job_id: str) -> None:
|
|
31
|
+
self.job_id = job_id
|
|
32
|
+
self._fut: concurrent.futures.Future = create_future(
|
|
33
|
+
_poll_export(job_id))
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def id(self) -> str:
|
|
37
|
+
"""The unique ID of this export job."""
|
|
38
|
+
return self.job_id
|
|
39
|
+
|
|
40
|
+
@override
|
|
41
|
+
def result(self) -> ArtifactExportResult:
|
|
42
|
+
return self._fut.result()
|
|
43
|
+
|
|
44
|
+
@override
|
|
45
|
+
def future(self) -> 'concurrent.futures.Future[ArtifactExportResult]':
|
|
46
|
+
return self._fut
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def _attach_internal(
|
|
50
|
+
self,
|
|
51
|
+
interval_s: float = 20.0,
|
|
52
|
+
) -> ArtifactExportResult:
|
|
53
|
+
"""Allows a user to attach to a running export job and view
|
|
54
|
+
its progress.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
interval_s (float): Time interval (seconds) between polls, minimum
|
|
58
|
+
value allowed is 4 seconds.
|
|
59
|
+
"""
|
|
60
|
+
assert interval_s >= 4.0
|
|
61
|
+
print(f"Attaching to export job {self.job_id}. To detach from "
|
|
62
|
+
f"this job, please enter Ctrl+C (the job will continue to run, "
|
|
63
|
+
f"and you can re-attach anytime).")
|
|
64
|
+
|
|
65
|
+
# TODO improve print statements.
|
|
66
|
+
# Will require changes to status to return
|
|
67
|
+
# JobStatusReport instead of JobStatus.
|
|
68
|
+
while not self.done():
|
|
69
|
+
status = self.status()
|
|
70
|
+
print(f"Export job {self.job_id} status: {status}")
|
|
71
|
+
time.sleep(interval_s)
|
|
72
|
+
|
|
73
|
+
return self.result()
|
|
74
|
+
|
|
75
|
+
def status(self) -> JobStatus:
|
|
76
|
+
"""Returns the status of a running export job."""
|
|
77
|
+
return get_export_status(self.job_id)
|
|
78
|
+
|
|
79
|
+
def cancel(self) -> bool:
|
|
80
|
+
"""Cancels a running export job.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
bool: True if the job is in a terminal state.
|
|
84
|
+
"""
|
|
85
|
+
api = global_state.client.artifact_export_api
|
|
86
|
+
status = api.cancel(self.job_id)
|
|
87
|
+
if status == JobStatus.CANCELLED:
|
|
88
|
+
return True
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_export_status(job_id: str) -> JobStatus:
|
|
93
|
+
api = global_state.client.artifact_export_api
|
|
94
|
+
resource = api.get(job_id)
|
|
95
|
+
return resource
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def _poll_export(job_id: str) -> ArtifactExportResult:
|
|
99
|
+
status = get_export_status(job_id)
|
|
100
|
+
while not status.is_terminal:
|
|
101
|
+
await asyncio.sleep(10)
|
|
102
|
+
status = get_export_status(job_id)
|
|
103
|
+
|
|
104
|
+
if status != JobStatus.DONE:
|
|
105
|
+
raise RuntimeError(f"Export job {job_id} failed "
|
|
106
|
+
f"with job status {status}.")
|
|
107
|
+
|
|
108
|
+
return ArtifactExportResult(job_id=job_id)
|