kumoai 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +190 -71
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +192 -87
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +27 -0
  23. kumoai/experimental/rfm/graph.py +115 -107
  24. kumoai/experimental/rfm/infer/dtype.py +4 -1
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/pquery/training_table.py +16 -2
  31. kumoai/testing/snow.py +3 -3
  32. kumoai/trainer/distilled_trainer.py +175 -0
  33. kumoai/utils/display.py +87 -0
  34. kumoai/utils/progress_logger.py +13 -1
  35. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +2 -2
  36. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +39 -34
  37. kumoai/experimental/rfm/base/column_expression.py +0 -50
  38. kumoai/experimental/rfm/base/sql_table.py +0 -229
  39. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  40. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,9 @@ class DataBackend(StrEnum):
8
8
 
9
9
 
10
10
  from .source import SourceColumn, SourceForeignKey # noqa: E402
11
- from .column import Column # noqa: E402
12
- from .column_expression import ColumnExpressionSpec # noqa: E402
13
- from .column_expression import ColumnExpressionType # noqa: E402
14
- from .column_expression import ColumnExpression # noqa: E402
11
+ from .expression import Expression, LocalExpression # noqa: E402
12
+ from .column import ColumnSpec, ColumnSpecType, Column # noqa: E402
15
13
  from .table import Table # noqa: E402
16
- from .sql_table import SQLTable # noqa: E402
17
14
  from .sampler import SamplerOutput, Sampler # noqa: E402
18
15
  from .sql_sampler import SQLSampler # noqa: E402
19
16
 
@@ -21,12 +18,12 @@ __all__ = [
21
18
  'DataBackend',
22
19
  'SourceColumn',
23
20
  'SourceForeignKey',
21
+ 'Expression',
22
+ 'LocalExpression',
23
+ 'ColumnSpec',
24
+ 'ColumnSpecType',
24
25
  'Column',
25
- 'ColumnExpressionSpec',
26
- 'ColumnExpressionType',
27
- 'ColumnExpression',
28
26
  'Table',
29
- 'SQLTable',
30
27
  'SamplerOutput',
31
28
  'Sampler',
32
29
  'SQLSampler',
@@ -1,15 +1,89 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Any
4
+ from typing import Any, Mapping, TypeAlias
3
5
 
4
6
  from kumoapi.typing import Dtype, Stype
7
+ from typing_extensions import Self
8
+
9
+ from kumoai.experimental.rfm.base import Expression
10
+ from kumoai.mixin import CastMixin
11
+
12
+
13
+ @dataclass(init=False)
14
+ class ColumnSpec(CastMixin):
15
+ r"""A column specification for adding a column to a table.
16
+
17
+ A column specification can either refer to a physical column present in
18
+ the data source, or be defined logically via an expression.
19
+
20
+ Args:
21
+ name: The name of the column.
22
+ expr: A column expression to define logical columns.
23
+ dtype: The data type of the column.
24
+ """
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ expr: Expression | Mapping[str, str] | str | None = None,
29
+ dtype: Dtype | str | None = None,
30
+ stype: Stype | str | None = None,
31
+ ) -> None:
32
+
33
+ self.name = name
34
+ self.expr = Expression.coerce(expr)
35
+ self.dtype = Dtype(dtype) if dtype is not None else None
36
+ self.stype = Stype(dtype) if stype is not None else None
37
+
38
+ @classmethod
39
+ def coerce(cls, spec: ColumnSpec | Mapping[str, Any] | str) -> Self:
40
+ r"""Coerces a column specification into a :class:`ColumnSpec`."""
41
+ if isinstance(spec, cls):
42
+ return spec
43
+ if isinstance(spec, str):
44
+ return cls(name=spec)
45
+ if isinstance(spec, Mapping):
46
+ try:
47
+ return cls(**spec)
48
+ except TypeError:
49
+ pass
50
+ raise TypeError(f"Unable to coerce 'ColumnSpec' from '{spec}'")
51
+
52
+ @property
53
+ def is_source(self) -> bool:
54
+ r"""Whether the column specification refers to a phyiscal column
55
+ present in the data source.
56
+ """
57
+ return self.expr is None
58
+
59
+
60
+ ColumnSpecType: TypeAlias = ColumnSpec | Mapping[str, Any] | str
5
61
 
6
62
 
7
63
  @dataclass(init=False, repr=False, eq=False)
8
64
  class Column:
65
+ r"""Column-level metadata information.
66
+
67
+ A column can either refer to a physical column present in the data source,
68
+ or be defined logically via an expression.
69
+
70
+ Args:
71
+ name: The name of the column.
72
+ expr: A column expression to define logical columns.
73
+ dtype: The data type of the column.
74
+ stype: The semantic type of the column.
75
+ """
9
76
  stype: Stype
10
77
 
11
- def __init__(self, name: str, stype: Stype, dtype: Dtype) -> None:
78
+ def __init__(
79
+ self,
80
+ name: str,
81
+ expr: Expression | None,
82
+ dtype: Dtype,
83
+ stype: Stype,
84
+ ) -> None:
12
85
  self._name = name
86
+ self._expr = expr
13
87
  self._dtype = Dtype(dtype)
14
88
 
15
89
  self._is_primary_key = False
@@ -20,19 +94,25 @@ class Column:
20
94
 
21
95
  @property
22
96
  def name(self) -> str:
97
+ r"""The name of the column."""
23
98
  return self._name
24
99
 
25
100
  @property
26
- def dtype(self) -> Dtype:
27
- return self._dtype
101
+ def expr(self) -> Expression | None:
102
+ r"""The expression of column (if logically)."""
103
+ return self._expr
28
104
 
29
105
  @property
30
- def is_physical(self) -> bool:
31
- return True
106
+ def dtype(self) -> Dtype:
107
+ r"""The data type of the column."""
108
+ return self._dtype
32
109
 
33
110
  @property
34
- def is_logical(self) -> bool:
35
- return not self.is_physical
111
+ def is_source(self) -> bool:
112
+ r"""Whether the column refers to a phyiscal column present in the data
113
+ source.
114
+ """
115
+ return self.expr is None
36
116
 
37
117
  def __setattr__(self, key: str, val: Any) -> None:
38
118
  if key == 'stype':
@@ -56,7 +136,7 @@ class Column:
56
136
  super().__setattr__(key, val)
57
137
 
58
138
  def __hash__(self) -> int:
59
- return hash((self.name, self.stype, self.dtype))
139
+ return hash((self.name, self.expr, self.dtype, self.stype))
60
140
 
61
141
  def __eq__(self, other: Any) -> bool:
62
142
  if not isinstance(other, Column):
@@ -64,5 +144,9 @@ class Column:
64
144
  return hash(self) == hash(other)
65
145
 
66
146
  def __repr__(self) -> str:
67
- return (f'{self.__class__.__name__}(name={self.name}, '
68
- f'stype={self.stype}, dtype={self.dtype})')
147
+ parts = [f'name={self.name}']
148
+ if self.expr is not None:
149
+ parts.append(f'expr={self.expr}')
150
+ parts.append(f'dtype={self.dtype}')
151
+ parts.append(f'stype={self.stype}')
152
+ return f"{self.__class__.__name__}({', '.join(parts)})"
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from dataclasses import dataclass
5
+ from typing import Mapping
6
+
7
+
8
+ class Expression(ABC):
9
+ """A base expression to define logical columns."""
10
+ @classmethod
11
+ def coerce(
12
+ cls,
13
+ spec: Expression | Mapping[str, str] | str | None,
14
+ ) -> Expression | None:
15
+ r"""Coerces an expression specification into an :class:`Expression`, if
16
+ possible.
17
+ """
18
+ if spec is None:
19
+ return None
20
+ if isinstance(spec, Expression):
21
+ return spec
22
+ if isinstance(spec, str):
23
+ return LocalExpression(spec)
24
+ if isinstance(spec, Mapping):
25
+ for sub_cls in (LocalExpression, ):
26
+ try:
27
+ return sub_cls(**spec)
28
+ except TypeError:
29
+ pass
30
+ raise TypeError(f"Unable to coerce 'Expression' from '{spec}'")
31
+
32
+
33
+ @dataclass(frozen=True, repr=False)
34
+ class LocalExpression(Expression):
35
+ r"""A local expression to define a row-level logical attribute based on
36
+ physical columns of the data source in the same row.
37
+
38
+ Args:
39
+ value: The value of the expression.
40
+ """
41
+ value: str
42
+
43
+ def __repr__(self) -> str:
44
+ return self.value
@@ -0,0 +1,69 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ class Mapper:
6
+ r"""A mapper to map ``(pkey, batch)`` pairs to contiguous node IDs.
7
+
8
+ Args:
9
+ num_examples: The maximum number of examples to add/retrieve.
10
+ """
11
+ def __init__(self, num_examples: int):
12
+ self._pkey_dtype: pd.CategoricalDtype | None = None
13
+ self._indices: list[np.ndarray] = []
14
+ self._index_dtype: pd.CategoricalDtype | None = None
15
+ self._num_examples = num_examples
16
+
17
+ def add(self, pkey: pd.Series, batch: np.ndarray) -> None:
18
+ r"""Adds a set of ``(pkey, batch)`` pairs to the mapper.
19
+
20
+ Args:
21
+ pkey: The primary keys.
22
+ batch: The batch vector.
23
+ """
24
+ if self._pkey_dtype is not None:
25
+ category = np.concatenate([
26
+ self._pkey_dtype.categories.values,
27
+ pkey,
28
+ ], axis=0)
29
+ category = pd.unique(category)
30
+ self._pkey_dtype = pd.CategoricalDtype(category)
31
+ elif pd.api.types.is_string_dtype(pkey):
32
+ category = pd.unique(pkey)
33
+ self._pkey_dtype = pd.CategoricalDtype(category)
34
+
35
+ if self._pkey_dtype is not None:
36
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
37
+ index = index.astype('int64')
38
+ else:
39
+ index = pkey.to_numpy()
40
+ index = self._num_examples * index + batch
41
+ self._indices.append(index)
42
+ self._index_dtype = None
43
+
44
+ def get(self, pkey: pd.Series, batch: np.ndarray) -> np.ndarray:
45
+ r"""Retrieves the node IDs for a set of ``(pkey, batch)`` pairs.
46
+
47
+ Returns ``-1`` for any pair not registered in the mapping.
48
+
49
+ Args:
50
+ pkey: The primary keys.
51
+ batch: The batch vector.
52
+ """
53
+ if len(self._indices) == 0:
54
+ return np.full(len(pkey), -1, dtype=np.int64)
55
+
56
+ if self._index_dtype is None: # Lazy build index:
57
+ category = pd.unique(np.concatenate(self._indices))
58
+ self._index_dtype = pd.CategoricalDtype(category)
59
+
60
+ if self._pkey_dtype is not None:
61
+ index = pd.Categorical(pkey, dtype=self._pkey_dtype).codes
62
+ index = index.astype('int64')
63
+ else:
64
+ index = pkey.to_numpy()
65
+ index = self._num_examples * index + batch
66
+
67
+ out = pd.Categorical(index, dtype=self._index_dtype).codes
68
+ out = out.astype('int64')
69
+ return out
@@ -13,7 +13,6 @@ from kumoapi.pquery.AST import Aggregation, ASTNode
13
13
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
14
14
  from kumoapi.typing import Stype
15
15
 
16
- from kumoai.experimental.rfm.base import SourceColumn
17
16
  from kumoai.utils import ProgressLogger
18
17
 
19
18
  if TYPE_CHECKING:
@@ -53,12 +52,24 @@ class Sampler(ABC):
53
52
  graph: 'Graph',
54
53
  verbose: bool | ProgressLogger = True,
55
54
  ) -> None:
55
+
56
56
  self._edge_types: list[tuple[str, str, str]] = []
57
57
  for edge in graph.edges:
58
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
59
59
  self._edge_types.append(edge_type)
60
60
  self._edge_types.append(Subgraph.rev_edge_type(edge_type))
61
61
 
62
+ # Source Table -> [(Foreign Key, Destination Table)]
63
+ self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
64
+ # Destination Table -> [(Source Table, Foreign Key)]
65
+ self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
66
+ for table in graph.tables.values():
67
+ self._foreign_key_dict[table.name] = []
68
+ self._rev_foreign_key_dict[table.name] = []
69
+ for src_table, fkey, dst_table in graph.edges:
70
+ self._foreign_key_dict[src_table].append((fkey, dst_table))
71
+ self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
72
+
62
73
  self._primary_key_dict: dict[str, str] = {
63
74
  table.name: table._primary_key
64
75
  for table in graph.tables.values()
@@ -88,10 +99,6 @@ class Sampler(ABC):
88
99
  continue
89
100
  self._table_stype_dict[table.name][column.name] = column.stype
90
101
 
91
- self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
92
- for table in graph.tables.values():
93
- self._source_table_dict[table.name] = table._source_column_dict
94
-
95
102
  self._min_time_dict: dict[str, pd.Timestamp] = {}
96
103
  self._max_time_dict: dict[str, pd.Timestamp] = {}
97
104
 
@@ -102,6 +109,16 @@ class Sampler(ABC):
102
109
  r"""All available edge types in the graph."""
103
110
  return self._edge_types
104
111
 
112
+ @property
113
+ def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
114
+ r"""The foreign keys for all tables in the graph."""
115
+ return self._foreign_key_dict
116
+
117
+ @property
118
+ def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
119
+ r"""The foreign key back references for all tables in the graph."""
120
+ return self._rev_foreign_key_dict
121
+
105
122
  @property
106
123
  def primary_key_dict(self) -> dict[str, str]:
107
124
  r"""All available primary keys in the graph."""
@@ -119,16 +136,11 @@ class Sampler(ABC):
119
136
 
120
137
  @property
121
138
  def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
122
- r"""The registered semantic types for all columns in all tables in
123
- the graph.
139
+ r"""The registered semantic types for all feature columns in all tables
140
+ in the graph.
124
141
  """
125
142
  return self._table_stype_dict
126
143
 
127
- @property
128
- def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
129
- r"""Source column information for all tables in the graph."""
130
- return self._source_table_dict
131
-
132
144
  def get_min_time(
133
145
  self,
134
146
  table_names: list[str] | None = None,
@@ -189,7 +201,7 @@ class Sampler(ABC):
189
201
  exclude_cols_dict: The columns to exclude from the subgraph.
190
202
  """
191
203
  # Exclude all columns that leak target information:
192
- table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
204
+ table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
193
205
  if exclude_cols_dict is not None:
194
206
  table_stype_dict = copy.deepcopy(table_stype_dict)
195
207
  for table_name, exclude_cols in exclude_cols_dict.items():
@@ -237,11 +249,8 @@ class Sampler(ABC):
237
249
  # Set end time to NaT for all values greater than anchor time:
238
250
  assert table_name not in out.inverse_dict
239
251
  ser = df[end_time_column]
240
- if ser.dtype != 'datetime64[ns]':
241
- ser = ser.astype('datetime64[ns]')
242
252
  mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
243
- ser.iloc[mask] = pd.NaT
244
- df[end_time_column] = ser
253
+ df.loc[mask, end_time_column] = pd.NaT
245
254
 
246
255
  stype_dict = table_stype_dict[table_name]
247
256
  for column_name, stype in stype_dict.items():
@@ -286,7 +295,8 @@ class Sampler(ABC):
286
295
 
287
296
  # Store in compressed representation if more efficient:
288
297
  num_cols = subgraph.table_dict[edge_type[2]].num_rows
289
- if col is not None and len(col) > num_cols + 1:
298
+ if (col is not None and len(col) > num_cols + 1
299
+ and ((col[1:] - col[:-1]) >= 0).all()):
290
300
  layout = EdgeLayout.CSC
291
301
  colcount = np.bincount(col, minlength=num_cols)
292
302
  col = np.empty(num_cols + 1, dtype=col.dtype)
@@ -6,7 +6,7 @@ from kumoapi.typing import Dtype
6
6
  @dataclass
7
7
  class SourceColumn:
8
8
  name: str
9
- dtype: Dtype
9
+ dtype: Dtype | None
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
12
  is_nullable: bool