kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202601041732__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +25 -24
  10. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -51
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  13. kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
  14. kumoai/experimental/rfm/base/__init__.py +6 -7
  15. kumoai/experimental/rfm/base/column.py +97 -5
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +5 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +68 -9
  20. kumoai/experimental/rfm/base/table.py +291 -126
  21. kumoai/experimental/rfm/graph.py +139 -86
  22. kumoai/experimental/rfm/infer/__init__.py +6 -4
  23. kumoai/experimental/rfm/infer/dtype.py +6 -1
  24. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  25. kumoai/experimental/rfm/infer/stype.py +35 -0
  26. kumoai/experimental/rfm/relbench.py +76 -0
  27. kumoai/experimental/rfm/rfm.py +30 -42
  28. kumoai/experimental/rfm/task_table.py +247 -0
  29. kumoai/trainer/distilled_trainer.py +175 -0
  30. kumoai/utils/display.py +51 -0
  31. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/METADATA +1 -1
  32. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/RECORD +35 -31
  33. kumoai/experimental/rfm/base/column_expression.py +0 -16
  34. kumoai/experimental/rfm/base/sql_table.py +0 -113
  35. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/WHEEL +0 -0
  36. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/licenses/LICENSE +0 -0
  37. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202601041732.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import copy
3
5
  import io
@@ -16,9 +18,10 @@ from kumoapi.typing import Stype
16
18
  from typing_extensions import Self
17
19
 
18
20
  from kumoai import in_notebook, in_snowflake_notebook
19
- from kumoai.experimental.rfm.base import DataBackend, SQLTable, Table
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
20
22
  from kumoai.graph import Edge
21
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
22
25
 
23
26
  if TYPE_CHECKING:
24
27
  import graphviz
@@ -98,24 +101,25 @@ class Graph:
98
101
  for table in tables:
99
102
  self.add_table(table)
100
103
 
101
- for table in tables:
102
- if not isinstance(table, SQLTable):
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
103
106
  continue
104
107
  for fkey in table._source_foreign_key_dict.values():
105
108
  if fkey.name not in table:
106
109
  continue
107
- # TODO Skip for non-physical table[fkey.name].
110
+ if not table[fkey.name].is_source:
111
+ continue
108
112
  dst_table_names = [
109
113
  table.name for table in self.tables.values()
110
- if isinstance(table, SQLTable)
111
- and table._source_name == fkey.dst_table
114
+ if table.source_name == fkey.dst_table
112
115
  ]
113
116
  if len(dst_table_names) != 1:
114
117
  continue
115
118
  dst_table = self[dst_table_names[0]]
116
119
  if dst_table._primary_key != fkey.primary_key:
117
120
  continue
118
- # TODO Skip for non-physical dst_table.primary_key.
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
119
123
  self.link(table.name, fkey.name, dst_table.name)
120
124
 
121
125
  for edge in (edges or []):
@@ -418,6 +422,7 @@ class Graph:
418
422
  graph = cls(tables=[])
419
423
 
420
424
  msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
421
426
  for table_cfg in cfg['tables']:
422
427
  table_name = table_cfg['name']
423
428
  source_table_name = table_cfg['base_table']['table']
@@ -434,14 +439,47 @@ class Graph:
434
439
  f"'{table_name}' since composite primary keys "
435
440
  f"are not yet supported")
436
441
 
437
- columns: list[str] = []
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
438
444
  for column_cfg in chain(
439
445
  table_cfg.get('dimensions', []),
440
446
  table_cfg.get('time_dimensions', []),
441
447
  table_cfg.get('facts', []),
442
448
  ):
443
- # TODO Add support for derived columns.
444
- columns.append(column_cfg['name'])
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
445
483
 
446
484
  table = SnowTable(
447
485
  connection,
@@ -501,6 +539,35 @@ class Graph:
501
539
 
502
540
  return graph
503
541
 
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
504
571
  # Backend #################################################################
505
572
 
506
573
  @property
@@ -612,24 +679,8 @@ class Graph:
612
679
 
613
680
  def print_metadata(self) -> None:
614
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
615
- if in_snowflake_notebook():
616
- import streamlit as st
617
- st.markdown("### 🗂️ Graph Metadata")
618
- st.dataframe(self.metadata, hide_index=True)
619
- elif in_notebook():
620
- from IPython.display import Markdown, display
621
- display(Markdown("### 🗂️ Graph Metadata"))
622
- df = self.metadata
623
- try:
624
- if hasattr(df.style, 'hide'):
625
- display(df.style.hide(axis='index')) # pandas=2
626
- else:
627
- display(df.style.hide_index()) # pandas<1.3
628
- except ImportError:
629
- print(df.to_string(index=False)) # missing jinja2
630
- else:
631
- print("🗂️ Graph Metadata:")
632
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
633
684
 
634
685
  def infer_metadata(self, verbose: bool = True) -> Self:
635
686
  r"""Infers metadata for all tables in the graph.
@@ -658,40 +709,21 @@ class Graph:
658
709
 
659
710
  def print_links(self) -> None:
660
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
661
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
662
- edge.src_table, edge.fkey) for edge in self.edges]
663
- edges = sorted(edges)
664
-
665
- if in_snowflake_notebook():
666
- import streamlit as st
667
- st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
668
- if len(edges) > 0:
669
- st.markdown('\n'.join([
670
- f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
671
- for edge in edges
672
- ]))
673
- else:
674
- st.markdown("*No links registered*")
675
- elif in_notebook():
676
- from IPython.display import Markdown, display
677
- display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
678
- if len(edges) > 0:
679
- display(
680
- Markdown('\n'.join([
681
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
682
- for edge in edges
683
- ])))
684
- else:
685
- display(Markdown("*No links registered*"))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
686
725
  else:
687
- print("🕸️ Graph Links (FK ↔️ PK):")
688
- if len(edges) > 0:
689
- print('\n'.join([
690
- f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
691
- for edge in edges
692
- ]))
693
- else:
694
- print("No links registered")
726
+ display.italic("No links registered")
695
727
 
696
728
  def link(
697
729
  self,
@@ -798,6 +830,30 @@ class Graph:
798
830
  """
799
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
800
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
801
857
  # A list of primary key candidates (+score) for every column:
802
858
  candidate_dict: dict[
803
859
  tuple[str, str],
@@ -897,13 +953,8 @@ class Graph:
897
953
  if score < 5.0:
898
954
  continue
899
955
 
900
- candidate_dict[(
901
- src_table.name,
902
- src_key.name,
903
- )].append((
904
- dst_table.name,
905
- score,
906
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
907
958
 
908
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
909
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -962,24 +1013,26 @@ class Graph:
962
1013
  f"either the primary key or the link before "
963
1014
  f"before proceeding.")
964
1015
 
965
- # Check that fkey/pkey have valid and consistent data types:
966
- assert src_key.dtype is not None
967
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
968
- src_string = src_key.dtype.is_string()
969
- assert dst_key.dtype is not None
970
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
971
- dst_string = dst_key.dtype.is_string()
972
-
973
- if not src_number and not src_string:
974
- raise ValueError(f"{edge} is invalid as foreign key must be a "
975
- f"number or string (got '{src_key.dtype}'")
976
-
977
- if src_number != dst_number or src_string != dst_string:
978
- raise ValueError(f"{edge} is invalid as foreign key "
979
- f"'{fkey}' and primary key '{dst_key.name}' "
980
- f"have incompatible data types (got "
981
- f"fkey.dtype '{src_key.dtype}' and "
982
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
983
1036
 
984
1037
  return self
985
1038
 
@@ -1,17 +1,19 @@
1
1
  from .dtype import infer_dtype
2
- from .pkey import infer_primary_key
3
- from .time_col import infer_time_column
4
2
  from .id import contains_id
5
3
  from .timestamp import contains_timestamp
6
4
  from .categorical import contains_categorical
7
5
  from .multicategorical import contains_multicategorical
6
+ from .stype import infer_stype
7
+ from .pkey import infer_primary_key
8
+ from .time_col import infer_time_column
8
9
 
9
10
  __all__ = [
10
11
  'infer_dtype',
11
- 'infer_primary_key',
12
- 'infer_time_column',
13
12
  'contains_id',
14
13
  'contains_timestamp',
15
14
  'contains_categorical',
16
15
  'contains_multicategorical',
16
+ 'infer_stype',
17
+ 'infer_primary_key',
18
+ 'infer_time_column',
17
19
  ]
@@ -10,6 +10,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
10
10
  'int16': Dtype.int,
11
11
  'int32': Dtype.int,
12
12
  'int64': Dtype.int,
13
+ 'float': Dtype.float,
14
+ 'double': Dtype.float,
13
15
  'float16': Dtype.float,
14
16
  'float32': Dtype.float,
15
17
  'float64': Dtype.float,
@@ -18,6 +20,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
18
20
  'string[python]': Dtype.string,
19
21
  'string[pyarrow]': Dtype.string,
20
22
  'binary': Dtype.binary,
23
+ 'binary[python]': Dtype.binary,
24
+ 'binary[pyarrow]': Dtype.binary,
21
25
  }
22
26
 
23
27
 
@@ -48,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
48
52
  ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
49
53
 
50
54
  if isinstance(ser.dtype, pd.ArrowDtype):
51
- if pa.types.is_list(ser.dtype.pyarrow_dtype):
55
+ if (pa.types.is_list(ser.dtype.pyarrow_dtype)
56
+ or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
52
57
  elem_dtype = ser.dtype.pyarrow_dtype.value_type
53
58
  if pa.types.is_integer(elem_dtype):
54
59
  return Dtype.intlist
@@ -40,7 +40,7 @@ def contains_multicategorical(
40
40
  sep = max(candidates, key=candidates.get) # type: ignore
41
41
  ser = ser.str.split(sep)
42
42
 
43
- num_unique_multi = ser.explode().nunique()
43
+ num_unique_multi = ser.astype('object').explode().nunique()
44
44
 
45
45
  if dtype.is_list():
46
46
  return num_unique_multi <= MAX_CAT
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from kumoapi.typing import Dtype, Stype
3
+
4
+ from kumoai.experimental.rfm.infer import (
5
+ contains_categorical,
6
+ contains_id,
7
+ contains_multicategorical,
8
+ contains_timestamp,
9
+ )
10
+
11
+
12
+ def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
13
+ """Infers the :class:`Stype` from a :class:`pandas.Series`.
14
+
15
+ Args:
16
+ ser: A :class:`pandas.Series` to analyze.
17
+ column_name: The column name.
18
+ dtype: The data type.
19
+
20
+ Returns:
21
+ The semantic type.
22
+ """
23
+ if contains_id(ser, column_name, dtype):
24
+ return Stype.ID
25
+
26
+ if contains_timestamp(ser, column_name, dtype):
27
+ return Stype.timestamp
28
+
29
+ if contains_multicategorical(ser, column_name, dtype):
30
+ return Stype.multicategorical
31
+
32
+ if contains_categorical(ser, column_name, dtype):
33
+ return Stype.categorical
34
+
35
+ return dtype.default_stype
@@ -0,0 +1,76 @@
1
+ import difflib
2
+ import json
3
+ from functools import lru_cache
4
+ from urllib.request import urlopen
5
+
6
+ import pooch
7
+ import pyarrow as pa
8
+
9
+ from kumoai.experimental.rfm import Graph
10
+ from kumoai.experimental.rfm.backend.local import LocalTable
11
+
12
+ PREFIX = 'rel-'
13
+ CACHE_DIR = pooch.os_cache('relbench')
14
+ HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
15
+ 'relbench/datasets/hashes.json')
16
+
17
+
18
+ @lru_cache
19
+ def get_registry() -> pooch.Pooch:
20
+ with urlopen(HASH_URL) as r:
21
+ hashes = json.load(r)
22
+
23
+ return pooch.create(
24
+ path=CACHE_DIR,
25
+ base_url='https://relbench.stanford.edu/download/',
26
+ registry=hashes,
27
+ )
28
+
29
+
30
+ def from_relbench(dataset: str, verbose: bool = True) -> Graph:
31
+ dataset = dataset.lower()
32
+ if dataset.startswith(PREFIX):
33
+ dataset = dataset[len(PREFIX):]
34
+
35
+ registry = get_registry()
36
+
37
+ datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
38
+ if dataset not in datasets:
39
+ matches = difflib.get_close_matches(dataset, datasets, n=1)
40
+ hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
41
+ raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
42
+ f"datasets are {str(datasets)[1:-1]}.")
43
+
44
+ registry.fetch(
45
+ f'{PREFIX}{dataset}/db.zip',
46
+ processor=pooch.Unzip(extract_dir='.'),
47
+ progressbar=verbose,
48
+ )
49
+
50
+ graph = Graph(tables=[])
51
+ edges: list[tuple[str, str, str]] = []
52
+ for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
53
+ data = pa.parquet.read_table(path)
54
+ metadata = {
55
+ key.decode('utf-8'): json.loads(value.decode('utf-8'))
56
+ for key, value in data.schema.metadata.items()
57
+ if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
58
+ }
59
+
60
+ table = LocalTable(
61
+ df=data.to_pandas(),
62
+ name=path.stem,
63
+ primary_key=metadata['pkey_col'],
64
+ time_column=metadata['time_col'],
65
+ )
66
+ graph.add_table(table)
67
+
68
+ edges.extend([
69
+ (path.stem, fkey, dst_table)
70
+ for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
71
+ ])
72
+
73
+ for edge in edges:
74
+ graph.link(*edge)
75
+
76
+ return graph
@@ -28,13 +28,12 @@ from kumoapi.rfm import (
28
28
  from kumoapi.task import TaskType
29
29
  from kumoapi.typing import AggregationType, Stype
30
30
 
31
- from kumoai import in_notebook, in_snowflake_notebook
32
31
  from kumoai.client.rfm import RFMAPI
33
32
  from kumoai.exceptions import HTTPException
34
33
  from kumoai.experimental.rfm import Graph
35
34
  from kumoai.experimental.rfm.base import DataBackend, Sampler
36
35
  from kumoai.mixin import CastMixin
37
- from kumoai.utils import ProgressLogger
36
+ from kumoai.utils import ProgressLogger, display
38
37
 
39
38
  _RANDOM_SEED = 42
40
39
 
@@ -104,23 +103,8 @@ class Explanation:
104
103
 
105
104
  def print(self) -> None:
106
105
  r"""Prints the explanation."""
107
- if in_snowflake_notebook():
108
- import streamlit as st
109
- st.dataframe(self.prediction, hide_index=True)
110
- st.markdown(self.summary)
111
- elif in_notebook():
112
- from IPython.display import Markdown, display
113
- try:
114
- if hasattr(self.prediction.style, 'hide'):
115
- display(self.prediction.hide(axis='index')) # pandas=2
116
- else:
117
- display(self.prediction.hide_index()) # pandas <1.3
118
- except ImportError:
119
- print(self.prediction.to_string(index=False)) # missing jinja2
120
- display(Markdown(self.summary))
121
- else:
122
- print(self.prediction.to_string(index=False))
123
- print(self.summary)
106
+ display.dataframe(self.prediction)
107
+ display.message(self.summary)
124
108
 
125
109
  def _ipython_display_(self) -> None:
126
110
  self.print()
@@ -714,7 +698,7 @@ class KumoRFM:
714
698
  f"to have a time column")
715
699
 
716
700
  train, test = self._sampler.sample_target(
717
- query=query,
701
+ query=query_def,
718
702
  num_train_examples=0,
719
703
  train_anchor_time=anchor_time,
720
704
  num_train_trials=0,
@@ -742,30 +726,34 @@ class KumoRFM:
742
726
  "`predict()` or `evaluate()` methods to perform "
743
727
  "predictions or evaluations.")
744
728
 
745
- try:
746
- request = RFMParseQueryRequest(
747
- query=query,
748
- graph_definition=self._graph_def,
749
- )
750
-
751
- resp = self._api_client.parse_query(request)
752
-
753
- if len(resp.validation_response.warnings) > 0:
754
- msg = '\n'.join([
755
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
756
- in enumerate(resp.validation_response.warnings)
757
- ])
758
- warnings.warn(f"Encountered the following warnings during "
759
- f"parsing:\n{msg}")
729
+ request = RFMParseQueryRequest(
730
+ query=query,
731
+ graph_definition=self._graph_def,
732
+ )
760
733
 
761
- return resp.query
762
- except HTTPException as e:
734
+ for attempt in range(self.num_retries + 1):
763
735
  try:
764
- msg = json.loads(e.detail)['detail']
765
- except Exception:
766
- msg = e.detail
767
- raise ValueError(f"Failed to parse query '{query}'. "
768
- f"{msg}") from None
736
+ resp = self._api_client.parse_query(request)
737
+ break
738
+ except HTTPException as e:
739
+ if attempt == self.num_retries:
740
+ try:
741
+ msg = json.loads(e.detail)['detail']
742
+ except Exception:
743
+ msg = e.detail
744
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
745
+
746
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
747
+
748
+ if len(resp.validation_response.warnings) > 0:
749
+ msg = '\n'.join([
750
+ f'{i+1}. {warning.title}: {warning.message}'
751
+ for i, warning in enumerate(resp.validation_response.warnings)
752
+ ])
753
+ warnings.warn(f"Encountered the following warnings during "
754
+ f"parsing:\n{msg}")
755
+
756
+ return resp.query
769
757
 
770
758
  @staticmethod
771
759
  def _get_task_type(