kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/trainer/job.py ADDED
@@ -0,0 +1,1192 @@
1
+ import asyncio
2
+ import concurrent
3
+ import concurrent.futures
4
+ import time
5
+ from datetime import datetime, timezone
6
+ from functools import cached_property
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
8
+ from urllib.parse import urlparse, urlunparse
9
+
10
+ import pandas as pd
11
+ from kumoapi.common import JobStatus
12
+ from kumoapi.data_snapshot import GraphSnapshotID
13
+ from kumoapi.jobs import (
14
+ ArtifactExportRequest,
15
+ AutoTrainerProgress,
16
+ BaselineEvaluationMetrics,
17
+ BaselineJobRequest,
18
+ BatchPredictionJobSummary,
19
+ BatchPredictionRequest,
20
+ JobStatusReport,
21
+ ModelEvaluationMetrics,
22
+ PredictionProgress,
23
+ TrainingJobRequest,
24
+ )
25
+ from kumoapi.model_plan import ModelPlan
26
+ from kumoapi.online_serving import (
27
+ OnlinePredictionOptions,
28
+ OnlineServingEndpointRequest,
29
+ )
30
+ from kumoapi.task import TaskType
31
+ from tqdm.auto import tqdm
32
+ from typing_extensions import override
33
+
34
+ from kumoai import global_state
35
+ from kumoai.artifact_export import (
36
+ ArtifactExportJob,
37
+ ArtifactExportResult,
38
+ OutputConfig,
39
+ )
40
+ from kumoai.client.jobs import (
41
+ BaselineJobAPI,
42
+ BaselineJobID,
43
+ BaselineJobResource,
44
+ BatchPredictionJobAPI,
45
+ BatchPredictionJobID,
46
+ BatchPredictionJobResource,
47
+ TrainingJobAPI,
48
+ TrainingJobID,
49
+ TrainingJobResource,
50
+ )
51
+ from kumoai.databricks import to_db_table_name
52
+ from kumoai.futures import KumoProgressFuture, create_future
53
+ from kumoai.jobs import JobInterface
54
+ from kumoai.trainer.online_serving import OnlineServingEndpointFuture
55
+ from kumoai.trainer.util import (
56
+ build_prediction_output_config,
57
+ validate_output_arguments,
58
+ )
59
+
60
+ if TYPE_CHECKING:
61
+ from kumoai.pquery import (
62
+ PredictionTable,
63
+ PredictionTableJob,
64
+ PredictiveQuery,
65
+ TrainingTable,
66
+ TrainingTableJob,
67
+ )
68
+
69
+
70
+ class BaselineJobResult:
71
+ r"""Represents a completed baseline job.
72
+
73
+ A :class:`BaselineJobResult` object can either be obtained by creating a
74
+ :class:`~kumoai.trainer.BaselineJob` object and calling the
75
+ :meth:`~kumoai.trainer.BaselineJob.result` method to await the job's
76
+ completion, or by directly creating the object. The former approach is
77
+ recommended, as it includes verification that the job finished succesfully.
78
+
79
+ Example:
80
+ >>> import kumoai # doctest: +SKIP
81
+ >>> job_future = kumoai.BaselineJob(id=...) # doctest: +SKIP
82
+ >>> job = job_future.result() # doctest: +SKIP
83
+ """
84
+ def __init__(self, job_id: BaselineJobID) -> None:
85
+ self.job_id = job_id
86
+
87
+ # A cached completed, finalized job resource:
88
+ self._job_resource: Optional[BaselineJobResource] = None
89
+
90
+ def metrics(self) -> Dict[str, BaselineEvaluationMetrics]:
91
+ r"""Returns the metrics associated with this completed training job,
92
+ or raises an exception if metrics cannot be obtained.
93
+ """
94
+ return self._get_job_resource(
95
+ require_completed=True).result.eval_metrics
96
+
97
+ def _get_job_resource(self,
98
+ require_completed: bool) -> BaselineJobResource:
99
+ if self._job_resource:
100
+ return self._job_resource
101
+
102
+ try:
103
+ api = global_state.client.baseline_job_api
104
+ resource: BaselineJobResource = api.get(self.job_id)
105
+ except Exception as e:
106
+ raise RuntimeError(
107
+ f"Baseline job {self.job_id} was not found in the Kumo "
108
+ f"database. Please contact Kumo for further assistance. "
109
+ ) from e
110
+
111
+ if not require_completed:
112
+ return resource
113
+
114
+ status = resource.job_status_report.status
115
+ if not status.is_terminal:
116
+ raise RuntimeError(
117
+ f"Baseline job {self.job_id} has not yet completed. Please "
118
+ f"create a `BaselineJob` class and await its completion "
119
+ f"before attempting to view metrics.")
120
+
121
+ if status != JobStatus.DONE:
122
+ # Should never happen, the future will not resolve:
123
+ raise ValueError(
124
+ f"Baseline job {self.job_id} completed with status {status}, "
125
+ f"and was therefore unable to produce metrics. Please "
126
+ f"re-train the job until it successfully completes.")
127
+
128
+ self._job_resource = resource
129
+ return self._job_resource
130
+
131
+
132
+ class TrainingJobResult:
133
+ r"""Represents a completed training job.
134
+
135
+ A :class:`TrainingJobResult` object can either be obtained by creating a
136
+ :class:`~kumoai.trainer.TrainingJob` object and calling the
137
+ :meth:`~kumoai.trainer.TrainingJob.result` method to await the job's
138
+ completion, or by directly creating the object. The former approach is
139
+ recommended, as it includes verification that the job finished succesfully.
140
+
141
+ .. code-block:: python
142
+
143
+ import kumoai
144
+
145
+ training_job = kumoai.TrainingJob("trainingjob-...")
146
+
147
+ # Wait for a training job's completion, and get its result:
148
+ training_job_result = training_job.result()
149
+
150
+ # Alternatively, create the result directly, but be sure that the job
151
+ # is completed:
152
+ training_job_result = kumoai.TrainingJobResult("trainingjob-...")
153
+
154
+ # Get associated objects:
155
+ pquery = training_job_result.predictive_query
156
+ training_table = training_job_result.training_table
157
+
158
+ # Get holdout data:
159
+ holdout_df = training_job_result.holdout_df()
160
+
161
+ Example:
162
+ >>> import kumoai # doctest: +SKIP
163
+ >>> job_future = kumoai.TrainingJob(id=...) # doctest: +SKIP
164
+ >>> job = job_future.result() # doctest: +SKIP
165
+ """
166
+ def __init__(self, job_id: TrainingJobID) -> None:
167
+ self.job_id = job_id
168
+
169
+ # A cached completed, finalized job resource:
170
+ self._job_resource: Optional[TrainingJobResource] = None
171
+
172
+ @property
173
+ def id(self) -> TrainingJobID:
174
+ r"""The unique ID of this training job."""
175
+ return self.job_id
176
+
177
+ @property
178
+ def model_plan(self) -> ModelPlan:
179
+ r"""Returns the model plan associated with this training job."""
180
+ return self._get_job_resource(
181
+ require_completed=False).config.model_plan
182
+
183
+ @property
184
+ def training_table(self) -> Union['TrainingTableJob', 'TrainingTable']:
185
+ r"""Returns the training table associated with this training job,
186
+ either as a :class:`~kumoai.pquery.TrainingTable` or a
187
+ :class:`~kumoai.pquery.TrainingTableJob` depending on the status of
188
+ the training table generation job.
189
+ """
190
+ from kumoai.pquery import TrainingTableJob
191
+ training_table_job_id = self._get_job_resource(
192
+ require_completed=False).config.train_table_job_id
193
+ if training_table_job_id is None:
194
+ raise RuntimeError(
195
+ f"Unable to access the training table generation job ID for "
196
+ f"job {self.job_id}. Did this job fail before generating its "
197
+ f"training table?")
198
+ fut = TrainingTableJob(training_table_job_id)
199
+ if fut.status().status == JobStatus.DONE:
200
+ return fut.result()
201
+ return fut
202
+
203
+ @property
204
+ def predictive_query(self) -> 'PredictiveQuery':
205
+ r"""Returns the :class:`~kumoai.pquery.PredictiveQuery` object that
206
+ defined the training table for this training job.
207
+ """
208
+ from kumoai.pquery import PredictiveQuery
209
+ return PredictiveQuery.load_from_training_job(self.job_id)
210
+
211
+ @property
212
+ def tracking_url(self) -> str:
213
+ r"""Returns a tracking URL pointing to the UI display of this training
214
+ job.
215
+ """
216
+ tracking_url = self._get_job_resource(
217
+ require_completed=False).job_status_report.tracking_url
218
+ return _rewrite_tracking_url(tracking_url)
219
+
220
+ def metrics(self) -> ModelEvaluationMetrics:
221
+ r"""Returns the metrics associated with this completed training job,
222
+ or raises an exception if metrics cannot be obtained.
223
+ """
224
+ return self._get_job_resource(
225
+ require_completed=True).result.eval_metrics
226
+
227
+ def holdout_url(self) -> str:
228
+ r"""Returns a URL for downloading or reading the holdout dataset.
229
+
230
+ If Kumo is deployed as a SaaS application, the returned URL will be a
231
+ presigned AWS S3 URL with a TTL of 1 hour. If Kumo is deployed as a
232
+ Snowpark Container Services application, the returned URL will be a
233
+ Snowflake stage path that can be directly accessed within a Snowflake
234
+ worksheet.
235
+ """
236
+ api: TrainingJobAPI = global_state.client.training_job_api
237
+ return api.holdout_data_url(self.job_id, presigned=True)
238
+
239
+ def holdout_df(self) -> pd.DataFrame:
240
+ r"""Reads the holdout dataset (parquet file) as pandas DataFrame.
241
+
242
+ .. note::
243
+ Please note that this function may be memory-intensive, depending
244
+ on the size of your holdout dataframe. Please exercise caution.
245
+ """
246
+ holdout_url = self.holdout_url()
247
+
248
+ if global_state.is_spcs:
249
+ from kumoai.spcs import _get_session
250
+
251
+ # TODO(dm): return type hint is wrong
252
+ return _get_session().read.parquet(holdout_url)
253
+
254
+ if holdout_url.startswith("dbfs:"):
255
+ raise ValueError(f"holdout_df is unsupported for "
256
+ f"Databricks UC Volume path {holdout_url}")
257
+
258
+ return pd.read_parquet(holdout_url)
259
+
260
+ def launch_online_serving_endpoint(
261
+ self,
262
+ pred_options: OnlinePredictionOptions = OnlinePredictionOptions(),
263
+ snapshot_id: Optional[GraphSnapshotID] = None,
264
+ ) -> OnlineServingEndpointFuture:
265
+ self._get_job_resource(require_completed=True)
266
+ pquery = self.predictive_query
267
+ task_type = pquery.get_task_type()
268
+ if task_type == TaskType.BINARY_CLASSIFICATION:
269
+ if not pred_options.binary_classification_threshold:
270
+ raise ValueError(
271
+ 'Missing binary_classification_threshold option')
272
+ if (not task_type.is_classification
273
+ and task_type != TaskType.REGRESSION):
274
+ raise ValueError(
275
+ f'{task_type} does not yet support online serving')
276
+
277
+ endpoint_id = global_state.client.online_serving_endpoint_api.create(
278
+ OnlineServingEndpointRequest(self.id, pred_options, snapshot_id))
279
+ return OnlineServingEndpointFuture(endpoint_id)
280
+
281
+ def _get_job_resource(self,
282
+ require_completed: bool) -> TrainingJobResource:
283
+ if self._job_resource:
284
+ return self._job_resource
285
+
286
+ try:
287
+ api = global_state.client.training_job_api
288
+ resource: TrainingJobResource = api.get(self.job_id)
289
+ except Exception as e:
290
+ raise RuntimeError(
291
+ f"Training job {self.job_id} was not found in the Kumo "
292
+ f"database. Please contact Kumo for further assistance. "
293
+ ) from e
294
+
295
+ if not require_completed:
296
+ return resource
297
+
298
+ status = resource.job_status_report.status
299
+ if not status.is_terminal:
300
+ raise RuntimeError(
301
+ f"Training job {self.job_id} has not yet completed. Please "
302
+ f"create a `TrainingJob` class and await its completion "
303
+ f"before attempting to view metrics.")
304
+
305
+ if status != JobStatus.DONE:
306
+ # Should never happen, the future will not resolve:
307
+ raise ValueError(
308
+ f"Training job {self.job_id} completed with status {status}, "
309
+ f"and was therefore unable to produce metrics. Please "
310
+ f"re-train the job until it successfully completes.")
311
+
312
+ self._job_resource = resource
313
+ return self._job_resource
314
+
315
+
316
+ class BatchPredictionJobResult:
317
+ r"""Represents a completed batch prediction job.
318
+
319
+ A :class:`BatchPredictionJobResult` object can either be obtained by
320
+ creating a :class:`~kumoai.trainer.BatchPredictionJob` object and calling
321
+ the :meth:`~kumoai.trainer.BatchPredictionJob.result` method to await the
322
+ job's completion, or by directly creating the object. The former approach
323
+ is recommended, as it includes verification that the job finished
324
+ succesfully.
325
+
326
+ .. code-block:: python
327
+
328
+ import kumoai
329
+
330
+ prediction_job = kumoai.BatchPredictionJob("bp-job-...")
331
+
332
+ # Wait for a batch prediction job's completion, and get its result:
333
+ prediction_job_result = prediction_job.result()
334
+
335
+ # Alternatively, create the result directly, but be sure that the job
336
+ # is completed:
337
+ prediction_job_result = kumoai.BatchPredictionJobResult("bp-job-...")
338
+
339
+ # Get associated objects:
340
+ prediction_table = prediction_job_result.prediction_table
341
+
342
+ # Get prediction data (in-memory):
343
+ predictions_df = training_job.predictions_df()
344
+
345
+ # Export prediction data to any output connector:
346
+ prediction_job_result.export(
347
+ output_type = ...,
348
+ output_connector = ...,
349
+ output_table_name = ...,
350
+ )
351
+ """ # noqa: E501
352
+
353
+ def __init__(self, job_id: BatchPredictionJobID) -> None:
354
+ self.job_id = job_id
355
+ self._job_resource: Optional[BatchPredictionJobResource] = None
356
+
357
+ @property
358
+ def id(self) -> BatchPredictionJobID:
359
+ r"""The unique ID of this batch prediction job."""
360
+ return self.job_id
361
+
362
+ @property
363
+ def tracking_url(self) -> str:
364
+ r"""Returns a tracking URL pointing to the UI display of this batch
365
+ prediction job.
366
+ """
367
+ tracking_url = self._get_job_resource(
368
+ require_completed=False).job_status_report.tracking_url
369
+ return _rewrite_tracking_url(tracking_url)
370
+
371
+ def summary(self) -> BatchPredictionJobSummary:
372
+ r"""Returns summary statistics associated with the batch prediction
373
+ job's output, or raises an exception if summary statistics cannot be
374
+ obtained.
375
+ """
376
+ return self._get_job_resource(require_completed=True).result
377
+
378
+ @property
379
+ def prediction_table(
380
+ self) -> Union['PredictionTableJob', 'PredictionTable']:
381
+ r"""Returns the prediction table associated with this prediction job,
382
+ either as a :class:`~kumoai.pquery.PredictionTable` or a
383
+ :class:`~kumoai.pquery.PredictionTableJob` depending on the status
384
+ of the prediction table generation job.
385
+ """
386
+ from kumoai.pquery import PredictionTableJob
387
+ prediction_table_job_id = self._get_job_resource(
388
+ require_completed=False).config.pred_table_job_id
389
+ if prediction_table_job_id is None:
390
+ raise RuntimeError(
391
+ f"Unable to access the prediction table generation job ID for "
392
+ f"job {self.job_id}. Did this job fail before generating its "
393
+ f"prediction table, or use a custom prediction table?")
394
+ fut = PredictionTableJob(prediction_table_job_id)
395
+ if fut.status().status == JobStatus.DONE:
396
+ return fut.result()
397
+ return fut
398
+
399
+ def export(
400
+ self,
401
+ output_config: OutputConfig,
402
+ non_blocking: bool = True,
403
+ ) -> Union['ArtifactExportJob', 'ArtifactExportResult']:
404
+ r"""Export the prediction output or the embedding output to the
405
+ specific output location.
406
+
407
+ Args:
408
+ output_config: The output configuration to be used.
409
+ non_blocking: If ``True``, the method will return a future object
410
+ `ArtifactExportJob` representing the export job.
411
+ If ``False``, the method will block until the export job is
412
+ complete and return `ArtifactExportResult`.
413
+ """
414
+ output_table_name = to_db_table_name(output_config.output_table_name)
415
+ validate_output_arguments(
416
+ (output_config.output_types),
417
+ output_config.output_connector,
418
+ output_table_name,
419
+ )
420
+ if output_config.output_types is not None and len(
421
+ output_config.output_types) > 1:
422
+ raise ValueError(
423
+ f'Each export request can only support one output_type, '
424
+ f'received {output_config.output_types}. If you want to make '
425
+ 'multiple output_type exports, please make separate export() '
426
+ 'calls.')
427
+ prediction_output_config = build_prediction_output_config(
428
+ list(output_config.output_types)[0],
429
+ output_config.output_connector,
430
+ output_table_name,
431
+ output_config.output_metadata_fields,
432
+ output_config,
433
+ )
434
+
435
+ api = global_state.client.artifact_export_api
436
+ request = ArtifactExportRequest(
437
+ job_id=self.id, prediction_output=prediction_output_config)
438
+ job_id = api.create(request)
439
+ if non_blocking:
440
+ return ArtifactExportJob(job_id)
441
+ return ArtifactExportJob(job_id).attach()
442
+
443
+ def predictions_urls(self) -> List[str]:
444
+ r"""Returns a list of URLs for downloading or reading the predictions.
445
+
446
+ If Kumo is deployed as a SaaS application, the returned URLs will be
447
+ presigned AWS S3 URLs. If Kumo is deployed as a Snowpark Container
448
+ Services application, the returned URLs will be Snowflake stage paths
449
+ that can be directly accessed within a Snowflake worksheet. If Kumo is
450
+ deployed as a Databricks application, Databricks UC volume paths.
451
+ """
452
+ api: BatchPredictionJobAPI = (
453
+ global_state.client.batch_prediction_job_api)
454
+ return api.get_batch_predictions_url(self.job_id)
455
+
456
+ def predictions_df(self) -> pd.DataFrame:
457
+ r"""Returns a :class:`~pandas.DataFrame` object representing the
458
+ generated predictions.
459
+
460
+ .. warning::
461
+
462
+ This method will load the full prediction output into memory as a
463
+ :class:`~pandas.DataFrame` object. If you are working on a machine
464
+ with limited resources, please use
465
+ :meth:`~kumoai.trainer.BatchPredictionResult.predictions_urls`
466
+ instead to download the data and perform analysis per-partition.
467
+ """
468
+ urls = self.predictions_urls()
469
+ try:
470
+ return pd.concat(pd.read_parquet(x) for x in urls)
471
+ except Exception as e:
472
+ raise ValueError(
473
+ f"Could not create a Pandas DataFrame object from data paths "
474
+ f"{urls}. Please construct the DataFrame manually.") from e
475
+
476
+ def embeddings_urls(self) -> List[str]:
477
+ r"""Returns a list of URLs for downloading or reading the embeddings.
478
+
479
+ If Kumo is deployed as a SaaS application, the returned URLs will be
480
+ presigned AWS S3 URLs. If Kumo is deployed as a Snowpark Container
481
+ Services application, the returned URLs will be Snowflake stage paths
482
+ that can be directly accessed within a Snowflake worksheet. If Kumo is
483
+ deployed as a Databricks application, Databricks UC volume paths.
484
+ """
485
+ api: BatchPredictionJobAPI = (
486
+ global_state.client.batch_prediction_job_api)
487
+ return api.get_batch_embeddings_url(self.job_id)
488
+
489
+ def embeddings_df(self) -> pd.DataFrame:
490
+ r"""Returns a :class:`~pandas.DataFrame` object representing the
491
+ generated embeddings.
492
+
493
+ .. warning::
494
+
495
+ This method will load the full prediction output into memory as a
496
+ :class:`~pandas.DataFrame` object. If you are working on a machine
497
+ with limited resources, please use
498
+ :meth:`~kumoai.trainer.BatchPredictionResult.embeddings_urls`
499
+ instead to download the data and perform analysis per-partition.
500
+ """
501
+ urls = self.embeddings_urls()
502
+ try:
503
+ return pd.concat(pd.read_parquet(x) for x in urls)
504
+ except Exception as e:
505
+ raise ValueError(
506
+ f"Could not create a Pandas DataFrame object from data paths "
507
+ f"{urls}. Please construct the DataFrame manually.") from e
508
+
509
+ def _get_job_resource(
510
+ self, require_completed: bool) -> BatchPredictionJobResource:
511
+ if self._job_resource:
512
+ return self._job_resource
513
+
514
+ try:
515
+ api = global_state.client.batch_prediction_job_api
516
+ resource: BatchPredictionJobResource = api.get(self.job_id)
517
+ except Exception as e:
518
+ raise RuntimeError(
519
+ f"Batch prediction job {self.job_id} was not found in the "
520
+ f"Kumo database. Please contact Kumo for further assistance. "
521
+ ) from e
522
+
523
+ if not require_completed:
524
+ return resource
525
+
526
+ status = resource.job_status_report.status
527
+ if not status.is_terminal:
528
+ raise RuntimeError(
529
+ f"Batch prediction job {self.job_id} has not yet completed. "
530
+ f"Please create a `BatchPredictionJob` class and await "
531
+ "its completion before attempting to view stats.")
532
+
533
+ if status != JobStatus.DONE:
534
+ validation_resp = resource.job_status_report.validation_response
535
+ validation_message = ""
536
+ if validation_resp:
537
+ validation_message = validation_resp.message()
538
+ if len(validation_message) > 0:
539
+ validation_message = f"Details: {validation_message}"
540
+
541
+ raise ValueError(
542
+ f"Batch prediction job {self.job_id} completed with status "
543
+ f"{status}, and was therefore unable to produce metrics. "
544
+ f"{validation_message}")
545
+
546
+ self._job_resource = resource
547
+ return resource
548
+
549
+
550
+ # Training Job Future #########################################################
551
+
552
+
553
+ class TrainingJob(JobInterface[TrainingJobID, TrainingJobRequest,
554
+ TrainingJobResource],
555
+ KumoProgressFuture[TrainingJobResult]):
556
+ r"""Represents an in-progress training job.
557
+
558
+ A :class:`TrainingJob` object can either be created as the result of
559
+ :meth:`~kumoai.trainer.Trainer.fit` with ``non_blocking=True``, or
560
+ directly with a training job ID (*e.g.* of a job created asynchronously in
561
+ a different environment).
562
+
563
+
564
+ .. code-block:: python
565
+
566
+ import kumoai
567
+
568
+ # See `Trainer` documentation:
569
+ trainer = kumoai.Trainer(...)
570
+
571
+ # If a Trainer is `fit` in nonblocking mode, the response will be of
572
+ # type `TrainingJob`:
573
+ training_job = trainer.fit(..., non_blocking=True)
574
+
575
+ # You can also construct a `TrainingJob` from a job ID, e.g. one that
576
+ # is present in the Kumo Jobs page:
577
+ training_job = kumoai.TrainingJob("trainingjob-...")
578
+
579
+ # Get the status of the job:
580
+ print(training_job.status())
581
+
582
+ # Attach to the job, and poll progress updates:
583
+ training_job.attach()
584
+ # Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917
585
+
586
+ # Cancel the job:
587
+ training_job.cancel()
588
+
589
+ # Wait for the job to complete, and return a `TrainingJobResult`:
590
+ training_job.result()
591
+
592
+ Args:
593
+ job_id: The training job ID to await completion of.
594
+ """ # noqa
595
+
596
+ @override
597
+ @staticmethod
598
+ def _api() -> TrainingJobAPI:
599
+ return global_state.client.training_job_api
600
+
601
+ def __init__(self, job_id: TrainingJobID) -> None:
602
+ self.job_id = job_id
603
+
604
+ @cached_property
605
+ def _fut(self) -> concurrent.futures.Future:
606
+ return create_future(_poll_training(self.job_id))
607
+
608
+ @override
609
+ @property
610
+ def id(self) -> TrainingJobID:
611
+ r"""The unique ID of this training job."""
612
+ return self.job_id
613
+
614
+ @override
615
+ def result(self) -> TrainingJobResult:
616
+ return self._fut.result()
617
+
618
+ @override
619
+ def future(self) -> 'concurrent.futures.Future[TrainingJobResult]':
620
+ return self._fut
621
+
622
+ @property
623
+ def tracking_url(self) -> str:
624
+ r"""Returns a tracking URL pointing to the UI that can be used to
625
+ monitor the status of an ongoing or completed job.
626
+ """
627
+ return _rewrite_tracking_url(self.status().tracking_url)
628
+
629
+ @override
630
+ def _attach_internal(
631
+ self,
632
+ interval_s: float = 20.0,
633
+ ) -> TrainingJobResult:
634
+ r"""Allows a user to attach to a running training job, and view its
635
+ progress inline.
636
+
637
+ Args:
638
+ interval_s (float): Time interval (seconds) between polls, minimum
639
+ value allowed is 4 seconds.
640
+
641
+ Example:
642
+ >>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP
643
+ >>> job_future.attach() # doctest: +SKIP
644
+ Attaching to training job <id>. To track this job...
645
+ Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917
646
+ """ # noqa
647
+ assert interval_s >= 4.0
648
+ print(f"Attaching to training job {self.job_id}. To track this job in "
649
+ f"the Kumo UI, please visit {self.tracking_url}. To detach from "
650
+ f"this job, please enter Ctrl+C: the job will continue to run, "
651
+ f"and you can re-attach anytime by calling the `attach()` "
652
+ f"method on the `TrainingJob` object. For example: "
653
+ f"kumoai.TrainingJob(\"{self.job_id}\").attach()")
654
+
655
+ # TODO(manan): this is not perfect, the `asyncio.sleep` in the poller
656
+ # may cause a "DONE" status to be printed for up to
657
+ # interval_s*`timeout` seconds before the future resolves.
658
+ # That's probably fine:
659
+ if self.done():
660
+ return self.result()
661
+
662
+ # For every non-training stage, just show the stage and status:
663
+ print("Waiting for job to start.")
664
+ current_status = JobStatus.NOT_STARTED
665
+ while current_status == JobStatus.NOT_STARTED:
666
+ report = self.status()
667
+ current_status = report.status
668
+ current_stage = report.event_log[-1].stage_name
669
+ time.sleep(interval_s)
670
+
671
+ prev_stage = current_stage
672
+ print(f"Current stage: {current_stage}. In progress...", end="",
673
+ flush=True)
674
+ while not self.done():
675
+ # Print status of stage:
676
+ if current_stage != prev_stage:
677
+ print(" Done.")
678
+ print(f"Current stage: {current_stage}. In progress...",
679
+ end="", flush=True)
680
+ if current_stage == "Training":
681
+ _time = self.progress().estimated_training_time
682
+ if _time and _time != 0:
683
+ break
684
+ time.sleep(interval_s)
685
+ report = self.status()
686
+ prev_stage = current_stage
687
+ current_stage = report.event_log[-1].stage_name
688
+
689
+ # We are not on Training:
690
+ if self.done():
691
+ return self.result()
692
+
693
+ # We are training: print a progress bar
694
+ progress = self.progress()
695
+ bar_format = '{desc}: {percentage:3.0f}%|{bar}|{unit} '
696
+ total = int(progress.estimated_training_time)
697
+ elapsed = int(progress.elapsed_training_time)
698
+ pbar = tqdm(desc="Training", unit="% done", bar_format=bar_format,
699
+ total=total, dynamic_ncols=True)
700
+ pbar.update(elapsed)
701
+
702
+ while not self.done():
703
+ progress = self.progress()
704
+ trial_no = min(progress.completed_trials + 1,
705
+ progress.total_trials)
706
+
707
+ if f'{max(trial_no-1, 0)}' in progress.trial_progress:
708
+ trial_metrics = progress.trial_progress[
709
+ f'{max(trial_no-1, 0)}'].metrics
710
+ elif f'{max(trial_no-2, 0)}' in progress.trial_progress:
711
+ trial_metrics = progress.trial_progress[
712
+ f'{max(trial_no-2, 0)}'].metrics
713
+ else:
714
+ trial_metrics = {}
715
+
716
+ # If we don't have metrics, wait until we do:
717
+ if len(trial_metrics) == 0:
718
+ continue
719
+
720
+ # Show all metrics:
721
+ # TODO(manan): only show tune metric, trial, epoch, and loss:
722
+ last_epoch_metrics = trial_metrics[sorted(
723
+ trial_metrics.keys())[-1]]
724
+ train_metrics_s = ", ".join([
725
+ f"{key_name}={key_val:.3f}" for key_name, key_val in
726
+ last_epoch_metrics.train_metrics.items()
727
+ ])
728
+ val_metrics_s = ", ".join([
729
+ f"{key_name}={key_val:.3f}" for key_name, key_val in
730
+ last_epoch_metrics.validation_metrics.items()
731
+ ])
732
+
733
+ # Update numbers:
734
+ delta = int(progress.elapsed_training_time - pbar.n)
735
+ total = int(progress.estimated_training_time)
736
+ pbar.update(delta)
737
+ pbar.total = total
738
+ if pbar.n > pbar.total:
739
+ pbar.total = pbar.n
740
+
741
+ # NOTE we use `unit` here as a hack, instead of `set_postfix`,
742
+ # since `tqdm` defaults to adding a comma before the postfix
743
+ # (https://github.com/tqdm/tqdm/issues/712)
744
+ pbar.unit = (f"[{pbar.n}s<{pbar.total-pbar.n}s, trial={trial_no}, "
745
+ f"{train_metrics_s}, {val_metrics_s}]")
746
+ pbar.refresh()
747
+ time.sleep(interval_s)
748
+ pbar.update(pbar.total - pbar.n)
749
+ pbar.close()
750
+
751
+ # Future is done:
752
+ return self.result()
753
+
754
+ def progress(self) -> AutoTrainerProgress:
755
+ r"""Returns the progress of an ongoing or completed training job."""
756
+ return self._api().get_progress(self.job_id)
757
+
758
+ @override
759
+ def status(self) -> JobStatusReport:
760
+ r"""Returns the status of a running training job."""
761
+ return _get_training_status(self.job_id)
762
+
763
+ def cancel(self) -> bool:
764
+ r"""Cancels a running training job, and returns ``True`` if
765
+ cancellation succeeded.
766
+
767
+ Example:
768
+ >>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP
769
+ >>> job_future.cancel() # doctest: +SKIP
770
+ """ # noqa
771
+ return self._api().cancel(self.job_id).is_cancelled
772
+
773
+ @override
774
+ def load_config(self) -> TrainingJobRequest:
775
+ r"""Load the full configuration for this training job.
776
+
777
+ Returns:
778
+ TrainingJobRequest: Complete configuration including model_plan,
779
+ pquery_id, graph_snapshot_id, train_table_job_id, etc.
780
+ """
781
+ return self._api().get_config(self.job_id)
782
+
783
+
784
+ def _get_training_status(job_id: str) -> JobStatusReport:
785
+ api = global_state.client.training_job_api
786
+ resource: TrainingJobResource = api.get(job_id)
787
+ return resource.job_status_report
788
+
789
+
790
+ async def _poll_training(job_id: str) -> TrainingJobResult:
791
+ # TODO(manan): make asynchronous natively with aiohttp:
792
+ status = _get_training_status(job_id).status
793
+ while not status.is_terminal:
794
+ await asyncio.sleep(10)
795
+ status = _get_training_status(job_id).status
796
+
797
+ # TODO(manan, siyang): improve
798
+ if status != JobStatus.DONE:
799
+ api = global_state.client.training_job_api
800
+ job_resource = api.get(job_id)
801
+ validation_resp = (job_resource.job_status_report.validation_response)
802
+
803
+ validation_message = ""
804
+ if validation_resp:
805
+ validation_message = validation_resp.message()
806
+ if len(validation_message) > 0:
807
+ validation_message = f"Details: {validation_message}"
808
+
809
+ raise RuntimeError(f"Training job {job_id} completed with job status "
810
+ f"{status}. {validation_message}")
811
+
812
+ # TODO(manan): improve
813
+ return TrainingJobResult(job_id=job_id)
814
+
815
+
816
+ # Batch Prediction Job Future #################################################
817
+
818
+
819
+ class BatchPredictionJob(JobInterface[BatchPredictionJobID,
820
+ BatchPredictionRequest,
821
+ BatchPredictionJobResource],
822
+ KumoProgressFuture[BatchPredictionJobResult]):
823
+ r"""Represents an in-progress batch prediction job.
824
+
825
+ A :class:`BatchPredictionJob` object can either be created as the
826
+ result of :meth:`~kumoai.trainer.Trainer.predict` with
827
+ ``non_blocking=True``, or directly with a batch prediction job ID (*e.g.*
828
+ of a job created asynchronously in a different environment).
829
+
830
+ .. code-block:: python
831
+
832
+ import kumoai
833
+
834
+ # See `Trainer` documentation:
835
+ trainer = kumoai.Trainer(...)
836
+
837
+ # If a Trainer `predict` is called in nonblocking mode, the response
838
+ # will be of type `BatchPredictionJob`:
839
+ prediction_job = trainer.predict(..., non_blocking=True)
840
+
841
+ # You can also construct a `BatchPredictionJob` from a job ID, e.g. one
842
+ # that is present in the Kumo Jobs page:
843
+ prediction_job = kumoai.BatchPredictionJob("bp-job-...")
844
+
845
+ # Get the status of the job:
846
+ print(prediction_job.status())
847
+
848
+ # Attach to the job, and poll progress updates:
849
+ prediction_job.attach()
850
+ # Attaching to batch prediction job <id>. To track this job...
851
+ # Predicting (job_id=..., start=..., elapsed=..., status=...). Stage: ...
852
+
853
+ # Cancel the job:
854
+ prediction_job.cancel()
855
+
856
+ # Wait for the job to complete, and return a `BatchPredictionJobResult`:
857
+ prediction_job.result()
858
+
859
+ Args:
860
+ job_id: The batch prediction job ID to await completion of.
861
+ """ # noqa
862
+
863
+ @override
864
+ @staticmethod
865
+ def _api() -> BatchPredictionJobAPI:
866
+ return global_state.client.batch_prediction_job_api
867
+
868
+ def __init__(self, job_id: BatchPredictionJobID) -> None:
869
+ self.job_id = job_id
870
+
871
+ @cached_property
872
+ def _fut(self) -> concurrent.futures.Future:
873
+ return create_future(_poll_batch_prediction(self.job_id))
874
+
875
+ @override
876
+ @property
877
+ def id(self) -> BatchPredictionJobID:
878
+ r"""The unique ID of this batch prediction job."""
879
+ return self.job_id
880
+
881
+ @override
882
+ def result(self) -> BatchPredictionJobResult:
883
+ return self._fut.result()
884
+
885
+ @override
886
+ def future(self) -> 'concurrent.futures.Future[BatchPredictionJobResult]':
887
+ return self._fut
888
+
889
+ @property
890
+ def tracking_url(self) -> str:
891
+ r"""Returns a tracking URL pointing to the UI that can be used to
892
+ monitor the status of an ongoing or completed job.
893
+ """
894
+ return _rewrite_tracking_url(self.status().tracking_url)
895
+
896
+ @override
897
+ def _attach_internal(
898
+ self,
899
+ interval_s: float = 20.0,
900
+ ) -> BatchPredictionJobResult:
901
+ r"""Allows a user to attach to a running batch prediction job, and view
902
+ its progress inline.
903
+
904
+ Args:
905
+ interval_s (float): Time interval (seconds) between polls, minimum
906
+ value allowed is 4 seconds.
907
+
908
+ """
909
+ assert interval_s >= 4.0
910
+ print(f"Attaching to batch prediction job {self.job_id}. To track "
911
+ f"this job in the Kumo UI, please visit {self.tracking_url}. To "
912
+ f"detach from this job, please enter Ctrl+C (the job will "
913
+ f"continue to run, and you can re-attach anytime).")
914
+ # TODO(manan): this is not perfect, the `asyncio.sleep` in the poller
915
+ # may cause a "DONE" status to be printed for up to
916
+ # interval_s*`timeout` seconds before the future resolves.
917
+ # That's probably fine:
918
+ if self.done():
919
+ return self.result()
920
+
921
+ print("Waiting for job to start.")
922
+ current_status = JobStatus.NOT_STARTED
923
+ while current_status == JobStatus.NOT_STARTED:
924
+ report = self.status()
925
+ current_status = report.status
926
+ current_stage = report.event_log[-1].stage_name
927
+ time.sleep(interval_s)
928
+
929
+ prev_stage = current_stage
930
+ print(f"Current stage: {current_stage}. In progress...", end="",
931
+ flush=True)
932
+ while not self.done():
933
+ # Print status of stage:
934
+ if current_stage != prev_stage:
935
+ print(" Done.")
936
+ print(f"Current stage: {current_stage}. In progress...",
937
+ end="", flush=True)
938
+ if current_stage == "Predicting":
939
+ _time = self.progress().estimated_prediction_time
940
+ if _time and _time != 0:
941
+ break
942
+
943
+ time.sleep(interval_s)
944
+ report = self.status()
945
+ prev_stage = current_stage
946
+ current_stage = report.event_log[-1].stage_name
947
+
948
+ # We are not on Batch Prediction:
949
+ if self.done():
950
+ return self.result()
951
+
952
+ # We are predicting: print a progress bar
953
+ bar_format = '{desc}: {percentage:3.0f}%|{bar} '
954
+ total_iterations, elapsed = 0, 0
955
+ pbar = tqdm(desc="Predicting", unit="% done", bar_format=bar_format,
956
+ total=100, dynamic_ncols=True)
957
+ pbar.update(elapsed)
958
+
959
+ while not self.done():
960
+ progress = self.progress()
961
+ if progress is None:
962
+ time.sleep(interval_s)
963
+ continue
964
+ total_iterations = progress.total_iterations
965
+ completed_iterations = progress.completed_iterations
966
+ pbar.update(
967
+ (completed_iterations - elapsed) / total_iterations * 100)
968
+ elapsed = completed_iterations
969
+ elapsed_pct = completed_iterations / total_iterations
970
+ pbar.refresh()
971
+ time.sleep(interval_s)
972
+
973
+ pbar.update(1.0 - elapsed_pct)
974
+ pbar.close()
975
+
976
+ # Future is done:
977
+ return self.result()
978
+
979
+ def progress(self) -> PredictionProgress:
980
+ r"""Returns the progress of an ongoing or completed batch prediction
981
+ job.
982
+ """
983
+ return self._api().get_progress(self.job_id)
984
+
985
+ @override
986
+ def status(self) -> JobStatusReport:
987
+ r"""Returns the status of a running batch prediction job."""
988
+ return _get_batch_prediction_status(self.job_id)
989
+
990
+ def cancel(self) -> bool:
991
+ r"""Cancels a running batch prediction job, and returns ``True`` if
992
+ cancellation succeeded.
993
+ """
994
+ return self._api().cancel(self.job_id).is_cancelled
995
+
996
+ @override
997
+ def load_config(self) -> BatchPredictionRequest:
998
+ r"""Load the full configuration for this batch prediction job.
999
+
1000
+ Returns:
1001
+ BatchPredictionRequest: Complete
1002
+ configuration including predict_options,
1003
+ outputs, model_training_job_id, etc.
1004
+ """
1005
+ return self._api().get_config(self.job_id)
1006
+
1007
+
1008
+ def _get_batch_prediction_job(job_id: str) -> BatchPredictionJobResource:
1009
+ api = global_state.client.batch_prediction_job_api
1010
+ return api.get(job_id)
1011
+
1012
+
1013
+ def _get_batch_prediction_status(job_id: str) -> JobStatusReport:
1014
+ api = global_state.client.batch_prediction_job_api
1015
+ resource: BatchPredictionJobResource = api.get(job_id)
1016
+ return resource.job_status_report
1017
+
1018
+
1019
+ async def _poll_batch_prediction(job_id: str) -> BatchPredictionJobResult:
1020
+ # TODO(manan): make asynchronous natively with aiohttp:
1021
+ job_resource = _get_batch_prediction_job(job_id)
1022
+ status = job_resource.job_status_report.status
1023
+ while not status.is_terminal:
1024
+ await asyncio.sleep(10)
1025
+ job_resource = _get_batch_prediction_job(job_id)
1026
+ status = job_resource.job_status_report.status
1027
+
1028
+ # TODO(manan, siyang): improve
1029
+ if status != JobStatus.DONE:
1030
+ validation_resp = job_resource.job_status_report.validation_response
1031
+ validation_message = ""
1032
+ if validation_resp:
1033
+ validation_message = validation_resp.message()
1034
+ if len(validation_message) > 0:
1035
+ validation_message = f"Details: {validation_message}"
1036
+
1037
+ raise ValueError(
1038
+ f"Batch prediction job {job_id} completed with status "
1039
+ f"{status}, and was therefore unable to produce metrics. "
1040
+ f"{validation_message}")
1041
+
1042
+ # TODO(manan): improve
1043
+ return BatchPredictionJobResult(job_id=job_id)
1044
+
1045
+
1046
+ # Baseline Job Future #################################################
1047
+
1048
+
1049
+ class BaselineJob(JobInterface[BaselineJobID, BaselineJobRequest,
1050
+ BaselineJobResource],
1051
+ KumoProgressFuture[BaselineJobResult]):
1052
+ r"""Represents an in-progress baseline job.
1053
+
1054
+ A :class:`BaselineJob` object can either be created as the result of
1055
+ :meth:`~kumoai.trainer.BaselineTrainer.run` with ``non_blocking=True``, or
1056
+ directly with a baseline job ID (*e.g.* of a job created asynchronously in
1057
+ a different environment).
1058
+
1059
+ Args:
1060
+ job_id: The baseline job ID to await completion of.
1061
+
1062
+ Example:
1063
+ >>> import kumoai # doctest: +SKIP
1064
+ >>> id = "some_baseline_job_id"
1065
+ >>> job_future = kumoai.BaselineJob(id) # doctest: +SKIP
1066
+ >>> job_future.attach() # doctest: +SKIP
1067
+ Attaching to baseline job <id>. To track this job...
1068
+ """ # noqa
1069
+
1070
+ @override
1071
+ @staticmethod
1072
+ def _api() -> BaselineJobAPI:
1073
+ return global_state.client.baseline_job_api
1074
+
1075
+ def __init__(self, job_id: BaselineJobID) -> None:
1076
+ self.job_id = job_id
1077
+
1078
+ @cached_property
1079
+ def _fut(self) -> concurrent.futures.Future:
1080
+ return create_future(_poll_baseline(self.job_id))
1081
+
1082
+ @override
1083
+ @property
1084
+ def id(self) -> BaselineJobID:
1085
+ r"""The unique ID of this training job."""
1086
+ return self.job_id
1087
+
1088
+ @override
1089
+ def result(self) -> BaselineJobResult:
1090
+ return self._fut.result()
1091
+
1092
+ @override
1093
+ def future(self) -> 'concurrent.futures.Future[BaselineJobResult]':
1094
+ return self._fut
1095
+
1096
+ @property
1097
+ def tracking_url(self) -> str:
1098
+ r"""Returns a tracking URL pointing to the UI that can be used to
1099
+ monitor the status of an ongoing or completed job.
1100
+ """
1101
+ return ""
1102
+
1103
+ @override
1104
+ def _attach_internal(
1105
+ self,
1106
+ interval_s: float = 20.0,
1107
+ ) -> BaselineJobResult:
1108
+ r"""Allows a user to attach to a running baseline job, and view its
1109
+ progress inline.
1110
+
1111
+ Args:
1112
+ interval_s (float): Time interval (seconds) between polls, minimum
1113
+ value allowed is 4 seconds.
1114
+
1115
+ Example:
1116
+ >>> job_future = kumoai.BaselineJob(job_id="...") # doctest: +SKIP
1117
+ >>> job_future.attach() # doctest: +SKIP
1118
+ Attaching to baseline job <id>. To track this job...
1119
+ """ # noqa
1120
+ assert interval_s >= 4.0
1121
+ print(f"Attaching to baseline job {self.job_id}."
1122
+ f"To detach from "
1123
+ f"this job, please enter Ctrl+C (the job will continue to run, "
1124
+ f"and you can re-attach anytime).")
1125
+
1126
+ while not self.done():
1127
+ report = self.status()
1128
+ status = report.status
1129
+ latest_event = report.event_log[-1]
1130
+ stage = latest_event.stage_name
1131
+ detail = ", " + latest_event.detail if latest_event.detail else ""
1132
+
1133
+ start = report.start_time
1134
+ now = datetime.now(timezone.utc)
1135
+ print(f"Baseline job (job_id={self.job_id} start={start}, elapsed="
1136
+ f"{now-start}, status={status}). Stage: {stage}{detail}")
1137
+ time.sleep(interval_s)
1138
+
1139
+ # Future is done:
1140
+ return self.result()
1141
+
1142
+ @override
1143
+ def status(self) -> JobStatusReport:
1144
+ r"""Returns the status of a running baseline job."""
1145
+ return _get_baseline_status(self.job_id)
1146
+
1147
+ @override
1148
+ def load_config(self) -> BaselineJobRequest:
1149
+ r"""Load the full configuration for this baseline job.
1150
+
1151
+ Returns:
1152
+ BaselineJobRequest: Complete configuration including metrics,
1153
+ pquery_id, graph_snapshot_id, etc.
1154
+ """
1155
+ return self._api().get_config(self.job_id)
1156
+
1157
+
1158
+ def _get_baseline_status(job_id: str) -> JobStatusReport:
1159
+ api = global_state.client.baseline_job_api
1160
+ resource: BaselineJobResource = api.get(job_id)
1161
+ return resource.job_status_report
1162
+
1163
+
1164
+ async def _poll_baseline(job_id: str) -> BaselineJobResult:
1165
+ status = _get_baseline_status(job_id).status
1166
+ while not status.is_terminal:
1167
+ await asyncio.sleep(10)
1168
+ status = _get_baseline_status(job_id).status
1169
+
1170
+ if status != JobStatus.DONE:
1171
+ raise RuntimeError(
1172
+ f"Baseline job {job_id} failed with job status {status}.")
1173
+
1174
+ return BaselineJobResult(job_id=job_id)
1175
+
1176
+
1177
+ def _rewrite_tracking_url(tracking_url: str) -> str:
1178
+ r"""Rewrites tracking URL to account for deployment subdomains."""
1179
+ # TODO(manan): improve...
1180
+ if 'http' not in tracking_url:
1181
+ return tracking_url
1182
+ parsed_base = urlparse(global_state.client._url)
1183
+ parsed_tracking = urlparse(tracking_url)
1184
+ tracking_url = urlunparse((
1185
+ parsed_base.scheme,
1186
+ parsed_base.netloc,
1187
+ parsed_tracking.path,
1188
+ parsed_tracking.params,
1189
+ parsed_tracking.query,
1190
+ parsed_tracking.fragment,
1191
+ ))
1192
+ return tracking_url