kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/trainer/job.py
ADDED
|
@@ -0,0 +1,1192 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import time
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
8
|
+
from urllib.parse import urlparse, urlunparse
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from kumoapi.common import JobStatus
|
|
12
|
+
from kumoapi.data_snapshot import GraphSnapshotID
|
|
13
|
+
from kumoapi.jobs import (
|
|
14
|
+
ArtifactExportRequest,
|
|
15
|
+
AutoTrainerProgress,
|
|
16
|
+
BaselineEvaluationMetrics,
|
|
17
|
+
BaselineJobRequest,
|
|
18
|
+
BatchPredictionJobSummary,
|
|
19
|
+
BatchPredictionRequest,
|
|
20
|
+
JobStatusReport,
|
|
21
|
+
ModelEvaluationMetrics,
|
|
22
|
+
PredictionProgress,
|
|
23
|
+
TrainingJobRequest,
|
|
24
|
+
)
|
|
25
|
+
from kumoapi.model_plan import ModelPlan
|
|
26
|
+
from kumoapi.online_serving import (
|
|
27
|
+
OnlinePredictionOptions,
|
|
28
|
+
OnlineServingEndpointRequest,
|
|
29
|
+
)
|
|
30
|
+
from kumoapi.task import TaskType
|
|
31
|
+
from tqdm.auto import tqdm
|
|
32
|
+
from typing_extensions import override
|
|
33
|
+
|
|
34
|
+
from kumoai import global_state
|
|
35
|
+
from kumoai.artifact_export import (
|
|
36
|
+
ArtifactExportJob,
|
|
37
|
+
ArtifactExportResult,
|
|
38
|
+
OutputConfig,
|
|
39
|
+
)
|
|
40
|
+
from kumoai.client.jobs import (
|
|
41
|
+
BaselineJobAPI,
|
|
42
|
+
BaselineJobID,
|
|
43
|
+
BaselineJobResource,
|
|
44
|
+
BatchPredictionJobAPI,
|
|
45
|
+
BatchPredictionJobID,
|
|
46
|
+
BatchPredictionJobResource,
|
|
47
|
+
TrainingJobAPI,
|
|
48
|
+
TrainingJobID,
|
|
49
|
+
TrainingJobResource,
|
|
50
|
+
)
|
|
51
|
+
from kumoai.databricks import to_db_table_name
|
|
52
|
+
from kumoai.futures import KumoProgressFuture, create_future
|
|
53
|
+
from kumoai.jobs import JobInterface
|
|
54
|
+
from kumoai.trainer.online_serving import OnlineServingEndpointFuture
|
|
55
|
+
from kumoai.trainer.util import (
|
|
56
|
+
build_prediction_output_config,
|
|
57
|
+
validate_output_arguments,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if TYPE_CHECKING:
|
|
61
|
+
from kumoai.pquery import (
|
|
62
|
+
PredictionTable,
|
|
63
|
+
PredictionTableJob,
|
|
64
|
+
PredictiveQuery,
|
|
65
|
+
TrainingTable,
|
|
66
|
+
TrainingTableJob,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BaselineJobResult:
|
|
71
|
+
r"""Represents a completed baseline job.
|
|
72
|
+
|
|
73
|
+
A :class:`BaselineJobResult` object can either be obtained by creating a
|
|
74
|
+
:class:`~kumoai.trainer.BaselineJob` object and calling the
|
|
75
|
+
:meth:`~kumoai.trainer.BaselineJob.result` method to await the job's
|
|
76
|
+
completion, or by directly creating the object. The former approach is
|
|
77
|
+
recommended, as it includes verification that the job finished succesfully.
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
>>> import kumoai # doctest: +SKIP
|
|
81
|
+
>>> job_future = kumoai.BaselineJob(id=...) # doctest: +SKIP
|
|
82
|
+
>>> job = job_future.result() # doctest: +SKIP
|
|
83
|
+
"""
|
|
84
|
+
def __init__(self, job_id: BaselineJobID) -> None:
|
|
85
|
+
self.job_id = job_id
|
|
86
|
+
|
|
87
|
+
# A cached completed, finalized job resource:
|
|
88
|
+
self._job_resource: Optional[BaselineJobResource] = None
|
|
89
|
+
|
|
90
|
+
def metrics(self) -> Dict[str, BaselineEvaluationMetrics]:
|
|
91
|
+
r"""Returns the metrics associated with this completed training job,
|
|
92
|
+
or raises an exception if metrics cannot be obtained.
|
|
93
|
+
"""
|
|
94
|
+
return self._get_job_resource(
|
|
95
|
+
require_completed=True).result.eval_metrics
|
|
96
|
+
|
|
97
|
+
def _get_job_resource(self,
|
|
98
|
+
require_completed: bool) -> BaselineJobResource:
|
|
99
|
+
if self._job_resource:
|
|
100
|
+
return self._job_resource
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
api = global_state.client.baseline_job_api
|
|
104
|
+
resource: BaselineJobResource = api.get(self.job_id)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
raise RuntimeError(
|
|
107
|
+
f"Baseline job {self.job_id} was not found in the Kumo "
|
|
108
|
+
f"database. Please contact Kumo for further assistance. "
|
|
109
|
+
) from e
|
|
110
|
+
|
|
111
|
+
if not require_completed:
|
|
112
|
+
return resource
|
|
113
|
+
|
|
114
|
+
status = resource.job_status_report.status
|
|
115
|
+
if not status.is_terminal:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
f"Baseline job {self.job_id} has not yet completed. Please "
|
|
118
|
+
f"create a `BaselineJob` class and await its completion "
|
|
119
|
+
f"before attempting to view metrics.")
|
|
120
|
+
|
|
121
|
+
if status != JobStatus.DONE:
|
|
122
|
+
# Should never happen, the future will not resolve:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Baseline job {self.job_id} completed with status {status}, "
|
|
125
|
+
f"and was therefore unable to produce metrics. Please "
|
|
126
|
+
f"re-train the job until it successfully completes.")
|
|
127
|
+
|
|
128
|
+
self._job_resource = resource
|
|
129
|
+
return self._job_resource
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class TrainingJobResult:
|
|
133
|
+
r"""Represents a completed training job.
|
|
134
|
+
|
|
135
|
+
A :class:`TrainingJobResult` object can either be obtained by creating a
|
|
136
|
+
:class:`~kumoai.trainer.TrainingJob` object and calling the
|
|
137
|
+
:meth:`~kumoai.trainer.TrainingJob.result` method to await the job's
|
|
138
|
+
completion, or by directly creating the object. The former approach is
|
|
139
|
+
recommended, as it includes verification that the job finished succesfully.
|
|
140
|
+
|
|
141
|
+
.. code-block:: python
|
|
142
|
+
|
|
143
|
+
import kumoai
|
|
144
|
+
|
|
145
|
+
training_job = kumoai.TrainingJob("trainingjob-...")
|
|
146
|
+
|
|
147
|
+
# Wait for a training job's completion, and get its result:
|
|
148
|
+
training_job_result = training_job.result()
|
|
149
|
+
|
|
150
|
+
# Alternatively, create the result directly, but be sure that the job
|
|
151
|
+
# is completed:
|
|
152
|
+
training_job_result = kumoai.TrainingJobResult("trainingjob-...")
|
|
153
|
+
|
|
154
|
+
# Get associated objects:
|
|
155
|
+
pquery = training_job_result.predictive_query
|
|
156
|
+
training_table = training_job_result.training_table
|
|
157
|
+
|
|
158
|
+
# Get holdout data:
|
|
159
|
+
holdout_df = training_job_result.holdout_df()
|
|
160
|
+
|
|
161
|
+
Example:
|
|
162
|
+
>>> import kumoai # doctest: +SKIP
|
|
163
|
+
>>> job_future = kumoai.TrainingJob(id=...) # doctest: +SKIP
|
|
164
|
+
>>> job = job_future.result() # doctest: +SKIP
|
|
165
|
+
"""
|
|
166
|
+
def __init__(self, job_id: TrainingJobID) -> None:
|
|
167
|
+
self.job_id = job_id
|
|
168
|
+
|
|
169
|
+
# A cached completed, finalized job resource:
|
|
170
|
+
self._job_resource: Optional[TrainingJobResource] = None
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def id(self) -> TrainingJobID:
|
|
174
|
+
r"""The unique ID of this training job."""
|
|
175
|
+
return self.job_id
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def model_plan(self) -> ModelPlan:
|
|
179
|
+
r"""Returns the model plan associated with this training job."""
|
|
180
|
+
return self._get_job_resource(
|
|
181
|
+
require_completed=False).config.model_plan
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def training_table(self) -> Union['TrainingTableJob', 'TrainingTable']:
|
|
185
|
+
r"""Returns the training table associated with this training job,
|
|
186
|
+
either as a :class:`~kumoai.pquery.TrainingTable` or a
|
|
187
|
+
:class:`~kumoai.pquery.TrainingTableJob` depending on the status of
|
|
188
|
+
the training table generation job.
|
|
189
|
+
"""
|
|
190
|
+
from kumoai.pquery import TrainingTableJob
|
|
191
|
+
training_table_job_id = self._get_job_resource(
|
|
192
|
+
require_completed=False).config.train_table_job_id
|
|
193
|
+
if training_table_job_id is None:
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
f"Unable to access the training table generation job ID for "
|
|
196
|
+
f"job {self.job_id}. Did this job fail before generating its "
|
|
197
|
+
f"training table?")
|
|
198
|
+
fut = TrainingTableJob(training_table_job_id)
|
|
199
|
+
if fut.status().status == JobStatus.DONE:
|
|
200
|
+
return fut.result()
|
|
201
|
+
return fut
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def predictive_query(self) -> 'PredictiveQuery':
|
|
205
|
+
r"""Returns the :class:`~kumoai.pquery.PredictiveQuery` object that
|
|
206
|
+
defined the training table for this training job.
|
|
207
|
+
"""
|
|
208
|
+
from kumoai.pquery import PredictiveQuery
|
|
209
|
+
return PredictiveQuery.load_from_training_job(self.job_id)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def tracking_url(self) -> str:
|
|
213
|
+
r"""Returns a tracking URL pointing to the UI display of this training
|
|
214
|
+
job.
|
|
215
|
+
"""
|
|
216
|
+
tracking_url = self._get_job_resource(
|
|
217
|
+
require_completed=False).job_status_report.tracking_url
|
|
218
|
+
return _rewrite_tracking_url(tracking_url)
|
|
219
|
+
|
|
220
|
+
def metrics(self) -> ModelEvaluationMetrics:
|
|
221
|
+
r"""Returns the metrics associated with this completed training job,
|
|
222
|
+
or raises an exception if metrics cannot be obtained.
|
|
223
|
+
"""
|
|
224
|
+
return self._get_job_resource(
|
|
225
|
+
require_completed=True).result.eval_metrics
|
|
226
|
+
|
|
227
|
+
def holdout_url(self) -> str:
|
|
228
|
+
r"""Returns a URL for downloading or reading the holdout dataset.
|
|
229
|
+
|
|
230
|
+
If Kumo is deployed as a SaaS application, the returned URL will be a
|
|
231
|
+
presigned AWS S3 URL with a TTL of 1 hour. If Kumo is deployed as a
|
|
232
|
+
Snowpark Container Services application, the returned URL will be a
|
|
233
|
+
Snowflake stage path that can be directly accessed within a Snowflake
|
|
234
|
+
worksheet.
|
|
235
|
+
"""
|
|
236
|
+
api: TrainingJobAPI = global_state.client.training_job_api
|
|
237
|
+
return api.holdout_data_url(self.job_id, presigned=True)
|
|
238
|
+
|
|
239
|
+
def holdout_df(self) -> pd.DataFrame:
|
|
240
|
+
r"""Reads the holdout dataset (parquet file) as pandas DataFrame.
|
|
241
|
+
|
|
242
|
+
.. note::
|
|
243
|
+
Please note that this function may be memory-intensive, depending
|
|
244
|
+
on the size of your holdout dataframe. Please exercise caution.
|
|
245
|
+
"""
|
|
246
|
+
holdout_url = self.holdout_url()
|
|
247
|
+
|
|
248
|
+
if global_state.is_spcs:
|
|
249
|
+
from kumoai.spcs import _get_session
|
|
250
|
+
|
|
251
|
+
# TODO(dm): return type hint is wrong
|
|
252
|
+
return _get_session().read.parquet(holdout_url)
|
|
253
|
+
|
|
254
|
+
if holdout_url.startswith("dbfs:"):
|
|
255
|
+
raise ValueError(f"holdout_df is unsupported for "
|
|
256
|
+
f"Databricks UC Volume path {holdout_url}")
|
|
257
|
+
|
|
258
|
+
return pd.read_parquet(holdout_url)
|
|
259
|
+
|
|
260
|
+
def launch_online_serving_endpoint(
|
|
261
|
+
self,
|
|
262
|
+
pred_options: OnlinePredictionOptions = OnlinePredictionOptions(),
|
|
263
|
+
snapshot_id: Optional[GraphSnapshotID] = None,
|
|
264
|
+
) -> OnlineServingEndpointFuture:
|
|
265
|
+
self._get_job_resource(require_completed=True)
|
|
266
|
+
pquery = self.predictive_query
|
|
267
|
+
task_type = pquery.get_task_type()
|
|
268
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
269
|
+
if not pred_options.binary_classification_threshold:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
'Missing binary_classification_threshold option')
|
|
272
|
+
if (not task_type.is_classification
|
|
273
|
+
and task_type != TaskType.REGRESSION):
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f'{task_type} does not yet support online serving')
|
|
276
|
+
|
|
277
|
+
endpoint_id = global_state.client.online_serving_endpoint_api.create(
|
|
278
|
+
OnlineServingEndpointRequest(self.id, pred_options, snapshot_id))
|
|
279
|
+
return OnlineServingEndpointFuture(endpoint_id)
|
|
280
|
+
|
|
281
|
+
def _get_job_resource(self,
|
|
282
|
+
require_completed: bool) -> TrainingJobResource:
|
|
283
|
+
if self._job_resource:
|
|
284
|
+
return self._job_resource
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
api = global_state.client.training_job_api
|
|
288
|
+
resource: TrainingJobResource = api.get(self.job_id)
|
|
289
|
+
except Exception as e:
|
|
290
|
+
raise RuntimeError(
|
|
291
|
+
f"Training job {self.job_id} was not found in the Kumo "
|
|
292
|
+
f"database. Please contact Kumo for further assistance. "
|
|
293
|
+
) from e
|
|
294
|
+
|
|
295
|
+
if not require_completed:
|
|
296
|
+
return resource
|
|
297
|
+
|
|
298
|
+
status = resource.job_status_report.status
|
|
299
|
+
if not status.is_terminal:
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
f"Training job {self.job_id} has not yet completed. Please "
|
|
302
|
+
f"create a `TrainingJob` class and await its completion "
|
|
303
|
+
f"before attempting to view metrics.")
|
|
304
|
+
|
|
305
|
+
if status != JobStatus.DONE:
|
|
306
|
+
# Should never happen, the future will not resolve:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Training job {self.job_id} completed with status {status}, "
|
|
309
|
+
f"and was therefore unable to produce metrics. Please "
|
|
310
|
+
f"re-train the job until it successfully completes.")
|
|
311
|
+
|
|
312
|
+
self._job_resource = resource
|
|
313
|
+
return self._job_resource
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class BatchPredictionJobResult:
|
|
317
|
+
r"""Represents a completed batch prediction job.
|
|
318
|
+
|
|
319
|
+
A :class:`BatchPredictionJobResult` object can either be obtained by
|
|
320
|
+
creating a :class:`~kumoai.trainer.BatchPredictionJob` object and calling
|
|
321
|
+
the :meth:`~kumoai.trainer.BatchPredictionJob.result` method to await the
|
|
322
|
+
job's completion, or by directly creating the object. The former approach
|
|
323
|
+
is recommended, as it includes verification that the job finished
|
|
324
|
+
succesfully.
|
|
325
|
+
|
|
326
|
+
.. code-block:: python
|
|
327
|
+
|
|
328
|
+
import kumoai
|
|
329
|
+
|
|
330
|
+
prediction_job = kumoai.BatchPredictionJob("bp-job-...")
|
|
331
|
+
|
|
332
|
+
# Wait for a batch prediction job's completion, and get its result:
|
|
333
|
+
prediction_job_result = prediction_job.result()
|
|
334
|
+
|
|
335
|
+
# Alternatively, create the result directly, but be sure that the job
|
|
336
|
+
# is completed:
|
|
337
|
+
prediction_job_result = kumoai.BatchPredictionJobResult("bp-job-...")
|
|
338
|
+
|
|
339
|
+
# Get associated objects:
|
|
340
|
+
prediction_table = prediction_job_result.prediction_table
|
|
341
|
+
|
|
342
|
+
# Get prediction data (in-memory):
|
|
343
|
+
predictions_df = training_job.predictions_df()
|
|
344
|
+
|
|
345
|
+
# Export prediction data to any output connector:
|
|
346
|
+
prediction_job_result.export(
|
|
347
|
+
output_type = ...,
|
|
348
|
+
output_connector = ...,
|
|
349
|
+
output_table_name = ...,
|
|
350
|
+
)
|
|
351
|
+
""" # noqa: E501
|
|
352
|
+
|
|
353
|
+
def __init__(self, job_id: BatchPredictionJobID) -> None:
|
|
354
|
+
self.job_id = job_id
|
|
355
|
+
self._job_resource: Optional[BatchPredictionJobResource] = None
|
|
356
|
+
|
|
357
|
+
@property
|
|
358
|
+
def id(self) -> BatchPredictionJobID:
|
|
359
|
+
r"""The unique ID of this batch prediction job."""
|
|
360
|
+
return self.job_id
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def tracking_url(self) -> str:
|
|
364
|
+
r"""Returns a tracking URL pointing to the UI display of this batch
|
|
365
|
+
prediction job.
|
|
366
|
+
"""
|
|
367
|
+
tracking_url = self._get_job_resource(
|
|
368
|
+
require_completed=False).job_status_report.tracking_url
|
|
369
|
+
return _rewrite_tracking_url(tracking_url)
|
|
370
|
+
|
|
371
|
+
def summary(self) -> BatchPredictionJobSummary:
|
|
372
|
+
r"""Returns summary statistics associated with the batch prediction
|
|
373
|
+
job's output, or raises an exception if summary statistics cannot be
|
|
374
|
+
obtained.
|
|
375
|
+
"""
|
|
376
|
+
return self._get_job_resource(require_completed=True).result
|
|
377
|
+
|
|
378
|
+
@property
|
|
379
|
+
def prediction_table(
|
|
380
|
+
self) -> Union['PredictionTableJob', 'PredictionTable']:
|
|
381
|
+
r"""Returns the prediction table associated with this prediction job,
|
|
382
|
+
either as a :class:`~kumoai.pquery.PredictionTable` or a
|
|
383
|
+
:class:`~kumoai.pquery.PredictionTableJob` depending on the status
|
|
384
|
+
of the prediction table generation job.
|
|
385
|
+
"""
|
|
386
|
+
from kumoai.pquery import PredictionTableJob
|
|
387
|
+
prediction_table_job_id = self._get_job_resource(
|
|
388
|
+
require_completed=False).config.pred_table_job_id
|
|
389
|
+
if prediction_table_job_id is None:
|
|
390
|
+
raise RuntimeError(
|
|
391
|
+
f"Unable to access the prediction table generation job ID for "
|
|
392
|
+
f"job {self.job_id}. Did this job fail before generating its "
|
|
393
|
+
f"prediction table, or use a custom prediction table?")
|
|
394
|
+
fut = PredictionTableJob(prediction_table_job_id)
|
|
395
|
+
if fut.status().status == JobStatus.DONE:
|
|
396
|
+
return fut.result()
|
|
397
|
+
return fut
|
|
398
|
+
|
|
399
|
+
def export(
|
|
400
|
+
self,
|
|
401
|
+
output_config: OutputConfig,
|
|
402
|
+
non_blocking: bool = True,
|
|
403
|
+
) -> Union['ArtifactExportJob', 'ArtifactExportResult']:
|
|
404
|
+
r"""Export the prediction output or the embedding output to the
|
|
405
|
+
specific output location.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
output_config: The output configuration to be used.
|
|
409
|
+
non_blocking: If ``True``, the method will return a future object
|
|
410
|
+
`ArtifactExportJob` representing the export job.
|
|
411
|
+
If ``False``, the method will block until the export job is
|
|
412
|
+
complete and return `ArtifactExportResult`.
|
|
413
|
+
"""
|
|
414
|
+
output_table_name = to_db_table_name(output_config.output_table_name)
|
|
415
|
+
validate_output_arguments(
|
|
416
|
+
(output_config.output_types),
|
|
417
|
+
output_config.output_connector,
|
|
418
|
+
output_table_name,
|
|
419
|
+
)
|
|
420
|
+
if output_config.output_types is not None and len(
|
|
421
|
+
output_config.output_types) > 1:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
f'Each export request can only support one output_type, '
|
|
424
|
+
f'received {output_config.output_types}. If you want to make '
|
|
425
|
+
'multiple output_type exports, please make separate export() '
|
|
426
|
+
'calls.')
|
|
427
|
+
prediction_output_config = build_prediction_output_config(
|
|
428
|
+
list(output_config.output_types)[0],
|
|
429
|
+
output_config.output_connector,
|
|
430
|
+
output_table_name,
|
|
431
|
+
output_config.output_metadata_fields,
|
|
432
|
+
output_config,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
api = global_state.client.artifact_export_api
|
|
436
|
+
request = ArtifactExportRequest(
|
|
437
|
+
job_id=self.id, prediction_output=prediction_output_config)
|
|
438
|
+
job_id = api.create(request)
|
|
439
|
+
if non_blocking:
|
|
440
|
+
return ArtifactExportJob(job_id)
|
|
441
|
+
return ArtifactExportJob(job_id).attach()
|
|
442
|
+
|
|
443
|
+
def predictions_urls(self) -> List[str]:
|
|
444
|
+
r"""Returns a list of URLs for downloading or reading the predictions.
|
|
445
|
+
|
|
446
|
+
If Kumo is deployed as a SaaS application, the returned URLs will be
|
|
447
|
+
presigned AWS S3 URLs. If Kumo is deployed as a Snowpark Container
|
|
448
|
+
Services application, the returned URLs will be Snowflake stage paths
|
|
449
|
+
that can be directly accessed within a Snowflake worksheet. If Kumo is
|
|
450
|
+
deployed as a Databricks application, Databricks UC volume paths.
|
|
451
|
+
"""
|
|
452
|
+
api: BatchPredictionJobAPI = (
|
|
453
|
+
global_state.client.batch_prediction_job_api)
|
|
454
|
+
return api.get_batch_predictions_url(self.job_id)
|
|
455
|
+
|
|
456
|
+
def predictions_df(self) -> pd.DataFrame:
|
|
457
|
+
r"""Returns a :class:`~pandas.DataFrame` object representing the
|
|
458
|
+
generated predictions.
|
|
459
|
+
|
|
460
|
+
.. warning::
|
|
461
|
+
|
|
462
|
+
This method will load the full prediction output into memory as a
|
|
463
|
+
:class:`~pandas.DataFrame` object. If you are working on a machine
|
|
464
|
+
with limited resources, please use
|
|
465
|
+
:meth:`~kumoai.trainer.BatchPredictionResult.predictions_urls`
|
|
466
|
+
instead to download the data and perform analysis per-partition.
|
|
467
|
+
"""
|
|
468
|
+
urls = self.predictions_urls()
|
|
469
|
+
try:
|
|
470
|
+
return pd.concat(pd.read_parquet(x) for x in urls)
|
|
471
|
+
except Exception as e:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
f"Could not create a Pandas DataFrame object from data paths "
|
|
474
|
+
f"{urls}. Please construct the DataFrame manually.") from e
|
|
475
|
+
|
|
476
|
+
def embeddings_urls(self) -> List[str]:
|
|
477
|
+
r"""Returns a list of URLs for downloading or reading the embeddings.
|
|
478
|
+
|
|
479
|
+
If Kumo is deployed as a SaaS application, the returned URLs will be
|
|
480
|
+
presigned AWS S3 URLs. If Kumo is deployed as a Snowpark Container
|
|
481
|
+
Services application, the returned URLs will be Snowflake stage paths
|
|
482
|
+
that can be directly accessed within a Snowflake worksheet. If Kumo is
|
|
483
|
+
deployed as a Databricks application, Databricks UC volume paths.
|
|
484
|
+
"""
|
|
485
|
+
api: BatchPredictionJobAPI = (
|
|
486
|
+
global_state.client.batch_prediction_job_api)
|
|
487
|
+
return api.get_batch_embeddings_url(self.job_id)
|
|
488
|
+
|
|
489
|
+
def embeddings_df(self) -> pd.DataFrame:
|
|
490
|
+
r"""Returns a :class:`~pandas.DataFrame` object representing the
|
|
491
|
+
generated embeddings.
|
|
492
|
+
|
|
493
|
+
.. warning::
|
|
494
|
+
|
|
495
|
+
This method will load the full prediction output into memory as a
|
|
496
|
+
:class:`~pandas.DataFrame` object. If you are working on a machine
|
|
497
|
+
with limited resources, please use
|
|
498
|
+
:meth:`~kumoai.trainer.BatchPredictionResult.embeddings_urls`
|
|
499
|
+
instead to download the data and perform analysis per-partition.
|
|
500
|
+
"""
|
|
501
|
+
urls = self.embeddings_urls()
|
|
502
|
+
try:
|
|
503
|
+
return pd.concat(pd.read_parquet(x) for x in urls)
|
|
504
|
+
except Exception as e:
|
|
505
|
+
raise ValueError(
|
|
506
|
+
f"Could not create a Pandas DataFrame object from data paths "
|
|
507
|
+
f"{urls}. Please construct the DataFrame manually.") from e
|
|
508
|
+
|
|
509
|
+
def _get_job_resource(
|
|
510
|
+
self, require_completed: bool) -> BatchPredictionJobResource:
|
|
511
|
+
if self._job_resource:
|
|
512
|
+
return self._job_resource
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
api = global_state.client.batch_prediction_job_api
|
|
516
|
+
resource: BatchPredictionJobResource = api.get(self.job_id)
|
|
517
|
+
except Exception as e:
|
|
518
|
+
raise RuntimeError(
|
|
519
|
+
f"Batch prediction job {self.job_id} was not found in the "
|
|
520
|
+
f"Kumo database. Please contact Kumo for further assistance. "
|
|
521
|
+
) from e
|
|
522
|
+
|
|
523
|
+
if not require_completed:
|
|
524
|
+
return resource
|
|
525
|
+
|
|
526
|
+
status = resource.job_status_report.status
|
|
527
|
+
if not status.is_terminal:
|
|
528
|
+
raise RuntimeError(
|
|
529
|
+
f"Batch prediction job {self.job_id} has not yet completed. "
|
|
530
|
+
f"Please create a `BatchPredictionJob` class and await "
|
|
531
|
+
"its completion before attempting to view stats.")
|
|
532
|
+
|
|
533
|
+
if status != JobStatus.DONE:
|
|
534
|
+
validation_resp = resource.job_status_report.validation_response
|
|
535
|
+
validation_message = ""
|
|
536
|
+
if validation_resp:
|
|
537
|
+
validation_message = validation_resp.message()
|
|
538
|
+
if len(validation_message) > 0:
|
|
539
|
+
validation_message = f"Details: {validation_message}"
|
|
540
|
+
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"Batch prediction job {self.job_id} completed with status "
|
|
543
|
+
f"{status}, and was therefore unable to produce metrics. "
|
|
544
|
+
f"{validation_message}")
|
|
545
|
+
|
|
546
|
+
self._job_resource = resource
|
|
547
|
+
return resource
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
# Training Job Future #########################################################
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
class TrainingJob(JobInterface[TrainingJobID, TrainingJobRequest,
|
|
554
|
+
TrainingJobResource],
|
|
555
|
+
KumoProgressFuture[TrainingJobResult]):
|
|
556
|
+
r"""Represents an in-progress training job.
|
|
557
|
+
|
|
558
|
+
A :class:`TrainingJob` object can either be created as the result of
|
|
559
|
+
:meth:`~kumoai.trainer.Trainer.fit` with ``non_blocking=True``, or
|
|
560
|
+
directly with a training job ID (*e.g.* of a job created asynchronously in
|
|
561
|
+
a different environment).
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
.. code-block:: python
|
|
565
|
+
|
|
566
|
+
import kumoai
|
|
567
|
+
|
|
568
|
+
# See `Trainer` documentation:
|
|
569
|
+
trainer = kumoai.Trainer(...)
|
|
570
|
+
|
|
571
|
+
# If a Trainer is `fit` in nonblocking mode, the response will be of
|
|
572
|
+
# type `TrainingJob`:
|
|
573
|
+
training_job = trainer.fit(..., non_blocking=True)
|
|
574
|
+
|
|
575
|
+
# You can also construct a `TrainingJob` from a job ID, e.g. one that
|
|
576
|
+
# is present in the Kumo Jobs page:
|
|
577
|
+
training_job = kumoai.TrainingJob("trainingjob-...")
|
|
578
|
+
|
|
579
|
+
# Get the status of the job:
|
|
580
|
+
print(training_job.status())
|
|
581
|
+
|
|
582
|
+
# Attach to the job, and poll progress updates:
|
|
583
|
+
training_job.attach()
|
|
584
|
+
# Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917
|
|
585
|
+
|
|
586
|
+
# Cancel the job:
|
|
587
|
+
training_job.cancel()
|
|
588
|
+
|
|
589
|
+
# Wait for the job to complete, and return a `TrainingJobResult`:
|
|
590
|
+
training_job.result()
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
job_id: The training job ID to await completion of.
|
|
594
|
+
""" # noqa
|
|
595
|
+
|
|
596
|
+
@override
|
|
597
|
+
@staticmethod
|
|
598
|
+
def _api() -> TrainingJobAPI:
|
|
599
|
+
return global_state.client.training_job_api
|
|
600
|
+
|
|
601
|
+
def __init__(self, job_id: TrainingJobID) -> None:
|
|
602
|
+
self.job_id = job_id
|
|
603
|
+
|
|
604
|
+
@cached_property
|
|
605
|
+
def _fut(self) -> concurrent.futures.Future:
|
|
606
|
+
return create_future(_poll_training(self.job_id))
|
|
607
|
+
|
|
608
|
+
@override
|
|
609
|
+
@property
|
|
610
|
+
def id(self) -> TrainingJobID:
|
|
611
|
+
r"""The unique ID of this training job."""
|
|
612
|
+
return self.job_id
|
|
613
|
+
|
|
614
|
+
@override
|
|
615
|
+
def result(self) -> TrainingJobResult:
|
|
616
|
+
return self._fut.result()
|
|
617
|
+
|
|
618
|
+
@override
|
|
619
|
+
def future(self) -> 'concurrent.futures.Future[TrainingJobResult]':
|
|
620
|
+
return self._fut
|
|
621
|
+
|
|
622
|
+
@property
|
|
623
|
+
def tracking_url(self) -> str:
|
|
624
|
+
r"""Returns a tracking URL pointing to the UI that can be used to
|
|
625
|
+
monitor the status of an ongoing or completed job.
|
|
626
|
+
"""
|
|
627
|
+
return _rewrite_tracking_url(self.status().tracking_url)
|
|
628
|
+
|
|
629
|
+
@override
|
|
630
|
+
def _attach_internal(
|
|
631
|
+
self,
|
|
632
|
+
interval_s: float = 20.0,
|
|
633
|
+
) -> TrainingJobResult:
|
|
634
|
+
r"""Allows a user to attach to a running training job, and view its
|
|
635
|
+
progress inline.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
interval_s (float): Time interval (seconds) between polls, minimum
|
|
639
|
+
value allowed is 4 seconds.
|
|
640
|
+
|
|
641
|
+
Example:
|
|
642
|
+
>>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP
|
|
643
|
+
>>> job_future.attach() # doctest: +SKIP
|
|
644
|
+
Attaching to training job <id>. To track this job...
|
|
645
|
+
Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917
|
|
646
|
+
""" # noqa
|
|
647
|
+
assert interval_s >= 4.0
|
|
648
|
+
print(f"Attaching to training job {self.job_id}. To track this job in "
|
|
649
|
+
f"the Kumo UI, please visit {self.tracking_url}. To detach from "
|
|
650
|
+
f"this job, please enter Ctrl+C: the job will continue to run, "
|
|
651
|
+
f"and you can re-attach anytime by calling the `attach()` "
|
|
652
|
+
f"method on the `TrainingJob` object. For example: "
|
|
653
|
+
f"kumoai.TrainingJob(\"{self.job_id}\").attach()")
|
|
654
|
+
|
|
655
|
+
# TODO(manan): this is not perfect, the `asyncio.sleep` in the poller
|
|
656
|
+
# may cause a "DONE" status to be printed for up to
|
|
657
|
+
# interval_s*`timeout` seconds before the future resolves.
|
|
658
|
+
# That's probably fine:
|
|
659
|
+
if self.done():
|
|
660
|
+
return self.result()
|
|
661
|
+
|
|
662
|
+
# For every non-training stage, just show the stage and status:
|
|
663
|
+
print("Waiting for job to start.")
|
|
664
|
+
current_status = JobStatus.NOT_STARTED
|
|
665
|
+
while current_status == JobStatus.NOT_STARTED:
|
|
666
|
+
report = self.status()
|
|
667
|
+
current_status = report.status
|
|
668
|
+
current_stage = report.event_log[-1].stage_name
|
|
669
|
+
time.sleep(interval_s)
|
|
670
|
+
|
|
671
|
+
prev_stage = current_stage
|
|
672
|
+
print(f"Current stage: {current_stage}. In progress...", end="",
|
|
673
|
+
flush=True)
|
|
674
|
+
while not self.done():
|
|
675
|
+
# Print status of stage:
|
|
676
|
+
if current_stage != prev_stage:
|
|
677
|
+
print(" Done.")
|
|
678
|
+
print(f"Current stage: {current_stage}. In progress...",
|
|
679
|
+
end="", flush=True)
|
|
680
|
+
if current_stage == "Training":
|
|
681
|
+
_time = self.progress().estimated_training_time
|
|
682
|
+
if _time and _time != 0:
|
|
683
|
+
break
|
|
684
|
+
time.sleep(interval_s)
|
|
685
|
+
report = self.status()
|
|
686
|
+
prev_stage = current_stage
|
|
687
|
+
current_stage = report.event_log[-1].stage_name
|
|
688
|
+
|
|
689
|
+
# We are not on Training:
|
|
690
|
+
if self.done():
|
|
691
|
+
return self.result()
|
|
692
|
+
|
|
693
|
+
# We are training: print a progress bar
|
|
694
|
+
progress = self.progress()
|
|
695
|
+
bar_format = '{desc}: {percentage:3.0f}%|{bar}|{unit} '
|
|
696
|
+
total = int(progress.estimated_training_time)
|
|
697
|
+
elapsed = int(progress.elapsed_training_time)
|
|
698
|
+
pbar = tqdm(desc="Training", unit="% done", bar_format=bar_format,
|
|
699
|
+
total=total, dynamic_ncols=True)
|
|
700
|
+
pbar.update(elapsed)
|
|
701
|
+
|
|
702
|
+
while not self.done():
|
|
703
|
+
progress = self.progress()
|
|
704
|
+
trial_no = min(progress.completed_trials + 1,
|
|
705
|
+
progress.total_trials)
|
|
706
|
+
|
|
707
|
+
if f'{max(trial_no-1, 0)}' in progress.trial_progress:
|
|
708
|
+
trial_metrics = progress.trial_progress[
|
|
709
|
+
f'{max(trial_no-1, 0)}'].metrics
|
|
710
|
+
elif f'{max(trial_no-2, 0)}' in progress.trial_progress:
|
|
711
|
+
trial_metrics = progress.trial_progress[
|
|
712
|
+
f'{max(trial_no-2, 0)}'].metrics
|
|
713
|
+
else:
|
|
714
|
+
trial_metrics = {}
|
|
715
|
+
|
|
716
|
+
# If we don't have metrics, wait until we do:
|
|
717
|
+
if len(trial_metrics) == 0:
|
|
718
|
+
continue
|
|
719
|
+
|
|
720
|
+
# Show all metrics:
|
|
721
|
+
# TODO(manan): only show tune metric, trial, epoch, and loss:
|
|
722
|
+
last_epoch_metrics = trial_metrics[sorted(
|
|
723
|
+
trial_metrics.keys())[-1]]
|
|
724
|
+
train_metrics_s = ", ".join([
|
|
725
|
+
f"{key_name}={key_val:.3f}" for key_name, key_val in
|
|
726
|
+
last_epoch_metrics.train_metrics.items()
|
|
727
|
+
])
|
|
728
|
+
val_metrics_s = ", ".join([
|
|
729
|
+
f"{key_name}={key_val:.3f}" for key_name, key_val in
|
|
730
|
+
last_epoch_metrics.validation_metrics.items()
|
|
731
|
+
])
|
|
732
|
+
|
|
733
|
+
# Update numbers:
|
|
734
|
+
delta = int(progress.elapsed_training_time - pbar.n)
|
|
735
|
+
total = int(progress.estimated_training_time)
|
|
736
|
+
pbar.update(delta)
|
|
737
|
+
pbar.total = total
|
|
738
|
+
if pbar.n > pbar.total:
|
|
739
|
+
pbar.total = pbar.n
|
|
740
|
+
|
|
741
|
+
# NOTE we use `unit` here as a hack, instead of `set_postfix`,
|
|
742
|
+
# since `tqdm` defaults to adding a comma before the postfix
|
|
743
|
+
# (https://github.com/tqdm/tqdm/issues/712)
|
|
744
|
+
pbar.unit = (f"[{pbar.n}s<{pbar.total-pbar.n}s, trial={trial_no}, "
|
|
745
|
+
f"{train_metrics_s}, {val_metrics_s}]")
|
|
746
|
+
pbar.refresh()
|
|
747
|
+
time.sleep(interval_s)
|
|
748
|
+
pbar.update(pbar.total - pbar.n)
|
|
749
|
+
pbar.close()
|
|
750
|
+
|
|
751
|
+
# Future is done:
|
|
752
|
+
return self.result()
|
|
753
|
+
|
|
754
|
+
def progress(self) -> AutoTrainerProgress:
|
|
755
|
+
r"""Returns the progress of an ongoing or completed training job."""
|
|
756
|
+
return self._api().get_progress(self.job_id)
|
|
757
|
+
|
|
758
|
+
@override
|
|
759
|
+
def status(self) -> JobStatusReport:
|
|
760
|
+
r"""Returns the status of a running training job."""
|
|
761
|
+
return _get_training_status(self.job_id)
|
|
762
|
+
|
|
763
|
+
def cancel(self) -> bool:
|
|
764
|
+
r"""Cancels a running training job, and returns ``True`` if
|
|
765
|
+
cancellation succeeded.
|
|
766
|
+
|
|
767
|
+
Example:
|
|
768
|
+
>>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP
|
|
769
|
+
>>> job_future.cancel() # doctest: +SKIP
|
|
770
|
+
""" # noqa
|
|
771
|
+
return self._api().cancel(self.job_id).is_cancelled
|
|
772
|
+
|
|
773
|
+
@override
|
|
774
|
+
def load_config(self) -> TrainingJobRequest:
|
|
775
|
+
r"""Load the full configuration for this training job.
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
TrainingJobRequest: Complete configuration including model_plan,
|
|
779
|
+
pquery_id, graph_snapshot_id, train_table_job_id, etc.
|
|
780
|
+
"""
|
|
781
|
+
return self._api().get_config(self.job_id)
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def _get_training_status(job_id: str) -> JobStatusReport:
|
|
785
|
+
api = global_state.client.training_job_api
|
|
786
|
+
resource: TrainingJobResource = api.get(job_id)
|
|
787
|
+
return resource.job_status_report
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
async def _poll_training(job_id: str) -> TrainingJobResult:
|
|
791
|
+
# TODO(manan): make asynchronous natively with aiohttp:
|
|
792
|
+
status = _get_training_status(job_id).status
|
|
793
|
+
while not status.is_terminal:
|
|
794
|
+
await asyncio.sleep(10)
|
|
795
|
+
status = _get_training_status(job_id).status
|
|
796
|
+
|
|
797
|
+
# TODO(manan, siyang): improve
|
|
798
|
+
if status != JobStatus.DONE:
|
|
799
|
+
api = global_state.client.training_job_api
|
|
800
|
+
job_resource = api.get(job_id)
|
|
801
|
+
validation_resp = (job_resource.job_status_report.validation_response)
|
|
802
|
+
|
|
803
|
+
validation_message = ""
|
|
804
|
+
if validation_resp:
|
|
805
|
+
validation_message = validation_resp.message()
|
|
806
|
+
if len(validation_message) > 0:
|
|
807
|
+
validation_message = f"Details: {validation_message}"
|
|
808
|
+
|
|
809
|
+
raise RuntimeError(f"Training job {job_id} completed with job status "
|
|
810
|
+
f"{status}. {validation_message}")
|
|
811
|
+
|
|
812
|
+
# TODO(manan): improve
|
|
813
|
+
return TrainingJobResult(job_id=job_id)
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
# Batch Prediction Job Future #################################################
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
class BatchPredictionJob(JobInterface[BatchPredictionJobID,
|
|
820
|
+
BatchPredictionRequest,
|
|
821
|
+
BatchPredictionJobResource],
|
|
822
|
+
KumoProgressFuture[BatchPredictionJobResult]):
|
|
823
|
+
r"""Represents an in-progress batch prediction job.
|
|
824
|
+
|
|
825
|
+
A :class:`BatchPredictionJob` object can either be created as the
|
|
826
|
+
result of :meth:`~kumoai.trainer.Trainer.predict` with
|
|
827
|
+
``non_blocking=True``, or directly with a batch prediction job ID (*e.g.*
|
|
828
|
+
of a job created asynchronously in a different environment).
|
|
829
|
+
|
|
830
|
+
.. code-block:: python
|
|
831
|
+
|
|
832
|
+
import kumoai
|
|
833
|
+
|
|
834
|
+
# See `Trainer` documentation:
|
|
835
|
+
trainer = kumoai.Trainer(...)
|
|
836
|
+
|
|
837
|
+
# If a Trainer `predict` is called in nonblocking mode, the response
|
|
838
|
+
# will be of type `BatchPredictionJob`:
|
|
839
|
+
prediction_job = trainer.predict(..., non_blocking=True)
|
|
840
|
+
|
|
841
|
+
# You can also construct a `BatchPredictionJob` from a job ID, e.g. one
|
|
842
|
+
# that is present in the Kumo Jobs page:
|
|
843
|
+
prediction_job = kumoai.BatchPredictionJob("bp-job-...")
|
|
844
|
+
|
|
845
|
+
# Get the status of the job:
|
|
846
|
+
print(prediction_job.status())
|
|
847
|
+
|
|
848
|
+
# Attach to the job, and poll progress updates:
|
|
849
|
+
prediction_job.attach()
|
|
850
|
+
# Attaching to batch prediction job <id>. To track this job...
|
|
851
|
+
# Predicting (job_id=..., start=..., elapsed=..., status=...). Stage: ...
|
|
852
|
+
|
|
853
|
+
# Cancel the job:
|
|
854
|
+
prediction_job.cancel()
|
|
855
|
+
|
|
856
|
+
# Wait for the job to complete, and return a `BatchPredictionJobResult`:
|
|
857
|
+
prediction_job.result()
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
job_id: The batch prediction job ID to await completion of.
|
|
861
|
+
""" # noqa
|
|
862
|
+
|
|
863
|
+
@override
|
|
864
|
+
@staticmethod
|
|
865
|
+
def _api() -> BatchPredictionJobAPI:
|
|
866
|
+
return global_state.client.batch_prediction_job_api
|
|
867
|
+
|
|
868
|
+
def __init__(self, job_id: BatchPredictionJobID) -> None:
|
|
869
|
+
self.job_id = job_id
|
|
870
|
+
|
|
871
|
+
@cached_property
|
|
872
|
+
def _fut(self) -> concurrent.futures.Future:
|
|
873
|
+
return create_future(_poll_batch_prediction(self.job_id))
|
|
874
|
+
|
|
875
|
+
@override
|
|
876
|
+
@property
|
|
877
|
+
def id(self) -> BatchPredictionJobID:
|
|
878
|
+
r"""The unique ID of this batch prediction job."""
|
|
879
|
+
return self.job_id
|
|
880
|
+
|
|
881
|
+
@override
|
|
882
|
+
def result(self) -> BatchPredictionJobResult:
|
|
883
|
+
return self._fut.result()
|
|
884
|
+
|
|
885
|
+
@override
|
|
886
|
+
def future(self) -> 'concurrent.futures.Future[BatchPredictionJobResult]':
|
|
887
|
+
return self._fut
|
|
888
|
+
|
|
889
|
+
@property
|
|
890
|
+
def tracking_url(self) -> str:
|
|
891
|
+
r"""Returns a tracking URL pointing to the UI that can be used to
|
|
892
|
+
monitor the status of an ongoing or completed job.
|
|
893
|
+
"""
|
|
894
|
+
return _rewrite_tracking_url(self.status().tracking_url)
|
|
895
|
+
|
|
896
|
+
@override
|
|
897
|
+
def _attach_internal(
|
|
898
|
+
self,
|
|
899
|
+
interval_s: float = 20.0,
|
|
900
|
+
) -> BatchPredictionJobResult:
|
|
901
|
+
r"""Allows a user to attach to a running batch prediction job, and view
|
|
902
|
+
its progress inline.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
interval_s (float): Time interval (seconds) between polls, minimum
|
|
906
|
+
value allowed is 4 seconds.
|
|
907
|
+
|
|
908
|
+
"""
|
|
909
|
+
assert interval_s >= 4.0
|
|
910
|
+
print(f"Attaching to batch prediction job {self.job_id}. To track "
|
|
911
|
+
f"this job in the Kumo UI, please visit {self.tracking_url}. To "
|
|
912
|
+
f"detach from this job, please enter Ctrl+C (the job will "
|
|
913
|
+
f"continue to run, and you can re-attach anytime).")
|
|
914
|
+
# TODO(manan): this is not perfect, the `asyncio.sleep` in the poller
|
|
915
|
+
# may cause a "DONE" status to be printed for up to
|
|
916
|
+
# interval_s*`timeout` seconds before the future resolves.
|
|
917
|
+
# That's probably fine:
|
|
918
|
+
if self.done():
|
|
919
|
+
return self.result()
|
|
920
|
+
|
|
921
|
+
print("Waiting for job to start.")
|
|
922
|
+
current_status = JobStatus.NOT_STARTED
|
|
923
|
+
while current_status == JobStatus.NOT_STARTED:
|
|
924
|
+
report = self.status()
|
|
925
|
+
current_status = report.status
|
|
926
|
+
current_stage = report.event_log[-1].stage_name
|
|
927
|
+
time.sleep(interval_s)
|
|
928
|
+
|
|
929
|
+
prev_stage = current_stage
|
|
930
|
+
print(f"Current stage: {current_stage}. In progress...", end="",
|
|
931
|
+
flush=True)
|
|
932
|
+
while not self.done():
|
|
933
|
+
# Print status of stage:
|
|
934
|
+
if current_stage != prev_stage:
|
|
935
|
+
print(" Done.")
|
|
936
|
+
print(f"Current stage: {current_stage}. In progress...",
|
|
937
|
+
end="", flush=True)
|
|
938
|
+
if current_stage == "Predicting":
|
|
939
|
+
_time = self.progress().estimated_prediction_time
|
|
940
|
+
if _time and _time != 0:
|
|
941
|
+
break
|
|
942
|
+
|
|
943
|
+
time.sleep(interval_s)
|
|
944
|
+
report = self.status()
|
|
945
|
+
prev_stage = current_stage
|
|
946
|
+
current_stage = report.event_log[-1].stage_name
|
|
947
|
+
|
|
948
|
+
# We are not on Batch Prediction:
|
|
949
|
+
if self.done():
|
|
950
|
+
return self.result()
|
|
951
|
+
|
|
952
|
+
# We are predicting: print a progress bar
|
|
953
|
+
bar_format = '{desc}: {percentage:3.0f}%|{bar} '
|
|
954
|
+
total_iterations, elapsed = 0, 0
|
|
955
|
+
pbar = tqdm(desc="Predicting", unit="% done", bar_format=bar_format,
|
|
956
|
+
total=100, dynamic_ncols=True)
|
|
957
|
+
pbar.update(elapsed)
|
|
958
|
+
|
|
959
|
+
while not self.done():
|
|
960
|
+
progress = self.progress()
|
|
961
|
+
if progress is None:
|
|
962
|
+
time.sleep(interval_s)
|
|
963
|
+
continue
|
|
964
|
+
total_iterations = progress.total_iterations
|
|
965
|
+
completed_iterations = progress.completed_iterations
|
|
966
|
+
pbar.update(
|
|
967
|
+
(completed_iterations - elapsed) / total_iterations * 100)
|
|
968
|
+
elapsed = completed_iterations
|
|
969
|
+
elapsed_pct = completed_iterations / total_iterations
|
|
970
|
+
pbar.refresh()
|
|
971
|
+
time.sleep(interval_s)
|
|
972
|
+
|
|
973
|
+
pbar.update(1.0 - elapsed_pct)
|
|
974
|
+
pbar.close()
|
|
975
|
+
|
|
976
|
+
# Future is done:
|
|
977
|
+
return self.result()
|
|
978
|
+
|
|
979
|
+
def progress(self) -> PredictionProgress:
|
|
980
|
+
r"""Returns the progress of an ongoing or completed batch prediction
|
|
981
|
+
job.
|
|
982
|
+
"""
|
|
983
|
+
return self._api().get_progress(self.job_id)
|
|
984
|
+
|
|
985
|
+
@override
|
|
986
|
+
def status(self) -> JobStatusReport:
|
|
987
|
+
r"""Returns the status of a running batch prediction job."""
|
|
988
|
+
return _get_batch_prediction_status(self.job_id)
|
|
989
|
+
|
|
990
|
+
def cancel(self) -> bool:
|
|
991
|
+
r"""Cancels a running batch prediction job, and returns ``True`` if
|
|
992
|
+
cancellation succeeded.
|
|
993
|
+
"""
|
|
994
|
+
return self._api().cancel(self.job_id).is_cancelled
|
|
995
|
+
|
|
996
|
+
@override
|
|
997
|
+
def load_config(self) -> BatchPredictionRequest:
|
|
998
|
+
r"""Load the full configuration for this batch prediction job.
|
|
999
|
+
|
|
1000
|
+
Returns:
|
|
1001
|
+
BatchPredictionRequest: Complete
|
|
1002
|
+
configuration including predict_options,
|
|
1003
|
+
outputs, model_training_job_id, etc.
|
|
1004
|
+
"""
|
|
1005
|
+
return self._api().get_config(self.job_id)
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def _get_batch_prediction_job(job_id: str) -> BatchPredictionJobResource:
|
|
1009
|
+
api = global_state.client.batch_prediction_job_api
|
|
1010
|
+
return api.get(job_id)
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def _get_batch_prediction_status(job_id: str) -> JobStatusReport:
|
|
1014
|
+
api = global_state.client.batch_prediction_job_api
|
|
1015
|
+
resource: BatchPredictionJobResource = api.get(job_id)
|
|
1016
|
+
return resource.job_status_report
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
async def _poll_batch_prediction(job_id: str) -> BatchPredictionJobResult:
|
|
1020
|
+
# TODO(manan): make asynchronous natively with aiohttp:
|
|
1021
|
+
job_resource = _get_batch_prediction_job(job_id)
|
|
1022
|
+
status = job_resource.job_status_report.status
|
|
1023
|
+
while not status.is_terminal:
|
|
1024
|
+
await asyncio.sleep(10)
|
|
1025
|
+
job_resource = _get_batch_prediction_job(job_id)
|
|
1026
|
+
status = job_resource.job_status_report.status
|
|
1027
|
+
|
|
1028
|
+
# TODO(manan, siyang): improve
|
|
1029
|
+
if status != JobStatus.DONE:
|
|
1030
|
+
validation_resp = job_resource.job_status_report.validation_response
|
|
1031
|
+
validation_message = ""
|
|
1032
|
+
if validation_resp:
|
|
1033
|
+
validation_message = validation_resp.message()
|
|
1034
|
+
if len(validation_message) > 0:
|
|
1035
|
+
validation_message = f"Details: {validation_message}"
|
|
1036
|
+
|
|
1037
|
+
raise ValueError(
|
|
1038
|
+
f"Batch prediction job {job_id} completed with status "
|
|
1039
|
+
f"{status}, and was therefore unable to produce metrics. "
|
|
1040
|
+
f"{validation_message}")
|
|
1041
|
+
|
|
1042
|
+
# TODO(manan): improve
|
|
1043
|
+
return BatchPredictionJobResult(job_id=job_id)
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
# Baseline Job Future #################################################
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
class BaselineJob(JobInterface[BaselineJobID, BaselineJobRequest,
|
|
1050
|
+
BaselineJobResource],
|
|
1051
|
+
KumoProgressFuture[BaselineJobResult]):
|
|
1052
|
+
r"""Represents an in-progress baseline job.
|
|
1053
|
+
|
|
1054
|
+
A :class:`BaselineJob` object can either be created as the result of
|
|
1055
|
+
:meth:`~kumoai.trainer.BaselineTrainer.run` with ``non_blocking=True``, or
|
|
1056
|
+
directly with a baseline job ID (*e.g.* of a job created asynchronously in
|
|
1057
|
+
a different environment).
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
job_id: The baseline job ID to await completion of.
|
|
1061
|
+
|
|
1062
|
+
Example:
|
|
1063
|
+
>>> import kumoai # doctest: +SKIP
|
|
1064
|
+
>>> id = "some_baseline_job_id"
|
|
1065
|
+
>>> job_future = kumoai.BaselineJob(id) # doctest: +SKIP
|
|
1066
|
+
>>> job_future.attach() # doctest: +SKIP
|
|
1067
|
+
Attaching to baseline job <id>. To track this job...
|
|
1068
|
+
""" # noqa
|
|
1069
|
+
|
|
1070
|
+
@override
|
|
1071
|
+
@staticmethod
|
|
1072
|
+
def _api() -> BaselineJobAPI:
|
|
1073
|
+
return global_state.client.baseline_job_api
|
|
1074
|
+
|
|
1075
|
+
def __init__(self, job_id: BaselineJobID) -> None:
|
|
1076
|
+
self.job_id = job_id
|
|
1077
|
+
|
|
1078
|
+
@cached_property
|
|
1079
|
+
def _fut(self) -> concurrent.futures.Future:
|
|
1080
|
+
return create_future(_poll_baseline(self.job_id))
|
|
1081
|
+
|
|
1082
|
+
@override
|
|
1083
|
+
@property
|
|
1084
|
+
def id(self) -> BaselineJobID:
|
|
1085
|
+
r"""The unique ID of this training job."""
|
|
1086
|
+
return self.job_id
|
|
1087
|
+
|
|
1088
|
+
@override
|
|
1089
|
+
def result(self) -> BaselineJobResult:
|
|
1090
|
+
return self._fut.result()
|
|
1091
|
+
|
|
1092
|
+
@override
|
|
1093
|
+
def future(self) -> 'concurrent.futures.Future[BaselineJobResult]':
|
|
1094
|
+
return self._fut
|
|
1095
|
+
|
|
1096
|
+
@property
|
|
1097
|
+
def tracking_url(self) -> str:
|
|
1098
|
+
r"""Returns a tracking URL pointing to the UI that can be used to
|
|
1099
|
+
monitor the status of an ongoing or completed job.
|
|
1100
|
+
"""
|
|
1101
|
+
return ""
|
|
1102
|
+
|
|
1103
|
+
@override
|
|
1104
|
+
def _attach_internal(
|
|
1105
|
+
self,
|
|
1106
|
+
interval_s: float = 20.0,
|
|
1107
|
+
) -> BaselineJobResult:
|
|
1108
|
+
r"""Allows a user to attach to a running baseline job, and view its
|
|
1109
|
+
progress inline.
|
|
1110
|
+
|
|
1111
|
+
Args:
|
|
1112
|
+
interval_s (float): Time interval (seconds) between polls, minimum
|
|
1113
|
+
value allowed is 4 seconds.
|
|
1114
|
+
|
|
1115
|
+
Example:
|
|
1116
|
+
>>> job_future = kumoai.BaselineJob(job_id="...") # doctest: +SKIP
|
|
1117
|
+
>>> job_future.attach() # doctest: +SKIP
|
|
1118
|
+
Attaching to baseline job <id>. To track this job...
|
|
1119
|
+
""" # noqa
|
|
1120
|
+
assert interval_s >= 4.0
|
|
1121
|
+
print(f"Attaching to baseline job {self.job_id}."
|
|
1122
|
+
f"To detach from "
|
|
1123
|
+
f"this job, please enter Ctrl+C (the job will continue to run, "
|
|
1124
|
+
f"and you can re-attach anytime).")
|
|
1125
|
+
|
|
1126
|
+
while not self.done():
|
|
1127
|
+
report = self.status()
|
|
1128
|
+
status = report.status
|
|
1129
|
+
latest_event = report.event_log[-1]
|
|
1130
|
+
stage = latest_event.stage_name
|
|
1131
|
+
detail = ", " + latest_event.detail if latest_event.detail else ""
|
|
1132
|
+
|
|
1133
|
+
start = report.start_time
|
|
1134
|
+
now = datetime.now(timezone.utc)
|
|
1135
|
+
print(f"Baseline job (job_id={self.job_id} start={start}, elapsed="
|
|
1136
|
+
f"{now-start}, status={status}). Stage: {stage}{detail}")
|
|
1137
|
+
time.sleep(interval_s)
|
|
1138
|
+
|
|
1139
|
+
# Future is done:
|
|
1140
|
+
return self.result()
|
|
1141
|
+
|
|
1142
|
+
@override
|
|
1143
|
+
def status(self) -> JobStatusReport:
|
|
1144
|
+
r"""Returns the status of a running baseline job."""
|
|
1145
|
+
return _get_baseline_status(self.job_id)
|
|
1146
|
+
|
|
1147
|
+
@override
|
|
1148
|
+
def load_config(self) -> BaselineJobRequest:
|
|
1149
|
+
r"""Load the full configuration for this baseline job.
|
|
1150
|
+
|
|
1151
|
+
Returns:
|
|
1152
|
+
BaselineJobRequest: Complete configuration including metrics,
|
|
1153
|
+
pquery_id, graph_snapshot_id, etc.
|
|
1154
|
+
"""
|
|
1155
|
+
return self._api().get_config(self.job_id)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
def _get_baseline_status(job_id: str) -> JobStatusReport:
|
|
1159
|
+
api = global_state.client.baseline_job_api
|
|
1160
|
+
resource: BaselineJobResource = api.get(job_id)
|
|
1161
|
+
return resource.job_status_report
|
|
1162
|
+
|
|
1163
|
+
|
|
1164
|
+
async def _poll_baseline(job_id: str) -> BaselineJobResult:
|
|
1165
|
+
status = _get_baseline_status(job_id).status
|
|
1166
|
+
while not status.is_terminal:
|
|
1167
|
+
await asyncio.sleep(10)
|
|
1168
|
+
status = _get_baseline_status(job_id).status
|
|
1169
|
+
|
|
1170
|
+
if status != JobStatus.DONE:
|
|
1171
|
+
raise RuntimeError(
|
|
1172
|
+
f"Baseline job {job_id} failed with job status {status}.")
|
|
1173
|
+
|
|
1174
|
+
return BaselineJobResult(job_id=job_id)
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
def _rewrite_tracking_url(tracking_url: str) -> str:
|
|
1178
|
+
r"""Rewrites tracking URL to account for deployment subdomains."""
|
|
1179
|
+
# TODO(manan): improve...
|
|
1180
|
+
if 'http' not in tracking_url:
|
|
1181
|
+
return tracking_url
|
|
1182
|
+
parsed_base = urlparse(global_state.client._url)
|
|
1183
|
+
parsed_tracking = urlparse(tracking_url)
|
|
1184
|
+
tracking_url = urlunparse((
|
|
1185
|
+
parsed_base.scheme,
|
|
1186
|
+
parsed_base.netloc,
|
|
1187
|
+
parsed_tracking.path,
|
|
1188
|
+
parsed_tracking.params,
|
|
1189
|
+
parsed_tracking.query,
|
|
1190
|
+
parsed_tracking.fragment,
|
|
1191
|
+
))
|
|
1192
|
+
return tracking_url
|