kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +37 -8
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
  12. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  13. kumoai/experimental/rfm/backend/local/table.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
  20. kumoai/experimental/rfm/base/__init__.py +25 -0
  21. kumoai/experimental/rfm/base/column.py +66 -0
  22. kumoai/experimental/rfm/base/sampler.py +773 -0
  23. kumoai/experimental/rfm/base/source.py +19 -0
  24. kumoai/experimental/rfm/base/sql_sampler.py +60 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
  27. kumoai/experimental/rfm/infer/__init__.py +6 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +126 -0
  30. kumoai/experimental/rfm/infer/time_col.py +62 -0
  31. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  32. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  33. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  34. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
  35. kumoai/experimental/rfm/rfm.py +669 -246
  36. kumoai/experimental/rfm/sagemaker.py +138 -0
  37. kumoai/jobs.py +1 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/trainer.py +12 -10
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +239 -4
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
  47. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -176
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -404
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
  52. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
8
 
8
9
  import pandas as pd
9
10
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,160 +12,401 @@ from kumoapi.table import TableDefinition
11
12
  from kumoapi.typing import Stype
12
13
  from typing_extensions import Self
13
14
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai import in_notebook, in_snowflake_notebook
16
+ from kumoai.experimental.rfm.base import DataBackend, Table
16
17
  from kumoai.graph import Edge
18
+ from kumoai.mixin import CastMixin
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  import graphviz
22
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
23
+ from snowflake.connector import SnowflakeConnection
20
24
 
21
25
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
26
+ @dataclass
27
+ class SqliteConnectionConfig(CastMixin):
28
+ uri: Union[str, Path]
29
+ kwargs: Dict[str, Any] = field(default_factory=dict)
30
+
31
+
32
+ class Graph:
33
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
34
  tables in a relational database.
25
35
 
26
36
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
37
+ :class:`Graph` is created, you can use it to initialize the
28
38
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
39
 
30
40
  .. code-block:: python
31
41
 
32
- import pandas as pd
33
- import kumoai.experimental.rfm as rfm
42
+ >>> # doctest: +SKIP
43
+ >>> import pandas as pd
44
+ >>> import kumoai.experimental.rfm as rfm
34
45
 
35
- # Load data frames into memory:
36
- df1 = pd.DataFrame(...)
37
- df2 = pd.DataFrame(...)
38
- df3 = pd.DataFrame(...)
46
+ >>> # Load data frames into memory:
47
+ >>> df1 = pd.DataFrame(...)
48
+ >>> df2 = pd.DataFrame(...)
49
+ >>> df3 = pd.DataFrame(...)
39
50
 
40
- # Define tables from data frames:
41
- table1 = rfm.LocalTable(name="table1", data=df1)
42
- table2 = rfm.LocalTable(name="table2", data=df2)
43
- table3 = rfm.LocalTable(name="table3", data=df3)
51
+ >>> # Define tables from data frames:
52
+ >>> table1 = rfm.LocalTable(name="table1", data=df1)
53
+ >>> table2 = rfm.LocalTable(name="table2", data=df2)
54
+ >>> table3 = rfm.LocalTable(name="table3", data=df3)
44
55
 
45
- # Create a graph from a dictionary of tables:
46
- graph = rfm.LocalGraph({
47
- "table1": table1,
48
- "table2": table2,
49
- "table3": table3,
50
- })
56
+ >>> # Create a graph from a dictionary of tables:
57
+ >>> graph = rfm.Graph({
58
+ ... "table1": table1,
59
+ ... "table2": table2,
60
+ ... "table3": table3,
61
+ ... })
51
62
 
52
- # Infer table metadata:
53
- graph.infer_metadata()
63
+ >>> # Infer table metadata:
64
+ >>> graph.infer_metadata()
54
65
 
55
- # Infer links/edges:
56
- graph.infer_links()
66
+ >>> # Infer links/edges:
67
+ >>> graph.infer_links()
57
68
 
58
- # Inspect table metadata:
59
- for table in graph.tables.values():
60
- table.print_metadata()
69
+ >>> # Inspect table metadata:
70
+ >>> for table in graph.tables.values():
71
+ ... table.print_metadata()
61
72
 
62
- # Visualize graph (if graphviz is installed):
63
- graph.visualize()
73
+ >>> # Visualize graph (if graphviz is installed):
74
+ >>> graph.visualize()
64
75
 
65
- # Add/Remove edges between tables:
66
- graph.link(src_table="table1", fkey="id1", dst_table="table2")
67
- graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
76
+ >>> # Add/Remove edges between tables:
77
+ >>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
78
+ >>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
68
79
 
69
- # Validate graph:
70
- graph.validate()
80
+ >>> # Validate graph:
81
+ >>> graph.validate()
71
82
  """
72
83
 
73
84
  # Constructors ############################################################
74
85
 
75
86
  def __init__(
76
87
  self,
77
- tables: List[LocalTable],
78
- edges: Optional[List[Edge]] = None,
88
+ tables: Sequence[Table],
89
+ edges: Optional[Sequence[Edge]] = None,
79
90
  ) -> None:
80
91
 
81
- self._tables: Dict[str, LocalTable] = {}
92
+ self._tables: Dict[str, Table] = {}
82
93
  self._edges: List[Edge] = []
83
94
 
84
95
  for table in tables:
85
96
  self.add_table(table)
86
97
 
98
+ for table in tables:
99
+ for fkey in table._source_foreign_key_dict.values():
100
+ if fkey.name not in table or fkey.dst_table not in self:
101
+ continue
102
+ if self[fkey.dst_table].primary_key is None:
103
+ self[fkey.dst_table].primary_key = fkey.primary_key
104
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
+ raise ValueError(f"Found duplicate primary key definition "
106
+ f"'{self[fkey.dst_table]._primary_key}' "
107
+ f"and '{fkey.primary_key}' in table "
108
+ f"'{fkey.dst_table}'.")
109
+ self.link(table.name, fkey.name, fkey.dst_table)
110
+
87
111
  for edge in (edges or []):
88
112
  _edge = Edge._cast(edge)
89
113
  assert _edge is not None
90
- self.link(*_edge)
114
+ if _edge not in self._edges:
115
+ self.link(*_edge)
91
116
 
92
117
  @classmethod
93
118
  def from_data(
94
119
  cls,
95
120
  df_dict: Dict[str, pd.DataFrame],
96
- edges: Optional[List[Edge]] = None,
121
+ edges: Optional[Sequence[Edge]] = None,
97
122
  infer_metadata: bool = True,
98
123
  verbose: bool = True,
99
124
  ) -> Self:
100
- r"""Creates a :class:`LocalGraph` from a dictionary of
125
+ r"""Creates a :class:`Graph` from a dictionary of
101
126
  :class:`pandas.DataFrame` objects.
102
127
 
103
- Automatically infers table metadata and links.
128
+ Automatically infers table metadata and links by default.
104
129
 
105
130
  .. code-block:: python
106
131
 
107
- import pandas as pd
108
- import kumoai.experimental.rfm as rfm
109
-
110
- # Load data frames into memory:
111
- df1 = pd.DataFrame(...)
112
- df2 = pd.DataFrame(...)
113
- df3 = pd.DataFrame(...)
114
-
115
- # Create a graph from a dictionary of data frames:
116
- graph = rfm.LocalGraph.from_data({
117
- "table1": df1,
118
- "table2": df2,
119
- "table3": df3,
120
- })
132
+ >>> # doctest: +SKIP
133
+ >>> import pandas as pd
134
+ >>> import kumoai.experimental.rfm as rfm
121
135
 
122
- # Inspect table metadata:
123
- for table in graph.tables.values():
124
- table.print_metadata()
136
+ >>> # Load data frames into memory:
137
+ >>> df1 = pd.DataFrame(...)
138
+ >>> df2 = pd.DataFrame(...)
139
+ >>> df3 = pd.DataFrame(...)
125
140
 
126
- # Visualize graph (if graphviz is installed):
127
- graph.visualize()
141
+ >>> # Create a graph from a dictionary of data frames:
142
+ >>> graph = rfm.Graph.from_data({
143
+ ... "table1": df1,
144
+ ... "table2": df2,
145
+ ... "table3": df3,
146
+ ... })
128
147
 
129
148
  Args:
130
149
  df_dict: A dictionary of data frames, where the keys are the names
131
150
  of the tables and the values hold table data.
151
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
152
+ add to the graph. If not provided, edges will be automatically
153
+ inferred from the data in case ``infer_metadata=True``.
132
154
  infer_metadata: Whether to infer metadata for all tables in the
133
155
  graph.
156
+ verbose: Whether to print verbose output.
157
+ """
158
+ from kumoai.experimental.rfm.backend.local import LocalTable
159
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
+
161
+ graph = cls(tables, edges=edges or [])
162
+
163
+ if infer_metadata:
164
+ graph.infer_metadata(False)
165
+
166
+ if edges is None:
167
+ graph.infer_links(False)
168
+
169
+ if verbose:
170
+ graph.print_metadata()
171
+ graph.print_links()
172
+
173
+ return graph
174
+
175
+ @classmethod
176
+ def from_sqlite(
177
+ cls,
178
+ connection: Union[
179
+ 'AdbcSqliteConnection',
180
+ SqliteConnectionConfig,
181
+ str,
182
+ Path,
183
+ Dict[str, Any],
184
+ ],
185
+ table_names: Optional[Sequence[str]] = None,
186
+ edges: Optional[Sequence[Edge]] = None,
187
+ infer_metadata: bool = True,
188
+ verbose: bool = True,
189
+ ) -> Self:
190
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
191
+
192
+ Automatically infers table metadata and links by default.
193
+
194
+ .. code-block:: python
195
+
196
+ >>> # doctest: +SKIP
197
+ >>> import kumoai.experimental.rfm as rfm
198
+
199
+ >>> # Create a graph from a SQLite database:
200
+ >>> graph = rfm.Graph.from_sqlite('data.db')
201
+
202
+ Args:
203
+ connection: An open connection from
204
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
+ path to the database file.
206
+ table_names: Set of table names to include. If ``None``, will add
207
+ all tables present in the database.
134
208
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
135
209
  add to the graph. If not provided, edges will be automatically
136
- inferred from the data.
210
+ inferred from the data in case ``infer_metadata=True``.
211
+ infer_metadata: Whether to infer metadata for all tables in the
212
+ graph.
137
213
  verbose: Whether to print verbose output.
214
+ """
215
+ from kumoai.experimental.rfm.backend.sqlite import (
216
+ Connection,
217
+ SQLiteTable,
218
+ connect,
219
+ )
220
+
221
+ internal_connection = False
222
+ if not isinstance(connection, Connection):
223
+ connection = SqliteConnectionConfig._cast(connection)
224
+ assert isinstance(connection, SqliteConnectionConfig)
225
+ connection = connect(connection.uri, **connection.kwargs)
226
+ internal_connection = True
227
+ assert isinstance(connection, Connection)
228
+
229
+ if table_names is None:
230
+ with connection.cursor() as cursor:
231
+ cursor.execute("SELECT name FROM sqlite_master "
232
+ "WHERE type='table'")
233
+ table_names = [row[0] for row in cursor.fetchall()]
234
+
235
+ tables = [SQLiteTable(connection, name) for name in table_names]
138
236
 
139
- Note:
140
- This method will automatically infer metadata and links for the
141
- graph.
237
+ graph = cls(tables, edges=edges or [])
142
238
 
143
- Example:
239
+ if internal_connection:
240
+ graph._connection = connection # type: ignore
241
+
242
+ if infer_metadata:
243
+ graph.infer_metadata(False)
244
+
245
+ if edges is None:
246
+ graph.infer_links(False)
247
+
248
+ if verbose:
249
+ graph.print_metadata()
250
+ graph.print_links()
251
+
252
+ return graph
253
+
254
+ @classmethod
255
+ def from_snowflake(
256
+ cls,
257
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
258
+ database: Optional[str] = None,
259
+ schema: Optional[str] = None,
260
+ table_names: Optional[Sequence[str]] = None,
261
+ edges: Optional[Sequence[Edge]] = None,
262
+ infer_metadata: bool = True,
263
+ verbose: bool = True,
264
+ ) -> Self:
265
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
266
+ schema.
267
+
268
+ Automatically infers table metadata and links by default.
269
+
270
+ .. code-block:: python
271
+
272
+ >>> # doctest: +SKIP
144
273
  >>> import kumoai.experimental.rfm as rfm
145
- >>> df1 = pd.DataFrame(...)
146
- >>> df2 = pd.DataFrame(...)
147
- >>> df3 = pd.DataFrame(...)
148
- >>> graph = rfm.LocalGraph.from_data(data={
149
- ... "table1": df1,
150
- ... "table2": df2,
151
- ... "table3": df3,
152
- ... })
153
- ... graph.validate()
274
+
275
+ >>> # Create a graph directly in a Snowflake notebook:
276
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
277
+
278
+ Args:
279
+ connection: An open connection from
280
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
281
+ :class:`snowflake` connector keyword arguments to open a new
282
+ connection. If ``None``, will re-use an active session in case
283
+ it exists, or create a new connection from credentials stored
284
+ in environment variables.
285
+ database: The database.
286
+ schema: The schema.
287
+ table_names: Set of table names to include. If ``None``, will add
288
+ all tables present in the database.
289
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
290
+ add to the graph. If not provided, edges will be automatically
291
+ inferred from the data in case ``infer_metadata=True``.
292
+ infer_metadata: Whether to infer metadata for all tables in the
293
+ graph.
294
+ verbose: Whether to print verbose output.
154
295
  """
155
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
296
+ from kumoai.experimental.rfm.backend.snow import (
297
+ Connection,
298
+ SnowTable,
299
+ connect,
300
+ )
301
+
302
+ if not isinstance(connection, Connection):
303
+ connection = connect(**(connection or {}))
304
+ assert isinstance(connection, Connection)
305
+
306
+ if table_names is None:
307
+ with connection.cursor() as cursor:
308
+ if database is None and schema is None:
309
+ cursor.execute("SELECT CURRENT_DATABASE(), "
310
+ "CURRENT_SCHEMA()")
311
+ result = cursor.fetchone()
312
+ database = database or result[0]
313
+ schema = schema or result[1]
314
+ cursor.execute(f"""
315
+ SELECT TABLE_NAME
316
+ FROM {database}.INFORMATION_SCHEMA.TABLES
317
+ WHERE TABLE_SCHEMA = '{schema}'
318
+ """)
319
+ table_names = [row[0] for row in cursor.fetchall()]
320
+
321
+ tables = [
322
+ SnowTable(
323
+ connection,
324
+ name=table_name,
325
+ database=database,
326
+ schema=schema,
327
+ ) for table_name in table_names
328
+ ]
156
329
 
157
330
  graph = cls(tables, edges=edges or [])
158
331
 
159
332
  if infer_metadata:
160
- graph.infer_metadata(verbose)
333
+ graph.infer_metadata(False)
161
334
 
162
335
  if edges is None:
163
- graph.infer_links(verbose)
336
+ graph.infer_links(False)
337
+
338
+ if verbose:
339
+ graph.print_metadata()
340
+ graph.print_links()
164
341
 
165
342
  return graph
166
343
 
167
- # Tables ##############################################################
344
+ @classmethod
345
+ def from_snowflake_semantic_view(
346
+ cls,
347
+ semantic_view_name: str,
348
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
349
+ verbose: bool = True,
350
+ ) -> Self:
351
+ import yaml
352
+
353
+ from kumoai.experimental.rfm.backend.snow import (
354
+ Connection,
355
+ SnowTable,
356
+ connect,
357
+ )
358
+
359
+ if not isinstance(connection, Connection):
360
+ connection = connect(**(connection or {}))
361
+ assert isinstance(connection, Connection)
362
+
363
+ with connection.cursor() as cursor:
364
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
365
+ f"'{semantic_view_name}')")
366
+ view = yaml.safe_load(cursor.fetchone()[0])
367
+
368
+ graph = cls(tables=[])
369
+
370
+ for table_desc in view['tables']:
371
+ primary_key: Optional[str] = None
372
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
373
+ and len(table_desc['primary_key']['columns']) == 1):
374
+ primary_key = table_desc['primary_key']['columns'][0]
375
+
376
+ table = SnowTable(
377
+ connection,
378
+ name=table_desc['base_table']['table'],
379
+ database=table_desc['base_table']['database'],
380
+ schema=table_desc['base_table']['schema'],
381
+ primary_key=primary_key,
382
+ )
383
+ graph.add_table(table)
384
+
385
+ # TODO Find a solution to register time columns!
386
+
387
+ for relations in view['relationships']:
388
+ if len(relations['relationship_columns']) != 1:
389
+ continue # NOTE No composite keys yet.
390
+ graph.link(
391
+ src_table=relations['left_table'],
392
+ fkey=relations['relationship_columns'][0]['left_column'],
393
+ dst_table=relations['right_table'],
394
+ )
395
+
396
+ if verbose:
397
+ graph.print_metadata()
398
+ graph.print_links()
399
+
400
+ return graph
401
+
402
+ # Backend #################################################################
403
+
404
+ @property
405
+ def backend(self) -> DataBackend | None:
406
+ backends = [table.backend for table in self._tables.values()]
407
+ return backends[0] if len(backends) > 0 else None
408
+
409
+ # Tables ##################################################################
168
410
 
169
411
  def has_table(self, name: str) -> bool:
170
412
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -172,7 +414,7 @@ class LocalGraph:
172
414
  """
173
415
  return name in self.tables
174
416
 
175
- def table(self, name: str) -> LocalTable:
417
+ def table(self, name: str) -> Table:
176
418
  r"""Returns the table with name ``name`` in the graph.
177
419
 
178
420
  Raises:
@@ -183,11 +425,11 @@ class LocalGraph:
183
425
  return self.tables[name]
184
426
 
185
427
  @property
186
- def tables(self) -> Dict[str, LocalTable]:
428
+ def tables(self) -> Dict[str, Table]:
187
429
  r"""Returns the dictionary of table objects."""
188
430
  return self._tables
189
431
 
190
- def add_table(self, table: LocalTable) -> Self:
432
+ def add_table(self, table: Table) -> Self:
191
433
  r"""Adds a table to the graph.
192
434
 
193
435
  Args:
@@ -196,11 +438,18 @@ class LocalGraph:
196
438
  Raises:
197
439
  KeyError: If a table with the same name already exists in the
198
440
  graph.
441
+ ValueError: If the table belongs to a different backend than the
442
+ rest of the tables in the graph.
199
443
  """
200
444
  if table.name in self._tables:
201
445
  raise KeyError(f"Cannot add table with name '{table.name}' to "
202
446
  f"this graph; table names must be globally unique.")
203
447
 
448
+ if self.backend is not None and table.backend != self.backend:
449
+ raise ValueError(f"Cannot register a table with backend "
450
+ f"'{table.backend}' to this graph since other "
451
+ f"tables have backend '{self.backend}'.")
452
+
204
453
  self._tables[table.name] = table
205
454
 
206
455
  return self
@@ -231,16 +480,17 @@ class LocalGraph:
231
480
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
232
481
  information about the tables in this graph.
233
482
 
234
- The returned dataframe has columns ``name``, ``primary_key``, and
235
- ``time_column``, which provide an aggregate view of the properties of
236
- the tables of this graph.
483
+ The returned dataframe has columns ``name``, ``primary_key``,
484
+ ``time_column``, and ``end_time_column``, which provide an aggregate
485
+ view of the properties of the tables of this graph.
237
486
 
238
487
  Example:
488
+ >>> # doctest: +SKIP
239
489
  >>> import kumoai.experimental.rfm as rfm
240
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
241
- >>> graph.metadata
242
- name primary_key time_column
243
- 0 users user_id -
490
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
491
+ >>> graph.metadata # doctest: +SKIP
492
+ name primary_key time_column end_time_column
493
+ 0 users user_id - -
244
494
  """
245
495
  tables = list(self.tables.values())
246
496
 
@@ -251,13 +501,22 @@ class LocalGraph:
251
501
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
252
502
  'time_column':
253
503
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
504
+ 'end_time_column':
505
+ pd.Series(
506
+ dtype=str,
507
+ data=[t._end_time_column or '-' for t in tables],
508
+ ),
254
509
  })
255
510
 
256
511
  def print_metadata(self) -> None:
257
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
258
- if in_notebook():
512
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
513
+ if in_snowflake_notebook():
514
+ import streamlit as st
515
+ st.markdown("### 🗂️ Graph Metadata")
516
+ st.dataframe(self.metadata, hide_index=True)
517
+ elif in_notebook():
259
518
  from IPython.display import Markdown, display
260
- display(Markdown('### 🗂️ Graph Metadata'))
519
+ display(Markdown("### 🗂️ Graph Metadata"))
261
520
  df = self.metadata
262
521
  try:
263
522
  if hasattr(df.style, 'hide'):
@@ -278,7 +537,7 @@ class LocalGraph:
278
537
 
279
538
  Note:
280
539
  For more information, please see
281
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
540
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
282
541
  """
283
542
  for table in self.tables.values():
284
543
  table.infer_metadata(verbose=False)
@@ -296,37 +555,47 @@ class LocalGraph:
296
555
  return self._edges
297
556
 
298
557
  def print_links(self) -> None:
299
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
558
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
300
559
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
301
560
  edge.src_table, edge.fkey) for edge in self.edges]
302
561
  edges = sorted(edges)
303
562
 
304
- if in_notebook():
563
+ if in_snowflake_notebook():
564
+ import streamlit as st
565
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
566
+ if len(edges) > 0:
567
+ st.markdown('\n'.join([
568
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
569
+ for edge in edges
570
+ ]))
571
+ else:
572
+ st.markdown("*No links registered*")
573
+ elif in_notebook():
305
574
  from IPython.display import Markdown, display
306
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
575
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
307
576
  if len(edges) > 0:
308
577
  display(
309
578
  Markdown('\n'.join([
310
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
579
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
311
580
  for edge in edges
312
581
  ])))
313
582
  else:
314
- display(Markdown('*No links registered*'))
583
+ display(Markdown("*No links registered*"))
315
584
  else:
316
585
  print("🕸️ Graph Links (FK ↔️ PK):")
317
586
  if len(edges) > 0:
318
587
  print('\n'.join([
319
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
588
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
320
589
  for edge in edges
321
590
  ]))
322
591
  else:
323
- print('No links registered')
592
+ print("No links registered")
324
593
 
325
594
  def link(
326
595
  self,
327
- src_table: Union[str, LocalTable],
596
+ src_table: Union[str, Table],
328
597
  fkey: str,
329
- dst_table: Union[str, LocalTable],
598
+ dst_table: Union[str, Table],
330
599
  ) -> Self:
331
600
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
332
601
  key ``fkey`` in the source table to the primary key in the destination
@@ -349,11 +618,11 @@ class LocalGraph:
349
618
  table does not exist in the graph, if the source key does not
350
619
  exist in the source table.
351
620
  """
352
- if isinstance(src_table, LocalTable):
621
+ if isinstance(src_table, Table):
353
622
  src_table = src_table.name
354
623
  assert isinstance(src_table, str)
355
624
 
356
- if isinstance(dst_table, LocalTable):
625
+ if isinstance(dst_table, Table):
357
626
  dst_table = dst_table.name
358
627
  assert isinstance(dst_table, str)
359
628
 
@@ -387,9 +656,9 @@ class LocalGraph:
387
656
 
388
657
  def unlink(
389
658
  self,
390
- src_table: Union[str, LocalTable],
659
+ src_table: Union[str, Table],
391
660
  fkey: str,
392
- dst_table: Union[str, LocalTable],
661
+ dst_table: Union[str, Table],
393
662
  ) -> Self:
394
663
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
395
664
 
@@ -401,11 +670,11 @@ class LocalGraph:
401
670
  Raises:
402
671
  ValueError: if the edge is not present in the graph.
403
672
  """
404
- if isinstance(src_table, LocalTable):
673
+ if isinstance(src_table, Table):
405
674
  src_table = src_table.name
406
675
  assert isinstance(src_table, str)
407
676
 
408
- if isinstance(dst_table, LocalTable):
677
+ if isinstance(dst_table, Table):
409
678
  dst_table = dst_table.name
410
679
  assert isinstance(dst_table, str)
411
680
 
@@ -419,17 +688,13 @@ class LocalGraph:
419
688
  return self
420
689
 
421
690
  def infer_links(self, verbose: bool = True) -> Self:
422
- r"""Infers links for the tables and adds them as edges to the graph.
691
+ r"""Infers missing links for the tables and adds them as edges to the
692
+ graph.
423
693
 
424
694
  Args:
425
695
  verbose: Whether to print verbose output.
426
-
427
- Note:
428
- This function expects graph edges to be undefined upfront.
429
696
  """
430
- if len(self.edges) > 0:
431
- warnings.warn("Cannot infer links if graph edges already exist")
432
- return self
697
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
433
698
 
434
699
  # A list of primary key candidates (+score) for every column:
435
700
  candidate_dict: dict[
@@ -454,6 +719,9 @@ class LocalGraph:
454
719
  src_table_name = src_table.name.lower()
455
720
 
456
721
  for src_key in src_table.columns:
722
+ if (src_table.name, src_key.name) in known_edges:
723
+ continue
724
+
457
725
  if src_key == src_table.primary_key:
458
726
  continue # Cannot link to primary key.
459
727
 
@@ -519,7 +787,9 @@ class LocalGraph:
519
787
  score += 1.0
520
788
 
521
789
  # Cardinality ratio:
522
- if len(src_table._data) > len(dst_table._data):
790
+ if (src_table._num_rows is not None
791
+ and dst_table._num_rows is not None
792
+ and src_table._num_rows > dst_table._num_rows):
523
793
  score += 1.0
524
794
 
525
795
  if score < 5.0:
@@ -565,6 +835,10 @@ class LocalGraph:
565
835
  raise ValueError("At least one table needs to be added to the "
566
836
  "graph")
567
837
 
838
+ backends = {table.backend for table in self._tables.values()}
839
+ if len(backends) != 1:
840
+ raise ValueError("Found multiple table backends in the graph")
841
+
568
842
  for edge in self.edges:
569
843
  src_table, fkey, dst_table = edge
570
844
 
@@ -602,8 +876,8 @@ class LocalGraph:
602
876
  raise ValueError(f"{edge} is invalid as foreign key "
603
877
  f"'{fkey}' and primary key '{dst_key.name}' "
604
878
  f"have incompatible data types (got "
605
- f"fkey.dtype '{dst_key.dtype}' and "
606
- f"pkey.dtype '{src_key.dtype}')")
879
+ f"fkey.dtype '{src_key.dtype}' and "
880
+ f"pkey.dtype '{dst_key.dtype}')")
607
881
 
608
882
  return self
609
883
 
@@ -636,19 +910,19 @@ class LocalGraph:
636
910
 
637
911
  return True
638
912
 
639
- # Check basic dependency:
640
- if not find_spec('graphviz'):
641
- raise ModuleNotFoundError("The 'graphviz' package is required for "
642
- "visualization")
643
- elif not has_graphviz_executables():
913
+ try: # Check basic dependency:
914
+ import graphviz
915
+ except ImportError as e:
916
+ raise ImportError("The 'graphviz' package is required for "
917
+ "visualization") from e
918
+
919
+ if not in_snowflake_notebook() and not has_graphviz_executables():
644
920
  raise RuntimeError("Could not visualize graph as 'graphviz' "
645
921
  "executables are not installed. These "
646
922
  "dependencies are required in addition to the "
647
923
  "'graphviz' Python package. Please install "
648
924
  "them as described at "
649
925
  "https://graphviz.org/download/.")
650
- else:
651
- import graphviz
652
926
 
653
927
  format: Optional[str] = None
654
928
  if isinstance(path, str):
@@ -676,6 +950,11 @@ class LocalGraph:
676
950
  ]
677
951
  if time_column := table.time_column:
678
952
  keys += [f'{time_column.name}: Time ({time_column.dtype})']
953
+ if end_time_column := table.end_time_column:
954
+ keys += [
955
+ f'{end_time_column.name}: '
956
+ f'End Time ({end_time_column.dtype})'
957
+ ]
679
958
  key_repr = left_align(keys)
680
959
 
681
960
  columns = []
@@ -683,9 +962,9 @@ class LocalGraph:
683
962
  columns += [
684
963
  f'{column.name}: {column.stype} ({column.dtype})'
685
964
  for column in table.columns
686
- if column.name not in fkeys_dict[table_name]
687
- and column.name != table._primary_key
688
- and column.name != table._time_column
965
+ if column.name not in fkeys_dict[table_name] and
966
+ column.name != table._primary_key and column.name != table.
967
+ _time_column and column.name != table._end_time_column
689
968
  ]
690
969
  column_repr = left_align(columns)
691
970
 
@@ -727,6 +1006,9 @@ class LocalGraph:
727
1006
  graph.render(path, cleanup=True)
728
1007
  elif isinstance(path, io.BytesIO):
729
1008
  path.write(graph.pipe())
1009
+ elif in_snowflake_notebook():
1010
+ import streamlit as st
1011
+ st.graphviz_chart(graph)
730
1012
  elif in_notebook():
731
1013
  from IPython.display import display
732
1014
  display(graph)
@@ -752,16 +1034,18 @@ class LocalGraph:
752
1034
  def _to_api_graph_definition(self) -> GraphDefinition:
753
1035
  tables: Dict[str, TableDefinition] = {}
754
1036
  col_groups: List[ColumnKeyGroup] = []
755
- for t_name, table in self.tables.items():
756
- tables[t_name] = table._to_api_table_definition()
1037
+ for table_name, table in self.tables.items():
1038
+ tables[table_name] = table._to_api_table_definition()
757
1039
  if table.primary_key is None:
758
1040
  continue
759
- keys = [ColumnKey(t_name, table.primary_key.name)]
1041
+ keys = [ColumnKey(table_name, table.primary_key.name)]
760
1042
  for edge in self.edges:
761
- if edge.dst_table == t_name:
1043
+ if edge.dst_table == table_name:
762
1044
  keys.append(ColumnKey(edge.src_table, edge.fkey))
763
- keys = sorted(list(set(keys)),
764
- key=lambda x: f'{x.table_name}.{x.col_name}')
1045
+ keys = sorted(
1046
+ list(set(keys)),
1047
+ key=lambda x: f'{x.table_name}.{x.col_name}',
1048
+ )
765
1049
  if len(keys) > 1:
766
1050
  col_groups.append(ColumnKeyGroup(keys))
767
1051
  return GraphDefinition(tables, col_groups)
@@ -774,7 +1058,7 @@ class LocalGraph:
774
1058
  def __contains__(self, name: str) -> bool:
775
1059
  return self.has_table(name)
776
1060
 
777
- def __getitem__(self, name: str) -> LocalTable:
1061
+ def __getitem__(self, name: str) -> Table:
778
1062
  return self.table(name)
779
1063
 
780
1064
  def __delitem__(self, name: str) -> None:
@@ -792,3 +1076,7 @@ class LocalGraph:
792
1076
  f' tables={tables},\n'
793
1077
  f' edges={edges},\n'
794
1078
  f')')
1079
+
1080
+ def __del__(self) -> None:
1081
+ if hasattr(self, '_connection'):
1082
+ self._connection.close()