kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +138 -28
- kumoai/experimental/rfm/backend/snow/table.py +16 -13
- kumoai/experimental/rfm/backend/sqlite/sampler.py +73 -15
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +23 -1
- kumoai/experimental/rfm/base/sql_sampler.py +252 -11
- kumoai/experimental/rfm/base/table.py +15 -29
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +9 -9
- kumoai/experimental/rfm/infer/dtype.py +3 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/rfm.py +195 -114
- kumoai/experimental/rfm/task_table.py +2 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/utils/display.py +44 -8
- kumoai/utils/progress_logger.py +2 -1
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +25 -23
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -8,7 +8,6 @@ from contextlib import contextmanager
|
|
|
8
8
|
from dataclasses import dataclass, replace
|
|
9
9
|
from typing import Any, Literal, overload
|
|
10
10
|
|
|
11
|
-
import numpy as np
|
|
12
11
|
import pandas as pd
|
|
13
12
|
from kumoapi.model_plan import RunMode
|
|
14
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
@@ -28,7 +27,10 @@ from kumoapi.rfm import (
|
|
|
28
27
|
)
|
|
29
28
|
from kumoapi.task import TaskType
|
|
30
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
31
32
|
|
|
33
|
+
from kumoai import in_notebook
|
|
32
34
|
from kumoai.client.rfm import RFMAPI
|
|
33
35
|
from kumoai.exceptions import HTTPException
|
|
34
36
|
from kumoai.experimental.rfm import Graph, TaskTable
|
|
@@ -106,10 +108,20 @@ class Explanation:
|
|
|
106
108
|
def __repr__(self) -> str:
|
|
107
109
|
return str((self.prediction, self.summary))
|
|
108
110
|
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
109
118
|
def print(self) -> None:
|
|
110
119
|
r"""Prints the explanation."""
|
|
111
|
-
|
|
112
|
-
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
123
|
+
else:
|
|
124
|
+
print(self)
|
|
113
125
|
|
|
114
126
|
def _ipython_display_(self) -> None:
|
|
115
127
|
self.print()
|
|
@@ -180,7 +192,7 @@ class KumoRFM:
|
|
|
180
192
|
self._client: RFMAPI | None = None
|
|
181
193
|
|
|
182
194
|
self._batch_size: int | Literal['max'] | None = None
|
|
183
|
-
self.
|
|
195
|
+
self._num_retries: int = 0
|
|
184
196
|
|
|
185
197
|
@property
|
|
186
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -194,6 +206,30 @@ class KumoRFM:
|
|
|
194
206
|
def __repr__(self) -> str:
|
|
195
207
|
return f'{self.__class__.__name__}()'
|
|
196
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
197
233
|
@contextmanager
|
|
198
234
|
def batch_mode(
|
|
199
235
|
self,
|
|
@@ -217,15 +253,10 @@ class KumoRFM:
|
|
|
217
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
218
254
|
f"(got {batch_size})")
|
|
219
255
|
|
|
220
|
-
if num_retries < 0:
|
|
221
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
222
|
-
f"zero (got {num_retries})")
|
|
223
|
-
|
|
224
256
|
self._batch_size = batch_size
|
|
225
|
-
self.
|
|
226
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
227
259
|
self._batch_size = None
|
|
228
|
-
self.num_retries = 0
|
|
229
260
|
|
|
230
261
|
@overload
|
|
231
262
|
def predict(
|
|
@@ -265,6 +296,25 @@ class KumoRFM:
|
|
|
265
296
|
) -> Explanation:
|
|
266
297
|
pass
|
|
267
298
|
|
|
299
|
+
@overload
|
|
300
|
+
def predict(
|
|
301
|
+
self,
|
|
302
|
+
query: str,
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
304
|
+
*,
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
310
|
+
num_hops: int = 2,
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
314
|
+
use_prediction_time: bool = False,
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
268
318
|
def predict(
|
|
269
319
|
self,
|
|
270
320
|
query: str,
|
|
@@ -288,8 +338,7 @@ class KumoRFM:
|
|
|
288
338
|
indices: The entity primary keys to predict for. Will override the
|
|
289
339
|
indices given as part of the predictive query. Predictions will
|
|
290
340
|
be generated for all indices, independent of whether they
|
|
291
|
-
fulfill entity filter constraints.
|
|
292
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
293
342
|
explain: Configuration for explainability.
|
|
294
343
|
If set to ``True``, will additionally explain the prediction.
|
|
295
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -329,8 +378,11 @@ class KumoRFM:
|
|
|
329
378
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
330
379
|
"pass them via `predict(query, indices=...)`")
|
|
331
380
|
indices = query_def.get_rfm_entity_id_list()
|
|
332
|
-
|
|
333
|
-
query_def
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
334
386
|
|
|
335
387
|
if not isinstance(verbose, ProgressLogger):
|
|
336
388
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -351,11 +403,11 @@ class KumoRFM:
|
|
|
351
403
|
random_seed=random_seed,
|
|
352
404
|
logger=logger,
|
|
353
405
|
)
|
|
354
|
-
task_table._query = query_def.to_string()
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
355
407
|
|
|
356
408
|
return self.predict_task(
|
|
357
409
|
task_table,
|
|
358
|
-
explain=explain,
|
|
410
|
+
explain=explain,
|
|
359
411
|
run_mode=run_mode,
|
|
360
412
|
num_neighbors=num_neighbors,
|
|
361
413
|
num_hops=num_hops,
|
|
@@ -397,6 +449,22 @@ class KumoRFM:
|
|
|
397
449
|
) -> Explanation:
|
|
398
450
|
pass
|
|
399
451
|
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
400
468
|
def predict_task(
|
|
401
469
|
self,
|
|
402
470
|
task: TaskTable,
|
|
@@ -477,9 +545,9 @@ class KumoRFM:
|
|
|
477
545
|
task_type_repr = str(task.task_type)
|
|
478
546
|
|
|
479
547
|
if explain_config is not None:
|
|
480
|
-
msg = f
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
481
549
|
else:
|
|
482
|
-
msg = f
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
483
551
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
484
552
|
|
|
485
553
|
with verbose as logger:
|
|
@@ -525,7 +593,7 @@ class KumoRFM:
|
|
|
525
593
|
request = RFMPredictRequest(
|
|
526
594
|
context=context,
|
|
527
595
|
run_mode=RunMode(run_mode),
|
|
528
|
-
query=
|
|
596
|
+
query=task._query,
|
|
529
597
|
use_prediction_time=use_prediction_time,
|
|
530
598
|
)
|
|
531
599
|
with warnings.catch_warnings():
|
|
@@ -544,7 +612,7 @@ class KumoRFM:
|
|
|
544
612
|
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
545
613
|
verbose.init_progress(total=num, description='Predicting')
|
|
546
614
|
|
|
547
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
548
616
|
try:
|
|
549
617
|
if explain_config is not None:
|
|
550
618
|
resp = self._api_client.explain(
|
|
@@ -582,7 +650,7 @@ class KumoRFM:
|
|
|
582
650
|
|
|
583
651
|
break
|
|
584
652
|
except HTTPException as e:
|
|
585
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
586
654
|
try:
|
|
587
655
|
msg = json.loads(e.detail)['detail']
|
|
588
656
|
except Exception:
|
|
@@ -612,51 +680,6 @@ class KumoRFM:
|
|
|
612
680
|
|
|
613
681
|
return prediction
|
|
614
682
|
|
|
615
|
-
def is_valid_entity(
|
|
616
|
-
self,
|
|
617
|
-
query: str,
|
|
618
|
-
indices: list[str] | list[float] | list[int] | None = None,
|
|
619
|
-
*,
|
|
620
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
621
|
-
) -> np.ndarray:
|
|
622
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
623
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
624
|
-
entity filter constraints.
|
|
625
|
-
|
|
626
|
-
Args:
|
|
627
|
-
query: The predictive query.
|
|
628
|
-
indices: The entity primary keys to predict for. Will override the
|
|
629
|
-
indices given as part of the predictive query.
|
|
630
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
631
|
-
``None``, will use the maximum timestamp in the data.
|
|
632
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
633
|
-
"""
|
|
634
|
-
query_def = self._parse_query(query)
|
|
635
|
-
|
|
636
|
-
if indices is None:
|
|
637
|
-
if query_def.rfm_entity_ids is None:
|
|
638
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
639
|
-
"pass them via "
|
|
640
|
-
"`is_valid_entity(query, indices=...)`")
|
|
641
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
642
|
-
|
|
643
|
-
if len(indices) == 0:
|
|
644
|
-
raise ValueError("At least one entity is required")
|
|
645
|
-
|
|
646
|
-
if anchor_time is None:
|
|
647
|
-
anchor_time = self._get_default_anchor_time(query_def)
|
|
648
|
-
|
|
649
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
650
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
651
|
-
else:
|
|
652
|
-
assert anchor_time == 'entity'
|
|
653
|
-
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
654
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
655
|
-
f"table '{query_def.entity_table}' "
|
|
656
|
-
f"to have a time column.")
|
|
657
|
-
|
|
658
|
-
raise NotImplementedError
|
|
659
|
-
|
|
660
683
|
def evaluate(
|
|
661
684
|
self,
|
|
662
685
|
query: str,
|
|
@@ -701,29 +724,12 @@ class KumoRFM:
|
|
|
701
724
|
Returns:
|
|
702
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
703
726
|
"""
|
|
704
|
-
query_def =
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
rfm_entity_ids=None,
|
|
709
|
-
)
|
|
710
|
-
|
|
711
|
-
task_type = self._get_task_type(
|
|
712
|
-
query=query_def,
|
|
713
|
-
edge_types=self._sampler.edge_types,
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
714
731
|
)
|
|
715
732
|
|
|
716
|
-
if num_hops != 2 and num_neighbors is not None:
|
|
717
|
-
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
718
|
-
f"custom 'num_hops={num_hops}' option")
|
|
719
|
-
if num_neighbors is None:
|
|
720
|
-
key = RunMode.FAST if task_type.is_link_pred else run_mode
|
|
721
|
-
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
722
|
-
|
|
723
|
-
if metrics is not None and len(metrics) > 0:
|
|
724
|
-
self._validate_metrics(metrics, task_type)
|
|
725
|
-
metrics = list(dict.fromkeys(metrics))
|
|
726
|
-
|
|
727
733
|
if not isinstance(verbose, ProgressLogger):
|
|
728
734
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
729
735
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
@@ -740,11 +746,96 @@ class KumoRFM:
|
|
|
740
746
|
random_seed=random_seed,
|
|
741
747
|
logger=logger,
|
|
742
748
|
)
|
|
743
|
-
|
|
744
|
-
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
745
753
|
run_mode=run_mode,
|
|
746
754
|
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
747
757
|
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
792
|
+
if num_hops != 2 and num_neighbors is not None:
|
|
793
|
+
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
794
|
+
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
798
|
+
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
802
|
+
|
|
803
|
+
if not isinstance(verbose, ProgressLogger):
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
816
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
817
|
+
|
|
818
|
+
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
context = self._get_context(
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
837
|
+
num_neighbors=num_neighbors,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
748
839
|
)
|
|
749
840
|
|
|
750
841
|
request = RFMEvaluateRequest(
|
|
@@ -764,12 +855,12 @@ class KumoRFM:
|
|
|
764
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
765
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
766
857
|
|
|
767
|
-
for attempt in range(self.
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
768
859
|
try:
|
|
769
860
|
resp = self._api_client.evaluate(request_bytes)
|
|
770
861
|
break
|
|
771
862
|
except HTTPException as e:
|
|
772
|
-
if attempt == self.
|
|
863
|
+
if attempt == self._num_retries:
|
|
773
864
|
try:
|
|
774
865
|
msg = json.loads(e.detail)['detail']
|
|
775
866
|
except Exception:
|
|
@@ -865,12 +956,12 @@ class KumoRFM:
|
|
|
865
956
|
graph_definition=self._graph_def,
|
|
866
957
|
)
|
|
867
958
|
|
|
868
|
-
for attempt in range(self.
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
869
960
|
try:
|
|
870
961
|
resp = self._api_client.parse_query(request)
|
|
871
962
|
break
|
|
872
963
|
except HTTPException as e:
|
|
873
|
-
if attempt == self.
|
|
964
|
+
if attempt == self._num_retries:
|
|
874
965
|
try:
|
|
875
966
|
msg = json.loads(e.detail)['detail']
|
|
876
967
|
except Exception:
|
|
@@ -953,8 +1044,16 @@ class KumoRFM:
|
|
|
953
1044
|
if len(self._sampler.time_column_dict) == 0:
|
|
954
1045
|
return # Graph without timestamps
|
|
955
1046
|
|
|
956
|
-
|
|
957
|
-
|
|
1047
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1048
|
+
aggr_table_names = [
|
|
1049
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1050
|
+
for aggr in query.get_all_target_aggregations()
|
|
1051
|
+
]
|
|
1052
|
+
min_time = self._sampler.get_min_time(aggr_table_names)
|
|
1053
|
+
max_time = self._sampler.get_max_time(aggr_table_names)
|
|
1054
|
+
else:
|
|
1055
|
+
min_time = self._sampler.get_min_time()
|
|
1056
|
+
max_time = self._sampler.get_max_time()
|
|
958
1057
|
|
|
959
1058
|
if anchor_time < min_time:
|
|
960
1059
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
@@ -1193,24 +1292,6 @@ class KumoRFM:
|
|
|
1193
1292
|
top_k: int | None = None,
|
|
1194
1293
|
) -> Context:
|
|
1195
1294
|
|
|
1196
|
-
# TODO Remove all
|
|
1197
|
-
if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
|
|
1198
|
-
raise ValueError(f"Cannot process a context with more than "
|
|
1199
|
-
f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
|
|
1200
|
-
f"(got {task.num_context_examples:,})")
|
|
1201
|
-
if task.evaluate:
|
|
1202
|
-
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
1203
|
-
raise ValueError(f"Cannot process a test set with more than "
|
|
1204
|
-
f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
|
|
1205
|
-
f"for evaluation "
|
|
1206
|
-
f"(got {task.num_prediction_examples:,})")
|
|
1207
|
-
else:
|
|
1208
|
-
if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
|
|
1209
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1210
|
-
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
1211
|
-
f"entities at once "
|
|
1212
|
-
f"(got {task.num_prediction_examples:,})")
|
|
1213
|
-
|
|
1214
1295
|
if num_neighbors is None:
|
|
1215
1296
|
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1216
1297
|
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
kumoai/pquery/training_table.py
CHANGED
|
@@ -199,6 +199,7 @@ class TrainingTable:
|
|
|
199
199
|
self,
|
|
200
200
|
source_table_type: SourceTableType,
|
|
201
201
|
train_table_mod: TrainingTableSpec,
|
|
202
|
+
extensive_validation: bool = False,
|
|
202
203
|
) -> None:
|
|
203
204
|
r"""Validates the modified training table.
|
|
204
205
|
|
|
@@ -206,6 +207,8 @@ class TrainingTable:
|
|
|
206
207
|
source_table_type: The source table to be used as the modified
|
|
207
208
|
training table.
|
|
208
209
|
train_table_mod: The modification specification.
|
|
210
|
+
extensive_validation: Enable extensive validation for custom
|
|
211
|
+
table.
|
|
209
212
|
|
|
210
213
|
Raises:
|
|
211
214
|
ValueError: If the modified training table is invalid.
|
|
@@ -215,7 +218,8 @@ class TrainingTable:
|
|
|
215
218
|
global_state.client.generate_train_table_job_api)
|
|
216
219
|
response = api.validate_custom_train_table(self.job_id,
|
|
217
220
|
source_table_type,
|
|
218
|
-
train_table_mod
|
|
221
|
+
train_table_mod,
|
|
222
|
+
extensive_validation)
|
|
219
223
|
if not response.ok:
|
|
220
224
|
raise ValueError("Invalid weighted train table",
|
|
221
225
|
response.error_message)
|
|
@@ -225,6 +229,7 @@ class TrainingTable:
|
|
|
225
229
|
source_table: SourceTable,
|
|
226
230
|
train_table_mod: TrainingTableSpec,
|
|
227
231
|
validate: bool = True,
|
|
232
|
+
extensive_validation: bool = False,
|
|
228
233
|
) -> Self:
|
|
229
234
|
r"""Sets the `source_table` as the modified training table.
|
|
230
235
|
|
|
@@ -243,6 +248,9 @@ class TrainingTable:
|
|
|
243
248
|
train_table_mod: The modification specification.
|
|
244
249
|
validate: Whether to validate the modified training table. This can
|
|
245
250
|
be slow for large tables.
|
|
251
|
+
extensive_validation: Whether to validate number of rows in
|
|
252
|
+
existing and modified training table.
|
|
253
|
+
It can be slow for large tables.
|
|
246
254
|
"""
|
|
247
255
|
if isinstance(source_table.connector, S3Connector):
|
|
248
256
|
# Special handling for s3 as `source_table._to_api_source_table`
|
|
@@ -252,7 +260,13 @@ class TrainingTable:
|
|
|
252
260
|
else:
|
|
253
261
|
source_table_type = source_table._to_api_source_table()
|
|
254
262
|
if validate:
|
|
255
|
-
|
|
263
|
+
if extensive_validation:
|
|
264
|
+
logger.warning(
|
|
265
|
+
"You have opted in to perform extensive validation on"
|
|
266
|
+
" your custom training table."
|
|
267
|
+
" This operation can be slow for large tables.")
|
|
268
|
+
self.validate_custom_table(source_table_type, train_table_mod,
|
|
269
|
+
extensive_validation)
|
|
256
270
|
self._custom_train_table = CustomTrainingTable(
|
|
257
271
|
source_table=source_table_type, table_mod_spec=train_table_mod,
|
|
258
272
|
validated=validate)
|
kumoai/testing/snow.py
CHANGED
|
@@ -10,7 +10,7 @@ def connect(
|
|
|
10
10
|
id: str,
|
|
11
11
|
account: str,
|
|
12
12
|
user: str,
|
|
13
|
-
warehouse: str,
|
|
13
|
+
warehouse: str | None = None,
|
|
14
14
|
database: str | None = None,
|
|
15
15
|
schema: str | None = None,
|
|
16
16
|
) -> Connection:
|
|
@@ -42,8 +42,8 @@ def connect(
|
|
|
42
42
|
return _connect(
|
|
43
43
|
account=account,
|
|
44
44
|
user=user,
|
|
45
|
-
warehouse='WH_XS',
|
|
46
|
-
database='KUMO',
|
|
45
|
+
warehouse=warehouse or 'WH_XS',
|
|
46
|
+
database=database or 'KUMO',
|
|
47
47
|
schema=schema,
|
|
48
48
|
session_parameters=dict(CLIENT_TELEMETRY_ENABLED=False),
|
|
49
49
|
**kwargs,
|
kumoai/utils/display.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
|
|
3
3
|
import pandas as pd
|
|
4
|
+
from rich import box
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.table import Table
|
|
7
|
+
from rich.text import Text
|
|
4
8
|
|
|
5
9
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
6
10
|
|
|
7
11
|
|
|
8
12
|
def message(msg: str) -> None:
|
|
9
|
-
msg = msg.replace("`", "'") if not in_notebook() else msg
|
|
10
|
-
|
|
11
13
|
if in_snowflake_notebook():
|
|
12
14
|
import streamlit as st
|
|
13
15
|
st.markdown(msg)
|
|
@@ -15,23 +17,40 @@ def message(msg: str) -> None:
|
|
|
15
17
|
from IPython.display import Markdown, display
|
|
16
18
|
display(Markdown(msg))
|
|
17
19
|
else:
|
|
18
|
-
print(msg)
|
|
20
|
+
print(msg.replace("`", "'"))
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def title(msg: str) -> None:
|
|
22
|
-
|
|
24
|
+
if in_notebook():
|
|
25
|
+
message(f"### {msg}")
|
|
26
|
+
else:
|
|
27
|
+
msg = msg.replace("`", "'")
|
|
28
|
+
Console().print(f"[bold]{msg}[/bold]", highlight=False)
|
|
23
29
|
|
|
24
30
|
|
|
25
31
|
def italic(msg: str) -> None:
|
|
26
|
-
|
|
32
|
+
if in_notebook():
|
|
33
|
+
message(f"*{msg}*")
|
|
34
|
+
else:
|
|
35
|
+
msg = msg.replace("`", "'")
|
|
36
|
+
Console().print(
|
|
37
|
+
f"[italic]{msg}[/italic]",
|
|
38
|
+
highlight=False,
|
|
39
|
+
style='dim',
|
|
40
|
+
)
|
|
27
41
|
|
|
28
42
|
|
|
29
43
|
def unordered_list(items: Sequence[str]) -> None:
|
|
30
44
|
if in_notebook():
|
|
31
45
|
msg = '\n'.join([f"- {item}" for item in items])
|
|
46
|
+
message(msg)
|
|
32
47
|
else:
|
|
33
|
-
|
|
34
|
-
|
|
48
|
+
text = Text('\n').join(
|
|
49
|
+
Text.assemble(
|
|
50
|
+
Text(' • ', style='yellow'),
|
|
51
|
+
Text(item.replace('`', '')),
|
|
52
|
+
) for item in items)
|
|
53
|
+
Console().print(text, highlight=False)
|
|
35
54
|
|
|
36
55
|
|
|
37
56
|
def dataframe(df: pd.DataFrame) -> None:
|
|
@@ -48,4 +67,21 @@ def dataframe(df: pd.DataFrame) -> None:
|
|
|
48
67
|
except ImportError:
|
|
49
68
|
print(df.to_string(index=False)) # missing jinja2
|
|
50
69
|
else:
|
|
51
|
-
print(df
|
|
70
|
+
Console().print(to_rich_table(df))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def to_rich_table(df: pd.DataFrame) -> Table:
|
|
74
|
+
table = Table(box=box.ROUNDED)
|
|
75
|
+
for column in df.columns:
|
|
76
|
+
table.add_column(str(column))
|
|
77
|
+
for _, row in df.iterrows():
|
|
78
|
+
values: list[str | Text] = []
|
|
79
|
+
for value in row:
|
|
80
|
+
if str(value) == 'True':
|
|
81
|
+
values.append('✅')
|
|
82
|
+
elif str(value) in {'False', '-'}:
|
|
83
|
+
values.append(Text('-', style='dim'))
|
|
84
|
+
else:
|
|
85
|
+
values.append(str(value))
|
|
86
|
+
table.add_row(*values)
|
|
87
|
+
return table
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -57,7 +57,8 @@ class ProgressLogger:
|
|
|
57
57
|
|
|
58
58
|
def __enter__(self) -> Self:
|
|
59
59
|
self.depth += 1
|
|
60
|
-
self.
|
|
60
|
+
if self.depth == 1:
|
|
61
|
+
self.start_time = time.perf_counter()
|
|
61
62
|
return self
|
|
62
63
|
|
|
63
64
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
kumoai/utils/sql.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
def quote_ident(
|
|
1
|
+
def quote_ident(ident: str, char: str = '"') -> str:
|
|
2
2
|
r"""Quotes a SQL identifier."""
|
|
3
|
-
return
|
|
3
|
+
return char + ident.replace(char, char + char) + char
|
{kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.15.0.dev202601141731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api
|
|
26
|
+
Requires-Dist: kumo-api<1.0.0,>=0.53.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|