kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import (
|
|
3
|
+
List,
|
|
4
|
+
Literal,
|
|
5
|
+
Mapping,
|
|
6
|
+
Optional,
|
|
7
|
+
Set,
|
|
8
|
+
Tuple,
|
|
9
|
+
Union,
|
|
10
|
+
overload,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from kumoapi.jobs import (
|
|
14
|
+
BatchPredictionOptions,
|
|
15
|
+
BatchPredictionRequest,
|
|
16
|
+
JobStatus,
|
|
17
|
+
MetadataField,
|
|
18
|
+
PredictionOutputConfig,
|
|
19
|
+
TrainingJobRequest,
|
|
20
|
+
TrainingJobResource,
|
|
21
|
+
)
|
|
22
|
+
from kumoapi.model_plan import ModelPlan
|
|
23
|
+
|
|
24
|
+
from kumoai import global_state
|
|
25
|
+
from kumoai.artifact_export.config import OutputConfig
|
|
26
|
+
from kumoai.client.jobs import (
|
|
27
|
+
GeneratePredictionTableJobID,
|
|
28
|
+
TrainingJobAPI,
|
|
29
|
+
TrainingJobID,
|
|
30
|
+
)
|
|
31
|
+
from kumoai.connector.base import Connector
|
|
32
|
+
from kumoai.connector.s3_connector import S3URI
|
|
33
|
+
from kumoai.databricks import to_db_table_name
|
|
34
|
+
from kumoai.graph import Graph
|
|
35
|
+
from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
|
|
36
|
+
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
|
|
37
|
+
from kumoai.trainer.job import (
|
|
38
|
+
BatchPredictionJob,
|
|
39
|
+
BatchPredictionJobResult,
|
|
40
|
+
TrainingJob,
|
|
41
|
+
TrainingJobResult,
|
|
42
|
+
)
|
|
43
|
+
from kumoai.trainer.util import (
|
|
44
|
+
build_prediction_output_config,
|
|
45
|
+
validate_output_arguments,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Trainer:
|
|
52
|
+
r"""A trainer supports creating a Kumo machine learning model on a
|
|
53
|
+
:class:`~kumoai.pquery.PredictiveQuery`. It is primarily oriented around
|
|
54
|
+
two methods: :meth:`~kumoai.trainer.Trainer.fit`, which accepts a
|
|
55
|
+
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.TrainingTable` and
|
|
56
|
+
produces a :class:`~kumoai.trainer.TrainingJobResult`, and
|
|
57
|
+
:meth:`~kumoai.trainer.Trainer.predict`, which accepts a
|
|
58
|
+
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.PredictionTable`
|
|
59
|
+
and produces a :class:`~kumoai.trainer.BatchPredictionJobResult`.
|
|
60
|
+
|
|
61
|
+
A :class:`~kumoai.trainer.Trainer` can also be loaded from a training job,
|
|
62
|
+
with :meth:`~kumoai.trainer.Trainer.load`.
|
|
63
|
+
|
|
64
|
+
.. code-block:: python
|
|
65
|
+
|
|
66
|
+
import kumoai
|
|
67
|
+
|
|
68
|
+
# See `Graph` and `PredictiveQuery` documentation:
|
|
69
|
+
graph = kumoai.Graph(...)
|
|
70
|
+
pquery = kumoai.PredictiveQuery(graph=graph, query=...)
|
|
71
|
+
|
|
72
|
+
# Create a `Trainer` object, using a suggested model plan given the
|
|
73
|
+
# predictive query:
|
|
74
|
+
model_plan = pquery.suggest_model_plan()
|
|
75
|
+
trainer = kumoai.Trainer(model_plan)
|
|
76
|
+
|
|
77
|
+
# Create a training table from the predictive query:
|
|
78
|
+
training_table = pquery.generate_training_table()
|
|
79
|
+
|
|
80
|
+
# Fit a model:
|
|
81
|
+
training_job = trainer.fit(
|
|
82
|
+
graph = graph,
|
|
83
|
+
training_table = training_table,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Create a prediction table from the predictive query:
|
|
87
|
+
prediction_table = pquery.generate_prediction_table()
|
|
88
|
+
|
|
89
|
+
# Generate predictions:
|
|
90
|
+
prediction_job = trainer.predict(
|
|
91
|
+
graph = graph,
|
|
92
|
+
prediction_table = prediction_table,
|
|
93
|
+
# other arguments here...
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Load a trained query to generate predictions:
|
|
97
|
+
pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
|
|
98
|
+
trainer = kumoai.Trainer.load("trainingjob-...")
|
|
99
|
+
prediction_job = trainer.predict(
|
|
100
|
+
pquery.graph,
|
|
101
|
+
pquery.generate_prediction_table(),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model_plan: A model plan that the trainer should follow when fitting a
|
|
106
|
+
Kumo model to a predictive query. This model plan can either be
|
|
107
|
+
generated from a predictive query, with
|
|
108
|
+
:meth:`~kumoai.pquery.PredictiveQuery.suggest_model_plan`, or can
|
|
109
|
+
be manually specified.
|
|
110
|
+
""" # noqa: E501
|
|
111
|
+
|
|
112
|
+
def __init__(self, model_plan: ModelPlan) -> None:
|
|
113
|
+
self._model_plan: Optional[ModelPlan] = model_plan
|
|
114
|
+
|
|
115
|
+
# Cached from backend:
|
|
116
|
+
self._training_job_id: Optional[TrainingJobID] = None
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def model_plan(self) -> Optional[ModelPlan]:
|
|
120
|
+
return self._model_plan
|
|
121
|
+
|
|
122
|
+
@model_plan.setter
|
|
123
|
+
def model_plan(self, model_plan: ModelPlan) -> None:
|
|
124
|
+
self._model_plan = model_plan
|
|
125
|
+
|
|
126
|
+
# Metadata ################################################################
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def is_trained(self) -> bool:
|
|
130
|
+
r"""Returns ``True`` if this trainer instance has successfully been
|
|
131
|
+
trained (and is therefore ready for prediction); ``False`` otherwise.
|
|
132
|
+
"""
|
|
133
|
+
if not self._training_job_id:
|
|
134
|
+
return False
|
|
135
|
+
try:
|
|
136
|
+
api = global_state.client.training_job_api
|
|
137
|
+
res: TrainingJobResource = api.get(self._training_job_id)
|
|
138
|
+
except Exception as e: # noqa
|
|
139
|
+
logger.exception(
|
|
140
|
+
"Failed to fetch training status for training job with ID %s",
|
|
141
|
+
self._training_job_id, exc_info=e)
|
|
142
|
+
return False
|
|
143
|
+
return res.job_status_report.status == JobStatus.DONE
|
|
144
|
+
|
|
145
|
+
# Fit / predict ###########################################################
|
|
146
|
+
|
|
147
|
+
@overload
|
|
148
|
+
def fit(
|
|
149
|
+
self,
|
|
150
|
+
graph: Graph,
|
|
151
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
152
|
+
) -> TrainingJobResult:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
@overload
|
|
156
|
+
def fit(
|
|
157
|
+
self,
|
|
158
|
+
graph: Graph,
|
|
159
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
160
|
+
*,
|
|
161
|
+
non_blocking: Literal[False],
|
|
162
|
+
) -> TrainingJobResult:
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
@overload
|
|
166
|
+
def fit(
|
|
167
|
+
self,
|
|
168
|
+
graph: Graph,
|
|
169
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
170
|
+
*,
|
|
171
|
+
non_blocking: Literal[True],
|
|
172
|
+
) -> TrainingJob:
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
def fit(
|
|
177
|
+
self,
|
|
178
|
+
graph: Graph,
|
|
179
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
180
|
+
*,
|
|
181
|
+
non_blocking: bool,
|
|
182
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
def fit(
|
|
186
|
+
self,
|
|
187
|
+
graph: Graph,
|
|
188
|
+
train_table: Union[TrainingTable, TrainingTableJob],
|
|
189
|
+
*,
|
|
190
|
+
non_blocking: bool = False,
|
|
191
|
+
custom_tags: Mapping[str, str] = {},
|
|
192
|
+
warm_start_job_id: Optional[TrainingJobID] = None,
|
|
193
|
+
) -> Union[TrainingJob, TrainingJobResult]:
|
|
194
|
+
r"""Fits a model to the specified graph and training table, with the
|
|
195
|
+
strategy defined by this :class:`Trainer`'s :obj:`model_plan`.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
graph: The :class:`~kumoai.graph.Graph` object that represents the
|
|
199
|
+
tables and relationships that Kumo will learn from.
|
|
200
|
+
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
|
|
201
|
+
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
|
|
202
|
+
represents the training data produced by a
|
|
203
|
+
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
|
|
204
|
+
non_blocking: Whether this operation should return immediately
|
|
205
|
+
after launching the training job, or await completion of the
|
|
206
|
+
training job.
|
|
207
|
+
custom_tags: Additional, customer defined k-v tags to be associated
|
|
208
|
+
with the job to be launched. Job tags are useful for grouping
|
|
209
|
+
and searching jobs.
|
|
210
|
+
warm_start_job_id: Optional job ID of a completed training job to
|
|
211
|
+
warm start from. Initializes the new model with the best
|
|
212
|
+
weights from the specified job, using its model
|
|
213
|
+
architecture, column processing, and neighbor sampling
|
|
214
|
+
configurations.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Union[TrainingJobResult, TrainingJob]:
|
|
218
|
+
If ``non_blocking=False``, returns a training job object. If
|
|
219
|
+
``non_blocking=True``, returns a training job future object.
|
|
220
|
+
"""
|
|
221
|
+
# TODO(manan, siyang): remove soon:
|
|
222
|
+
job_id = train_table.job_id
|
|
223
|
+
assert job_id is not None
|
|
224
|
+
|
|
225
|
+
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
226
|
+
pq_id = train_table_job_api.get(job_id).config.pquery_id
|
|
227
|
+
assert pq_id is not None
|
|
228
|
+
|
|
229
|
+
custom_table = None
|
|
230
|
+
if isinstance(train_table, TrainingTable):
|
|
231
|
+
custom_table = train_table._custom_train_table
|
|
232
|
+
|
|
233
|
+
# NOTE the backend implementation currently handles sequentialization
|
|
234
|
+
# between a training table future and a training job; that is, if the
|
|
235
|
+
# training table future is still executing, the backend will wait on
|
|
236
|
+
# the job ID completion before executing a training job. This preserves
|
|
237
|
+
# semantics for both futures, ensures that Kumo works as expected if
|
|
238
|
+
# used only via REST API, and allows us to avoid chaining calllbacks
|
|
239
|
+
# in an ugly way here:
|
|
240
|
+
api = global_state.client.training_job_api
|
|
241
|
+
self._training_job_id = api.create(
|
|
242
|
+
TrainingJobRequest(
|
|
243
|
+
dict(custom_tags),
|
|
244
|
+
pquery_id=pq_id,
|
|
245
|
+
model_plan=self._model_plan,
|
|
246
|
+
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
|
|
247
|
+
train_table_job_id=job_id,
|
|
248
|
+
custom_train_table=custom_table,
|
|
249
|
+
warm_start_job_id=warm_start_job_id,
|
|
250
|
+
))
|
|
251
|
+
|
|
252
|
+
out = TrainingJob(job_id=self._training_job_id)
|
|
253
|
+
if non_blocking:
|
|
254
|
+
return out
|
|
255
|
+
return out.attach()
|
|
256
|
+
|
|
257
|
+
def predict(
|
|
258
|
+
self,
|
|
259
|
+
graph: Graph,
|
|
260
|
+
prediction_table: Union[PredictionTable, PredictionTableJob],
|
|
261
|
+
output_types: Optional[Set[str]] = None,
|
|
262
|
+
output_connector: Optional[Connector] = None,
|
|
263
|
+
output_table_name: Optional[Union[str, Tuple]] = None,
|
|
264
|
+
output_metadata_fields: Optional[List[MetadataField]] = None,
|
|
265
|
+
output_config: Optional[OutputConfig] = None,
|
|
266
|
+
training_job_id: Optional[TrainingJobID] = None,
|
|
267
|
+
binary_classification_threshold: Optional[float] = None,
|
|
268
|
+
num_classes_to_return: Optional[int] = None,
|
|
269
|
+
num_workers: int = 1,
|
|
270
|
+
non_blocking: bool = False,
|
|
271
|
+
custom_tags: Mapping[str, str] = {},
|
|
272
|
+
) -> Union[BatchPredictionJob, BatchPredictionJobResult]:
|
|
273
|
+
"""Using the trained model specified by :obj:`training_job_id` (or
|
|
274
|
+
the model corresponding to the last invocation of
|
|
275
|
+
:meth:`~kumoai.trainer.Trainer.fit`, if not present), predicts the
|
|
276
|
+
future values of the entities in :obj:`prediction_table` leveraging
|
|
277
|
+
current information from :obj:`graph`.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
graph: The :class:`~kumoai.graph.Graph` object that represents the
|
|
281
|
+
tables and relationships that Kumo will use to make
|
|
282
|
+
predictions.
|
|
283
|
+
prediction_table: The :class:`~kumoai.pquery.PredictionTable`, or
|
|
284
|
+
in-progress :class:`~kumoai.pquery.PredictionTableJob`, that
|
|
285
|
+
represents the prediction data produced by a
|
|
286
|
+
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`. This
|
|
287
|
+
table may also be custom-generated.
|
|
288
|
+
output_config: Output configuration defining properties of the
|
|
289
|
+
generated batch prediction outputs. This includes:
|
|
290
|
+
|
|
291
|
+
- output_types: The types of outputs that should be produced by
|
|
292
|
+
the prediction job. Can include either ``'predictions'``,
|
|
293
|
+
``'embeddings'``, or both.
|
|
294
|
+
- output_connector: The output data source that Kumo should
|
|
295
|
+
write batch predictions to, if it is None, produce
|
|
296
|
+
local download output only.
|
|
297
|
+
- output_table_name: The name of the table in the output data
|
|
298
|
+
source that Kumo should write batch predictions to. In the
|
|
299
|
+
case of a Databricks connector, this should be a tuple of
|
|
300
|
+
two strings, the schema name and the output prediction
|
|
301
|
+
table name.
|
|
302
|
+
- output_metadata_fields: Any additional metadata fields to
|
|
303
|
+
include as new columns in the produced ``'predictions'``
|
|
304
|
+
output. Currently, allowed options are ``JOB_TIMESTAMP``
|
|
305
|
+
and ``ANCHOR_TIMESTAMP``.
|
|
306
|
+
- connector_specific_config: Custom connector specific output
|
|
307
|
+
configuration, such as whether to append or overwrite
|
|
308
|
+
existing tables.
|
|
309
|
+
|
|
310
|
+
output_types: *(Deprecated)* The types of outputs that should be
|
|
311
|
+
produced by the prediction job. Can include either
|
|
312
|
+
``'predictions'``, ``'embeddings'``, or both. Use
|
|
313
|
+
:obj:`output_config` instead.
|
|
314
|
+
output_connector: *(Deprecated)* The output data source that Kumo
|
|
315
|
+
should write batch predictions to, if it is None, produce local
|
|
316
|
+
download output only. Use :obj:`output_config` instead.
|
|
317
|
+
output_table_name: *(Deprecated)* The name of the table in the
|
|
318
|
+
output data source that Kumo should write batch predictions to.
|
|
319
|
+
In the case of a Databricks connector, this should be a tuple
|
|
320
|
+
of two strings: the schema name and the output prediction
|
|
321
|
+
table name. Use :obj:`output_config` instead.
|
|
322
|
+
output_metadata_fields: *(Deprecated)* Any additional metadata
|
|
323
|
+
fields to include as new columns in the produced
|
|
324
|
+
``'predictions'`` output. Currently, allowed options are
|
|
325
|
+
``JOB_TIMESTAMP`` and ``ANCHOR_TIMESTAMP``. Use
|
|
326
|
+
:obj:`output_config` instead.
|
|
327
|
+
training_job_id: The job ID of the training job whose model will be
|
|
328
|
+
used for making predictions.
|
|
329
|
+
binary_classification_threshold: If this model corresponds to
|
|
330
|
+
a binary classification task, the threshold for which higher
|
|
331
|
+
values correspond to ``1``, and lower values correspond to
|
|
332
|
+
``0``.
|
|
333
|
+
num_classes_to_return: If this model corresponds to a ranking task,
|
|
334
|
+
the number of classes to return in the prediction output.
|
|
335
|
+
num_workers: Number of workers to use when generating batch
|
|
336
|
+
predictions. When set to a value greater than 1, the prediction
|
|
337
|
+
table is partitioned into smaller parts and processed in
|
|
338
|
+
parallel. The default is 1, which implies sequential inference
|
|
339
|
+
over the prediction table.
|
|
340
|
+
non_blocking: Whether this operation should return immediately
|
|
341
|
+
after launching the batch prediction job, or await
|
|
342
|
+
completion of the batch prediction job.
|
|
343
|
+
custom_tags: Additional, customer defined k-v tags to be associated
|
|
344
|
+
with the job to be launched. Job tags are useful for grouping
|
|
345
|
+
and searching jobs.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Union[BatchPredictionJob, BatchPredictionJobResult]:
|
|
349
|
+
If ``non_blocking=False``, returns a training job object. If
|
|
350
|
+
``non_blocking=True``, returns a training job future object.
|
|
351
|
+
"""
|
|
352
|
+
if (output_types is not None or output_connector is not None
|
|
353
|
+
or output_table_name is not None
|
|
354
|
+
or output_metadata_fields is not None):
|
|
355
|
+
raise ValueError(
|
|
356
|
+
'Specifying output_types, output_connector, '
|
|
357
|
+
'output_metadata_fields '
|
|
358
|
+
'and output_table_name as direct inputs to predict() is '
|
|
359
|
+
'deprecated. Please use output_config to specify these '
|
|
360
|
+
'parameters.')
|
|
361
|
+
assert output_config is not None
|
|
362
|
+
# Be able to pass output_config as a dictionary
|
|
363
|
+
if isinstance(output_config, dict):
|
|
364
|
+
output_config = OutputConfig(**output_config)
|
|
365
|
+
output_table_name = to_db_table_name(output_config.output_table_name)
|
|
366
|
+
validate_output_arguments(
|
|
367
|
+
output_config.output_types,
|
|
368
|
+
output_config.output_connector,
|
|
369
|
+
output_table_name,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Create outputs:
|
|
373
|
+
outputs: List[PredictionOutputConfig] = []
|
|
374
|
+
for output_type in output_config.output_types:
|
|
375
|
+
if output_config.output_connector is None:
|
|
376
|
+
# Predictions are generated to the Kumo dataplane, and can
|
|
377
|
+
# only be exported via the UI for now:
|
|
378
|
+
pass
|
|
379
|
+
else:
|
|
380
|
+
outputs.append(
|
|
381
|
+
build_prediction_output_config(
|
|
382
|
+
output_type,
|
|
383
|
+
output_config.output_connector,
|
|
384
|
+
output_table_name,
|
|
385
|
+
output_config.output_metadata_fields,
|
|
386
|
+
output_config,
|
|
387
|
+
))
|
|
388
|
+
|
|
389
|
+
training_job_id = training_job_id or self._training_job_id
|
|
390
|
+
if training_job_id is None:
|
|
391
|
+
raise ValueError(
|
|
392
|
+
"Cannot run batch prediction without a specified or saved "
|
|
393
|
+
"training job ID. Please either call `fit(...)` or pass a "
|
|
394
|
+
"job ID of a completed training job to proceed.")
|
|
395
|
+
|
|
396
|
+
pred_table_job_id: Optional[GeneratePredictionTableJobID] = \
|
|
397
|
+
prediction_table.job_id
|
|
398
|
+
pred_table_data_path = None
|
|
399
|
+
if pred_table_job_id is None:
|
|
400
|
+
assert isinstance(prediction_table, PredictionTable)
|
|
401
|
+
if isinstance(prediction_table.table_data_uri, S3URI):
|
|
402
|
+
pred_table_data_path = prediction_table.table_data_uri.uri
|
|
403
|
+
else:
|
|
404
|
+
pred_table_data_path = prediction_table.table_data_uri
|
|
405
|
+
|
|
406
|
+
api = global_state.client.batch_prediction_job_api
|
|
407
|
+
# Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
|
|
408
|
+
# from kumoai.pquery.predictive_query import PredictiveQuery
|
|
409
|
+
# pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
410
|
+
# if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
411
|
+
# if binary_classification_threshold is None:
|
|
412
|
+
# logger.warning(
|
|
413
|
+
# "No binary classification threshold provided. "
|
|
414
|
+
# "Using default threshold of 0.5.")
|
|
415
|
+
# binary_classification_threshold = 0.5
|
|
416
|
+
job_id, response = api.maybe_create(
|
|
417
|
+
BatchPredictionRequest(
|
|
418
|
+
dict(custom_tags),
|
|
419
|
+
model_training_job_id=training_job_id,
|
|
420
|
+
predict_options=BatchPredictionOptions(
|
|
421
|
+
binary_classification_threshold=(
|
|
422
|
+
binary_classification_threshold),
|
|
423
|
+
num_classes_to_return=num_classes_to_return,
|
|
424
|
+
num_workers=num_workers,
|
|
425
|
+
),
|
|
426
|
+
outputs=outputs,
|
|
427
|
+
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
|
|
428
|
+
pred_table_job_id=pred_table_job_id,
|
|
429
|
+
pred_table_path=pred_table_data_path,
|
|
430
|
+
))
|
|
431
|
+
|
|
432
|
+
message = response.message()
|
|
433
|
+
if not response.ok:
|
|
434
|
+
raise RuntimeError(f"Prediction failed. {message}")
|
|
435
|
+
elif not response.empty():
|
|
436
|
+
logger.warning("Prediction produced the following warnings: %s",
|
|
437
|
+
message)
|
|
438
|
+
assert job_id is not None
|
|
439
|
+
|
|
440
|
+
self._batch_prediction_job_id = job_id
|
|
441
|
+
out = BatchPredictionJob(job_id=self._batch_prediction_job_id)
|
|
442
|
+
if non_blocking:
|
|
443
|
+
return out
|
|
444
|
+
return out.attach()
|
|
445
|
+
|
|
446
|
+
# Persistence #############################################################
|
|
447
|
+
|
|
448
|
+
@classmethod
|
|
449
|
+
def _load_from_job(cls, job: TrainingJobResource) -> 'Trainer':
|
|
450
|
+
trainer = cls(job.config.model_plan)
|
|
451
|
+
trainer._training_job_id = job.job_id
|
|
452
|
+
return trainer
|
|
453
|
+
|
|
454
|
+
@classmethod
|
|
455
|
+
def load(cls, job_id: TrainingJobID) -> 'Trainer':
|
|
456
|
+
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
|
|
457
|
+
job ID.
|
|
458
|
+
"""
|
|
459
|
+
api: TrainingJobAPI = global_state.client.training_job_api
|
|
460
|
+
job = api.get(job_id)
|
|
461
|
+
return cls._load_from_job(job)
|
|
462
|
+
|
|
463
|
+
# TODO(siyang): load trainer by searching training job via tags.
|
|
464
|
+
@classmethod
|
|
465
|
+
def load_from_tags(cls, tags: Mapping[str, str]) -> 'Trainer':
|
|
466
|
+
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a set of
|
|
467
|
+
job tags. If multiple jobs match the list of tags, only one will be
|
|
468
|
+
selected.
|
|
469
|
+
"""
|
|
470
|
+
api = global_state.client.training_job_api
|
|
471
|
+
jobs = api.list(limit=1, additional_tags=tags)
|
|
472
|
+
if not jobs:
|
|
473
|
+
raise RuntimeError(f'No successful training job found for {tags}')
|
|
474
|
+
assert len(jobs) == 1
|
|
475
|
+
return cls._load_from_job(jobs[0])
|
kumoai/trainer/util.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Optional, Set, Tuple, Union
|
|
3
|
+
|
|
4
|
+
from kumoapi.jobs import (
|
|
5
|
+
BigQueryPredictionOutput,
|
|
6
|
+
DatabricksPredictionOutput,
|
|
7
|
+
MetadataField,
|
|
8
|
+
PredictionArtifactType,
|
|
9
|
+
PredictionOutputConfig,
|
|
10
|
+
PredictionStorageType,
|
|
11
|
+
S3PredictionOutput,
|
|
12
|
+
SnowflakePredictionOutput,
|
|
13
|
+
WriteMode,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from kumoai.artifact_export.config import OutputConfig
|
|
17
|
+
from kumoai.connector import (
|
|
18
|
+
BigQueryConnector,
|
|
19
|
+
Connector,
|
|
20
|
+
DatabricksConnector,
|
|
21
|
+
S3Connector,
|
|
22
|
+
SnowflakeConnector,
|
|
23
|
+
)
|
|
24
|
+
from kumoai.databricks import DB_SEP
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def validate_output_arguments(
|
|
28
|
+
output_types: Set[str],
|
|
29
|
+
output_connector: Optional[Connector] = None,
|
|
30
|
+
output_table_name: Optional[str] = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
r"""Validate the output arguments for a prediction job or an export job."""
|
|
33
|
+
output_types = {x.lower() for x in output_types}
|
|
34
|
+
assert output_types.issubset({'predictions', 'embeddings'})
|
|
35
|
+
if output_connector is not None:
|
|
36
|
+
assert output_table_name is not None
|
|
37
|
+
if not isinstance(output_connector,
|
|
38
|
+
(S3Connector, SnowflakeConnector,
|
|
39
|
+
DatabricksConnector, BigQueryConnector)):
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Connector type {type(output_connector)} is not supported for"
|
|
42
|
+
f" outputs. Supported output connector types are S3, "
|
|
43
|
+
f"Snowflake, Databricks, and BigQuery.")
|
|
44
|
+
if not isinstance(output_table_name, str):
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"The output table name must be a string for all "
|
|
47
|
+
f"non-Databricks connectors. Got '{output_table_name}'.")
|
|
48
|
+
|
|
49
|
+
if isinstance(output_connector, S3Connector):
|
|
50
|
+
assert output_connector.root_dir is not None
|
|
51
|
+
if isinstance(output_connector, DatabricksConnector):
|
|
52
|
+
assert DB_SEP in output_table_name
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def build_prediction_output_config(
|
|
56
|
+
output_type: str,
|
|
57
|
+
output_connector: Optional[Connector] = None,
|
|
58
|
+
output_table_name: Optional[Union[str, Tuple]] = None,
|
|
59
|
+
output_metadata_fields: Optional[List[MetadataField]] = None,
|
|
60
|
+
output_config: Optional[OutputConfig] = None,
|
|
61
|
+
) -> PredictionOutputConfig:
|
|
62
|
+
r"""Build the prediction output config."""
|
|
63
|
+
assert output_config is not None
|
|
64
|
+
artifact_type = PredictionArtifactType(output_type.upper())
|
|
65
|
+
output_name = f"{output_table_name}_{output_type}"
|
|
66
|
+
output_metadata_fields = output_metadata_fields or []
|
|
67
|
+
if isinstance(output_connector, S3Connector):
|
|
68
|
+
assert output_connector.root_dir is not None
|
|
69
|
+
return S3PredictionOutput(
|
|
70
|
+
artifact_type=artifact_type,
|
|
71
|
+
file_path=os.path.join(output_connector.root_dir, output_name),
|
|
72
|
+
extra_fields=output_metadata_fields,
|
|
73
|
+
)
|
|
74
|
+
elif isinstance(output_connector, SnowflakeConnector):
|
|
75
|
+
return SnowflakePredictionOutput(
|
|
76
|
+
artifact_type=artifact_type,
|
|
77
|
+
connector_id=output_connector.name,
|
|
78
|
+
table_name=output_name,
|
|
79
|
+
extra_fields=output_metadata_fields,
|
|
80
|
+
write_mode=output_config.connector_specific_config.write_mode
|
|
81
|
+
if output_config.connector_specific_config is not None else
|
|
82
|
+
WriteMode.OVERWRITE,
|
|
83
|
+
)
|
|
84
|
+
elif isinstance(output_connector, DatabricksConnector):
|
|
85
|
+
return DatabricksPredictionOutput(
|
|
86
|
+
artifact_type=artifact_type,
|
|
87
|
+
connector_id=output_connector.name,
|
|
88
|
+
table_name=output_name,
|
|
89
|
+
extra_fields=output_metadata_fields,
|
|
90
|
+
)
|
|
91
|
+
elif isinstance(output_connector, BigQueryConnector):
|
|
92
|
+
return BigQueryPredictionOutput(
|
|
93
|
+
storage_type=PredictionStorageType.BIGQUERY,
|
|
94
|
+
artifact_type=artifact_type,
|
|
95
|
+
connector_id=output_connector.name,
|
|
96
|
+
table_name=output_name,
|
|
97
|
+
extra_fields=output_metadata_fields,
|
|
98
|
+
write_mode=output_config.connector_specific_config.write_mode
|
|
99
|
+
if output_config.connector_specific_config is not None else
|
|
100
|
+
WriteMode.OVERWRITE,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
raise NotImplementedError()
|
kumoai/utils/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .sql import quote_ident
|
|
2
|
+
from .progress_logger import ProgressLogger
|
|
3
|
+
from .forecasting import ForecastVisualizer
|
|
4
|
+
from .datasets import from_relbench
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'quote_ident',
|
|
8
|
+
'ProgressLogger',
|
|
9
|
+
'ForecastVisualizer',
|
|
10
|
+
'from_relbench',
|
|
11
|
+
]
|
kumoai/utils/datasets.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from kumoai.connector import FileUploadConnector
|
|
2
|
+
from kumoai.connector.utils import replace_table
|
|
3
|
+
from kumoai.graph import Edge, Graph, Table
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def from_relbench(dataset_name: str) -> Graph:
|
|
7
|
+
r"""Creates a Kumo graph from a RelBench dataset. This function processes
|
|
8
|
+
the specified RelBench dataset, uploads its tables to the Kumo data plane,
|
|
9
|
+
and registers them as part of a Kumo graph, including inferred table
|
|
10
|
+
metadata and edges.
|
|
11
|
+
|
|
12
|
+
.. note::
|
|
13
|
+
|
|
14
|
+
Please note that this method is subject to the limitations for file
|
|
15
|
+
upload in :class:`~kumoai.connector.FileUploadConnector`.
|
|
16
|
+
|
|
17
|
+
.. code-block:: python
|
|
18
|
+
|
|
19
|
+
import kumoai
|
|
20
|
+
from kumoai.utils.datasets import from_relbench
|
|
21
|
+
|
|
22
|
+
# Assume dataset `example_dataset` in the RelBench repository:
|
|
23
|
+
graph = from_relbench(dataset_name="example_dataset")
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
dataset_name: The name of the RelBench dataset to be processed.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
A :class:`~kumoai.Graph` object containing the tables and edges
|
|
30
|
+
derived from the RelBench dataset.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If the dataset cannot be retrieved or processed.
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
import relbench
|
|
37
|
+
except ImportError:
|
|
38
|
+
raise RuntimeError(
|
|
39
|
+
"Creating a Kumo Graph from a RelBench dataset requires the "
|
|
40
|
+
"'relbench' package to be installed. Please install the package "
|
|
41
|
+
"before proceeding.")
|
|
42
|
+
|
|
43
|
+
connector = FileUploadConnector(file_type="parquet")
|
|
44
|
+
dataset = relbench.datasets.get_dataset(dataset_name, download=True)
|
|
45
|
+
db = dataset.get_db(upto_test_timestamp=False)
|
|
46
|
+
|
|
47
|
+
# Store table metadata and edges:
|
|
48
|
+
table_metadata = {}
|
|
49
|
+
|
|
50
|
+
# Process each table in the database
|
|
51
|
+
for table_key in db.table_dict.keys():
|
|
52
|
+
# Save the table locally as a parquet file:
|
|
53
|
+
table = db.table_dict[table_key]
|
|
54
|
+
parquet_path = f"tmp_{table_key}.parquet"
|
|
55
|
+
table.df.to_parquet(parquet_path, index=False)
|
|
56
|
+
|
|
57
|
+
# Replace the table on the Kumo data plane:
|
|
58
|
+
replace_table(name=table_key, path=parquet_path, file_type="parquet")
|
|
59
|
+
|
|
60
|
+
# Register the table with inferred metadata and collect edge
|
|
61
|
+
# information:
|
|
62
|
+
table_metadata[table_key] = dict(
|
|
63
|
+
table=Table.from_source_table(
|
|
64
|
+
source_table=connector[table_key],
|
|
65
|
+
primary_key=table.pkey_col,
|
|
66
|
+
time_column=table.time_col,
|
|
67
|
+
).infer_metadata(), edges=table.fkey_col_to_pkey_table)
|
|
68
|
+
|
|
69
|
+
tables = {
|
|
70
|
+
table_key: table_metadata[table_key]['table']
|
|
71
|
+
for table_key in table_metadata.keys()
|
|
72
|
+
}
|
|
73
|
+
edges = []
|
|
74
|
+
for table_key, table_data in table_metadata.items():
|
|
75
|
+
for edge_key, dst_table in table_data['edges'].items():
|
|
76
|
+
edges.append(
|
|
77
|
+
Edge(src_table=table_key, fkey=edge_key, dst_table=dst_table))
|
|
78
|
+
|
|
79
|
+
# Create and return the graph
|
|
80
|
+
return Graph(
|
|
81
|
+
tables=tables,
|
|
82
|
+
edges=edges,
|
|
83
|
+
)
|