kumoai 2.13.0.dev202511261731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/connector/utils.py +23 -2
  3. kumoai/experimental/rfm/__init__.py +20 -45
  4. kumoai/experimental/rfm/backend/__init__.py +0 -0
  5. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  6. kumoai/experimental/rfm/backend/local/table.py +244 -0
  7. kumoai/experimental/rfm/backend/snow/__init__.py +32 -0
  8. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  9. kumoai/experimental/rfm/backend/sqlite/table.py +124 -0
  10. kumoai/experimental/rfm/base/__init__.py +7 -0
  11. kumoai/experimental/rfm/base/column.py +66 -0
  12. kumoai/experimental/rfm/{local_table.py → base/table.py} +71 -139
  13. kumoai/experimental/rfm/{local_graph.py → graph.py} +144 -57
  14. kumoai/experimental/rfm/infer/__init__.py +2 -0
  15. kumoai/experimental/rfm/infer/stype.py +35 -0
  16. kumoai/experimental/rfm/local_graph_store.py +12 -11
  17. kumoai/experimental/rfm/rfm.py +5 -5
  18. kumoai/experimental/rfm/sagemaker.py +11 -3
  19. kumoai/experimental/rfm/utils.py +1 -120
  20. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  21. kumoai/testing/decorators.py +1 -1
  22. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/METADATA +8 -8
  23. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/RECORD +26 -17
  24. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/WHEEL +0 -0
  25. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/licenses/LICENSE +0 -0
  26. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
5
  from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
7
 
8
8
  import pandas as pd
9
9
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -12,19 +12,19 @@ from kumoapi.typing import Stype
12
12
  from typing_extensions import Self
13
13
 
14
14
  from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai.experimental.rfm import Table
16
16
  from kumoai.graph import Edge
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  import graphviz
20
20
 
21
21
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
22
+ class Graph:
23
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
24
  tables in a relational database.
25
25
 
26
26
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
27
+ :class:`Graph` is created, you can use it to initialize the
28
28
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
29
 
30
30
  .. code-block:: python
@@ -44,7 +44,7 @@ class LocalGraph:
44
44
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
45
 
46
46
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
47
+ >>> graph = rfm.Graph({
48
48
  ... "table1": table1,
49
49
  ... "table2": table2,
50
50
  ... "table3": table3,
@@ -75,33 +75,44 @@ class LocalGraph:
75
75
 
76
76
  def __init__(
77
77
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
78
+ tables: Sequence[Table],
79
+ edges: Optional[Sequence[Edge]] = None,
80
80
  ) -> None:
81
81
 
82
- self._tables: Dict[str, LocalTable] = {}
82
+ self._tables: Dict[str, Table] = {}
83
83
  self._edges: List[Edge] = []
84
84
 
85
85
  for table in tables:
86
86
  self.add_table(table)
87
87
 
88
+ for table in tables:
89
+ for fkey, dst_table, pkey in table._get_source_foreign_keys():
90
+ if self[dst_table].primary_key is None:
91
+ self[dst_table].primary_key = pkey
92
+ elif self[dst_table]._primary_key != pkey:
93
+ raise ValueError(f"Found duplicate primary key definition "
94
+ f"'{self[dst_table]._primary_key}' and "
95
+ f"'{pkey}' in table '{dst_table}'.")
96
+ self.link(table.name, fkey, dst_table)
97
+
88
98
  for edge in (edges or []):
89
99
  _edge = Edge._cast(edge)
90
100
  assert _edge is not None
91
- self.link(*_edge)
101
+ if _edge not in self._edges:
102
+ self.link(*_edge)
92
103
 
93
104
  @classmethod
94
105
  def from_data(
95
106
  cls,
96
107
  df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
108
+ edges: Optional[Sequence[Edge]] = None,
98
109
  infer_metadata: bool = True,
99
110
  verbose: bool = True,
100
111
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
112
+ r"""Creates a :class:`Graph` from a dictionary of
102
113
  :class:`pandas.DataFrame` objects.
103
114
 
104
- Automatically infers table metadata and links.
115
+ Automatically infers table metadata and links by default.
105
116
 
106
117
  .. code-block:: python
107
118
 
@@ -115,7 +126,7 @@ class LocalGraph:
115
126
  >>> df3 = pd.DataFrame(...)
116
127
 
117
128
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
129
+ >>> graph = rfm.Graph.from_data({
119
130
  ... "table1": df1,
120
131
  ... "table2": df2,
121
132
  ... "table3": df3,
@@ -131,39 +142,103 @@ class LocalGraph:
131
142
  Args:
132
143
  df_dict: A dictionary of data frames, where the keys are the names
133
144
  of the tables and the values hold table data.
134
- infer_metadata: Whether to infer metadata for all tables in the
135
- graph.
136
145
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
146
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
147
+ inferred from the data in case ``infer_metadata=True``.
148
+ infer_metadata: Whether to infer metadata for all tables in the
149
+ graph.
139
150
  verbose: Whether to print verbose output.
151
+ """
152
+ from kumoai.experimental.rfm.backend.local import LocalTable
153
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
140
154
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
155
+ graph = cls(tables, edges=edges or [])
156
+
157
+ if infer_metadata:
158
+ graph.infer_metadata(False)
159
+
160
+ if edges is None:
161
+ graph.infer_links(False)
162
+
163
+ if verbose:
164
+ graph.print_metadata()
165
+ graph.print_links()
166
+
167
+ return graph
168
+
169
+ @classmethod
170
+ def from_sqlite(
171
+ cls,
172
+ uri: Any,
173
+ table_names: Optional[Sequence[str]] = None,
174
+ edges: Optional[Sequence[Edge]] = None,
175
+ infer_metadata: bool = True,
176
+ verbose: bool = True,
177
+ conn_kwargs: Optional[Dict[str, Any]] = None,
178
+ ) -> Self:
179
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
180
+
181
+ Automatically infers table metadata and links by default.
182
+
183
+ .. code-block:: python
144
184
 
145
- Example:
146
185
  >>> # doctest: +SKIP
147
186
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
187
+
188
+ >>> # Create a graph from a SQLite database:
189
+ >>> graph = rfm.Graph.from_sqlite('data.db')
190
+
191
+ >>> # Inspect table metadata:
192
+ >>> for table in graph.tables.values():
193
+ ... table.print_metadata()
194
+
195
+ >>> # Visualize graph (if graphviz is installed):
196
+ >>> graph.visualize()
197
+
198
+ Args:
199
+ uri: The path to the database file or an open connection obtained
200
+ from :meth:`~kumoai.experimental.rfm.backend.sqlite.connect`.
201
+ table_names: Set of table names to include. If ``None``, will add
202
+ all tables present in the database.
203
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
204
+ add to the graph. If not provided, edges will be automatically
205
+ inferred from the data in case ``infer_metadata=True``.
206
+ infer_metadata: Whether to infer metadata for all tables in the
207
+ graph.
208
+ verbose: Whether to print verbose output.
209
+ conn_kwargs: Additional connection arguments, following the
210
+ :class:`adbc_driver_sqlite` protocol.
157
211
  """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
212
+ from kumoai.experimental.rfm.backend.sqlite import (
213
+ Connection,
214
+ SQLiteTable,
215
+ connect,
216
+ )
217
+
218
+ if not isinstance(uri, Connection):
219
+ connection = connect(uri, **(conn_kwargs or {}))
220
+ else:
221
+ connection = uri
222
+
223
+ if table_names is None:
224
+ with connection.cursor() as cursor:
225
+ cursor.execute("SELECT name FROM sqlite_master "
226
+ "WHERE type='table'")
227
+ table_names = [row[0] for row in cursor.fetchall()]
228
+
229
+ tables = [SQLiteTable(connection, name) for name in table_names]
159
230
 
160
231
  graph = cls(tables, edges=edges or [])
161
232
 
162
233
  if infer_metadata:
163
- graph.infer_metadata(verbose)
234
+ graph.infer_metadata(False)
164
235
 
165
236
  if edges is None:
166
- graph.infer_links(verbose)
237
+ graph.infer_links(False)
238
+
239
+ if verbose:
240
+ graph.print_metadata()
241
+ graph.print_links()
167
242
 
168
243
  return graph
169
244
 
@@ -175,7 +250,7 @@ class LocalGraph:
175
250
  """
176
251
  return name in self.tables
177
252
 
178
- def table(self, name: str) -> LocalTable:
253
+ def table(self, name: str) -> Table:
179
254
  r"""Returns the table with name ``name`` in the graph.
180
255
 
181
256
  Raises:
@@ -186,11 +261,11 @@ class LocalGraph:
186
261
  return self.tables[name]
187
262
 
188
263
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
264
+ def tables(self) -> Dict[str, Table]:
190
265
  r"""Returns the dictionary of table objects."""
191
266
  return self._tables
192
267
 
193
- def add_table(self, table: LocalTable) -> Self:
268
+ def add_table(self, table: Table) -> Self:
194
269
  r"""Adds a table to the graph.
195
270
 
196
271
  Args:
@@ -199,11 +274,21 @@ class LocalGraph:
199
274
  Raises:
200
275
  KeyError: If a table with the same name already exists in the
201
276
  graph.
277
+ ValueError: If the table belongs to a different backend than the
278
+ rest of the tables in the graph.
202
279
  """
203
280
  if table.name in self._tables:
204
281
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
282
  f"this graph; table names must be globally unique.")
206
283
 
284
+ if len(self._tables) > 0:
285
+ cls = next(iter(self._tables.values())).__class__
286
+ if table.__class__ != cls:
287
+ raise ValueError(f"Cannot register a "
288
+ f"'{table.__class__.__name__}' to this "
289
+ f"graph since other tables are of type "
290
+ f"'{cls.__name__}'.")
291
+
207
292
  self._tables[table.name] = table
208
293
 
209
294
  return self
@@ -241,7 +326,7 @@ class LocalGraph:
241
326
  Example:
242
327
  >>> # doctest: +SKIP
243
328
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
329
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
330
  >>> graph.metadata # doctest: +SKIP
246
331
  name primary_key time_column end_time_column
247
332
  0 users user_id - -
@@ -263,7 +348,7 @@ class LocalGraph:
263
348
  })
264
349
 
265
350
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
351
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
267
352
  if in_notebook():
268
353
  from IPython.display import Markdown, display
269
354
  display(Markdown('### 🗂️ Graph Metadata'))
@@ -287,7 +372,7 @@ class LocalGraph:
287
372
 
288
373
  Note:
289
374
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
375
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
376
  """
292
377
  for table in self.tables.values():
293
378
  table.infer_metadata(verbose=False)
@@ -305,7 +390,7 @@ class LocalGraph:
305
390
  return self._edges
306
391
 
307
392
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
393
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
394
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
395
  edge.src_table, edge.fkey) for edge in self.edges]
311
396
  edges = sorted(edges)
@@ -333,9 +418,9 @@ class LocalGraph:
333
418
 
334
419
  def link(
335
420
  self,
336
- src_table: Union[str, LocalTable],
421
+ src_table: Union[str, Table],
337
422
  fkey: str,
338
- dst_table: Union[str, LocalTable],
423
+ dst_table: Union[str, Table],
339
424
  ) -> Self:
340
425
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
426
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +443,11 @@ class LocalGraph:
358
443
  table does not exist in the graph, if the source key does not
359
444
  exist in the source table.
360
445
  """
361
- if isinstance(src_table, LocalTable):
446
+ if isinstance(src_table, Table):
362
447
  src_table = src_table.name
363
448
  assert isinstance(src_table, str)
364
449
 
365
- if isinstance(dst_table, LocalTable):
450
+ if isinstance(dst_table, Table):
366
451
  dst_table = dst_table.name
367
452
  assert isinstance(dst_table, str)
368
453
 
@@ -396,9 +481,9 @@ class LocalGraph:
396
481
 
397
482
  def unlink(
398
483
  self,
399
- src_table: Union[str, LocalTable],
484
+ src_table: Union[str, Table],
400
485
  fkey: str,
401
- dst_table: Union[str, LocalTable],
486
+ dst_table: Union[str, Table],
402
487
  ) -> Self:
403
488
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
489
 
@@ -410,11 +495,11 @@ class LocalGraph:
410
495
  Raises:
411
496
  ValueError: if the edge is not present in the graph.
412
497
  """
413
- if isinstance(src_table, LocalTable):
498
+ if isinstance(src_table, Table):
414
499
  src_table = src_table.name
415
500
  assert isinstance(src_table, str)
416
501
 
417
- if isinstance(dst_table, LocalTable):
502
+ if isinstance(dst_table, Table):
418
503
  dst_table = dst_table.name
419
504
  assert isinstance(dst_table, str)
420
505
 
@@ -428,17 +513,13 @@ class LocalGraph:
428
513
  return self
429
514
 
430
515
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
516
+ r"""Infers missing links for the tables and adds them as edges to the
517
+ graph.
432
518
 
433
519
  Args:
434
520
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
521
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
522
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
442
523
 
443
524
  # A list of primary key candidates (+score) for every column:
444
525
  candidate_dict: dict[
@@ -463,6 +544,9 @@ class LocalGraph:
463
544
  src_table_name = src_table.name.lower()
464
545
 
465
546
  for src_key in src_table.columns:
547
+ if (src_table.name, src_key.name) in known_edges:
548
+ continue
549
+
466
550
  if src_key == src_table.primary_key:
467
551
  continue # Cannot link to primary key.
468
552
 
@@ -528,7 +612,10 @@ class LocalGraph:
528
612
  score += 1.0
529
613
 
530
614
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
615
+ src_num_rows = src_table._num_rows()
616
+ dst_num_rows = dst_table._num_rows()
617
+ if (src_num_rows is not None and dst_num_rows is not None
618
+ and src_num_rows > dst_num_rows):
532
619
  score += 1.0
533
620
 
534
621
  if score < 5.0:
@@ -790,7 +877,7 @@ class LocalGraph:
790
877
  def __contains__(self, name: str) -> bool:
791
878
  return self.has_table(name)
792
879
 
793
- def __getitem__(self, name: str) -> LocalTable:
880
+ def __getitem__(self, name: str) -> Table:
794
881
  return self.table(name)
795
882
 
796
883
  def __delitem__(self, name: str) -> None:
@@ -2,10 +2,12 @@ from .id import contains_id
2
2
  from .timestamp import contains_timestamp
3
3
  from .categorical import contains_categorical
4
4
  from .multicategorical import contains_multicategorical
5
+ from .stype import infer_stype
5
6
 
6
7
  __all__ = [
7
8
  'contains_id',
8
9
  'contains_timestamp',
9
10
  'contains_categorical',
10
11
  'contains_multicategorical',
12
+ 'infer_stype',
11
13
  ]
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from kumoapi.typing import Dtype, Stype
3
+
4
+ from kumoai.experimental.rfm.infer import (
5
+ contains_categorical,
6
+ contains_id,
7
+ contains_multicategorical,
8
+ contains_timestamp,
9
+ )
10
+
11
+
12
+ def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
13
+ r"""Infers the semantic type of a column.
14
+
15
+ Args:
16
+ ser: A :class:`pandas.Series` to analyze.
17
+ column_name: The name of the column (used for pattern matching).
18
+ dtype: The data type.
19
+
20
+ Returns:
21
+ The semantic type.
22
+ """
23
+ if contains_id(ser, column_name, dtype):
24
+ return Stype.ID
25
+
26
+ if contains_timestamp(ser, column_name, dtype):
27
+ return Stype.timestamp
28
+
29
+ if contains_multicategorical(ser, column_name, dtype):
30
+ return Stype.multicategorical
31
+
32
+ if contains_categorical(ser, column_name, dtype):
33
+ return Stype.categorical
34
+
35
+ return dtype.default_stype
@@ -6,7 +6,7 @@ import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import LocalGraph
9
+ from kumoai.experimental.rfm import Graph, LocalTable
10
10
  from kumoai.experimental.rfm.utils import normalize_text
11
11
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
12
12
 
@@ -20,7 +20,7 @@ except ImportError:
20
20
  class LocalGraphStore:
21
21
  def __init__(
22
22
  self,
23
- graph: LocalGraph,
23
+ graph: Graph,
24
24
  preprocess: bool = False,
25
25
  verbose: Union[bool, ProgressLogger] = True,
26
26
  ) -> None:
@@ -105,7 +105,7 @@ class LocalGraphStore:
105
105
 
106
106
  def sanitize(
107
107
  self,
108
- graph: LocalGraph,
108
+ graph: Graph,
109
109
  preprocess: bool = False,
110
110
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
111
111
  r"""Sanitizes raw data according to table schema definition:
@@ -120,10 +120,11 @@ class LocalGraphStore:
120
120
  data for faster model processing. In particular, it:
121
121
  * tokenizes any text column that is not a foreign key
122
122
  """
123
- df_dict: Dict[str, pd.DataFrame] = {
124
- table_name: table._data.copy(deep=False).reset_index(drop=True)
125
- for table_name, table in graph.tables.items()
126
- }
123
+ df_dict: Dict[str, pd.DataFrame] = {}
124
+ for table_name, table in graph.tables.items():
125
+ assert isinstance(table, LocalTable)
126
+ df = table._data
127
+ df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
127
128
 
128
129
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
129
130
 
@@ -165,7 +166,7 @@ class LocalGraphStore:
165
166
 
166
167
  return df_dict, mask_dict
167
168
 
168
- def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
169
+ def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
169
170
  stype_dict: Dict[str, Dict[str, Stype]] = {}
170
171
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
171
172
  for table in graph.tables.values():
@@ -180,7 +181,7 @@ class LocalGraphStore:
180
181
 
181
182
  def get_pkey_data(
182
183
  self,
183
- graph: LocalGraph,
184
+ graph: Graph,
184
185
  ) -> Tuple[
185
186
  Dict[str, str],
186
187
  Dict[str, pd.DataFrame],
@@ -218,7 +219,7 @@ class LocalGraphStore:
218
219
 
219
220
  def get_time_data(
220
221
  self,
221
- graph: LocalGraph,
222
+ graph: Graph,
222
223
  ) -> Tuple[
223
224
  Dict[str, str],
224
225
  Dict[str, str],
@@ -259,7 +260,7 @@ class LocalGraphStore:
259
260
 
260
261
  def get_csc(
261
262
  self,
262
- graph: LocalGraph,
263
+ graph: Graph,
263
264
  ) -> Tuple[
264
265
  Dict[Tuple[str, str, str], np.ndarray],
265
266
  Dict[Tuple[str, str, str], np.ndarray],
@@ -32,7 +32,7 @@ from kumoapi.task import TaskType
32
32
 
33
33
  from kumoai.client.rfm import RFMAPI
34
34
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
35
+ from kumoai.experimental.rfm import Graph
36
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
37
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
38
  from kumoai.experimental.rfm.local_pquery_driver import (
@@ -123,17 +123,17 @@ class KumoRFM:
123
123
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
124
  relational dataset without training.
125
125
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
126
+ model from a :class:`Graph` object.
127
127
 
128
128
  .. code-block:: python
129
129
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
130
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
131
 
132
132
  df_users = pd.DataFrame(...)
133
133
  df_items = pd.DataFrame(...)
134
134
  df_orders = pd.DataFrame(...)
135
135
 
136
- graph = LocalGraph.from_data({
136
+ graph = Graph.from_data({
137
137
  'users': df_users,
138
138
  'items': df_items,
139
139
  'orders': df_orders,
@@ -163,7 +163,7 @@ class KumoRFM:
163
163
  """
164
164
  def __init__(
165
165
  self,
166
- graph: LocalGraph,
166
+ graph: Graph,
167
167
  preprocess: bool = False,
168
168
  verbose: Union[bool, ProgressLogger] = True,
169
169
  ) -> None:
@@ -2,15 +2,22 @@ import base64
2
2
  import json
3
3
  from typing import Any, Dict, List, Tuple
4
4
 
5
- import boto3
6
5
  import requests
7
- from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
8
- from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
9
6
 
10
7
  from kumoai.client import KumoClient
11
8
  from kumoai.client.endpoints import Endpoint, HTTPMethod
12
9
  from kumoai.exceptions import HTTPException
13
10
 
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
14
21
 
15
22
  class SageMakerResponseAdapter(requests.Response):
16
23
  def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
@@ -34,6 +41,7 @@ class SageMakerResponseAdapter(requests.Response):
34
41
 
35
42
  class KumoClient_SageMakerAdapter(KumoClient):
36
43
  def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
37
45
  self._client: SageMakerRuntimeClient = boto3.client(
38
46
  service_name="sagemaker-runtime", region_name=region)
39
47
  self._endpoint_name = endpoint_name