kumoai 2.13.0.dev202511231731__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List, Optional, Sequence
3
3
 
4
4
  import pandas as pd
5
5
  from kumoapi.source_table import UnavailableSourceTable
@@ -9,107 +9,17 @@ from kumoapi.typing import Dtype, Stype
9
9
  from typing_extensions import Self
10
10
 
11
11
  from kumoai import in_notebook
12
- from kumoai.experimental.rfm import utils
12
+ from kumoai.experimental.rfm.base import Column
13
13
 
14
14
 
15
- @dataclass(init=False, repr=False, eq=False)
16
- class Column:
17
- stype: Stype
18
-
19
- def __init__(
20
- self,
21
- name: str,
22
- dtype: Dtype,
23
- stype: Stype,
24
- is_primary_key: bool = False,
25
- is_time_column: bool = False,
26
- is_end_time_column: bool = False,
27
- ) -> None:
28
- self._name = name
29
- self._dtype = Dtype(dtype)
30
- self._is_primary_key = is_primary_key
31
- self._is_time_column = is_time_column
32
- self._is_end_time_column = is_end_time_column
33
- self.stype = Stype(stype)
34
-
35
- @property
36
- def name(self) -> str:
37
- return self._name
38
-
39
- @property
40
- def dtype(self) -> Dtype:
41
- return self._dtype
42
-
43
- def __setattr__(self, key: str, val: Any) -> None:
44
- if key == 'stype':
45
- if isinstance(val, str):
46
- val = Stype(val)
47
- assert isinstance(val, Stype)
48
- if not val.supports_dtype(self.dtype):
49
- raise ValueError(f"Column '{self.name}' received an "
50
- f"incompatible semantic type (got "
51
- f"dtype='{self.dtype}' and stype='{val}')")
52
- if self._is_primary_key and val != Stype.ID:
53
- raise ValueError(f"Primary key '{self.name}' must have 'ID' "
54
- f"semantic type (got '{val}')")
55
- if self._is_time_column and val != Stype.timestamp:
56
- raise ValueError(f"Time column '{self.name}' must have "
57
- f"'timestamp' semantic type (got '{val}')")
58
- if self._is_end_time_column and val != Stype.timestamp:
59
- raise ValueError(f"End time column '{self.name}' must have "
60
- f"'timestamp' semantic type (got '{val}')")
61
-
62
- super().__setattr__(key, val)
63
-
64
- def __hash__(self) -> int:
65
- return hash((self.name, self.stype, self.dtype))
66
-
67
- def __eq__(self, other: Any) -> bool:
68
- if not isinstance(other, Column):
69
- return False
70
- return hash(self) == hash(other)
71
-
72
- def __repr__(self) -> str:
73
- return (f'{self.__class__.__name__}(name={self.name}, '
74
- f'stype={self.stype}, dtype={self.dtype})')
75
-
76
-
77
- class LocalTable:
78
- r"""A table backed by a :class:`pandas.DataFrame`.
79
-
80
- A :class:`LocalTable` fully specifies the relevant metadata, *i.e.*
81
- selected columns, column semantic types, primary keys and time columns.
82
- :class:`LocalTable` is used to create a :class:`LocalGraph`.
83
-
84
- .. code-block:: python
85
-
86
- import pandas as pd
87
- import kumoai.experimental.rfm as rfm
88
-
89
- # Load data from a CSV file:
90
- df = pd.read_csv("data.csv")
91
-
92
- # Create a table from a `pandas.DataFrame` and infer its metadata ...
93
- table = rfm.LocalTable(df, name="my_table").infer_metadata()
94
-
95
- # ... or create a table explicitly:
96
- table = rfm.LocalTable(
97
- df=df,
98
- name="my_table",
99
- primary_key="id",
100
- time_column="time",
101
- end_time_column=None,
102
- )
103
-
104
- # Verify metadata:
105
- table.print_metadata()
106
-
107
- # Change the semantic type of a column:
108
- table[column].stype = "text"
15
+ class Table(ABC):
16
+ r"""A :class:`Table` fully specifies the relevant metadata of a single
17
+ table, *i.e.* its selected columns, data types, semantic types, primary
18
+ keys and time columns.
109
19
 
110
20
  Args:
111
- df: The data frame to create the table from.
112
- name: The name of the table.
21
+ name: The name of this table.
22
+ columns: The selected columns of this table.
113
23
  primary_key: The name of the primary key of this table, if it exists.
114
24
  time_column: The name of the time column of this table, if it exists.
115
25
  end_time_column: The name of the end time column of this table, if it
@@ -117,46 +27,40 @@ class LocalTable:
117
27
  """
118
28
  def __init__(
119
29
  self,
120
- df: pd.DataFrame,
121
30
  name: str,
31
+ columns: Optional[Sequence[str]] = None,
122
32
  primary_key: Optional[str] = None,
123
33
  time_column: Optional[str] = None,
124
34
  end_time_column: Optional[str] = None,
125
35
  ) -> None:
126
36
 
127
- if df.empty:
128
- raise ValueError("Data frame must have at least one row")
129
- if isinstance(df.columns, pd.MultiIndex):
130
- raise ValueError("Data frame must not have a multi-index")
131
- if not df.columns.is_unique:
132
- raise ValueError("Data frame must have unique column names")
133
- if any(col == '' for col in df.columns):
134
- raise ValueError("Data frame must have non-empty column names")
135
-
136
- df = df.copy(deep=False)
137
-
138
- self._data = df
139
37
  self._name = name
140
38
  self._primary_key: Optional[str] = None
141
39
  self._time_column: Optional[str] = None
142
40
  self._end_time_column: Optional[str] = None
143
41
 
144
42
  self._columns: Dict[str, Column] = {}
145
- for column_name in df.columns:
43
+ for column_name in columns or []:
146
44
  self.add_column(column_name)
147
45
 
148
46
  if primary_key is not None:
47
+ if primary_key not in self:
48
+ self.add_column(primary_key)
149
49
  self.primary_key = primary_key
150
50
 
151
51
  if time_column is not None:
52
+ if time_column not in self:
53
+ self.add_column(time_column)
152
54
  self.time_column = time_column
153
55
 
154
56
  if end_time_column is not None:
57
+ if end_time_column not in self:
58
+ self.add_column(end_time_column)
155
59
  self.end_time_column = end_time_column
156
60
 
157
61
  @property
158
62
  def name(self) -> str:
159
- r"""The name of the table."""
63
+ r"""The name of this table."""
160
64
  return self._name
161
65
 
162
66
  # Data column #############################################################
@@ -200,24 +104,25 @@ class LocalTable:
200
104
  raise KeyError(f"Column '{name}' already exists in table "
201
105
  f"'{self.name}'")
202
106
 
203
- if name not in self._data.columns:
204
- raise KeyError(f"Column '{name}' does not exist in the underyling "
205
- f"data frame")
107
+ if not self._has_source_column(name):
108
+ raise KeyError(f"Column '{name}' does not exist in the underlying "
109
+ f"source table")
206
110
 
207
111
  try:
208
- dtype = utils.to_dtype(self._data[name])
112
+ dtype = self._get_source_dtype(name)
209
113
  except Exception as e:
210
- raise RuntimeError(f"Data type inference for column '{name}' in "
211
- f"table '{self.name}' failed. Consider "
212
- f"changing the data type of the column or "
213
- f"removing it from the table.") from e
114
+ raise RuntimeError(f"Could not obtain data type for column "
115
+ f"'{name}' in table '{self.name}'. Change "
116
+ f"the data type of the column in the source "
117
+ f"table or remove it from the table.") from e
118
+
214
119
  try:
215
- stype = utils.infer_stype(self._data[name], name, dtype)
120
+ stype = self._get_source_stype(name, dtype)
216
121
  except Exception as e:
217
- raise RuntimeError(f"Semantic type inference for column '{name}' "
218
- f"in table '{self.name}' failed. Consider "
219
- f"changing the data type of the column or "
220
- f"removing it from the table.") from e
122
+ raise RuntimeError(f"Could not obtain semantic type for column "
123
+ f"'{name}' in table '{self.name}'. Change "
124
+ f"the data type of the column in the source "
125
+ f"table or remove it from the table.") from e
221
126
 
222
127
  self._columns[name] = Column(
223
128
  name=name,
@@ -432,12 +337,14 @@ class LocalTable:
432
337
  })
433
338
 
434
339
  def print_metadata(self) -> None:
435
- r"""Prints the :meth:`~LocalTable.metadata` of the table."""
340
+ r"""Prints the :meth:`~metadata` of this table."""
341
+ num_rows = self._num_rows()
342
+ num_rows_repr = ' ({num_rows:,} rows)' if num_rows is not None else ''
343
+
436
344
  if in_notebook():
437
345
  from IPython.display import Markdown, display
438
- display(
439
- Markdown(f"### 🏷️ Metadata of Table `{self.name}` "
440
- f"({len(self._data):,} rows)"))
346
+ md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
347
+ display(Markdown(md_repr))
441
348
  df = self.metadata
442
349
  try:
443
350
  if hasattr(df.style, 'hide'):
@@ -447,8 +354,7 @@ class LocalTable:
447
354
  except ImportError:
448
355
  print(df.to_string(index=False)) # missing jinja2
449
356
  else:
450
- print(f"🏷️ Metadata of Table '{self.name}' "
451
- f"({len(self._data):,} rows):")
357
+ print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
452
358
  print(self.metadata.to_string(index=False))
453
359
 
454
360
  def infer_metadata(self, verbose: bool = True) -> Self:
@@ -478,11 +384,7 @@ class LocalTable:
478
384
  column.name for column in self.columns if is_candidate(column)
479
385
  ]
480
386
 
481
- if primary_key := utils.detect_primary_key(
482
- table_name=self.name,
483
- df=self._data,
484
- candidates=candidates,
485
- ):
387
+ if primary_key := self._infer_primary_key(candidates):
486
388
  self.primary_key = primary_key
487
389
  logs.append(f"primary key '{primary_key}'")
488
390
 
@@ -493,7 +395,7 @@ class LocalTable:
493
395
  if column.stype == Stype.timestamp
494
396
  and column.name != self._end_time_column
495
397
  ]
496
- if time_column := utils.detect_time_column(self._data, candidates):
398
+ if time_column := self._infer_time_column(candidates):
497
399
  self.time_column = time_column
498
400
  logs.append(f"time column '{time_column}'")
499
401
 
@@ -543,3 +445,29 @@ class LocalTable:
543
445
  f' time_column={self._time_column},\n'
544
446
  f' end_time_column={self._end_time_column},\n'
545
447
  f')')
448
+
449
+ # Abstract method #########################################################
450
+
451
+ @abstractmethod
452
+ def _has_source_column(self, name: str) -> bool:
453
+ pass
454
+
455
+ @abstractmethod
456
+ def _get_source_dtype(self, name: str) -> Dtype:
457
+ pass
458
+
459
+ @abstractmethod
460
+ def _get_source_stype(self, name: str, dtype: Dtype) -> Stype:
461
+ pass
462
+
463
+ @abstractmethod
464
+ def _infer_primary_key(self, candidates: List[str]) -> Optional[str]:
465
+ pass
466
+
467
+ @abstractmethod
468
+ def _infer_time_column(self, candidates: List[str]) -> Optional[str]:
469
+ pass
470
+
471
+ @abstractmethod
472
+ def _num_rows(self) -> Optional[int]:
473
+ pass
@@ -3,7 +3,7 @@ import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
5
  from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
7
7
 
8
8
  import pandas as pd
9
9
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -12,19 +12,19 @@ from kumoapi.typing import Stype
12
12
  from typing_extensions import Self
13
13
 
14
14
  from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai.experimental.rfm import Table
16
16
  from kumoai.graph import Edge
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  import graphviz
20
20
 
21
21
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
22
+ class Graph:
23
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
24
  tables in a relational database.
25
25
 
26
26
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
27
+ :class:`Graph` is created, you can use it to initialize the
28
28
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
29
 
30
30
  .. code-block:: python
@@ -44,7 +44,7 @@ class LocalGraph:
44
44
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
45
 
46
46
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
47
+ >>> graph = rfm.Graph({
48
48
  ... "table1": table1,
49
49
  ... "table2": table2,
50
50
  ... "table3": table3,
@@ -75,11 +75,11 @@ class LocalGraph:
75
75
 
76
76
  def __init__(
77
77
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
78
+ tables: Sequence[Table],
79
+ edges: Optional[Sequence[Edge]] = None,
80
80
  ) -> None:
81
81
 
82
- self._tables: Dict[str, LocalTable] = {}
82
+ self._tables: Dict[str, Table] = {}
83
83
  self._edges: List[Edge] = []
84
84
 
85
85
  for table in tables:
@@ -94,11 +94,11 @@ class LocalGraph:
94
94
  def from_data(
95
95
  cls,
96
96
  df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
97
+ edges: Optional[Sequence[Edge]] = None,
98
98
  infer_metadata: bool = True,
99
99
  verbose: bool = True,
100
100
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
101
+ r"""Creates a :class:`Graph` from a dictionary of
102
102
  :class:`pandas.DataFrame` objects.
103
103
 
104
104
  Automatically infers table metadata and links.
@@ -115,7 +115,7 @@ class LocalGraph:
115
115
  >>> df3 = pd.DataFrame(...)
116
116
 
117
117
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
118
+ >>> graph = rfm.Graph.from_data({
119
119
  ... "table1": df1,
120
120
  ... "table2": df2,
121
121
  ... "table3": df3,
@@ -148,13 +148,14 @@ class LocalGraph:
148
148
  >>> df1 = pd.DataFrame(...)
149
149
  >>> df2 = pd.DataFrame(...)
150
150
  >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
151
+ >>> graph = rfm.Graph.from_data(data={
152
152
  ... "table1": df1,
153
153
  ... "table2": df2,
154
154
  ... "table3": df3,
155
155
  ... })
156
156
  >>> graph.validate()
157
157
  """
158
+ from kumoai.experimental.rfm import LocalTable
158
159
  tables = [LocalTable(df, name) for name, df in df_dict.items()]
159
160
 
160
161
  graph = cls(tables, edges=edges or [])
@@ -175,7 +176,7 @@ class LocalGraph:
175
176
  """
176
177
  return name in self.tables
177
178
 
178
- def table(self, name: str) -> LocalTable:
179
+ def table(self, name: str) -> Table:
179
180
  r"""Returns the table with name ``name`` in the graph.
180
181
 
181
182
  Raises:
@@ -186,11 +187,11 @@ class LocalGraph:
186
187
  return self.tables[name]
187
188
 
188
189
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
190
+ def tables(self) -> Dict[str, Table]:
190
191
  r"""Returns the dictionary of table objects."""
191
192
  return self._tables
192
193
 
193
- def add_table(self, table: LocalTable) -> Self:
194
+ def add_table(self, table: Table) -> Self:
194
195
  r"""Adds a table to the graph.
195
196
 
196
197
  Args:
@@ -199,11 +200,21 @@ class LocalGraph:
199
200
  Raises:
200
201
  KeyError: If a table with the same name already exists in the
201
202
  graph.
203
+ ValueError: If the table belongs to a different backend than the
204
+ rest of the tables in the graph.
202
205
  """
203
206
  if table.name in self._tables:
204
207
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
208
  f"this graph; table names must be globally unique.")
206
209
 
210
+ if len(self._tables) > 0:
211
+ cls = next(iter(self._tables.values())).__class__
212
+ if table.__class__ != cls:
213
+ raise ValueError(f"Cannot register a "
214
+ f"'{table.__class__.__name__}' to this "
215
+ f"graph since other tables are of type "
216
+ f"'{cls.__name__}'.")
217
+
207
218
  self._tables[table.name] = table
208
219
 
209
220
  return self
@@ -241,7 +252,7 @@ class LocalGraph:
241
252
  Example:
242
253
  >>> # doctest: +SKIP
243
254
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
255
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
256
  >>> graph.metadata # doctest: +SKIP
246
257
  name primary_key time_column end_time_column
247
258
  0 users user_id - -
@@ -263,7 +274,7 @@ class LocalGraph:
263
274
  })
264
275
 
265
276
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
277
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
267
278
  if in_notebook():
268
279
  from IPython.display import Markdown, display
269
280
  display(Markdown('### 🗂️ Graph Metadata'))
@@ -287,7 +298,7 @@ class LocalGraph:
287
298
 
288
299
  Note:
289
300
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
301
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
302
  """
292
303
  for table in self.tables.values():
293
304
  table.infer_metadata(verbose=False)
@@ -305,7 +316,7 @@ class LocalGraph:
305
316
  return self._edges
306
317
 
307
318
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
319
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
320
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
321
  edge.src_table, edge.fkey) for edge in self.edges]
311
322
  edges = sorted(edges)
@@ -333,9 +344,9 @@ class LocalGraph:
333
344
 
334
345
  def link(
335
346
  self,
336
- src_table: Union[str, LocalTable],
347
+ src_table: Union[str, Table],
337
348
  fkey: str,
338
- dst_table: Union[str, LocalTable],
349
+ dst_table: Union[str, Table],
339
350
  ) -> Self:
340
351
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
352
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +369,11 @@ class LocalGraph:
358
369
  table does not exist in the graph, if the source key does not
359
370
  exist in the source table.
360
371
  """
361
- if isinstance(src_table, LocalTable):
372
+ if isinstance(src_table, Table):
362
373
  src_table = src_table.name
363
374
  assert isinstance(src_table, str)
364
375
 
365
- if isinstance(dst_table, LocalTable):
376
+ if isinstance(dst_table, Table):
366
377
  dst_table = dst_table.name
367
378
  assert isinstance(dst_table, str)
368
379
 
@@ -396,9 +407,9 @@ class LocalGraph:
396
407
 
397
408
  def unlink(
398
409
  self,
399
- src_table: Union[str, LocalTable],
410
+ src_table: Union[str, Table],
400
411
  fkey: str,
401
- dst_table: Union[str, LocalTable],
412
+ dst_table: Union[str, Table],
402
413
  ) -> Self:
403
414
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
415
 
@@ -410,11 +421,11 @@ class LocalGraph:
410
421
  Raises:
411
422
  ValueError: if the edge is not present in the graph.
412
423
  """
413
- if isinstance(src_table, LocalTable):
424
+ if isinstance(src_table, Table):
414
425
  src_table = src_table.name
415
426
  assert isinstance(src_table, str)
416
427
 
417
- if isinstance(dst_table, LocalTable):
428
+ if isinstance(dst_table, Table):
418
429
  dst_table = dst_table.name
419
430
  assert isinstance(dst_table, str)
420
431
 
@@ -528,7 +539,10 @@ class LocalGraph:
528
539
  score += 1.0
529
540
 
530
541
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
542
+ src_num_rows = src_table._num_rows()
543
+ dst_num_rows = dst_table._num_rows()
544
+ if (src_num_rows is not None and dst_num_rows is not None
545
+ and src_num_rows > dst_num_rows):
532
546
  score += 1.0
533
547
 
534
548
  if score < 5.0:
@@ -790,7 +804,7 @@ class LocalGraph:
790
804
  def __contains__(self, name: str) -> bool:
791
805
  return self.has_table(name)
792
806
 
793
- def __getitem__(self, name: str) -> LocalTable:
807
+ def __getitem__(self, name: str) -> Table:
794
808
  return self.table(name)
795
809
 
796
810
  def __delitem__(self, name: str) -> None:
@@ -6,7 +6,7 @@ import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import LocalGraph
9
+ from kumoai.experimental.rfm import Graph, LocalTable
10
10
  from kumoai.experimental.rfm.utils import normalize_text
11
11
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
12
12
 
@@ -20,7 +20,7 @@ except ImportError:
20
20
  class LocalGraphStore:
21
21
  def __init__(
22
22
  self,
23
- graph: LocalGraph,
23
+ graph: Graph,
24
24
  preprocess: bool = False,
25
25
  verbose: Union[bool, ProgressLogger] = True,
26
26
  ) -> None:
@@ -105,7 +105,7 @@ class LocalGraphStore:
105
105
 
106
106
  def sanitize(
107
107
  self,
108
- graph: LocalGraph,
108
+ graph: Graph,
109
109
  preprocess: bool = False,
110
110
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
111
111
  r"""Sanitizes raw data according to table schema definition:
@@ -120,10 +120,11 @@ class LocalGraphStore:
120
120
  data for faster model processing. In particular, it:
121
121
  * tokenizes any text column that is not a foreign key
122
122
  """
123
- df_dict: Dict[str, pd.DataFrame] = {
124
- table_name: table._data.copy(deep=False).reset_index(drop=True)
125
- for table_name, table in graph.tables.items()
126
- }
123
+ df_dict: Dict[str, pd.DataFrame] = {}
124
+ for table_name, table in graph.tables.items():
125
+ assert isinstance(table, LocalTable)
126
+ df = table._data
127
+ df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
127
128
 
128
129
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
129
130
 
@@ -165,7 +166,7 @@ class LocalGraphStore:
165
166
 
166
167
  return df_dict, mask_dict
167
168
 
168
- def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
169
+ def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
169
170
  stype_dict: Dict[str, Dict[str, Stype]] = {}
170
171
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
171
172
  for table in graph.tables.values():
@@ -180,7 +181,7 @@ class LocalGraphStore:
180
181
 
181
182
  def get_pkey_data(
182
183
  self,
183
- graph: LocalGraph,
184
+ graph: Graph,
184
185
  ) -> Tuple[
185
186
  Dict[str, str],
186
187
  Dict[str, pd.DataFrame],
@@ -218,7 +219,7 @@ class LocalGraphStore:
218
219
 
219
220
  def get_time_data(
220
221
  self,
221
- graph: LocalGraph,
222
+ graph: Graph,
222
223
  ) -> Tuple[
223
224
  Dict[str, str],
224
225
  Dict[str, str],
@@ -259,7 +260,7 @@ class LocalGraphStore:
259
260
 
260
261
  def get_csc(
261
262
  self,
262
- graph: LocalGraph,
263
+ graph: Graph,
263
264
  ) -> Tuple[
264
265
  Dict[Tuple[str, str, str], np.ndarray],
265
266
  Dict[Tuple[str, str, str], np.ndarray],
@@ -32,7 +32,7 @@ from kumoapi.task import TaskType
32
32
 
33
33
  from kumoai.client.rfm import RFMAPI
34
34
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
35
+ from kumoai.experimental.rfm import Graph
36
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
37
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
38
  from kumoai.experimental.rfm.local_pquery_driver import (
@@ -123,17 +123,17 @@ class KumoRFM:
123
123
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
124
  relational dataset without training.
125
125
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
126
+ model from a :class:`Graph` object.
127
127
 
128
128
  .. code-block:: python
129
129
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
130
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
131
 
132
132
  df_users = pd.DataFrame(...)
133
133
  df_items = pd.DataFrame(...)
134
134
  df_orders = pd.DataFrame(...)
135
135
 
136
- graph = LocalGraph.from_data({
136
+ graph = Graph.from_data({
137
137
  'users': df_users,
138
138
  'items': df_items,
139
139
  'orders': df_orders,
@@ -163,7 +163,7 @@ class KumoRFM:
163
163
  """
164
164
  def __init__(
165
165
  self,
166
- graph: LocalGraph,
166
+ graph: Graph,
167
167
  preprocess: bool = False,
168
168
  verbose: Union[bool, ProgressLogger] = True,
169
169
  ) -> None:
@@ -172,10 +172,19 @@ class KumoRFM:
172
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
174
174
 
175
+ self._client: Optional[RFMAPI] = None
176
+
175
177
  self._batch_size: Optional[int | Literal['max']] = None
176
178
  self.num_retries: int = 0
179
+
180
+ @property
181
+ def _api_client(self) -> RFMAPI:
182
+ if self._client is not None:
183
+ return self._client
184
+
177
185
  from kumoai.experimental.rfm import global_state
178
- self._api_client = RFMAPI(global_state.client)
186
+ self._client = RFMAPI(global_state.client)
187
+ return self._client
179
188
 
180
189
  def __repr__(self) -> str:
181
190
  return f'{self.__class__.__name__}()'