kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,641 @@
1
+ import logging
2
+ from typing import List, Literal, Mapping, Optional, Tuple, Union, overload
3
+
4
+ from kumoapi.jobs import (
5
+ GeneratePredictionTableRequest,
6
+ GenerateTrainTableRequest,
7
+ )
8
+ from kumoapi.model_plan import (
9
+ InferredType,
10
+ PredictionTableGenerationPlan,
11
+ RunMode,
12
+ SuggestModelPlanRequest,
13
+ TrainingTableGenerationPlan,
14
+ )
15
+ from kumoapi.pquery import PQueryResource
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.train import TrainingTableSpec
18
+ from typing_extensions import Self
19
+
20
+ from kumoai import global_state
21
+ from kumoai.client.jobs import (
22
+ GeneratePredictionTableJobID,
23
+ GenerateTrainTableJobID,
24
+ TrainingJobAPI,
25
+ )
26
+ from kumoai.graph import Graph
27
+ from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
28
+ from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
29
+ from kumoai.trainer import (
30
+ BaselineTrainer,
31
+ ModelPlan,
32
+ Trainer,
33
+ TrainingJob,
34
+ TrainingJobResult,
35
+ )
36
+ from kumoai.trainer.job import BaselineJob, BaselineJobResult
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ PredictiveQueryID = str
41
+
42
+
43
+ class PredictiveQuery:
44
+ r"""The Kumo predictive query is a declarative syntax for describing a
45
+ machine learning task. Predictive queries are written using the predictive
46
+ query language (PQL), a concise SQL-like syntax that allows you to define a
47
+ model for a new business problem.
48
+
49
+ A predictive query object can be created from a
50
+ :class:`~kumoai.graph.Graph` and a query string. For information on the
51
+ construction of a query string, please visit the Kumo
52
+ `documentation <https://docs.kumo.ai/docs/pquery-structure/>`__.
53
+
54
+ .. code-block:: python
55
+
56
+ import kumoai
57
+
58
+ # See `Graph` documentation for more information:
59
+ graph = kumoai.Graph(...)
60
+
61
+ # Create a predictive query representing a machine learning problem
62
+ # over this Graph:
63
+ pquery = kumoai.PredictiveQuery(
64
+ graph=graph,
65
+ query=(
66
+ "PREDICT MAX(transaction.Quantity, 0, 30) "
67
+ "FOR EACH customer.CustomerID"
68
+ ),
69
+ )
70
+
71
+ # Validate the predictive query configuration, for syntax and
72
+ # correctness:
73
+ pquery.validate(verbose=True)
74
+
75
+ # Get the machine learning task type corresponding to this predictive
76
+ # query (e.g. binary classification, regression, link prediction, etc.)
77
+ print(pquery.get_task_type())
78
+
79
+ # Suggest a training table generation plan and use it to generate a
80
+ # training table from this query, to be used in `Trainer.fit`:
81
+ training_table_plan = pquery.suggest_training_table_plan()
82
+ training_table = pquery.generate_training_table(training_table_plan)
83
+
84
+ # Suggest a prediction table generation plan and use it to generate a
85
+ # prediction table from this query, to be used in `Trainer.predict`:
86
+ pred_table_plan = pquery.suggest_prediction_table_plan()
87
+ pred_table = pquery.generate_prediction_table(pred_table_plan)
88
+
89
+ Args:
90
+ graph: The :class:`~kumoai.graph.Graph` object which the predictive
91
+ query is defined over.
92
+ query: A string representation of the predictive query.
93
+ """
94
+ def __init__(
95
+ self,
96
+ graph: Graph,
97
+ query: str,
98
+ ) -> None:
99
+ self.graph = graph
100
+ self.query = query
101
+
102
+ # A predictive query owns a trainer object, which is used internally
103
+ # to support `fit` and `predict` directly on this object. A user can
104
+ # also inspect the training table, prediction table, and trainer
105
+ # objects, but cannot set them; any advanced configuration must be
106
+ # done directly via `Trainer`:
107
+ self._train_table: Optional[Union[TrainingTable,
108
+ TrainingTableJob]] = None
109
+ self._prediction_table: Optional[Union[PredictionTable,
110
+ PredictionTableJob]] = None
111
+
112
+ # Metadata ################################################################
113
+
114
+ @property
115
+ def id(self) -> str:
116
+ r"""Returns the unique ID for this predictive query, determined from
117
+ its schema and the schema of its associated graph. Two queries that
118
+ differ either in their syntax or in their graph will have different
119
+ ids.
120
+ """
121
+ return self.save()
122
+
123
+ @property
124
+ def train_table(self) -> Union[TrainingTable, TrainingTableJob]:
125
+ r"""Returns the training table that was last generated by this
126
+ predictive query. If the predictive query has not yet generated a
127
+ training table, raises a :class:`ValueError`.
128
+
129
+ Note that the training table may be of type
130
+ :class:`~kumoai.pquery.TrainingTable` or
131
+ :class:`~kumoai.pquery.TrainingTableJob`, depending on whether the
132
+ training table was generated with or without waiting for its
133
+ completion, respectively.
134
+ """
135
+ if not self._train_table:
136
+ raise ValueError(
137
+ "This predictive query has not yet generated a training "
138
+ "table. Please call `generate_training_table` to generate "
139
+ "a training table before proceeding.")
140
+
141
+ return self._train_table
142
+
143
+ @property
144
+ def prediction_table(self) -> Union[PredictionTable, PredictionTableJob]:
145
+ r"""Returns the prediction table that was last generated by this
146
+ predictive query. If the predictive query has not yet generated a
147
+ prediction table, raises a :class:`ValueError`.
148
+
149
+ Note that the prediction table may be of type
150
+ :class:`~kumoai.pquery.PredictionTable` or
151
+ :class:`~kumoai.pquery.PredictionTableJob`, depending on whether the
152
+ prediction table was generated with or without waiting for its
153
+ completion, respectively.
154
+ """
155
+ if not self._prediction_table:
156
+ raise ValueError(
157
+ "This predictive query has not yet generated a prediction "
158
+ "table. Please call `generate_prediction_table` to generate a "
159
+ "prediction table before proceeding.")
160
+
161
+ return self._prediction_table
162
+
163
+ def get_task_type(self) -> TaskType:
164
+ r"""Returns the task type of this predictive query. The task type of
165
+ the query corresponds to the machine learning problem that this query
166
+ translates to in the Kumo platform; for more information about possible
167
+ task types, please visit the Kumo `documentation
168
+ <https://docs.kumo.ai/docs/task-types/>`__.
169
+ """
170
+ try:
171
+ self.validate(verbose=False)
172
+ except ValueError as e:
173
+ raise ValueError(
174
+ f"Predictive query {self.query} is improperly configured, so "
175
+ f"a task type cannot be obtained. Please ensure your query "
176
+ f"has a valid configuration before proceeding. You can use "
177
+ f"the `validate` method to verify the validity of your query."
178
+ ) from e
179
+ task_type, _ = global_state.client.pquery_api.infer_task_type(
180
+ pquery_string=self.query, graph_id=self.graph.id)
181
+ return task_type
182
+
183
+ def validate(self, verbose: bool = True) -> Self:
184
+ r"""Validates the syntax of this predictive query, ensuring that
185
+ the query is formulated correctly in Kumo's Predictive Query Language
186
+ and that the query makes semantic sense (defines a suitable predictive
187
+ problem) on this :class:`~kumoai.graph.Graph`.
188
+
189
+ Args:
190
+ verbose: Whether to log non-error output of this validation.
191
+
192
+ Raises:
193
+ ValueError:
194
+ if validation fails.
195
+
196
+ Example:
197
+ >>> import kumoai
198
+ >>> query = kumoai.PredictiveQuery(...) # doctest: +SKIP
199
+ >>> query.validate() # doctest: +SKIP
200
+ ValidationResponse(warnings=[], errors=[])
201
+ """
202
+ self.graph.save() # Need a valid graph ID; also validates graph.
203
+
204
+ resp = global_state.client.pquery_api.validate(
205
+ self._to_api_pquery_resource())
206
+
207
+ if not resp.ok:
208
+ raise ValueError(resp.error_message())
209
+ if verbose:
210
+ if resp.empty():
211
+ logger.info("Query %s is configured correctly.", self.query)
212
+ else:
213
+ logger.warning(resp.message())
214
+
215
+ return self
216
+
217
+ # Persistence #############################################################
218
+
219
+ def _to_api_pquery_resource(
220
+ self,
221
+ name: Optional[str] = None,
222
+ ) -> PQueryResource:
223
+ return PQueryResource(
224
+ name=name,
225
+ query_string=self.query,
226
+ graph=self.graph._to_api_graph_definition(),
227
+ desc="",
228
+ )
229
+
230
+ def save(self, name: Optional[str] = None) -> PredictiveQueryID:
231
+ r"""Saves a predictive query to Kumo, returning a unique ID for this
232
+ query. If a name is provided, saves it as a named, re-usable template.
233
+
234
+ Args:
235
+ name: Optional name for the template. If provided, saves the
236
+ query as a named template. If the name already exists,
237
+ that template will be overwritten.
238
+
239
+ Example:
240
+ >>> import kumoai
241
+ >>> query = kumoai.PredictiveQuery(...) # doctest: +SKIP
242
+ >>> query.save() # doctest: +SKIP
243
+ pquery-xxx
244
+ >>> query.save("my_template") # doctest: +SKIP
245
+ my_template
246
+ """
247
+ try:
248
+ self.validate(verbose=False)
249
+ except ValueError as e:
250
+ raise ValueError(
251
+ f"Predictive query {self.query} is improperly configured, so "
252
+ f"it cannot be saved. Please ensure your query "
253
+ f"has a valid configuration before proceeding. You can use "
254
+ f"the `validate` method to verify the validity of your query."
255
+ ) from e
256
+
257
+ if name is not None:
258
+ template_resource = global_state.client.pquery_api.get_if_exists(
259
+ name)
260
+ if template_resource is not None:
261
+ template_string = template_resource.query_string
262
+ logger.warning(
263
+ ("Predictive query template %s already exists, with "
264
+ "query string %s. This template will be overridden with "
265
+ "configuration %s."), name, template_string, self.query)
266
+
267
+ self.graph.save()
268
+ return global_state.client.pquery_api.create(
269
+ pquery=self._to_api_pquery_resource(name))
270
+
271
+ @classmethod
272
+ def load(cls, pq_id_or_template: str) -> 'PredictiveQuery':
273
+ r"""Loads a predictive query from either a predictive query ID or a
274
+ named template. Returns a :class:`~kumoai.pquery.PredictiveQuery`
275
+ object that contains the loaded query along with its associated graph,
276
+ tables, etc.
277
+ """
278
+ api = global_state.client.pquery_api
279
+ res = api.get_if_exists(pq_id_or_template)
280
+ if not res:
281
+ raise ValueError(
282
+ f"Predictive query {pq_id_or_template} was not found.")
283
+ return cls(
284
+ graph=Graph._from_api_graph_definition(res.graph),
285
+ query=res.query_string,
286
+ )
287
+
288
+ @classmethod
289
+ def load_from_training_job(cls, training_job_id: str) -> 'PredictiveQuery':
290
+ r"""Loads a predictive query from a training job, regardless of the
291
+ training job's status. Returns a
292
+ :class:`~kumoai.pquery.PredictiveQuery` object that contains the loaded
293
+ query along with its associated graph, tables, etc.
294
+ """
295
+ train_api: TrainingJobAPI = global_state.client.training_job_api
296
+ job = train_api.get(training_job_id)
297
+ id_or_name = job.config.pquery_id
298
+ return PredictiveQuery.load(pq_id_or_template=id_or_name)
299
+
300
+ # Training & Prediction Table Generation ##################################
301
+
302
+ @overload
303
+ def generate_training_table(
304
+ self,
305
+ plan: Optional[TrainingTableGenerationPlan] = None,
306
+ ) -> TrainingTable:
307
+ pass
308
+
309
+ @overload
310
+ def generate_training_table(
311
+ self,
312
+ plan: Optional[TrainingTableGenerationPlan] = None,
313
+ *,
314
+ non_blocking: Literal[False],
315
+ ) -> TrainingTable:
316
+ pass
317
+
318
+ @overload
319
+ def generate_training_table(
320
+ self,
321
+ plan: Optional[TrainingTableGenerationPlan] = None,
322
+ *,
323
+ non_blocking: Literal[True],
324
+ ) -> TrainingTableJob:
325
+ pass
326
+
327
+ @overload
328
+ def generate_training_table(
329
+ self,
330
+ plan: Optional[TrainingTableGenerationPlan] = None,
331
+ *,
332
+ non_blocking: bool,
333
+ ) -> Union[TrainingTable, TrainingTableJob]:
334
+ pass
335
+
336
+ def generate_training_table(
337
+ self,
338
+ plan: Optional[TrainingTableGenerationPlan] = None,
339
+ *,
340
+ non_blocking: bool = False,
341
+ custom_tags: Mapping[str, str] = {},
342
+ ) -> Union[TrainingTable, TrainingTableJob]:
343
+ r"""Generates a training table from the specified :attr:`query`
344
+ string.
345
+
346
+ Args:
347
+ plan: A specification of the parameters for training table
348
+ generation. If not provided, will use an intelligently
349
+ generated default plan based on the query and graph. This plan
350
+ is equivalent to the plan inferred with
351
+ ``suggest_training_table_plan(run_mode=RunMode.NORMAL)``.
352
+ non_blocking: Whether this operation should return immediately
353
+ after launching the training table generation job, or await
354
+ completion of the generated training table.
355
+ custom_tags: Additional, customer defined k-v tags to be associated
356
+ with the job to be launched. Job tags are useful for grouping
357
+ and searching jobs.
358
+
359
+ Returns:
360
+ Union[TrainingTable, TrainingTableJob]:
361
+ If ``non_blocking=False``, returns a training table object. If
362
+ ``non_blocking=True``, returns a training table future object.
363
+ """
364
+ pq_id = self.save()
365
+
366
+ # TODO(manan): improve this...
367
+ if not plan:
368
+ plan = self.suggest_training_table_plan()
369
+
370
+ train_table_job_api = global_state.client.generate_train_table_job_api
371
+ job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
+ GenerateTrainTableRequest(
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
378
+
379
+ self._train_table = TrainingTableJob(job_id=job_id)
380
+ if non_blocking:
381
+ return self._train_table
382
+ self._train_table = self._train_table.attach()
383
+ return self._train_table
384
+
385
+ @overload
386
+ def generate_prediction_table(
387
+ self,
388
+ plan: Optional[PredictionTableGenerationPlan] = None,
389
+ ) -> PredictionTable:
390
+ pass
391
+
392
+ @overload
393
+ def generate_prediction_table(
394
+ self,
395
+ plan: Optional[PredictionTableGenerationPlan] = None,
396
+ *,
397
+ non_blocking: Literal[False],
398
+ ) -> PredictionTable:
399
+ pass
400
+
401
+ @overload
402
+ def generate_prediction_table(
403
+ self,
404
+ plan: Optional[PredictionTableGenerationPlan] = None,
405
+ *,
406
+ non_blocking: Literal[True],
407
+ ) -> PredictionTableJob:
408
+ pass
409
+
410
+ @overload
411
+ def generate_prediction_table(
412
+ self,
413
+ plan: Optional[PredictionTableGenerationPlan] = None,
414
+ *,
415
+ non_blocking: bool,
416
+ ) -> Union[PredictionTable, PredictionTableJob]:
417
+ pass
418
+
419
+ def generate_prediction_table(
420
+ self,
421
+ plan: Optional[PredictionTableGenerationPlan] = None,
422
+ *,
423
+ non_blocking: bool = False,
424
+ custom_tags: Mapping[str, str] = {},
425
+ ) -> Union[PredictionTable, PredictionTableJob]:
426
+ r"""Generates a prediction table from the predictive query
427
+ :attr:`query` string.
428
+
429
+ Args:
430
+ plan: A specification of the parameters for prediction table
431
+ generation. If not provided, will use an intelligently
432
+ generated default plan based on the query and graph. This plan
433
+ is equivalent to the plan inferred with
434
+ ``suggest_prediction_table_plan(run_mode=RunMode.NORMAL)``.
435
+ non_blocking: Whether this operation should return immediately
436
+ after launching the prediction table generation job, or await
437
+ completion of the generated prediction table.
438
+ custom_tags: Additional, customer defined k-v tags to be associated
439
+ with the job to be launched. Job tags are useful for grouping
440
+ and searching jobs.
441
+
442
+ Returns:
443
+ Union[PredictionTable, PredictionTableJob]:
444
+ If ``non_blocking=False``, returns a prediction table object.
445
+ If ``non_blocking=True``, returns a prediction table future
446
+ object.
447
+ """
448
+ pq_id = self.save()
449
+
450
+ if not plan:
451
+ plan = self.suggest_prediction_table_plan()
452
+
453
+ bp_table_api = global_state.client.generate_prediction_table_job_api
454
+ job_id: GeneratePredictionTableJobID = bp_table_api.create(
455
+ GeneratePredictionTableRequest(
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
461
+
462
+ self._prediction_table = PredictionTableJob(job_id=job_id)
463
+ if non_blocking:
464
+ return self._prediction_table
465
+ self._prediction_table = self._prediction_table.result()
466
+ return self._prediction_table
467
+
468
+ # Training & Prediction ###################################################
469
+
470
+ def suggest_training_table_plan(
471
+ self,
472
+ run_mode: RunMode = RunMode.FAST,
473
+ ) -> TrainingTableGenerationPlan:
474
+ r"""Suggests a training table generation plan given the predictive
475
+ query and graph. This training table generation plan can be used to
476
+ alter the approach Kumo uses to generate the training table for your
477
+ predictive query.
478
+
479
+ Args:
480
+ run_mode: A representation of how quickly you would like your
481
+ predictive query to complete. Faster run modes correspond to
482
+ lower training times, at the cost of potentially lower
483
+ performance.
484
+ """
485
+ self.graph.save()
486
+ req = SuggestModelPlanRequest(
487
+ query_string=self.query,
488
+ graph_id=self.graph.id,
489
+ run_mode=run_mode,
490
+ )
491
+ return global_state.client.pquery_api.suggest_training_table_plan(req)
492
+
493
+ def suggest_prediction_table_plan(self, ) -> PredictionTableGenerationPlan:
494
+ r"""Suggests a prediction table generation plan given the predictive
495
+ query and graph. This prediction table generation plan can be used to
496
+ alter the approach Kumo uses to generate the prediction table for your
497
+ predictive query.
498
+ """
499
+ return PredictionTableGenerationPlan(anchor_time=InferredType.VALUE)
500
+
501
+ def suggest_model_plan(
502
+ self,
503
+ run_mode: RunMode = RunMode.FAST,
504
+ train_table_spec: Optional[TrainingTableSpec] = None,
505
+ ) -> ModelPlan:
506
+ r"""Suggests a modeling plan given the predictive query and graph. This
507
+ model plan can be used to alter the approach Kumo uses to train your
508
+ machine learning model.
509
+
510
+ Args:
511
+ run_mode: A representation of how quickly you would like your
512
+ predictive query to complete. Faster run modes correspond to
513
+ lower training times, at the cost of potentially lower
514
+ performance.
515
+ train_table_spec: Needed if the original train table has been
516
+ modified by adding a weight column.
517
+ """
518
+ self.graph.save()
519
+ req = SuggestModelPlanRequest(
520
+ query_string=self.query,
521
+ graph_id=self.graph.id,
522
+ run_mode=run_mode,
523
+ train_table_spec=train_table_spec,
524
+ )
525
+ return global_state.client.pquery_api.suggest_model_plan(req)
526
+
527
+ @overload
528
+ def fit(
529
+ self,
530
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
531
+ model_plan: Optional[ModelPlan] = None,
532
+ ) -> Tuple[Trainer, TrainingJobResult]:
533
+ pass
534
+
535
+ @overload
536
+ def fit(
537
+ self,
538
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
539
+ model_plan: Optional[ModelPlan] = None,
540
+ *,
541
+ non_blocking: Literal[False],
542
+ ) -> Tuple[Trainer, TrainingJobResult]:
543
+ pass
544
+
545
+ @overload
546
+ def fit(
547
+ self,
548
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
549
+ model_plan: Optional[ModelPlan] = None,
550
+ *,
551
+ non_blocking: Literal[True],
552
+ ) -> Tuple[Trainer, TrainingJob]:
553
+ pass
554
+
555
+ @overload
556
+ def fit(
557
+ self,
558
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
559
+ model_plan: Optional[ModelPlan] = None,
560
+ *,
561
+ non_blocking: bool,
562
+ ) -> Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
563
+ pass
564
+
565
+ def fit(
566
+ self,
567
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
568
+ model_plan: Optional[ModelPlan] = None,
569
+ *,
570
+ non_blocking: bool = False,
571
+ ) -> Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
572
+ r"""Trains a Kumo model on this predictive query, given optional
573
+ additional specifications of the training table generation plan and
574
+ the model plan.
575
+
576
+ Args:
577
+ training_table_plan: A specification of the parameters for training
578
+ table generation. If not provided, will use an intelligently
579
+ generated default plan based on the query and graph. This plan
580
+ is equivalent to the plan inferred with
581
+ ``suggest_training_table_plan(run_mode=RunMode.NORMAL)``.
582
+ model_plan: A specification of the parameters for model training.
583
+ If not provided, will use an intelligently generated default
584
+ plan based on the query and graph. This plan
585
+ is equivalent to the plan inferred with
586
+ ``suggest_model_plan(run_mode=RunMode.NORMAL)``.
587
+ non_blocking: Whether this operation should return immediately
588
+ after launching the training job, or await completion of the
589
+ training job.
590
+
591
+ Returns:
592
+ Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
593
+ A tuple with two elements. The first element is the trainer
594
+ object used to launch the training job. The second element
595
+ is either a training job object (if ``non_blocking=True``)
596
+ or a training job future object (if ``non_blocking=False``).
597
+ """
598
+ # If we have already generated the training table, use it with Trainer:
599
+ if self._train_table is None:
600
+ # Nonblocking generate:
601
+ self._train_table = self.generate_training_table(
602
+ training_table_plan, non_blocking=True)
603
+
604
+ # TODO(manan): what if `self._train_table` represents a failed job?
605
+ model_plan = model_plan or self.suggest_model_plan()
606
+ trainer = Trainer(model_plan)
607
+ return (trainer,
608
+ trainer.fit(self.graph, self.train_table,
609
+ non_blocking=non_blocking))
610
+
611
+ def generate_baseline(
612
+ self,
613
+ metrics: List[str],
614
+ train_table: Union[TrainingTable, TrainingTableJob],
615
+ *,
616
+ non_blocking: bool = False,
617
+ ) -> Union[BaselineJob, BaselineJobResult]:
618
+ r"""Runs a baseline model on this predictive query, given metrics and
619
+ optional additional specifications of the training table generation
620
+ plan.
621
+
622
+ Args:
623
+ metrics (List[str]): A list to metrics that baseline model will be
624
+ evaluated on.
625
+ train_table (Union[TrainingTable, TrainingTableJob]): The
626
+ :class:`~kumoai.pquery.TrainingTable`, or in-progress
627
+ :class:`~kumoai.pquery.TrainingTableJob` that represents
628
+ the training data produced by a
629
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
630
+ non_blocking (bool): Whether this operation should
631
+ return immediately after launching the baseline job, or await
632
+ completion of the baseline job. Defaults to False.
633
+
634
+ Returns:
635
+ Union[BaselineJob, BaselineJobResult]: either a baseline job
636
+ object (if ``non_blocking=True``) or a baseline job future
637
+ object (if ``non_blocking=False``).
638
+ """ # noqa
639
+ baseline_trainer = BaselineTrainer(metrics)
640
+ return baseline_trainer.run(self.graph, train_table,
641
+ non_blocking=non_blocking)