kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,475 @@
1
+ import logging
2
+ from typing import (
3
+ List,
4
+ Literal,
5
+ Mapping,
6
+ Optional,
7
+ Set,
8
+ Tuple,
9
+ Union,
10
+ overload,
11
+ )
12
+
13
+ from kumoapi.jobs import (
14
+ BatchPredictionOptions,
15
+ BatchPredictionRequest,
16
+ JobStatus,
17
+ MetadataField,
18
+ PredictionOutputConfig,
19
+ TrainingJobRequest,
20
+ TrainingJobResource,
21
+ )
22
+ from kumoapi.model_plan import ModelPlan
23
+
24
+ from kumoai import global_state
25
+ from kumoai.artifact_export.config import OutputConfig
26
+ from kumoai.client.jobs import (
27
+ GeneratePredictionTableJobID,
28
+ TrainingJobAPI,
29
+ TrainingJobID,
30
+ )
31
+ from kumoai.connector.base import Connector
32
+ from kumoai.connector.s3_connector import S3URI
33
+ from kumoai.databricks import to_db_table_name
34
+ from kumoai.graph import Graph
35
+ from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
36
+ from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
37
+ from kumoai.trainer.job import (
38
+ BatchPredictionJob,
39
+ BatchPredictionJobResult,
40
+ TrainingJob,
41
+ TrainingJobResult,
42
+ )
43
+ from kumoai.trainer.util import (
44
+ build_prediction_output_config,
45
+ validate_output_arguments,
46
+ )
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class Trainer:
52
+ r"""A trainer supports creating a Kumo machine learning model on a
53
+ :class:`~kumoai.pquery.PredictiveQuery`. It is primarily oriented around
54
+ two methods: :meth:`~kumoai.trainer.Trainer.fit`, which accepts a
55
+ :class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.TrainingTable` and
56
+ produces a :class:`~kumoai.trainer.TrainingJobResult`, and
57
+ :meth:`~kumoai.trainer.Trainer.predict`, which accepts a
58
+ :class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.PredictionTable`
59
+ and produces a :class:`~kumoai.trainer.BatchPredictionJobResult`.
60
+
61
+ A :class:`~kumoai.trainer.Trainer` can also be loaded from a training job,
62
+ with :meth:`~kumoai.trainer.Trainer.load`.
63
+
64
+ .. code-block:: python
65
+
66
+ import kumoai
67
+
68
+ # See `Graph` and `PredictiveQuery` documentation:
69
+ graph = kumoai.Graph(...)
70
+ pquery = kumoai.PredictiveQuery(graph=graph, query=...)
71
+
72
+ # Create a `Trainer` object, using a suggested model plan given the
73
+ # predictive query:
74
+ model_plan = pquery.suggest_model_plan()
75
+ trainer = kumoai.Trainer(model_plan)
76
+
77
+ # Create a training table from the predictive query:
78
+ training_table = pquery.generate_training_table()
79
+
80
+ # Fit a model:
81
+ training_job = trainer.fit(
82
+ graph = graph,
83
+ training_table = training_table,
84
+ )
85
+
86
+ # Create a prediction table from the predictive query:
87
+ prediction_table = pquery.generate_prediction_table()
88
+
89
+ # Generate predictions:
90
+ prediction_job = trainer.predict(
91
+ graph = graph,
92
+ prediction_table = prediction_table,
93
+ # other arguments here...
94
+ )
95
+
96
+ # Load a trained query to generate predictions:
97
+ pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
98
+ trainer = kumoai.Trainer.load("trainingjob-...")
99
+ prediction_job = trainer.predict(
100
+ pquery.graph,
101
+ pquery.generate_prediction_table(),
102
+ )
103
+
104
+ Args:
105
+ model_plan: A model plan that the trainer should follow when fitting a
106
+ Kumo model to a predictive query. This model plan can either be
107
+ generated from a predictive query, with
108
+ :meth:`~kumoai.pquery.PredictiveQuery.suggest_model_plan`, or can
109
+ be manually specified.
110
+ """ # noqa: E501
111
+
112
+ def __init__(self, model_plan: ModelPlan) -> None:
113
+ self._model_plan: Optional[ModelPlan] = model_plan
114
+
115
+ # Cached from backend:
116
+ self._training_job_id: Optional[TrainingJobID] = None
117
+
118
+ @property
119
+ def model_plan(self) -> Optional[ModelPlan]:
120
+ return self._model_plan
121
+
122
+ @model_plan.setter
123
+ def model_plan(self, model_plan: ModelPlan) -> None:
124
+ self._model_plan = model_plan
125
+
126
+ # Metadata ################################################################
127
+
128
+ @property
129
+ def is_trained(self) -> bool:
130
+ r"""Returns ``True`` if this trainer instance has successfully been
131
+ trained (and is therefore ready for prediction); ``False`` otherwise.
132
+ """
133
+ if not self._training_job_id:
134
+ return False
135
+ try:
136
+ api = global_state.client.training_job_api
137
+ res: TrainingJobResource = api.get(self._training_job_id)
138
+ except Exception as e: # noqa
139
+ logger.exception(
140
+ "Failed to fetch training status for training job with ID %s",
141
+ self._training_job_id, exc_info=e)
142
+ return False
143
+ return res.job_status_report.status == JobStatus.DONE
144
+
145
+ # Fit / predict ###########################################################
146
+
147
+ @overload
148
+ def fit(
149
+ self,
150
+ graph: Graph,
151
+ train_table: Union[TrainingTable, TrainingTableJob],
152
+ ) -> TrainingJobResult:
153
+ pass
154
+
155
+ @overload
156
+ def fit(
157
+ self,
158
+ graph: Graph,
159
+ train_table: Union[TrainingTable, TrainingTableJob],
160
+ *,
161
+ non_blocking: Literal[False],
162
+ ) -> TrainingJobResult:
163
+ pass
164
+
165
+ @overload
166
+ def fit(
167
+ self,
168
+ graph: Graph,
169
+ train_table: Union[TrainingTable, TrainingTableJob],
170
+ *,
171
+ non_blocking: Literal[True],
172
+ ) -> TrainingJob:
173
+ pass
174
+
175
+ @overload
176
+ def fit(
177
+ self,
178
+ graph: Graph,
179
+ train_table: Union[TrainingTable, TrainingTableJob],
180
+ *,
181
+ non_blocking: bool,
182
+ ) -> Union[TrainingJob, TrainingJobResult]:
183
+ pass
184
+
185
+ def fit(
186
+ self,
187
+ graph: Graph,
188
+ train_table: Union[TrainingTable, TrainingTableJob],
189
+ *,
190
+ non_blocking: bool = False,
191
+ custom_tags: Mapping[str, str] = {},
192
+ warm_start_job_id: Optional[TrainingJobID] = None,
193
+ ) -> Union[TrainingJob, TrainingJobResult]:
194
+ r"""Fits a model to the specified graph and training table, with the
195
+ strategy defined by this :class:`Trainer`'s :obj:`model_plan`.
196
+
197
+ Args:
198
+ graph: The :class:`~kumoai.graph.Graph` object that represents the
199
+ tables and relationships that Kumo will learn from.
200
+ train_table: The :class:`~kumoai.pquery.TrainingTable`, or
201
+ in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
202
+ represents the training data produced by a
203
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
204
+ non_blocking: Whether this operation should return immediately
205
+ after launching the training job, or await completion of the
206
+ training job.
207
+ custom_tags: Additional, customer defined k-v tags to be associated
208
+ with the job to be launched. Job tags are useful for grouping
209
+ and searching jobs.
210
+ warm_start_job_id: Optional job ID of a completed training job to
211
+ warm start from. Initializes the new model with the best
212
+ weights from the specified job, using its model
213
+ architecture, column processing, and neighbor sampling
214
+ configurations.
215
+
216
+ Returns:
217
+ Union[TrainingJobResult, TrainingJob]:
218
+ If ``non_blocking=False``, returns a training job object. If
219
+ ``non_blocking=True``, returns a training job future object.
220
+ """
221
+ # TODO(manan, siyang): remove soon:
222
+ job_id = train_table.job_id
223
+ assert job_id is not None
224
+
225
+ train_table_job_api = global_state.client.generate_train_table_job_api
226
+ pq_id = train_table_job_api.get(job_id).config.pquery_id
227
+ assert pq_id is not None
228
+
229
+ custom_table = None
230
+ if isinstance(train_table, TrainingTable):
231
+ custom_table = train_table._custom_train_table
232
+
233
+ # NOTE the backend implementation currently handles sequentialization
234
+ # between a training table future and a training job; that is, if the
235
+ # training table future is still executing, the backend will wait on
236
+ # the job ID completion before executing a training job. This preserves
237
+ # semantics for both futures, ensures that Kumo works as expected if
238
+ # used only via REST API, and allows us to avoid chaining calllbacks
239
+ # in an ugly way here:
240
+ api = global_state.client.training_job_api
241
+ self._training_job_id = api.create(
242
+ TrainingJobRequest(
243
+ dict(custom_tags),
244
+ pquery_id=pq_id,
245
+ model_plan=self._model_plan,
246
+ graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
247
+ train_table_job_id=job_id,
248
+ custom_train_table=custom_table,
249
+ warm_start_job_id=warm_start_job_id,
250
+ ))
251
+
252
+ out = TrainingJob(job_id=self._training_job_id)
253
+ if non_blocking:
254
+ return out
255
+ return out.attach()
256
+
257
+ def predict(
258
+ self,
259
+ graph: Graph,
260
+ prediction_table: Union[PredictionTable, PredictionTableJob],
261
+ output_types: Optional[Set[str]] = None,
262
+ output_connector: Optional[Connector] = None,
263
+ output_table_name: Optional[Union[str, Tuple]] = None,
264
+ output_metadata_fields: Optional[List[MetadataField]] = None,
265
+ output_config: Optional[OutputConfig] = None,
266
+ training_job_id: Optional[TrainingJobID] = None,
267
+ binary_classification_threshold: Optional[float] = None,
268
+ num_classes_to_return: Optional[int] = None,
269
+ num_workers: int = 1,
270
+ non_blocking: bool = False,
271
+ custom_tags: Mapping[str, str] = {},
272
+ ) -> Union[BatchPredictionJob, BatchPredictionJobResult]:
273
+ """Using the trained model specified by :obj:`training_job_id` (or
274
+ the model corresponding to the last invocation of
275
+ :meth:`~kumoai.trainer.Trainer.fit`, if not present), predicts the
276
+ future values of the entities in :obj:`prediction_table` leveraging
277
+ current information from :obj:`graph`.
278
+
279
+ Args:
280
+ graph: The :class:`~kumoai.graph.Graph` object that represents the
281
+ tables and relationships that Kumo will use to make
282
+ predictions.
283
+ prediction_table: The :class:`~kumoai.pquery.PredictionTable`, or
284
+ in-progress :class:`~kumoai.pquery.PredictionTableJob`, that
285
+ represents the prediction data produced by a
286
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`. This
287
+ table may also be custom-generated.
288
+ output_config: Output configuration defining properties of the
289
+ generated batch prediction outputs. This includes:
290
+
291
+ - output_types: The types of outputs that should be produced by
292
+ the prediction job. Can include either ``'predictions'``,
293
+ ``'embeddings'``, or both.
294
+ - output_connector: The output data source that Kumo should
295
+ write batch predictions to, if it is None, produce
296
+ local download output only.
297
+ - output_table_name: The name of the table in the output data
298
+ source that Kumo should write batch predictions to. In the
299
+ case of a Databricks connector, this should be a tuple of
300
+ two strings, the schema name and the output prediction
301
+ table name.
302
+ - output_metadata_fields: Any additional metadata fields to
303
+ include as new columns in the produced ``'predictions'``
304
+ output. Currently, allowed options are ``JOB_TIMESTAMP``
305
+ and ``ANCHOR_TIMESTAMP``.
306
+ - connector_specific_config: Custom connector specific output
307
+ configuration, such as whether to append or overwrite
308
+ existing tables.
309
+
310
+ output_types: *(Deprecated)* The types of outputs that should be
311
+ produced by the prediction job. Can include either
312
+ ``'predictions'``, ``'embeddings'``, or both. Use
313
+ :obj:`output_config` instead.
314
+ output_connector: *(Deprecated)* The output data source that Kumo
315
+ should write batch predictions to, if it is None, produce local
316
+ download output only. Use :obj:`output_config` instead.
317
+ output_table_name: *(Deprecated)* The name of the table in the
318
+ output data source that Kumo should write batch predictions to.
319
+ In the case of a Databricks connector, this should be a tuple
320
+ of two strings: the schema name and the output prediction
321
+ table name. Use :obj:`output_config` instead.
322
+ output_metadata_fields: *(Deprecated)* Any additional metadata
323
+ fields to include as new columns in the produced
324
+ ``'predictions'`` output. Currently, allowed options are
325
+ ``JOB_TIMESTAMP`` and ``ANCHOR_TIMESTAMP``. Use
326
+ :obj:`output_config` instead.
327
+ training_job_id: The job ID of the training job whose model will be
328
+ used for making predictions.
329
+ binary_classification_threshold: If this model corresponds to
330
+ a binary classification task, the threshold for which higher
331
+ values correspond to ``1``, and lower values correspond to
332
+ ``0``.
333
+ num_classes_to_return: If this model corresponds to a ranking task,
334
+ the number of classes to return in the prediction output.
335
+ num_workers: Number of workers to use when generating batch
336
+ predictions. When set to a value greater than 1, the prediction
337
+ table is partitioned into smaller parts and processed in
338
+ parallel. The default is 1, which implies sequential inference
339
+ over the prediction table.
340
+ non_blocking: Whether this operation should return immediately
341
+ after launching the batch prediction job, or await
342
+ completion of the batch prediction job.
343
+ custom_tags: Additional, customer defined k-v tags to be associated
344
+ with the job to be launched. Job tags are useful for grouping
345
+ and searching jobs.
346
+
347
+ Returns:
348
+ Union[BatchPredictionJob, BatchPredictionJobResult]:
349
+ If ``non_blocking=False``, returns a training job object. If
350
+ ``non_blocking=True``, returns a training job future object.
351
+ """
352
+ if (output_types is not None or output_connector is not None
353
+ or output_table_name is not None
354
+ or output_metadata_fields is not None):
355
+ raise ValueError(
356
+ 'Specifying output_types, output_connector, '
357
+ 'output_metadata_fields '
358
+ 'and output_table_name as direct inputs to predict() is '
359
+ 'deprecated. Please use output_config to specify these '
360
+ 'parameters.')
361
+ assert output_config is not None
362
+ # Be able to pass output_config as a dictionary
363
+ if isinstance(output_config, dict):
364
+ output_config = OutputConfig(**output_config)
365
+ output_table_name = to_db_table_name(output_config.output_table_name)
366
+ validate_output_arguments(
367
+ output_config.output_types,
368
+ output_config.output_connector,
369
+ output_table_name,
370
+ )
371
+
372
+ # Create outputs:
373
+ outputs: List[PredictionOutputConfig] = []
374
+ for output_type in output_config.output_types:
375
+ if output_config.output_connector is None:
376
+ # Predictions are generated to the Kumo dataplane, and can
377
+ # only be exported via the UI for now:
378
+ pass
379
+ else:
380
+ outputs.append(
381
+ build_prediction_output_config(
382
+ output_type,
383
+ output_config.output_connector,
384
+ output_table_name,
385
+ output_config.output_metadata_fields,
386
+ output_config,
387
+ ))
388
+
389
+ training_job_id = training_job_id or self._training_job_id
390
+ if training_job_id is None:
391
+ raise ValueError(
392
+ "Cannot run batch prediction without a specified or saved "
393
+ "training job ID. Please either call `fit(...)` or pass a "
394
+ "job ID of a completed training job to proceed.")
395
+
396
+ pred_table_job_id: Optional[GeneratePredictionTableJobID] = \
397
+ prediction_table.job_id
398
+ pred_table_data_path = None
399
+ if pred_table_job_id is None:
400
+ assert isinstance(prediction_table, PredictionTable)
401
+ if isinstance(prediction_table.table_data_uri, S3URI):
402
+ pred_table_data_path = prediction_table.table_data_uri.uri
403
+ else:
404
+ pred_table_data_path = prediction_table.table_data_uri
405
+
406
+ api = global_state.client.batch_prediction_job_api
407
+ # Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
408
+ # from kumoai.pquery.predictive_query import PredictiveQuery
409
+ # pquery = PredictiveQuery.load_from_training_job(training_job_id)
410
+ # if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
411
+ # if binary_classification_threshold is None:
412
+ # logger.warning(
413
+ # "No binary classification threshold provided. "
414
+ # "Using default threshold of 0.5.")
415
+ # binary_classification_threshold = 0.5
416
+ job_id, response = api.maybe_create(
417
+ BatchPredictionRequest(
418
+ dict(custom_tags),
419
+ model_training_job_id=training_job_id,
420
+ predict_options=BatchPredictionOptions(
421
+ binary_classification_threshold=(
422
+ binary_classification_threshold),
423
+ num_classes_to_return=num_classes_to_return,
424
+ num_workers=num_workers,
425
+ ),
426
+ outputs=outputs,
427
+ graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
428
+ pred_table_job_id=pred_table_job_id,
429
+ pred_table_path=pred_table_data_path,
430
+ ))
431
+
432
+ message = response.message()
433
+ if not response.ok:
434
+ raise RuntimeError(f"Prediction failed. {message}")
435
+ elif not response.empty():
436
+ logger.warning("Prediction produced the following warnings: %s",
437
+ message)
438
+ assert job_id is not None
439
+
440
+ self._batch_prediction_job_id = job_id
441
+ out = BatchPredictionJob(job_id=self._batch_prediction_job_id)
442
+ if non_blocking:
443
+ return out
444
+ return out.attach()
445
+
446
+ # Persistence #############################################################
447
+
448
+ @classmethod
449
+ def _load_from_job(cls, job: TrainingJobResource) -> 'Trainer':
450
+ trainer = cls(job.config.model_plan)
451
+ trainer._training_job_id = job.job_id
452
+ return trainer
453
+
454
+ @classmethod
455
+ def load(cls, job_id: TrainingJobID) -> 'Trainer':
456
+ r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
457
+ job ID.
458
+ """
459
+ api: TrainingJobAPI = global_state.client.training_job_api
460
+ job = api.get(job_id)
461
+ return cls._load_from_job(job)
462
+
463
+ # TODO(siyang): load trainer by searching training job via tags.
464
+ @classmethod
465
+ def load_from_tags(cls, tags: Mapping[str, str]) -> 'Trainer':
466
+ r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a set of
467
+ job tags. If multiple jobs match the list of tags, only one will be
468
+ selected.
469
+ """
470
+ api = global_state.client.training_job_api
471
+ jobs = api.list(limit=1, additional_tags=tags)
472
+ if not jobs:
473
+ raise RuntimeError(f'No successful training job found for {tags}')
474
+ assert len(jobs) == 1
475
+ return cls._load_from_job(jobs[0])
kumoai/trainer/util.py ADDED
@@ -0,0 +1,103 @@
1
+ import os
2
+ from typing import List, Optional, Set, Tuple, Union
3
+
4
+ from kumoapi.jobs import (
5
+ BigQueryPredictionOutput,
6
+ DatabricksPredictionOutput,
7
+ MetadataField,
8
+ PredictionArtifactType,
9
+ PredictionOutputConfig,
10
+ PredictionStorageType,
11
+ S3PredictionOutput,
12
+ SnowflakePredictionOutput,
13
+ WriteMode,
14
+ )
15
+
16
+ from kumoai.artifact_export.config import OutputConfig
17
+ from kumoai.connector import (
18
+ BigQueryConnector,
19
+ Connector,
20
+ DatabricksConnector,
21
+ S3Connector,
22
+ SnowflakeConnector,
23
+ )
24
+ from kumoai.databricks import DB_SEP
25
+
26
+
27
+ def validate_output_arguments(
28
+ output_types: Set[str],
29
+ output_connector: Optional[Connector] = None,
30
+ output_table_name: Optional[str] = None,
31
+ ) -> None:
32
+ r"""Validate the output arguments for a prediction job or an export job."""
33
+ output_types = {x.lower() for x in output_types}
34
+ assert output_types.issubset({'predictions', 'embeddings'})
35
+ if output_connector is not None:
36
+ assert output_table_name is not None
37
+ if not isinstance(output_connector,
38
+ (S3Connector, SnowflakeConnector,
39
+ DatabricksConnector, BigQueryConnector)):
40
+ raise ValueError(
41
+ f"Connector type {type(output_connector)} is not supported for"
42
+ f" outputs. Supported output connector types are S3, "
43
+ f"Snowflake, Databricks, and BigQuery.")
44
+ if not isinstance(output_table_name, str):
45
+ raise ValueError(
46
+ f"The output table name must be a string for all "
47
+ f"non-Databricks connectors. Got '{output_table_name}'.")
48
+
49
+ if isinstance(output_connector, S3Connector):
50
+ assert output_connector.root_dir is not None
51
+ if isinstance(output_connector, DatabricksConnector):
52
+ assert DB_SEP in output_table_name
53
+
54
+
55
+ def build_prediction_output_config(
56
+ output_type: str,
57
+ output_connector: Optional[Connector] = None,
58
+ output_table_name: Optional[Union[str, Tuple]] = None,
59
+ output_metadata_fields: Optional[List[MetadataField]] = None,
60
+ output_config: Optional[OutputConfig] = None,
61
+ ) -> PredictionOutputConfig:
62
+ r"""Build the prediction output config."""
63
+ assert output_config is not None
64
+ artifact_type = PredictionArtifactType(output_type.upper())
65
+ output_name = f"{output_table_name}_{output_type}"
66
+ output_metadata_fields = output_metadata_fields or []
67
+ if isinstance(output_connector, S3Connector):
68
+ assert output_connector.root_dir is not None
69
+ return S3PredictionOutput(
70
+ artifact_type=artifact_type,
71
+ file_path=os.path.join(output_connector.root_dir, output_name),
72
+ extra_fields=output_metadata_fields,
73
+ )
74
+ elif isinstance(output_connector, SnowflakeConnector):
75
+ return SnowflakePredictionOutput(
76
+ artifact_type=artifact_type,
77
+ connector_id=output_connector.name,
78
+ table_name=output_name,
79
+ extra_fields=output_metadata_fields,
80
+ write_mode=output_config.connector_specific_config.write_mode
81
+ if output_config.connector_specific_config is not None else
82
+ WriteMode.OVERWRITE,
83
+ )
84
+ elif isinstance(output_connector, DatabricksConnector):
85
+ return DatabricksPredictionOutput(
86
+ artifact_type=artifact_type,
87
+ connector_id=output_connector.name,
88
+ table_name=output_name,
89
+ extra_fields=output_metadata_fields,
90
+ )
91
+ elif isinstance(output_connector, BigQueryConnector):
92
+ return BigQueryPredictionOutput(
93
+ storage_type=PredictionStorageType.BIGQUERY,
94
+ artifact_type=artifact_type,
95
+ connector_id=output_connector.name,
96
+ table_name=output_name,
97
+ extra_fields=output_metadata_fields,
98
+ write_mode=output_config.connector_specific_config.write_mode
99
+ if output_config.connector_specific_config is not None else
100
+ WriteMode.OVERWRITE,
101
+ )
102
+ else:
103
+ raise NotImplementedError()
@@ -0,0 +1,11 @@
1
+ from .sql import quote_ident
2
+ from .progress_logger import ProgressLogger
3
+ from .forecasting import ForecastVisualizer
4
+ from .datasets import from_relbench
5
+
6
+ __all__ = [
7
+ 'quote_ident',
8
+ 'ProgressLogger',
9
+ 'ForecastVisualizer',
10
+ 'from_relbench',
11
+ ]
@@ -0,0 +1,83 @@
1
+ from kumoai.connector import FileUploadConnector
2
+ from kumoai.connector.utils import replace_table
3
+ from kumoai.graph import Edge, Graph, Table
4
+
5
+
6
+ def from_relbench(dataset_name: str) -> Graph:
7
+ r"""Creates a Kumo graph from a RelBench dataset. This function processes
8
+ the specified RelBench dataset, uploads its tables to the Kumo data plane,
9
+ and registers them as part of a Kumo graph, including inferred table
10
+ metadata and edges.
11
+
12
+ .. note::
13
+
14
+ Please note that this method is subject to the limitations for file
15
+ upload in :class:`~kumoai.connector.FileUploadConnector`.
16
+
17
+ .. code-block:: python
18
+
19
+ import kumoai
20
+ from kumoai.utils.datasets import from_relbench
21
+
22
+ # Assume dataset `example_dataset` in the RelBench repository:
23
+ graph = from_relbench(dataset_name="example_dataset")
24
+
25
+ Args:
26
+ dataset_name: The name of the RelBench dataset to be processed.
27
+
28
+ Returns:
29
+ A :class:`~kumoai.Graph` object containing the tables and edges
30
+ derived from the RelBench dataset.
31
+
32
+ Raises:
33
+ ValueError: If the dataset cannot be retrieved or processed.
34
+ """
35
+ try:
36
+ import relbench
37
+ except ImportError:
38
+ raise RuntimeError(
39
+ "Creating a Kumo Graph from a RelBench dataset requires the "
40
+ "'relbench' package to be installed. Please install the package "
41
+ "before proceeding.")
42
+
43
+ connector = FileUploadConnector(file_type="parquet")
44
+ dataset = relbench.datasets.get_dataset(dataset_name, download=True)
45
+ db = dataset.get_db(upto_test_timestamp=False)
46
+
47
+ # Store table metadata and edges:
48
+ table_metadata = {}
49
+
50
+ # Process each table in the database
51
+ for table_key in db.table_dict.keys():
52
+ # Save the table locally as a parquet file:
53
+ table = db.table_dict[table_key]
54
+ parquet_path = f"tmp_{table_key}.parquet"
55
+ table.df.to_parquet(parquet_path, index=False)
56
+
57
+ # Replace the table on the Kumo data plane:
58
+ replace_table(name=table_key, path=parquet_path, file_type="parquet")
59
+
60
+ # Register the table with inferred metadata and collect edge
61
+ # information:
62
+ table_metadata[table_key] = dict(
63
+ table=Table.from_source_table(
64
+ source_table=connector[table_key],
65
+ primary_key=table.pkey_col,
66
+ time_column=table.time_col,
67
+ ).infer_metadata(), edges=table.fkey_col_to_pkey_table)
68
+
69
+ tables = {
70
+ table_key: table_metadata[table_key]['table']
71
+ for table_key in table_metadata.keys()
72
+ }
73
+ edges = []
74
+ for table_key, table_data in table_metadata.items():
75
+ for edge_key, dst_table in table_data['edges'].items():
76
+ edges.append(
77
+ Edge(src_table=table_key, fkey=edge_key, dst_table=dst_table))
78
+
79
+ # Create and return the graph
80
+ return Graph(
81
+ tables=tables,
82
+ edges=edges,
83
+ )