kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,637 @@
1
+ import logging
2
+ from typing import List, Literal, Mapping, Optional, Tuple, Union, overload
3
+
4
+ from kumoapi.jobs import (
5
+ GeneratePredictionTableRequest,
6
+ GenerateTrainTableRequest,
7
+ )
8
+ from kumoapi.model_plan import (
9
+ InferredType,
10
+ PredictionTableGenerationPlan,
11
+ RunMode,
12
+ SuggestModelPlanRequest,
13
+ TrainingTableGenerationPlan,
14
+ )
15
+ from kumoapi.pquery import PQueryResource
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.train import TrainingTableSpec
18
+ from typing_extensions import Self
19
+
20
+ from kumoai import global_state
21
+ from kumoai.client.jobs import (
22
+ GeneratePredictionTableJobID,
23
+ GenerateTrainTableJobID,
24
+ TrainingJobAPI,
25
+ )
26
+ from kumoai.graph import Graph
27
+ from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
28
+ from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
29
+ from kumoai.trainer import (
30
+ BaselineTrainer,
31
+ ModelPlan,
32
+ Trainer,
33
+ TrainingJob,
34
+ TrainingJobResult,
35
+ )
36
+ from kumoai.trainer.job import BaselineJob, BaselineJobResult
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ PredictiveQueryID = str
41
+
42
+
43
+ class PredictiveQuery:
44
+ r"""The Kumo predictive query is a declarative syntax for describing a
45
+ machine learning task. Predictive queries are written using the predictive
46
+ query language (PQL), a concise SQL-like syntax that allows you to define a
47
+ model for a new business problem.
48
+
49
+ A predictive query object can be created from a
50
+ :class:`~kumoai.graph.Graph` and a query string. For information on the
51
+ construction of a query string, please visit the Kumo
52
+ `documentation <https://docs.kumo.ai/docs/pquery-structure/>`__.
53
+
54
+ .. code-block:: python
55
+
56
+ import kumoai
57
+
58
+ # See `Graph` documentation for more information:
59
+ graph = kumoai.Graph(...)
60
+
61
+ # Create a predictive query representing a machine learning problem
62
+ # over this Graph:
63
+ pquery = kumoai.PredictiveQuery(
64
+ graph=graph,
65
+ query=(
66
+ "PREDICT MAX(transaction.Quantity, 0, 30) "
67
+ "FOR EACH customer.CustomerID"
68
+ ),
69
+ )
70
+
71
+ # Validate the predictive query configuration, for syntax and
72
+ # correctness:
73
+ pquery.validate(verbose=True)
74
+
75
+ # Get the machine learning task type corresponding to this predictive
76
+ # query (e.g. binary classification, regression, link prediction, etc.)
77
+ print(pquery.get_task_type())
78
+
79
+ # Suggest a training table generation plan and use it to generate a
80
+ # training table from this query, to be used in `Trainer.fit`:
81
+ training_table_plan = pquery.suggest_training_table_plan()
82
+ training_table = pquery.generate_training_table(training_table_plan)
83
+
84
+ # Suggest a prediction table generation plan and use it to generate a
85
+ # prediction table from this query, to be used in `Trainer.predict`:
86
+ pred_table_plan = pquery.suggest_prediction_table_plan()
87
+ pred_table = pquery.generate_prediction_table(pred_table_plan)
88
+
89
+ Args:
90
+ graph: The :class:`~kumoai.graph.Graph` object which the predictive
91
+ query is defined over.
92
+ query: A string representation of the predictive query.
93
+ """
94
+ def __init__(
95
+ self,
96
+ graph: Graph,
97
+ query: str,
98
+ ) -> None:
99
+ self.graph = graph
100
+ self.query = query
101
+
102
+ # A predictive query owns a trainer object, which is used internally
103
+ # to support `fit` and `predict` directly on this object. A user can
104
+ # also inspect the training table, prediction table, and trainer
105
+ # objects, but cannot set them; any advanced configuration must be
106
+ # done directly via `Trainer`:
107
+ self._train_table: Optional[Union[TrainingTable,
108
+ TrainingTableJob]] = None
109
+ self._prediction_table: Optional[Union[PredictionTable,
110
+ PredictionTableJob]] = None
111
+
112
+ # Metadata ################################################################
113
+
114
+ @property
115
+ def id(self) -> str:
116
+ r"""Returns the unique ID for this predictive query, determined from
117
+ its schema and the schema of its associated graph. Two queries that
118
+ differ either in their syntax or in their graph will have different
119
+ ids.
120
+ """
121
+ return self.save()
122
+
123
+ @property
124
+ def train_table(self) -> Union[TrainingTable, TrainingTableJob]:
125
+ r"""Returns the training table that was last generated by this
126
+ predictive query. If the predictive query has not yet generated a
127
+ training table, raises a :class:`ValueError`.
128
+
129
+ Note that the training table may be of type
130
+ :class:`~kumoai.pquery.TrainingTable` or
131
+ :class:`~kumoai.pquery.TrainingTableJob`, depending on whether the
132
+ training table was generated with or without waiting for its
133
+ completion, respectively.
134
+ """
135
+ if not self._train_table:
136
+ raise ValueError(
137
+ "This predictive query has not yet generated a training "
138
+ "table. Please call `generate_training_table` to generate "
139
+ "a training table before proceeding.")
140
+
141
+ return self._train_table
142
+
143
+ @property
144
+ def prediction_table(self) -> Union[PredictionTable, PredictionTableJob]:
145
+ r"""Returns the prediction table that was last generated by this
146
+ predictive query. If the predictive query has not yet generated a
147
+ prediction table, raises a :class:`ValueError`.
148
+
149
+ Note that the prediction table may be of type
150
+ :class:`~kumoai.pquery.PredictionTable` or
151
+ :class:`~kumoai.pquery.PredictionTableJob`, depending on whether the
152
+ prediction table was generated with or without waiting for its
153
+ completion, respectively.
154
+ """
155
+ if not self._prediction_table:
156
+ raise ValueError(
157
+ "This predictive query has not yet generated a prediction "
158
+ "table. Please call `generate_prediction_table` to generate a "
159
+ "prediction table before proceeding.")
160
+
161
+ return self._prediction_table
162
+
163
+ def get_task_type(self) -> TaskType:
164
+ r"""Returns the task type of this predictive query. The task type of
165
+ the query corresponds to the machine learning problem that this query
166
+ translates to in the Kumo platform; for more information about possible
167
+ task types, please visit the Kumo `documentation
168
+ <https://docs.kumo.ai/docs/task-types/>`__.
169
+ """
170
+ try:
171
+ self.validate(verbose=False)
172
+ except ValueError as e:
173
+ raise ValueError(
174
+ f"Predictive query {self.query} is improperly configured, so "
175
+ f"a task type cannot be obtained. Please ensure your query "
176
+ f"has a valid configuration before proceeding. You can use "
177
+ f"the `validate` method to verify the validity of your query."
178
+ ) from e
179
+ task_type, _ = global_state.client.pquery_api.infer_task_type(
180
+ pquery_string=self.query, graph_id=self.graph.id)
181
+ return task_type
182
+
183
+ def validate(self, verbose: bool = True) -> Self:
184
+ r"""Validates the syntax of this predictive query, ensuring that
185
+ the query is formulated correctly in Kumo's Predictive Query Language
186
+ and that the query makes semantic sense (defines a suitable predictive
187
+ problem) on this :class:`~kumoai.graph.Graph`.
188
+
189
+ Args:
190
+ verbose: Whether to log non-error output of this validation.
191
+
192
+ Raises:
193
+ ValueError:
194
+ if validation fails.
195
+
196
+ Example:
197
+ >>> import kumoai
198
+ >>> query = kumoai.PredictiveQuery(...) # doctest: +SKIP
199
+ >>> query.validate() # doctest: +SKIP
200
+ ValidationResponse(warnings=[], errors=[])
201
+ """
202
+ self.graph.save() # Need a valid graph ID; also validates graph.
203
+
204
+ resp = global_state.client.pquery_api.validate(
205
+ self._to_api_pquery_resource())
206
+
207
+ if not resp.ok:
208
+ raise ValueError(resp.error_message())
209
+ if verbose:
210
+ if resp.empty():
211
+ logger.info("Query %s is configured correctly.", self.query)
212
+ else:
213
+ logger.warning(resp.message())
214
+
215
+ return self
216
+
217
+ # Persistence #############################################################
218
+
219
+ def _to_api_pquery_resource(
220
+ self,
221
+ name: Optional[str] = None,
222
+ ) -> PQueryResource:
223
+ return PQueryResource(
224
+ name=name,
225
+ query_string=self.query,
226
+ graph=self.graph._to_api_graph_definition(),
227
+ desc="",
228
+ )
229
+
230
+ def save(self, name: Optional[str] = None) -> PredictiveQueryID:
231
+ r"""Saves a predictive query to Kumo, returning a unique ID for this
232
+ query. If a name is provided, saves it as a named, re-usable template.
233
+
234
+ Args:
235
+ name: Optional name for the template. If provided, saves the
236
+ query as a named template. If the name already exists,
237
+ that template will be overwritten.
238
+
239
+ Example:
240
+ >>> import kumoai
241
+ >>> query = kumoai.PredictiveQuery(...) # doctest: +SKIP
242
+ >>> query.save() # doctest: +SKIP
243
+ pquery-xxx
244
+ >>> query.save("my_template") # doctest: +SKIP
245
+ my_template
246
+ """
247
+ try:
248
+ self.validate(verbose=False)
249
+ except ValueError as e:
250
+ raise ValueError(
251
+ f"Predictive query {self.query} is improperly configured, so "
252
+ f"it cannot be saved. Please ensure your query "
253
+ f"has a valid configuration before proceeding. You can use "
254
+ f"the `validate` method to verify the validity of your query."
255
+ ) from e
256
+
257
+ if name is not None:
258
+ template_resource = global_state.client.pquery_api.get_if_exists(
259
+ name)
260
+ if template_resource is not None:
261
+ template_string = template_resource.query_string
262
+ logger.warning(
263
+ ("Predictive query template %s already exists, with "
264
+ "query string %s. This template will be overridden with "
265
+ "configuration %s."), name, template_string, self.query)
266
+
267
+ self.graph.save()
268
+ return global_state.client.pquery_api.create(
269
+ pquery=self._to_api_pquery_resource(name))
270
+
271
+ @classmethod
272
+ def load(cls, pq_id_or_template: str) -> 'PredictiveQuery':
273
+ r"""Loads a predictive query from either a predictive query ID or a
274
+ named template. Returns a :class:`~kumoai.pquery.PredictiveQuery`
275
+ object that contains the loaded query along with its associated graph,
276
+ tables, etc.
277
+ """
278
+ api = global_state.client.pquery_api
279
+ res = api.get_if_exists(pq_id_or_template)
280
+ if not res:
281
+ raise ValueError(
282
+ f"Predictive query {pq_id_or_template} was not found.")
283
+ return cls(
284
+ graph=Graph._from_api_graph_definition(res.graph),
285
+ query=res.query_string,
286
+ )
287
+
288
+ @classmethod
289
+ def load_from_training_job(cls, training_job_id: str) -> 'PredictiveQuery':
290
+ r"""Loads a predictive query from a training job, regardless of the
291
+ training job's status. Returns a
292
+ :class:`~kumoai.pquery.PredictiveQuery` object that contains the loaded
293
+ query along with its associated graph, tables, etc.
294
+ """
295
+ train_api: TrainingJobAPI = global_state.client.training_job_api
296
+ job = train_api.get(training_job_id)
297
+ id_or_name = job.config.pquery_id
298
+ return PredictiveQuery.load(pq_id_or_template=id_or_name)
299
+
300
+ # Training & Prediction Table Generation ##################################
301
+
302
+ @overload
303
+ def generate_training_table(
304
+ self,
305
+ plan: Optional[TrainingTableGenerationPlan] = None,
306
+ ) -> TrainingTable:
307
+ pass
308
+
309
+ @overload
310
+ def generate_training_table(
311
+ self,
312
+ plan: Optional[TrainingTableGenerationPlan] = None,
313
+ *,
314
+ non_blocking: Literal[False],
315
+ ) -> TrainingTable:
316
+ pass
317
+
318
+ @overload
319
+ def generate_training_table(
320
+ self,
321
+ plan: Optional[TrainingTableGenerationPlan] = None,
322
+ *,
323
+ non_blocking: Literal[True],
324
+ ) -> TrainingTableJob:
325
+ pass
326
+
327
+ @overload
328
+ def generate_training_table(
329
+ self,
330
+ plan: Optional[TrainingTableGenerationPlan] = None,
331
+ *,
332
+ non_blocking: bool,
333
+ ) -> Union[TrainingTable, TrainingTableJob]:
334
+ pass
335
+
336
+ def generate_training_table(
337
+ self,
338
+ plan: Optional[TrainingTableGenerationPlan] = None,
339
+ *,
340
+ non_blocking: bool = False,
341
+ custom_tags: Mapping[str, str] = {},
342
+ ) -> Union[TrainingTable, TrainingTableJob]:
343
+ r"""Generates a training table from the specified :attr:`query`
344
+ string.
345
+
346
+ Args:
347
+ plan: A specification of the parameters for training table
348
+ generation. If not provided, will use an intelligently
349
+ generated default plan based on the query and graph. This plan
350
+ is equivalent to the plan inferred with
351
+ ``suggest_training_table_plan(run_mode=RunMode.NORMAL)``.
352
+ non_blocking: Whether this operation should return immediately
353
+ after launching the training table generation job, or await
354
+ completion of the generated training table.
355
+ custom_tags: Additional, customer defined k-v tags to be associated
356
+ with the job to be launched. Job tags are useful for grouping
357
+ and searching jobs.
358
+
359
+ Returns:
360
+ Union[TrainingTable, TrainingTableJob]:
361
+ If ``non_blocking=False``, returns a training table object. If
362
+ ``non_blocking=True``, returns a training table future object.
363
+ """
364
+ pq_id = self.save()
365
+
366
+ # TODO(manan): improve this...
367
+ if not plan:
368
+ plan = self.suggest_training_table_plan()
369
+
370
+ train_table_job_api = global_state.client.generate_train_table_job_api
371
+ job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
+ GenerateTrainTableRequest(
373
+ dict(custom_tags), pq_id, plan,
374
+ graph_snapshot_id=self.graph.snapshot(
375
+ non_blocking=non_blocking)))
376
+
377
+ self._train_table = TrainingTableJob(job_id=job_id)
378
+ if non_blocking:
379
+ return self._train_table
380
+ self._train_table = self._train_table.attach()
381
+ return self._train_table
382
+
383
+ @overload
384
+ def generate_prediction_table(
385
+ self,
386
+ plan: Optional[PredictionTableGenerationPlan] = None,
387
+ ) -> PredictionTable:
388
+ pass
389
+
390
+ @overload
391
+ def generate_prediction_table(
392
+ self,
393
+ plan: Optional[PredictionTableGenerationPlan] = None,
394
+ *,
395
+ non_blocking: Literal[False],
396
+ ) -> PredictionTable:
397
+ pass
398
+
399
+ @overload
400
+ def generate_prediction_table(
401
+ self,
402
+ plan: Optional[PredictionTableGenerationPlan] = None,
403
+ *,
404
+ non_blocking: Literal[True],
405
+ ) -> PredictionTableJob:
406
+ pass
407
+
408
+ @overload
409
+ def generate_prediction_table(
410
+ self,
411
+ plan: Optional[PredictionTableGenerationPlan] = None,
412
+ *,
413
+ non_blocking: bool,
414
+ ) -> Union[PredictionTable, PredictionTableJob]:
415
+ pass
416
+
417
+ def generate_prediction_table(
418
+ self,
419
+ plan: Optional[PredictionTableGenerationPlan] = None,
420
+ *,
421
+ non_blocking: bool = False,
422
+ custom_tags: Mapping[str, str] = {},
423
+ ) -> Union[PredictionTable, PredictionTableJob]:
424
+ r"""Generates a prediction table from the predictive query
425
+ :attr:`query` string.
426
+
427
+ Args:
428
+ plan: A specification of the parameters for prediction table
429
+ generation. If not provided, will use an intelligently
430
+ generated default plan based on the query and graph. This plan
431
+ is equivalent to the plan inferred with
432
+ ``suggest_prediction_table_plan(run_mode=RunMode.NORMAL)``.
433
+ non_blocking: Whether this operation should return immediately
434
+ after launching the prediction table generation job, or await
435
+ completion of the generated prediction table.
436
+ custom_tags: Additional, customer defined k-v tags to be associated
437
+ with the job to be launched. Job tags are useful for grouping
438
+ and searching jobs.
439
+
440
+ Returns:
441
+ Union[PredictionTable, PredictionTableJob]:
442
+ If ``non_blocking=False``, returns a prediction table object.
443
+ If ``non_blocking=True``, returns a prediction table future
444
+ object.
445
+ """
446
+ pq_id = self.save()
447
+
448
+ if not plan:
449
+ plan = self.suggest_prediction_table_plan()
450
+
451
+ bp_table_api = global_state.client.generate_prediction_table_job_api
452
+ job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
+ GeneratePredictionTableRequest(
454
+ dict(custom_tags), pq_id, plan,
455
+ graph_snapshot_id=self.graph.snapshot(
456
+ non_blocking=non_blocking)))
457
+
458
+ self._prediction_table = PredictionTableJob(job_id=job_id)
459
+ if non_blocking:
460
+ return self._prediction_table
461
+ self._prediction_table = self._prediction_table.result()
462
+ return self._prediction_table
463
+
464
+ # Training & Prediction ###################################################
465
+
466
+ def suggest_training_table_plan(
467
+ self,
468
+ run_mode: RunMode = RunMode.FAST,
469
+ ) -> TrainingTableGenerationPlan:
470
+ r"""Suggests a training table generation plan given the predictive
471
+ query and graph. This training table generation plan can be used to
472
+ alter the approach Kumo uses to generate the training table for your
473
+ predictive query.
474
+
475
+ Args:
476
+ run_mode: A representation of how quickly you would like your
477
+ predictive query to complete. Faster run modes correspond to
478
+ lower training times, at the cost of potentially lower
479
+ performance.
480
+ """
481
+ self.graph.save()
482
+ req = SuggestModelPlanRequest(
483
+ query_string=self.query,
484
+ graph_id=self.graph.id,
485
+ run_mode=run_mode,
486
+ )
487
+ return global_state.client.pquery_api.suggest_training_table_plan(req)
488
+
489
+ def suggest_prediction_table_plan(self, ) -> PredictionTableGenerationPlan:
490
+ r"""Suggests a prediction table generation plan given the predictive
491
+ query and graph. This prediction table generation plan can be used to
492
+ alter the approach Kumo uses to generate the prediction table for your
493
+ predictive query.
494
+ """
495
+ return PredictionTableGenerationPlan(anchor_time=InferredType.VALUE)
496
+
497
+ def suggest_model_plan(
498
+ self,
499
+ run_mode: RunMode = RunMode.FAST,
500
+ train_table_spec: Optional[TrainingTableSpec] = None,
501
+ ) -> ModelPlan:
502
+ r"""Suggests a modeling plan given the predictive query and graph. This
503
+ model plan can be used to alter the approach Kumo uses to train your
504
+ machine learning model.
505
+
506
+ Args:
507
+ run_mode: A representation of how quickly you would like your
508
+ predictive query to complete. Faster run modes correspond to
509
+ lower training times, at the cost of potentially lower
510
+ performance.
511
+ train_table_spec: Needed if the original train table has been
512
+ modified by adding a weight column.
513
+ """
514
+ self.graph.save()
515
+ req = SuggestModelPlanRequest(
516
+ query_string=self.query,
517
+ graph_id=self.graph.id,
518
+ run_mode=run_mode,
519
+ train_table_spec=train_table_spec,
520
+ )
521
+ return global_state.client.pquery_api.suggest_model_plan(req)
522
+
523
+ @overload
524
+ def fit(
525
+ self,
526
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
527
+ model_plan: Optional[ModelPlan] = None,
528
+ ) -> Tuple[Trainer, TrainingJobResult]:
529
+ pass
530
+
531
+ @overload
532
+ def fit(
533
+ self,
534
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
535
+ model_plan: Optional[ModelPlan] = None,
536
+ *,
537
+ non_blocking: Literal[False],
538
+ ) -> Tuple[Trainer, TrainingJobResult]:
539
+ pass
540
+
541
+ @overload
542
+ def fit(
543
+ self,
544
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
545
+ model_plan: Optional[ModelPlan] = None,
546
+ *,
547
+ non_blocking: Literal[True],
548
+ ) -> Tuple[Trainer, TrainingJob]:
549
+ pass
550
+
551
+ @overload
552
+ def fit(
553
+ self,
554
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
555
+ model_plan: Optional[ModelPlan] = None,
556
+ *,
557
+ non_blocking: bool,
558
+ ) -> Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
559
+ pass
560
+
561
+ def fit(
562
+ self,
563
+ training_table_plan: Optional[TrainingTableGenerationPlan] = None,
564
+ model_plan: Optional[ModelPlan] = None,
565
+ *,
566
+ non_blocking: bool = False,
567
+ ) -> Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
568
+ r"""Trains a Kumo model on this predictive query, given optional
569
+ additional specifications of the training table generation plan and
570
+ the model plan.
571
+
572
+ Args:
573
+ training_table_plan: A specification of the parameters for training
574
+ table generation. If not provided, will use an intelligently
575
+ generated default plan based on the query and graph. This plan
576
+ is equivalent to the plan inferred with
577
+ ``suggest_training_table_plan(run_mode=RunMode.NORMAL)``.
578
+ model_plan: A specification of the parameters for model training.
579
+ If not provided, will use an intelligently generated default
580
+ plan based on the query and graph. This plan
581
+ is equivalent to the plan inferred with
582
+ ``suggest_model_plan(run_mode=RunMode.NORMAL)``.
583
+ non_blocking: Whether this operation should return immediately
584
+ after launching the training job, or await completion of the
585
+ training job.
586
+
587
+ Returns:
588
+ Tuple[Trainer, Union[TrainingJobResult, TrainingJob]]:
589
+ A tuple with two elements. The first element is the trainer
590
+ object used to launch the training job. The second element
591
+ is either a training job object (if ``non_blocking=True``)
592
+ or a training job future object (if ``non_blocking=False``).
593
+ """
594
+ # If we have already generated the training table, use it with Trainer:
595
+ if self._train_table is None:
596
+ # Nonblocking generate:
597
+ self._train_table = self.generate_training_table(
598
+ training_table_plan, non_blocking=True)
599
+
600
+ # TODO(manan): what if `self._train_table` represents a failed job?
601
+ model_plan = model_plan or self.suggest_model_plan()
602
+ trainer = Trainer(model_plan)
603
+ return (trainer,
604
+ trainer.fit(self.graph, self.train_table,
605
+ non_blocking=non_blocking))
606
+
607
+ def generate_baseline(
608
+ self,
609
+ metrics: List[str],
610
+ train_table: Union[TrainingTable, TrainingTableJob],
611
+ *,
612
+ non_blocking: bool = False,
613
+ ) -> Union[BaselineJob, BaselineJobResult]:
614
+ r"""Runs a baseline model on this predictive query, given metrics and
615
+ optional additional specifications of the training table generation
616
+ plan.
617
+
618
+ Args:
619
+ metrics (List[str]): A list to metrics that baseline model will be
620
+ evaluated on.
621
+ train_table (Union[TrainingTable, TrainingTableJob]): The
622
+ :class:`~kumoai.pquery.TrainingTable`, or in-progress
623
+ :class:`~kumoai.pquery.TrainingTableJob` that represents
624
+ the training data produced by a
625
+ :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
626
+ non_blocking (bool): Whether this operation should
627
+ return immediately after launching the baseline job, or await
628
+ completion of the baseline job. Defaults to False.
629
+
630
+ Returns:
631
+ Union[BaselineJob, BaselineJobResult]: either a baseline job
632
+ object (if ``non_blocking=True``) or a baseline job future
633
+ object (if ``non_blocking=False``).
634
+ """ # noqa
635
+ baseline_trainer = BaselineTrainer(metrics)
636
+ return baseline_trainer.run(self.graph, train_table,
637
+ non_blocking=non_blocking)