kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/experimental/rfm/__init__.py +22 -22
  6. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  7. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  8. kumoai/experimental/rfm/backend/local/table.py +25 -24
  9. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  10. kumoai/experimental/rfm/backend/snow/table.py +146 -51
  11. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  12. kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
  13. kumoai/experimental/rfm/base/__init__.py +6 -7
  14. kumoai/experimental/rfm/base/column.py +97 -5
  15. kumoai/experimental/rfm/base/expression.py +44 -0
  16. kumoai/experimental/rfm/base/sampler.py +5 -17
  17. kumoai/experimental/rfm/base/source.py +1 -1
  18. kumoai/experimental/rfm/base/sql_sampler.py +68 -9
  19. kumoai/experimental/rfm/base/table.py +284 -120
  20. kumoai/experimental/rfm/graph.py +139 -86
  21. kumoai/experimental/rfm/infer/__init__.py +6 -4
  22. kumoai/experimental/rfm/infer/dtype.py +6 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/infer/stype.py +35 -0
  25. kumoai/experimental/rfm/relbench.py +76 -0
  26. kumoai/experimental/rfm/rfm.py +4 -20
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +1 -1
  30. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +33 -30
  31. kumoai/experimental/rfm/base/column_expression.py +0 -16
  32. kumoai/experimental/rfm/base/sql_table.py +0 -113
  33. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  34. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  35. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,89 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Any
4
+ from typing import Any, Mapping, TypeAlias
3
5
 
4
6
  from kumoapi.typing import Dtype, Stype
7
+ from typing_extensions import Self
8
+
9
+ from kumoai.experimental.rfm.base import Expression
10
+ from kumoai.mixin import CastMixin
11
+
12
+
13
+ @dataclass(init=False)
14
+ class ColumnSpec(CastMixin):
15
+ r"""A column specification for adding a column to a table.
16
+
17
+ A column specification can either refer to a physical column present in
18
+ the data source, or be defined logically via an expression.
19
+
20
+ Args:
21
+ name: The name of the column.
22
+ expr: A column expression to define logical columns.
23
+ dtype: The data type of the column.
24
+ """
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ expr: Expression | Mapping[str, str] | str | None = None,
29
+ dtype: Dtype | str | None = None,
30
+ stype: Stype | str | None = None,
31
+ ) -> None:
32
+
33
+ self.name = name
34
+ self.expr = Expression.coerce(expr)
35
+ self.dtype = Dtype(dtype) if dtype is not None else None
36
+ self.stype = Stype(dtype) if stype is not None else None
37
+
38
+ @classmethod
39
+ def coerce(cls, spec: ColumnSpec | Mapping[str, Any] | str) -> Self:
40
+ r"""Coerces a column specification into a :class:`ColumnSpec`."""
41
+ if isinstance(spec, cls):
42
+ return spec
43
+ if isinstance(spec, str):
44
+ return cls(name=spec)
45
+ if isinstance(spec, Mapping):
46
+ try:
47
+ return cls(**spec)
48
+ except TypeError:
49
+ pass
50
+ raise TypeError(f"Unable to coerce 'ColumnSpec' from '{spec}'")
51
+
52
+ @property
53
+ def is_source(self) -> bool:
54
+ r"""Whether the column specification refers to a phyiscal column
55
+ present in the data source.
56
+ """
57
+ return self.expr is None
58
+
59
+
60
+ ColumnSpecType: TypeAlias = ColumnSpec | Mapping[str, Any] | str
5
61
 
6
62
 
7
63
  @dataclass(init=False, repr=False, eq=False)
8
64
  class Column:
65
+ r"""Column-level metadata information.
66
+
67
+ A column can either refer to a physical column present in the data source,
68
+ or be defined logically via an expression.
69
+
70
+ Args:
71
+ name: The name of the column.
72
+ expr: A column expression to define logical columns.
73
+ dtype: The data type of the column.
74
+ stype: The semantic type of the column.
75
+ """
9
76
  stype: Stype
10
77
 
11
- def __init__(self, name: str, stype: Stype, dtype: Dtype) -> None:
78
+ def __init__(
79
+ self,
80
+ name: str,
81
+ expr: Expression | None,
82
+ dtype: Dtype,
83
+ stype: Stype,
84
+ ) -> None:
12
85
  self._name = name
86
+ self._expr = expr
13
87
  self._dtype = Dtype(dtype)
14
88
 
15
89
  self._is_primary_key = False
@@ -20,12 +94,26 @@ class Column:
20
94
 
21
95
  @property
22
96
  def name(self) -> str:
97
+ r"""The name of the column."""
23
98
  return self._name
24
99
 
100
+ @property
101
+ def expr(self) -> Expression | None:
102
+ r"""The expression of column (if logically)."""
103
+ return self._expr
104
+
25
105
  @property
26
106
  def dtype(self) -> Dtype:
107
+ r"""The data type of the column."""
27
108
  return self._dtype
28
109
 
110
+ @property
111
+ def is_source(self) -> bool:
112
+ r"""Whether the column refers to a phyiscal column present in the data
113
+ source.
114
+ """
115
+ return self.expr is None
116
+
29
117
  def __setattr__(self, key: str, val: Any) -> None:
30
118
  if key == 'stype':
31
119
  if isinstance(val, str):
@@ -48,7 +136,7 @@ class Column:
48
136
  super().__setattr__(key, val)
49
137
 
50
138
  def __hash__(self) -> int:
51
- return hash((self.name, self.stype, self.dtype))
139
+ return hash((self.name, self.expr, self.dtype, self.stype))
52
140
 
53
141
  def __eq__(self, other: Any) -> bool:
54
142
  if not isinstance(other, Column):
@@ -56,5 +144,9 @@ class Column:
56
144
  return hash(self) == hash(other)
57
145
 
58
146
  def __repr__(self) -> str:
59
- return (f'{self.__class__.__name__}(name={self.name}, '
60
- f'stype={self.stype}, dtype={self.dtype})')
147
+ parts = [f'name={self.name}']
148
+ if self.expr is not None:
149
+ parts.append(f'expr={self.expr}')
150
+ parts.append(f'dtype={self.dtype}')
151
+ parts.append(f'stype={self.stype}')
152
+ return f"{self.__class__.__name__}({', '.join(parts)})"
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from dataclasses import dataclass
5
+ from typing import Mapping
6
+
7
+
8
+ class Expression(ABC):
9
+ """A base expression to define logical columns."""
10
+ @classmethod
11
+ def coerce(
12
+ cls,
13
+ spec: Expression | Mapping[str, str] | str | None,
14
+ ) -> Expression | None:
15
+ r"""Coerces an expression specification into an :class:`Expression`, if
16
+ possible.
17
+ """
18
+ if spec is None:
19
+ return None
20
+ if isinstance(spec, Expression):
21
+ return spec
22
+ if isinstance(spec, str):
23
+ return LocalExpression(spec)
24
+ if isinstance(spec, Mapping):
25
+ for sub_cls in (LocalExpression, ):
26
+ try:
27
+ return sub_cls(**spec)
28
+ except TypeError:
29
+ pass
30
+ raise TypeError(f"Unable to coerce 'Expression' from '{spec}'")
31
+
32
+
33
+ @dataclass(frozen=True, repr=False)
34
+ class LocalExpression(Expression):
35
+ r"""A local expression to define a row-level logical attribute based on
36
+ physical columns of the data source in the same row.
37
+
38
+ Args:
39
+ value: The value of the expression.
40
+ """
41
+ value: str
42
+
43
+ def __repr__(self) -> str:
44
+ return self.value
@@ -13,7 +13,6 @@ from kumoapi.pquery.AST import Aggregation, ASTNode
13
13
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
14
14
  from kumoapi.typing import Stype
15
15
 
16
- from kumoai.experimental.rfm.base import SourceColumn
17
16
  from kumoai.utils import ProgressLogger
18
17
 
19
18
  if TYPE_CHECKING:
@@ -53,6 +52,7 @@ class Sampler(ABC):
53
52
  graph: 'Graph',
54
53
  verbose: bool | ProgressLogger = True,
55
54
  ) -> None:
55
+
56
56
  self._edge_types: list[tuple[str, str, str]] = []
57
57
  for edge in graph.edges:
58
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
@@ -88,10 +88,6 @@ class Sampler(ABC):
88
88
  continue
89
89
  self._table_stype_dict[table.name][column.name] = column.stype
90
90
 
91
- self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
92
- for table in graph.tables.values():
93
- self._source_table_dict[table.name] = table._source_column_dict
94
-
95
91
  self._min_time_dict: dict[str, pd.Timestamp] = {}
96
92
  self._max_time_dict: dict[str, pd.Timestamp] = {}
97
93
 
@@ -119,16 +115,11 @@ class Sampler(ABC):
119
115
 
120
116
  @property
121
117
  def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
122
- r"""The registered semantic types for all columns in all tables in
123
- the graph.
118
+ r"""The registered semantic types for all feature columns in all tables
119
+ in the graph.
124
120
  """
125
121
  return self._table_stype_dict
126
122
 
127
- @property
128
- def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
129
- r"""Source column information for all tables in the graph."""
130
- return self._source_table_dict
131
-
132
123
  def get_min_time(
133
124
  self,
134
125
  table_names: list[str] | None = None,
@@ -189,7 +180,7 @@ class Sampler(ABC):
189
180
  exclude_cols_dict: The columns to exclude from the subgraph.
190
181
  """
191
182
  # Exclude all columns that leak target information:
192
- table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
183
+ table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
193
184
  if exclude_cols_dict is not None:
194
185
  table_stype_dict = copy.deepcopy(table_stype_dict)
195
186
  for table_name, exclude_cols in exclude_cols_dict.items():
@@ -237,11 +228,8 @@ class Sampler(ABC):
237
228
  # Set end time to NaT for all values greater than anchor time:
238
229
  assert table_name not in out.inverse_dict
239
230
  ser = df[end_time_column]
240
- if ser.dtype != 'datetime64[ns]':
241
- ser = ser.astype('datetime64[ns]')
242
231
  mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
243
- ser.iloc[mask] = pd.NaT
244
- df[end_time_column] = ser
232
+ df.loc[mask, end_time_column] = pd.NaT
245
233
 
246
234
  stype_dict = table_stype_dict[table_name]
247
235
  for column_name, stype in stype_dict.items():
@@ -6,7 +6,7 @@ from kumoapi.typing import Dtype
6
6
  @dataclass
7
7
  class SourceColumn:
8
8
  name: str
9
- dtype: Dtype
9
+ dtype: Dtype | None
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
12
  is_nullable: bool
@@ -3,9 +3,15 @@ from typing import TYPE_CHECKING, Literal
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
+ from kumoapi.typing import Dtype
6
7
 
7
- from kumoai.experimental.rfm.base import Sampler, SamplerOutput, SQLTable
8
- from kumoai.utils import ProgressLogger
8
+ from kumoai.experimental.rfm.base import (
9
+ LocalExpression,
10
+ Sampler,
11
+ SamplerOutput,
12
+ SourceColumn,
13
+ )
14
+ from kumoai.utils import ProgressLogger, quote_ident
9
15
 
10
16
  if TYPE_CHECKING:
11
17
  from kumoai.experimental.rfm import Graph
@@ -19,18 +25,71 @@ class SQLSampler(Sampler):
19
25
  ) -> None:
20
26
  super().__init__(graph=graph, verbose=verbose)
21
27
 
22
- self._fqn_dict: dict[str, str] = {}
28
+ self._source_name_dict: dict[str, str] = {
29
+ table.name: table._quoted_source_name
30
+ for table in graph.tables.values()
31
+ }
32
+
33
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
34
+ for table in graph.tables.values():
35
+ self._source_table_dict[table.name] = {}
36
+ for column in table.columns:
37
+ if not column.is_source:
38
+ continue
39
+ src_column = table._source_column_dict[column.name]
40
+ self._source_table_dict[table.name][column.name] = src_column
41
+
42
+ self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
43
+ for table in graph.tables.values():
44
+ self._table_dtype_dict[table.name] = {}
45
+ for column in table.columns:
46
+ self._table_dtype_dict[table.name][column.name] = column.dtype
47
+
48
+ self._table_column_ref_dict: dict[str, dict[str, str]] = {}
49
+ self._table_column_proj_dict: dict[str, dict[str, str]] = {}
23
50
  for table in graph.tables.values():
24
- assert isinstance(table, SQLTable)
25
- self._connection = table._connection
26
- self._fqn_dict[table.name] = table.fqn
51
+ column_ref_dict: dict[str, str] = {}
52
+ column_proj_dict: dict[str, str] = {}
53
+ for column in table.columns:
54
+ if column.expr is not None:
55
+ assert isinstance(column.expr, LocalExpression)
56
+ column_ref_dict[column.name] = column.expr.value
57
+ column_proj_dict[column.name] = (
58
+ f'{column.expr} AS {quote_ident(column.name)}')
59
+ else:
60
+ column_ref_dict[column.name] = quote_ident(column.name)
61
+ column_proj_dict[column.name] = quote_ident(column.name)
62
+ self._table_column_ref_dict[table.name] = column_ref_dict
63
+ self._table_column_proj_dict[table.name] = column_proj_dict
64
+
65
+ @property
66
+ def source_name_dict(self) -> dict[str, str]:
67
+ r"""The source table names for all tables in the graph."""
68
+ return self._source_name_dict
69
+
70
+ @property
71
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
72
+ r"""The source column information for all tables in the graph."""
73
+ return self._source_table_dict
74
+
75
+ @property
76
+ def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
77
+ r"""The data types for all columns in all tables in the graph."""
78
+ return self._table_dtype_dict
79
+
80
+ @property
81
+ def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
82
+ r"""The SQL reference expression for all columns in all tables in the
83
+ graph.
84
+ """
85
+ return self._table_column_ref_dict
27
86
 
28
87
  @property
29
- def fqn_dict(self) -> dict[str, str]:
30
- r"""The fully-qualified quoted source name for all table names in the
88
+ def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
89
+ r"""The SQL projection expressions for all columns in all tables in the
31
90
  graph.
32
91
  """
33
- return self._fqn_dict
92
+ return self._table_column_proj_dict
34
93
 
35
94
  def _sample_subgraph(
36
95
  self,