kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from concurrent.futures import Future
|
|
8
|
+
from functools import cached_property
|
|
9
|
+
from typing import List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from kumoapi.common import JobStatus
|
|
13
|
+
from kumoapi.jobs import (
|
|
14
|
+
ArtifactExportRequest,
|
|
15
|
+
CustomTrainingTable,
|
|
16
|
+
GenerateTrainTableJobResource,
|
|
17
|
+
GenerateTrainTableRequest,
|
|
18
|
+
JobStatusReport,
|
|
19
|
+
SourceTableType,
|
|
20
|
+
TrainingTableOutputConfig,
|
|
21
|
+
TrainingTableSpec,
|
|
22
|
+
WriteMode,
|
|
23
|
+
)
|
|
24
|
+
from kumoapi.source_table import S3SourceTable
|
|
25
|
+
from tqdm.auto import tqdm
|
|
26
|
+
from typing_extensions import Self, override
|
|
27
|
+
|
|
28
|
+
from kumoai import global_state
|
|
29
|
+
from kumoai.artifact_export import (
|
|
30
|
+
ArtifactExportJob,
|
|
31
|
+
ArtifactExportResult,
|
|
32
|
+
TrainingTableExportConfig,
|
|
33
|
+
)
|
|
34
|
+
from kumoai.client.jobs import (
|
|
35
|
+
GenerateTrainTableJobAPI,
|
|
36
|
+
GenerateTrainTableJobID,
|
|
37
|
+
)
|
|
38
|
+
from kumoai.connector import S3Connector, SourceTable
|
|
39
|
+
from kumoai.databricks import to_db_table_name
|
|
40
|
+
from kumoai.formatting import pretty_print_error_details
|
|
41
|
+
from kumoai.futures import KumoProgressFuture, create_future
|
|
42
|
+
from kumoai.jobs import JobInterface
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
_DEFAULT_INTERVAL_S = 20
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TrainingTable:
|
|
50
|
+
r"""A training table in the Kumo platform. A training table can be
|
|
51
|
+
initialized from a job ID of a completed training table generation job.
|
|
52
|
+
|
|
53
|
+
.. code-block:: python
|
|
54
|
+
|
|
55
|
+
import kumoai
|
|
56
|
+
|
|
57
|
+
# Create a Training Table from a training table generation job. Note
|
|
58
|
+
# that the job ID passed here must be in a completed state:
|
|
59
|
+
training_table = kumoai.TrainingTable("gen-traintable-job-...")
|
|
60
|
+
|
|
61
|
+
# Read the training table as a Pandas DataFrame:
|
|
62
|
+
training_df = training_table.data_df()
|
|
63
|
+
|
|
64
|
+
# Get URLs to download the training table:
|
|
65
|
+
training_download_urls = training_table.data_urls()
|
|
66
|
+
|
|
67
|
+
# Add weight column to the training table:
|
|
68
|
+
# see `kumo-sdk.examples.datasets.weighted_train_table.py`
|
|
69
|
+
# for a more detailed example
|
|
70
|
+
# 1. Export train table
|
|
71
|
+
connector = kumo.S3Connector("s3_path")
|
|
72
|
+
training_table.export(TrainingTableExportConfig(
|
|
73
|
+
output_types={'training_table'},
|
|
74
|
+
output_connector=connector,
|
|
75
|
+
output_table_name="<any_name>"))
|
|
76
|
+
# 2. Assume the weight column was added to the train table
|
|
77
|
+
# and it was saved to the same S3 path as "<mod_name>"
|
|
78
|
+
training_table.update(SourceTable("<mod_table>", connector),
|
|
79
|
+
TrainingTableSpec(weight_col="weight"))
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
job_id: ID of the training table generation job which generated this
|
|
83
|
+
training table.
|
|
84
|
+
"""
|
|
85
|
+
def __init__(self, job_id: GenerateTrainTableJobID):
|
|
86
|
+
self.job_id = job_id
|
|
87
|
+
status = _get_status(job_id).status
|
|
88
|
+
self._custom_train_table: Optional[CustomTrainingTable] = None
|
|
89
|
+
if status != JobStatus.DONE:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Job {job_id} is not yet complete (status: {status}). If you "
|
|
92
|
+
f"would like to create a future (waiting for training table "
|
|
93
|
+
f"generation success), please use `TrainingTableJob`.")
|
|
94
|
+
|
|
95
|
+
def data_urls(self) -> List[str]:
|
|
96
|
+
r"""Returns a list of URLs that can be used to view generated
|
|
97
|
+
training table data. The list will contain more than one element
|
|
98
|
+
if the table is partitioned; paths will be relative to the location of
|
|
99
|
+
the Kumo data plane.
|
|
100
|
+
"""
|
|
101
|
+
api: GenerateTrainTableJobAPI = (
|
|
102
|
+
global_state.client.generate_train_table_job_api)
|
|
103
|
+
return api._get_table_data(self.job_id, presigned=True, raw_path=True)
|
|
104
|
+
|
|
105
|
+
def data_df(self) -> pd.DataFrame:
|
|
106
|
+
r"""Returns a :class:`~pandas.DataFrame` object representing the
|
|
107
|
+
generated training data.
|
|
108
|
+
|
|
109
|
+
.. warning::
|
|
110
|
+
|
|
111
|
+
This method will load the full training table into memory as a
|
|
112
|
+
:class:`~pandas.DataFrame` object. If you are working on a machine
|
|
113
|
+
with limited resources, please use
|
|
114
|
+
:meth:`~kumoai.pquery.TrainingTable.data_urls` instead to download
|
|
115
|
+
the data and perform analysis per-partition.
|
|
116
|
+
"""
|
|
117
|
+
urls = self.data_urls()
|
|
118
|
+
if global_state.is_spcs:
|
|
119
|
+
from kumoai.spcs import _parquet_dataset_to_df
|
|
120
|
+
|
|
121
|
+
# TODO(dm): return type hint is wrong
|
|
122
|
+
return _parquet_dataset_to_df(self.data_urls())
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
return pd.concat([pd.read_parquet(x) for x in urls])
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Could not create a Pandas DataFrame object from data paths "
|
|
129
|
+
f"{urls}. Please construct the DataFrame manually.") from e
|
|
130
|
+
|
|
131
|
+
def __repr__(self) -> str:
|
|
132
|
+
return f'{self.__class__.__name__}(job_id={self.job_id})'
|
|
133
|
+
|
|
134
|
+
def _to_s3_api_source_table(self,
|
|
135
|
+
source_table: SourceTable) -> S3SourceTable:
|
|
136
|
+
assert isinstance(source_table.connector, S3Connector)
|
|
137
|
+
source_type = source_table._to_api_source_table()
|
|
138
|
+
root_dir: str = source_table.connector.root_dir # type: ignore
|
|
139
|
+
if root_dir[-1] != os.sep:
|
|
140
|
+
root_dir = root_dir + os.sep
|
|
141
|
+
return S3SourceTable(
|
|
142
|
+
s3_path=root_dir,
|
|
143
|
+
source_table_name=source_table.name,
|
|
144
|
+
file_type=source_type.file_type,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def export(
|
|
148
|
+
self,
|
|
149
|
+
output_config: TrainingTableExportConfig,
|
|
150
|
+
non_blocking: bool = True,
|
|
151
|
+
) -> Union[ArtifactExportJob, ArtifactExportResult]:
|
|
152
|
+
r"""Export the training table to the connector.
|
|
153
|
+
specified in the output config. Use the exported table to
|
|
154
|
+
add a weight column then use `update` to update the training table.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
output_config: The output configuration to write the training
|
|
158
|
+
table.
|
|
159
|
+
non_blocking: If ``True``, the method will return a future object
|
|
160
|
+
`ArtifactExportJob` representing the export job.
|
|
161
|
+
If ``False``, the method will block until the export job is
|
|
162
|
+
complete and return `ArtifactExportResult`.
|
|
163
|
+
"""
|
|
164
|
+
assert output_config.output_connector is not None
|
|
165
|
+
assert output_config.output_types == {'training_table'}
|
|
166
|
+
output_table_name = to_db_table_name(output_config.output_table_name)
|
|
167
|
+
assert output_table_name is not None
|
|
168
|
+
s3_path = None
|
|
169
|
+
connector_id = None
|
|
170
|
+
table_name = ""
|
|
171
|
+
write_mode = WriteMode.OVERWRITE
|
|
172
|
+
|
|
173
|
+
if isinstance(output_config.output_connector, S3Connector):
|
|
174
|
+
assert output_config.output_connector.root_dir is not None
|
|
175
|
+
s3_path = output_config.output_connector.root_dir
|
|
176
|
+
s3_path = os.path.join(s3_path, output_table_name)
|
|
177
|
+
else:
|
|
178
|
+
connector_id = output_config.output_connector.name
|
|
179
|
+
table_name = output_table_name
|
|
180
|
+
if output_config.connector_specific_config:
|
|
181
|
+
write_mode = output_config.connector_specific_config.write_mode
|
|
182
|
+
|
|
183
|
+
api = global_state.client.artifact_export_api
|
|
184
|
+
output_config = TrainingTableOutputConfig(
|
|
185
|
+
s3_path=s3_path,
|
|
186
|
+
connector_id=connector_id,
|
|
187
|
+
table_name=table_name,
|
|
188
|
+
write_mode=write_mode,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
request = ArtifactExportRequest(job_id=self.job_id,
|
|
192
|
+
training_table_output=output_config)
|
|
193
|
+
job_id = api.create(request)
|
|
194
|
+
if non_blocking:
|
|
195
|
+
return ArtifactExportJob(job_id)
|
|
196
|
+
return ArtifactExportJob(job_id).attach()
|
|
197
|
+
|
|
198
|
+
def validate_custom_table(
|
|
199
|
+
self,
|
|
200
|
+
source_table_type: SourceTableType,
|
|
201
|
+
train_table_mod: TrainingTableSpec,
|
|
202
|
+
) -> None:
|
|
203
|
+
r"""Validates the modified training table.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
source_table_type: The source table to be used as the modified
|
|
207
|
+
training table.
|
|
208
|
+
train_table_mod: The modification specification.
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
ValueError: If the modified training table is invalid.
|
|
212
|
+
|
|
213
|
+
"""
|
|
214
|
+
api: GenerateTrainTableJobAPI = (
|
|
215
|
+
global_state.client.generate_train_table_job_api)
|
|
216
|
+
response = api.validate_custom_train_table(self.job_id,
|
|
217
|
+
source_table_type,
|
|
218
|
+
train_table_mod)
|
|
219
|
+
if not response.ok:
|
|
220
|
+
raise ValueError("Invalid weighted train table",
|
|
221
|
+
response.error_message)
|
|
222
|
+
|
|
223
|
+
def update(
|
|
224
|
+
self,
|
|
225
|
+
source_table: SourceTable,
|
|
226
|
+
train_table_mod: TrainingTableSpec,
|
|
227
|
+
validate: bool = True,
|
|
228
|
+
) -> Self:
|
|
229
|
+
r"""Sets the `source_table` as the modified training table.
|
|
230
|
+
|
|
231
|
+
.. note::
|
|
232
|
+
The only allowed modification is the addition of weight column
|
|
233
|
+
Any other modification might lead to unintentded ERRORS while
|
|
234
|
+
training.
|
|
235
|
+
Further negative/NA weight values are not supported.
|
|
236
|
+
|
|
237
|
+
The custom training table is ingested during trainer.fit()
|
|
238
|
+
and is used as the training table.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
source_table: The source table to be used as the modified training
|
|
242
|
+
table.
|
|
243
|
+
train_table_mod: The modification specification.
|
|
244
|
+
validate: Whether to validate the modified training table. This can
|
|
245
|
+
be slow for large tables.
|
|
246
|
+
"""
|
|
247
|
+
if isinstance(source_table.connector, S3Connector):
|
|
248
|
+
# Special handling for s3 as `source_table._to_api_source_table`
|
|
249
|
+
# concatenates root_dir and file name. But the backend expects
|
|
250
|
+
# these to be separate.
|
|
251
|
+
source_table_type = self._to_s3_api_source_table(source_table)
|
|
252
|
+
else:
|
|
253
|
+
source_table_type = source_table._to_api_source_table()
|
|
254
|
+
if validate:
|
|
255
|
+
self.validate_custom_table(source_table_type, train_table_mod)
|
|
256
|
+
self._custom_train_table = CustomTrainingTable(
|
|
257
|
+
source_table=source_table_type, table_mod_spec=train_table_mod,
|
|
258
|
+
validated=validate)
|
|
259
|
+
return self
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# Training Table Future #######################################################
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class TrainingTableJob(JobInterface[GenerateTrainTableJobID,
|
|
266
|
+
GenerateTrainTableRequest,
|
|
267
|
+
GenerateTrainTableJobResource],
|
|
268
|
+
KumoProgressFuture[TrainingTable]):
|
|
269
|
+
r"""A representation of an ongoing training table generation job in the
|
|
270
|
+
Kumo platform.
|
|
271
|
+
|
|
272
|
+
.. code-block:: python
|
|
273
|
+
|
|
274
|
+
import kumoai
|
|
275
|
+
|
|
276
|
+
# See `PredictiveQuery` documentation:
|
|
277
|
+
pquery = kumoai.PredictiveQuery(...)
|
|
278
|
+
|
|
279
|
+
# If a training table is generated in nonblocking mode, the response
|
|
280
|
+
# will be of type `TrainingTableJob`:
|
|
281
|
+
training_table_job = pquery.generate_training_table(non_blocking=True)
|
|
282
|
+
|
|
283
|
+
# You can also construct a `TrainingTableJob` from a job ID, e.g.
|
|
284
|
+
# one that is present in the Kumo Jobs page:
|
|
285
|
+
training_table_job = kumoai.TrainingTableJob("trainingjob-...")
|
|
286
|
+
|
|
287
|
+
# Get the status of the job:
|
|
288
|
+
print(training_table_job.status())
|
|
289
|
+
|
|
290
|
+
# Attach to the job, and poll progress updates:
|
|
291
|
+
training_table_job.attach()
|
|
292
|
+
|
|
293
|
+
# Cancel the job:
|
|
294
|
+
training_table_job.cancel()
|
|
295
|
+
|
|
296
|
+
# Wait for the job to complete, and return a `TrainingTable`:
|
|
297
|
+
training_table_job.result()
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
job_id: ID of the training table generation job.
|
|
301
|
+
"""
|
|
302
|
+
@override
|
|
303
|
+
@staticmethod
|
|
304
|
+
def _api() -> GenerateTrainTableJobAPI:
|
|
305
|
+
return global_state.client.generate_train_table_job_api
|
|
306
|
+
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
job_id: GenerateTrainTableJobID,
|
|
310
|
+
) -> None:
|
|
311
|
+
self.job_id = job_id
|
|
312
|
+
|
|
313
|
+
@cached_property
|
|
314
|
+
def _fut(self) -> Future[TrainingTable]:
|
|
315
|
+
return create_future(_poll(self.job_id))
|
|
316
|
+
|
|
317
|
+
@override
|
|
318
|
+
@property
|
|
319
|
+
def id(self) -> GenerateTrainTableJobID:
|
|
320
|
+
r"""The unique ID of this training table generation process."""
|
|
321
|
+
return self.job_id
|
|
322
|
+
|
|
323
|
+
@override
|
|
324
|
+
def result(self) -> TrainingTable:
|
|
325
|
+
return self._fut.result()
|
|
326
|
+
|
|
327
|
+
@override
|
|
328
|
+
def future(self) -> Future[TrainingTable]:
|
|
329
|
+
return self._fut
|
|
330
|
+
|
|
331
|
+
@override
|
|
332
|
+
def status(self) -> JobStatusReport:
|
|
333
|
+
r"""Returns the status of a running training table generation job."""
|
|
334
|
+
return _get_status(self.job_id)
|
|
335
|
+
|
|
336
|
+
@override
|
|
337
|
+
def _attach_internal(self, interval_s: float = 20.0) -> TrainingTable:
|
|
338
|
+
assert interval_s >= 4.0
|
|
339
|
+
print(f"Attaching to training table generation job {self.job_id}. "
|
|
340
|
+
f"Tracking this job in the Kumo UI is coming soon. To detach "
|
|
341
|
+
f"from this job, please enter Ctrl+C (the job will continue to "
|
|
342
|
+
f"run, and you can re-attach anytime).")
|
|
343
|
+
|
|
344
|
+
def _get_progress() -> Optional[Tuple[int, int]]:
|
|
345
|
+
progress = self._api().get_progress(self.job_id)
|
|
346
|
+
if len(progress) == 0:
|
|
347
|
+
return None
|
|
348
|
+
expected_iter = progress['num_expected_iterations']
|
|
349
|
+
completed_iter = progress['num_finished_iterations']
|
|
350
|
+
return (expected_iter, completed_iter)
|
|
351
|
+
|
|
352
|
+
# Print progress bar:
|
|
353
|
+
print("Training table generation is in progress. If your task is "
|
|
354
|
+
"temporal, progress per timeframe will be loaded shortly.")
|
|
355
|
+
|
|
356
|
+
# Wait for either timeframes to become available, or the job is done:
|
|
357
|
+
progress = _get_progress()
|
|
358
|
+
while progress is None:
|
|
359
|
+
progress = _get_progress()
|
|
360
|
+
# Not a temporal task, and it's done:
|
|
361
|
+
if self.status().status.is_terminal:
|
|
362
|
+
return self.result()
|
|
363
|
+
time.sleep(interval_s)
|
|
364
|
+
|
|
365
|
+
# Wait for timeframes to become available:
|
|
366
|
+
progress = _get_progress()
|
|
367
|
+
assert progress is not None
|
|
368
|
+
total, prog = progress
|
|
369
|
+
pbar = tqdm(total=total, unit="timeframe",
|
|
370
|
+
desc="Generating Training Table")
|
|
371
|
+
pbar.update(pbar.n - prog)
|
|
372
|
+
while not self.done():
|
|
373
|
+
progress = _get_progress()
|
|
374
|
+
assert progress is not None
|
|
375
|
+
total, prog = progress
|
|
376
|
+
pbar.reset(total)
|
|
377
|
+
pbar.update(prog)
|
|
378
|
+
time.sleep(interval_s)
|
|
379
|
+
pbar.update(pbar.total)
|
|
380
|
+
pbar.close()
|
|
381
|
+
|
|
382
|
+
# Future is done:
|
|
383
|
+
return self.result()
|
|
384
|
+
|
|
385
|
+
def cancel(self) -> None:
|
|
386
|
+
r"""Cancels a running training table generation job, and raises an
|
|
387
|
+
error if cancellation failed.
|
|
388
|
+
"""
|
|
389
|
+
return self._api().cancel(self.job_id)
|
|
390
|
+
|
|
391
|
+
@override
|
|
392
|
+
def load_config(self) -> GenerateTrainTableRequest:
|
|
393
|
+
r"""Load the full configuration for this training table generation job.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
GenerateTrainTableRequest: Complete configuration including plan,
|
|
397
|
+
pquery_id, graph_snapshot_id, etc.
|
|
398
|
+
"""
|
|
399
|
+
return self._api().get_config(self.job_id)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _get_status(job_id: str) -> JobStatusReport:
|
|
403
|
+
api = global_state.client.generate_train_table_job_api
|
|
404
|
+
resource: GenerateTrainTableJobResource = api.get(job_id)
|
|
405
|
+
return resource.job_status_report
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
async def _poll(job_id: str) -> TrainingTable:
|
|
409
|
+
# TODO(manan): make asynchronous natively with aiohttp:
|
|
410
|
+
status = _get_status(job_id).status
|
|
411
|
+
while not status.is_terminal:
|
|
412
|
+
await asyncio.sleep(_DEFAULT_INTERVAL_S)
|
|
413
|
+
status = _get_status(job_id).status
|
|
414
|
+
|
|
415
|
+
if status != JobStatus.DONE:
|
|
416
|
+
api = global_state.client.generate_train_table_job_api
|
|
417
|
+
error_details = api.get_job_error(job_id)
|
|
418
|
+
error_str = pretty_print_error_details(error_details)
|
|
419
|
+
raise RuntimeError(
|
|
420
|
+
f"Training table generation for job {job_id} failed with "
|
|
421
|
+
f"job status {status}. Encountered below error(s):"
|
|
422
|
+
f'{error_str}')
|
|
423
|
+
|
|
424
|
+
return TrainingTable(job_id)
|
kumoai/spcs.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
from functools import reduce
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from snowflake.snowpark import DataFrame, Session
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_spcs_token(snowflake_credentials: Dict[str, str]) -> str:
|
|
11
|
+
r"""Fetches a token to access a Kumo application deployed in Snowflake
|
|
12
|
+
Snowpark Container Services (SPCS). This token is valid for 1 hour, after
|
|
13
|
+
which the token must be re-generated.
|
|
14
|
+
"""
|
|
15
|
+
# Create a request to the ingress endpoint with authz:
|
|
16
|
+
active_session = _get_active_session()
|
|
17
|
+
if active_session is not None:
|
|
18
|
+
ctx = active_session.connection
|
|
19
|
+
else:
|
|
20
|
+
user = snowflake_credentials["user"]
|
|
21
|
+
password = snowflake_credentials["password"]
|
|
22
|
+
account = snowflake_credentials["account"]
|
|
23
|
+
import snowflake.connector
|
|
24
|
+
ctx = snowflake.connector.connect(
|
|
25
|
+
user=user,
|
|
26
|
+
password=password,
|
|
27
|
+
account=account,
|
|
28
|
+
session_parameters={
|
|
29
|
+
'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json'
|
|
30
|
+
},
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Obtain a session token:
|
|
34
|
+
token_data = ctx._rest._token_request('ISSUE')
|
|
35
|
+
token_extract = token_data['data']['sessionToken']
|
|
36
|
+
return f'\"{token_extract}\"'
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _refresh_spcs_token() -> None:
|
|
40
|
+
r"""Refreshes the SPCS token in global state to avoid expiration."""
|
|
41
|
+
from kumoai import KumoClient, global_state
|
|
42
|
+
if (not global_state.initialized
|
|
43
|
+
or (not global_state._snowflake_credentials
|
|
44
|
+
and not global_state._snowpark_session)):
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"Please initialize the Kumo application with snowflake "
|
|
47
|
+
"credentials before attempting to refresh this token.")
|
|
48
|
+
spcs_token = _get_spcs_token(global_state._snowflake_credentials or {})
|
|
49
|
+
|
|
50
|
+
# Verify token validity:
|
|
51
|
+
assert global_state._url is not None
|
|
52
|
+
client = KumoClient(
|
|
53
|
+
url=global_state._url,
|
|
54
|
+
api_key=global_state._api_key,
|
|
55
|
+
spcs_token=spcs_token,
|
|
56
|
+
)
|
|
57
|
+
client.authenticate()
|
|
58
|
+
|
|
59
|
+
# Update state:
|
|
60
|
+
global_state.set_spcs_token(spcs_token)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def _run_refresh_spcs_token(minutes: int) -> None:
|
|
64
|
+
r"""Runs the SPCS token refresh loop every `minutes` minutes."""
|
|
65
|
+
while True:
|
|
66
|
+
await asyncio.sleep(minutes * 60)
|
|
67
|
+
_refresh_spcs_token()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_active_session() -> 'Optional[Session]':
|
|
71
|
+
try:
|
|
72
|
+
from snowflake.snowpark.context import get_active_session
|
|
73
|
+
return get_active_session()
|
|
74
|
+
except Exception:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_session() -> 'Session':
|
|
79
|
+
import snowflake.snowpark as snowpark
|
|
80
|
+
|
|
81
|
+
from kumoai import global_state
|
|
82
|
+
params = global_state._snowflake_credentials
|
|
83
|
+
assert params is not None
|
|
84
|
+
|
|
85
|
+
database = os.getenv("SNOWFLAKE_DATABASE")
|
|
86
|
+
schema = os.getenv("SNOWFLAKE_SCHEMA")
|
|
87
|
+
if not database or not schema:
|
|
88
|
+
raise ValueError("Please set the SNOWFLAKE_DATABASE and "
|
|
89
|
+
"SNOWFLAKE_SCHEMA environment variables.")
|
|
90
|
+
params['database'] = database
|
|
91
|
+
params['schema'] = schema
|
|
92
|
+
params['client_session_keep_alive'] = True
|
|
93
|
+
|
|
94
|
+
return snowpark.Session.builder.configs(params).create()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _remove_path(session: 'Session', stage_path: str, file_path: str) -> None:
|
|
98
|
+
stage_prefix = '.'.join(stage_path.split('.')[:2])
|
|
99
|
+
name_remove = '.'.join([stage_prefix, file_path])
|
|
100
|
+
session.sql(f"REMOVE {name_remove}").collect()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _parquet_to_df(path: str) -> 'DataFrame':
|
|
104
|
+
r"""Reads parquet from the given path and returns a snowpark DataFrame."""
|
|
105
|
+
session = _get_session()
|
|
106
|
+
if not path.endswith(os.path.sep):
|
|
107
|
+
path += os.path.sep
|
|
108
|
+
file_list = session.sql(f"LIST {path}").collect()
|
|
109
|
+
for file_row in file_list:
|
|
110
|
+
if file_row.name.endswith('.parquet'):
|
|
111
|
+
continue
|
|
112
|
+
_remove_path(session, path, file_row.name)
|
|
113
|
+
df = session.read.parquet(path)
|
|
114
|
+
return df
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _parquet_dataset_to_df(paths: List[str]) -> 'DataFrame':
|
|
118
|
+
r"""Reads parquet from the given paths and returns a snowpark DataFrame."""
|
|
119
|
+
from snowflake.snowpark import DataFrame
|
|
120
|
+
df_list = [_parquet_to_df(url) for url in paths]
|
|
121
|
+
return reduce(DataFrame.union_all, df_list)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import os
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import packaging
|
|
6
|
+
from packaging.requirements import Requirement
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_full_test() -> bool:
|
|
10
|
+
r"""Whether to run the full but time-consuming test suite."""
|
|
11
|
+
return os.getenv('FULL_TEST', '0') == '1'
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def onlyFullTest(func: Callable) -> Callable:
|
|
15
|
+
r"""A decorator to specify that this function belongs to the full test
|
|
16
|
+
suite.
|
|
17
|
+
"""
|
|
18
|
+
import pytest
|
|
19
|
+
return pytest.mark.skipif(
|
|
20
|
+
not is_full_test(),
|
|
21
|
+
reason="Fast test run",
|
|
22
|
+
)(func)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def has_package(package: str) -> bool:
|
|
26
|
+
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
|
+
req = Requirement(package)
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
module = importlib.import_module(req.name)
|
|
33
|
+
if not hasattr(module, '__version__'):
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
version = packaging.version.Version(module.__version__).base_version
|
|
37
|
+
return version in req.specifier
|
|
38
|
+
except Exception:
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def withPackage(*args: str) -> Callable:
|
|
43
|
+
r"""A decorator to skip tests if certain packages are not installed.
|
|
44
|
+
Also supports version specification.
|
|
45
|
+
"""
|
|
46
|
+
na_packages = {package for package in args if not has_package(package)}
|
|
47
|
+
|
|
48
|
+
if len(na_packages) == 1:
|
|
49
|
+
reason = f"Package '{list(na_packages)[0]}' not found"
|
|
50
|
+
else:
|
|
51
|
+
reason = f"Packages {na_packages} not found"
|
|
52
|
+
|
|
53
|
+
def decorator(func: Callable) -> Callable:
|
|
54
|
+
import pytest
|
|
55
|
+
return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func)
|
|
56
|
+
|
|
57
|
+
return decorator
|
kumoai/testing/snow.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
5
|
+
from kumoai.experimental.rfm.backend.snow import connect as _connect
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def connect(
|
|
9
|
+
region: str,
|
|
10
|
+
id: str,
|
|
11
|
+
account: str,
|
|
12
|
+
user: str,
|
|
13
|
+
warehouse: str,
|
|
14
|
+
database: str | None = None,
|
|
15
|
+
schema: str | None = None,
|
|
16
|
+
) -> Connection:
|
|
17
|
+
|
|
18
|
+
kwargs = dict(password=os.getenv('SNOWFLAKE_PASSWORD'))
|
|
19
|
+
if kwargs['password'] is None:
|
|
20
|
+
import boto3
|
|
21
|
+
from cryptography.hazmat.primitives import serialization
|
|
22
|
+
|
|
23
|
+
client = boto3.client(
|
|
24
|
+
service_name='secretsmanager',
|
|
25
|
+
region_name=region,
|
|
26
|
+
)
|
|
27
|
+
secret_id = (f'arn:aws:secretsmanager:{region}:{id}:secret:'
|
|
28
|
+
f'{account}.snowflakecomputing.com')
|
|
29
|
+
response = client.get_secret_value(SecretId=secret_id)['SecretString']
|
|
30
|
+
secret = json.loads(response)
|
|
31
|
+
|
|
32
|
+
private_key = serialization.load_pem_private_key(
|
|
33
|
+
secret['kumo_user_secretkey'].encode(),
|
|
34
|
+
password=None,
|
|
35
|
+
)
|
|
36
|
+
kwargs['private_key'] = private_key.private_bytes(
|
|
37
|
+
encoding=serialization.Encoding.DER,
|
|
38
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
39
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return _connect(
|
|
43
|
+
account=account,
|
|
44
|
+
user=user,
|
|
45
|
+
warehouse='WH_XS',
|
|
46
|
+
database='KUMO',
|
|
47
|
+
schema=schema,
|
|
48
|
+
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
|
+
**kwargs,
|
|
50
|
+
)
|