kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (52) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/utils.py +1399 -210
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +10 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/source.py +18 -0
  20. kumoai/experimental/rfm/base/table.py +545 -0
  21. kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
  22. kumoai/experimental/rfm/infer/__init__.py +6 -0
  23. kumoai/experimental/rfm/infer/dtype.py +79 -0
  24. kumoai/experimental/rfm/infer/pkey.py +126 -0
  25. kumoai/experimental/rfm/infer/time_col.py +62 -0
  26. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  27. kumoai/experimental/rfm/local_graph_sampler.py +58 -11
  28. kumoai/experimental/rfm/local_graph_store.py +45 -37
  29. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  30. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  31. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  32. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  33. kumoai/experimental/rfm/rfm.py +559 -148
  34. kumoai/experimental/rfm/sagemaker.py +138 -0
  35. kumoai/jobs.py +27 -1
  36. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  37. kumoai/pquery/prediction_table.py +5 -3
  38. kumoai/pquery/training_table.py +5 -3
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/trainer/job.py +9 -30
  42. kumoai/trainer/trainer.py +19 -10
  43. kumoai/utils/__init__.py +2 -1
  44. kumoai/utils/progress_logger.py +96 -16
  45. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
  46. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
  47. kumoai/experimental/rfm/local_table.py +0 -448
  48. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  49. kumoai/experimental/rfm/utils.py +0 -347
  50. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
  51. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
  52. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
8
 
8
9
  import pandas as pd
9
10
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,156 +12,385 @@ from kumoapi.table import TableDefinition
11
12
  from kumoapi.typing import Stype
12
13
  from typing_extensions import Self
13
14
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai import in_notebook, in_snowflake_notebook
16
+ from kumoai.experimental.rfm import Table
16
17
  from kumoai.graph import Edge
18
+ from kumoai.mixin import CastMixin
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  import graphviz
22
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
23
+ from snowflake.connector import SnowflakeConnection
20
24
 
21
25
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
26
+ @dataclass
27
+ class SqliteConnectionConfig(CastMixin):
28
+ uri: Union[str, Path]
29
+ kwargs: Dict[str, Any] = field(default_factory=dict)
30
+
31
+
32
+ class Graph:
33
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
34
  tables in a relational database.
25
35
 
26
36
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
37
+ :class:`Graph` is created, you can use it to initialize the
28
38
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
39
 
30
40
  .. code-block:: python
31
41
 
32
- import pandas as pd
33
- import kumoai.experimental.rfm as rfm
42
+ >>> # doctest: +SKIP
43
+ >>> import pandas as pd
44
+ >>> import kumoai.experimental.rfm as rfm
34
45
 
35
- # Load data frames into memory:
36
- df1 = pd.DataFrame(...)
37
- df2 = pd.DataFrame(...)
38
- df3 = pd.DataFrame(...)
46
+ >>> # Load data frames into memory:
47
+ >>> df1 = pd.DataFrame(...)
48
+ >>> df2 = pd.DataFrame(...)
49
+ >>> df3 = pd.DataFrame(...)
39
50
 
40
- # Define tables from data frames:
41
- table1 = rfm.LocalTable(name="table1", data=df1)
42
- table2 = rfm.LocalTable(name="table2", data=df2)
43
- table3 = rfm.LocalTable(name="table3", data=df3)
51
+ >>> # Define tables from data frames:
52
+ >>> table1 = rfm.LocalTable(name="table1", data=df1)
53
+ >>> table2 = rfm.LocalTable(name="table2", data=df2)
54
+ >>> table3 = rfm.LocalTable(name="table3", data=df3)
44
55
 
45
- # Create a graph from a dictionary of tables:
46
- graph = rfm.LocalGraph({
47
- "table1": table1,
48
- "table2": table2,
49
- "table3": table3,
50
- })
56
+ >>> # Create a graph from a dictionary of tables:
57
+ >>> graph = rfm.Graph({
58
+ ... "table1": table1,
59
+ ... "table2": table2,
60
+ ... "table3": table3,
61
+ ... })
51
62
 
52
- # Infer table metadata:
53
- graph.infer_metadata()
63
+ >>> # Infer table metadata:
64
+ >>> graph.infer_metadata()
54
65
 
55
- # Infer links/edges:
56
- graph.infer_links()
66
+ >>> # Infer links/edges:
67
+ >>> graph.infer_links()
57
68
 
58
- # Inspect table metadata:
59
- for table in graph.tables.values():
60
- table.print_metadata()
69
+ >>> # Inspect table metadata:
70
+ >>> for table in graph.tables.values():
71
+ ... table.print_metadata()
61
72
 
62
- # Visualize graph (if graphviz is installed):
63
- graph.visualize()
73
+ >>> # Visualize graph (if graphviz is installed):
74
+ >>> graph.visualize()
64
75
 
65
- # Add/Remove edges between tables:
66
- graph.link(src_table="table1", fkey="id1", dst_table="table2")
67
- graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
76
+ >>> # Add/Remove edges between tables:
77
+ >>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
78
+ >>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
68
79
 
69
- # Validate graph:
70
- graph.validate()
80
+ >>> # Validate graph:
81
+ >>> graph.validate()
71
82
  """
72
83
 
73
84
  # Constructors ############################################################
74
85
 
75
86
  def __init__(
76
87
  self,
77
- tables: List[LocalTable],
78
- edges: Optional[List[Edge]] = None,
88
+ tables: Sequence[Table],
89
+ edges: Optional[Sequence[Edge]] = None,
79
90
  ) -> None:
80
91
 
81
- self._tables: Dict[str, LocalTable] = {}
92
+ self._tables: Dict[str, Table] = {}
82
93
  self._edges: List[Edge] = []
83
94
 
84
95
  for table in tables:
85
96
  self.add_table(table)
86
97
 
98
+ for table in tables:
99
+ for fkey in table._source_foreign_key_dict.values():
100
+ if fkey.name not in table or fkey.dst_table not in self:
101
+ continue
102
+ if self[fkey.dst_table].primary_key is None:
103
+ self[fkey.dst_table].primary_key = fkey.primary_key
104
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
+ raise ValueError(f"Found duplicate primary key definition "
106
+ f"'{self[fkey.dst_table]._primary_key}' "
107
+ f"and '{fkey.primary_key}' in table "
108
+ f"'{fkey.dst_table}'.")
109
+ self.link(table.name, fkey.name, fkey.dst_table)
110
+
87
111
  for edge in (edges or []):
88
112
  _edge = Edge._cast(edge)
89
113
  assert _edge is not None
90
- self.link(*_edge)
114
+ if _edge not in self._edges:
115
+ self.link(*_edge)
91
116
 
92
117
  @classmethod
93
118
  def from_data(
94
119
  cls,
95
120
  df_dict: Dict[str, pd.DataFrame],
96
- edges: Optional[List[Edge]] = None,
121
+ edges: Optional[Sequence[Edge]] = None,
97
122
  infer_metadata: bool = True,
98
123
  verbose: bool = True,
99
124
  ) -> Self:
100
- r"""Creates a :class:`LocalGraph` from a dictionary of
125
+ r"""Creates a :class:`Graph` from a dictionary of
101
126
  :class:`pandas.DataFrame` objects.
102
127
 
103
- Automatically infers table metadata and links.
128
+ Automatically infers table metadata and links by default.
104
129
 
105
130
  .. code-block:: python
106
131
 
107
- import pandas as pd
108
- import kumoai.experimental.rfm as rfm
109
-
110
- # Load data frames into memory:
111
- df1 = pd.DataFrame(...)
112
- df2 = pd.DataFrame(...)
113
- df3 = pd.DataFrame(...)
114
-
115
- # Create a graph from a dictionary of data frames:
116
- graph = rfm.LocalGraph.from_data({
117
- "table1": df1,
118
- "table2": df2,
119
- "table3": df3,
120
- })
132
+ >>> # doctest: +SKIP
133
+ >>> import pandas as pd
134
+ >>> import kumoai.experimental.rfm as rfm
121
135
 
122
- # Inspect table metadata:
123
- for table in graph.tables.values():
124
- table.print_metadata()
136
+ >>> # Load data frames into memory:
137
+ >>> df1 = pd.DataFrame(...)
138
+ >>> df2 = pd.DataFrame(...)
139
+ >>> df3 = pd.DataFrame(...)
125
140
 
126
- # Visualize graph (if graphviz is installed):
127
- graph.visualize()
141
+ >>> # Create a graph from a dictionary of data frames:
142
+ >>> graph = rfm.Graph.from_data({
143
+ ... "table1": df1,
144
+ ... "table2": df2,
145
+ ... "table3": df3,
146
+ ... })
128
147
 
129
148
  Args:
130
149
  df_dict: A dictionary of data frames, where the keys are the names
131
150
  of the tables and the values hold table data.
151
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
152
+ add to the graph. If not provided, edges will be automatically
153
+ inferred from the data in case ``infer_metadata=True``.
132
154
  infer_metadata: Whether to infer metadata for all tables in the
133
155
  graph.
156
+ verbose: Whether to print verbose output.
157
+ """
158
+ from kumoai.experimental.rfm.backend.local import LocalTable
159
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
+
161
+ graph = cls(tables, edges=edges or [])
162
+
163
+ if infer_metadata:
164
+ graph.infer_metadata(False)
165
+
166
+ if edges is None:
167
+ graph.infer_links(False)
168
+
169
+ if verbose:
170
+ graph.print_metadata()
171
+ graph.print_links()
172
+
173
+ return graph
174
+
175
+ @classmethod
176
+ def from_sqlite(
177
+ cls,
178
+ connection: Union[
179
+ 'AdbcSqliteConnection',
180
+ SqliteConnectionConfig,
181
+ str,
182
+ Path,
183
+ Dict[str, Any],
184
+ ],
185
+ table_names: Optional[Sequence[str]] = None,
186
+ edges: Optional[Sequence[Edge]] = None,
187
+ infer_metadata: bool = True,
188
+ verbose: bool = True,
189
+ ) -> Self:
190
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
191
+
192
+ Automatically infers table metadata and links by default.
193
+
194
+ .. code-block:: python
195
+
196
+ >>> # doctest: +SKIP
197
+ >>> import kumoai.experimental.rfm as rfm
198
+
199
+ >>> # Create a graph from a SQLite database:
200
+ >>> graph = rfm.Graph.from_sqlite('data.db')
201
+
202
+ Args:
203
+ connection: An open connection from
204
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
+ path to the database file.
206
+ table_names: Set of table names to include. If ``None``, will add
207
+ all tables present in the database.
134
208
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
135
209
  add to the graph. If not provided, edges will be automatically
136
- inferred from the data.
210
+ inferred from the data in case ``infer_metadata=True``.
211
+ infer_metadata: Whether to infer metadata for all tables in the
212
+ graph.
137
213
  verbose: Whether to print verbose output.
214
+ """
215
+ from kumoai.experimental.rfm.backend.sqlite import (
216
+ Connection,
217
+ SQLiteTable,
218
+ connect,
219
+ )
220
+
221
+ if not isinstance(connection, Connection):
222
+ connection = SqliteConnectionConfig._cast(connection)
223
+ assert isinstance(connection, SqliteConnectionConfig)
224
+ connection = connect(connection.uri, **connection.kwargs)
225
+ assert isinstance(connection, Connection)
226
+
227
+ if table_names is None:
228
+ with connection.cursor() as cursor:
229
+ cursor.execute("SELECT name FROM sqlite_master "
230
+ "WHERE type='table'")
231
+ table_names = [row[0] for row in cursor.fetchall()]
232
+
233
+ tables = [SQLiteTable(connection, name) for name in table_names]
138
234
 
139
- Note:
140
- This method will automatically infer metadata and links for the
141
- graph.
235
+ graph = cls(tables, edges=edges or [])
142
236
 
143
- Example:
237
+ if infer_metadata:
238
+ graph.infer_metadata(False)
239
+
240
+ if edges is None:
241
+ graph.infer_links(False)
242
+
243
+ if verbose:
244
+ graph.print_metadata()
245
+ graph.print_links()
246
+
247
+ return graph
248
+
249
+ @classmethod
250
+ def from_snowflake(
251
+ cls,
252
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
253
+ database: Optional[str] = None,
254
+ schema: Optional[str] = None,
255
+ table_names: Optional[Sequence[str]] = None,
256
+ edges: Optional[Sequence[Edge]] = None,
257
+ infer_metadata: bool = True,
258
+ verbose: bool = True,
259
+ ) -> Self:
260
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
261
+ schema.
262
+
263
+ Automatically infers table metadata and links by default.
264
+
265
+ .. code-block:: python
266
+
267
+ >>> # doctest: +SKIP
144
268
  >>> import kumoai.experimental.rfm as rfm
145
- >>> df1 = pd.DataFrame(...)
146
- >>> df2 = pd.DataFrame(...)
147
- >>> df3 = pd.DataFrame(...)
148
- >>> graph = rfm.LocalGraph.from_data(data={
149
- ... "table1": df1,
150
- ... "table2": df2,
151
- ... "table3": df3,
152
- ... })
153
- ... graph.validate()
269
+
270
+ >>> # Create a graph directly in a Snowflake notebook:
271
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
272
+
273
+ Args:
274
+ connection: An open connection from
275
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
276
+ :class:`snowflake` connector keyword arguments to open a new
277
+ connection. If ``None``, will re-use an active session in case
278
+ it exists, or create a new connection from credentials stored
279
+ in environment variables.
280
+ database: The database.
281
+ schema: The schema.
282
+ table_names: Set of table names to include. If ``None``, will add
283
+ all tables present in the database.
284
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
285
+ add to the graph. If not provided, edges will be automatically
286
+ inferred from the data in case ``infer_metadata=True``.
287
+ infer_metadata: Whether to infer metadata for all tables in the
288
+ graph.
289
+ verbose: Whether to print verbose output.
154
290
  """
155
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
291
+ from kumoai.experimental.rfm.backend.snow import (
292
+ Connection,
293
+ SnowTable,
294
+ connect,
295
+ )
296
+
297
+ if not isinstance(connection, Connection):
298
+ connection = connect(**(connection or {}))
299
+ assert isinstance(connection, Connection)
300
+
301
+ if table_names is None:
302
+ with connection.cursor() as cursor:
303
+ if database is None and schema is None:
304
+ cursor.execute("SELECT CURRENT_DATABASE(), "
305
+ "CURRENT_SCHEMA()")
306
+ result = cursor.fetchone()
307
+ database = database or result[0]
308
+ schema = schema or result[1]
309
+ cursor.execute(f"""
310
+ SELECT TABLE_NAME
311
+ FROM {database}.INFORMATION_SCHEMA.TABLES
312
+ WHERE TABLE_SCHEMA = '{schema}'
313
+ """)
314
+ table_names = [row[0] for row in cursor.fetchall()]
315
+
316
+ tables = [
317
+ SnowTable(
318
+ connection,
319
+ name=table_name,
320
+ database=database,
321
+ schema=schema,
322
+ ) for table_name in table_names
323
+ ]
156
324
 
157
325
  graph = cls(tables, edges=edges or [])
158
326
 
159
327
  if infer_metadata:
160
- graph.infer_metadata(verbose)
328
+ graph.infer_metadata(False)
161
329
 
162
330
  if edges is None:
163
- graph.infer_links(verbose)
331
+ graph.infer_links(False)
332
+
333
+ if verbose:
334
+ graph.print_metadata()
335
+ graph.print_links()
336
+
337
+ return graph
338
+
339
+ @classmethod
340
+ def from_snowflake_semantic_view(
341
+ cls,
342
+ semantic_view_name: str,
343
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
344
+ verbose: bool = True,
345
+ ) -> Self:
346
+ import yaml
347
+
348
+ from kumoai.experimental.rfm.backend.snow import (
349
+ Connection,
350
+ SnowTable,
351
+ connect,
352
+ )
353
+
354
+ if not isinstance(connection, Connection):
355
+ connection = connect(**(connection or {}))
356
+ assert isinstance(connection, Connection)
357
+
358
+ with connection.cursor() as cursor:
359
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
360
+ f"'{semantic_view_name}')")
361
+ view = yaml.safe_load(cursor.fetchone()[0])
362
+
363
+ graph = cls(tables=[])
364
+
365
+ for table_desc in view['tables']:
366
+ primary_key: Optional[str] = None
367
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
368
+ and len(table_desc['primary_key']['columns']) == 1):
369
+ primary_key = table_desc['primary_key']['columns'][0]
370
+
371
+ table = SnowTable(
372
+ connection,
373
+ name=table_desc['base_table']['table'],
374
+ database=table_desc['base_table']['database'],
375
+ schema=table_desc['base_table']['schema'],
376
+ primary_key=primary_key,
377
+ )
378
+ graph.add_table(table)
379
+
380
+ # TODO Find a solution to register time columns!
381
+
382
+ for relations in view['relationships']:
383
+ if len(relations['relationship_columns']) != 1:
384
+ continue # NOTE No composite keys yet.
385
+ graph.link(
386
+ src_table=relations['left_table'],
387
+ fkey=relations['relationship_columns'][0]['left_column'],
388
+ dst_table=relations['right_table'],
389
+ )
390
+
391
+ if verbose:
392
+ graph.print_metadata()
393
+ graph.print_links()
164
394
 
165
395
  return graph
166
396
 
@@ -172,7 +402,7 @@ class LocalGraph:
172
402
  """
173
403
  return name in self.tables
174
404
 
175
- def table(self, name: str) -> LocalTable:
405
+ def table(self, name: str) -> Table:
176
406
  r"""Returns the table with name ``name`` in the graph.
177
407
 
178
408
  Raises:
@@ -183,11 +413,11 @@ class LocalGraph:
183
413
  return self.tables[name]
184
414
 
185
415
  @property
186
- def tables(self) -> Dict[str, LocalTable]:
416
+ def tables(self) -> Dict[str, Table]:
187
417
  r"""Returns the dictionary of table objects."""
188
418
  return self._tables
189
419
 
190
- def add_table(self, table: LocalTable) -> Self:
420
+ def add_table(self, table: Table) -> Self:
191
421
  r"""Adds a table to the graph.
192
422
 
193
423
  Args:
@@ -196,17 +426,21 @@ class LocalGraph:
196
426
  Raises:
197
427
  KeyError: If a table with the same name already exists in the
198
428
  graph.
429
+ ValueError: If the table belongs to a different backend than the
430
+ rest of the tables in the graph.
199
431
  """
200
- if len(self.tables) >= 15:
201
- raise ValueError("Cannot create a graph with more than 15 "
202
- "tables. Please create a feature request at "
203
- "'https://github.com/kumo-ai/kumo-rfm' if you "
204
- "must go beyond this for your use-case.")
205
-
206
432
  if table.name in self._tables:
207
433
  raise KeyError(f"Cannot add table with name '{table.name}' to "
208
434
  f"this graph; table names must be globally unique.")
209
435
 
436
+ if len(self._tables) > 0:
437
+ cls = next(iter(self._tables.values())).__class__
438
+ if table.__class__ != cls:
439
+ raise ValueError(f"Cannot register a "
440
+ f"'{table.__class__.__name__}' to this "
441
+ f"graph since other tables are of type "
442
+ f"'{cls.__name__}'.")
443
+
210
444
  self._tables[table.name] = table
211
445
 
212
446
  return self
@@ -237,16 +471,17 @@ class LocalGraph:
237
471
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
238
472
  information about the tables in this graph.
239
473
 
240
- The returned dataframe has columns ``name``, ``primary_key``, and
241
- ``time_column``, which provide an aggregate view of the properties of
242
- the tables of this graph.
474
+ The returned dataframe has columns ``name``, ``primary_key``,
475
+ ``time_column``, and ``end_time_column``, which provide an aggregate
476
+ view of the properties of the tables of this graph.
243
477
 
244
478
  Example:
479
+ >>> # doctest: +SKIP
245
480
  >>> import kumoai.experimental.rfm as rfm
246
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
247
- >>> graph.metadata
248
- name primary_key time_column
249
- 0 users user_id -
481
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
482
+ >>> graph.metadata # doctest: +SKIP
483
+ name primary_key time_column end_time_column
484
+ 0 users user_id - -
250
485
  """
251
486
  tables = list(self.tables.values())
252
487
 
@@ -257,13 +492,22 @@ class LocalGraph:
257
492
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
258
493
  'time_column':
259
494
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
495
+ 'end_time_column':
496
+ pd.Series(
497
+ dtype=str,
498
+ data=[t._end_time_column or '-' for t in tables],
499
+ ),
260
500
  })
261
501
 
262
502
  def print_metadata(self) -> None:
263
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
264
- if in_notebook():
503
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
504
+ if in_snowflake_notebook():
505
+ import streamlit as st
506
+ st.markdown("### 🗂️ Graph Metadata")
507
+ st.dataframe(self.metadata, hide_index=True)
508
+ elif in_notebook():
265
509
  from IPython.display import Markdown, display
266
- display(Markdown('### 🗂️ Graph Metadata'))
510
+ display(Markdown("### 🗂️ Graph Metadata"))
267
511
  df = self.metadata
268
512
  try:
269
513
  if hasattr(df.style, 'hide'):
@@ -284,7 +528,7 @@ class LocalGraph:
284
528
 
285
529
  Note:
286
530
  For more information, please see
287
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
531
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
288
532
  """
289
533
  for table in self.tables.values():
290
534
  table.infer_metadata(verbose=False)
@@ -302,37 +546,47 @@ class LocalGraph:
302
546
  return self._edges
303
547
 
304
548
  def print_links(self) -> None:
305
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
549
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
306
550
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
307
551
  edge.src_table, edge.fkey) for edge in self.edges]
308
552
  edges = sorted(edges)
309
553
 
310
- if in_notebook():
554
+ if in_snowflake_notebook():
555
+ import streamlit as st
556
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
557
+ if len(edges) > 0:
558
+ st.markdown('\n'.join([
559
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
560
+ for edge in edges
561
+ ]))
562
+ else:
563
+ st.markdown("*No links registered*")
564
+ elif in_notebook():
311
565
  from IPython.display import Markdown, display
312
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
566
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
313
567
  if len(edges) > 0:
314
568
  display(
315
569
  Markdown('\n'.join([
316
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
570
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
317
571
  for edge in edges
318
572
  ])))
319
573
  else:
320
- display(Markdown('*No links registered*'))
574
+ display(Markdown("*No links registered*"))
321
575
  else:
322
576
  print("🕸️ Graph Links (FK ↔️ PK):")
323
577
  if len(edges) > 0:
324
578
  print('\n'.join([
325
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
579
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
326
580
  for edge in edges
327
581
  ]))
328
582
  else:
329
- print('No links registered')
583
+ print("No links registered")
330
584
 
331
585
  def link(
332
586
  self,
333
- src_table: Union[str, LocalTable],
587
+ src_table: Union[str, Table],
334
588
  fkey: str,
335
- dst_table: Union[str, LocalTable],
589
+ dst_table: Union[str, Table],
336
590
  ) -> Self:
337
591
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
338
592
  key ``fkey`` in the source table to the primary key in the destination
@@ -355,11 +609,11 @@ class LocalGraph:
355
609
  table does not exist in the graph, if the source key does not
356
610
  exist in the source table.
357
611
  """
358
- if isinstance(src_table, LocalTable):
612
+ if isinstance(src_table, Table):
359
613
  src_table = src_table.name
360
614
  assert isinstance(src_table, str)
361
615
 
362
- if isinstance(dst_table, LocalTable):
616
+ if isinstance(dst_table, Table):
363
617
  dst_table = dst_table.name
364
618
  assert isinstance(dst_table, str)
365
619
 
@@ -393,9 +647,9 @@ class LocalGraph:
393
647
 
394
648
  def unlink(
395
649
  self,
396
- src_table: Union[str, LocalTable],
650
+ src_table: Union[str, Table],
397
651
  fkey: str,
398
- dst_table: Union[str, LocalTable],
652
+ dst_table: Union[str, Table],
399
653
  ) -> Self:
400
654
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
401
655
 
@@ -407,11 +661,11 @@ class LocalGraph:
407
661
  Raises:
408
662
  ValueError: if the edge is not present in the graph.
409
663
  """
410
- if isinstance(src_table, LocalTable):
664
+ if isinstance(src_table, Table):
411
665
  src_table = src_table.name
412
666
  assert isinstance(src_table, str)
413
667
 
414
- if isinstance(dst_table, LocalTable):
668
+ if isinstance(dst_table, Table):
415
669
  dst_table = dst_table.name
416
670
  assert isinstance(dst_table, str)
417
671
 
@@ -425,17 +679,13 @@ class LocalGraph:
425
679
  return self
426
680
 
427
681
  def infer_links(self, verbose: bool = True) -> Self:
428
- r"""Infers links for the tables and adds them as edges to the graph.
682
+ r"""Infers missing links for the tables and adds them as edges to the
683
+ graph.
429
684
 
430
685
  Args:
431
686
  verbose: Whether to print verbose output.
432
-
433
- Note:
434
- This function expects graph edges to be undefined upfront.
435
687
  """
436
- if len(self.edges) > 0:
437
- warnings.warn("Cannot infer links if graph edges already exist")
438
- return self
688
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
439
689
 
440
690
  # A list of primary key candidates (+score) for every column:
441
691
  candidate_dict: dict[
@@ -460,6 +710,9 @@ class LocalGraph:
460
710
  src_table_name = src_table.name.lower()
461
711
 
462
712
  for src_key in src_table.columns:
713
+ if (src_table.name, src_key.name) in known_edges:
714
+ continue
715
+
463
716
  if src_key == src_table.primary_key:
464
717
  continue # Cannot link to primary key.
465
718
 
@@ -525,7 +778,9 @@ class LocalGraph:
525
778
  score += 1.0
526
779
 
527
780
  # Cardinality ratio:
528
- if len(src_table._data) > len(dst_table._data):
781
+ if (src_table._num_rows is not None
782
+ and dst_table._num_rows is not None
783
+ and src_table._num_rows > dst_table._num_rows):
529
784
  score += 1.0
530
785
 
531
786
  if score < 5.0:
@@ -580,13 +835,17 @@ class LocalGraph:
580
835
  # Check that the destination table defines a primary key:
581
836
  if dst_key is None:
582
837
  raise ValueError(f"Edge {edge} is invalid since table "
583
- f"'{dst_table}' does not have a primary key")
838
+ f"'{dst_table}' does not have a primary key. "
839
+ f"Add either a primary key or remove the "
840
+ f"link before proceeding.")
584
841
 
585
842
  # Ensure that foreign key is not a primary key:
586
843
  src_pkey = self[src_table].primary_key
587
844
  if src_pkey is not None and src_pkey.name == fkey:
588
845
  raise ValueError(f"Cannot treat the primary key of table "
589
- f"'{src_table}' as a foreign key")
846
+ f"'{src_table}' as a foreign key. Remove "
847
+ f"either the primary key or the link before "
848
+ f"before proceeding.")
590
849
 
591
850
  # Check that fkey/pkey have valid and consistent data types:
592
851
  assert src_key.dtype is not None
@@ -604,8 +863,8 @@ class LocalGraph:
604
863
  raise ValueError(f"{edge} is invalid as foreign key "
605
864
  f"'{fkey}' and primary key '{dst_key.name}' "
606
865
  f"have incompatible data types (got "
607
- f"fkey.dtype '{dst_key.dtype}' and "
608
- f"pkey.dtype '{src_key.dtype}')")
866
+ f"fkey.dtype '{src_key.dtype}' and "
867
+ f"pkey.dtype '{dst_key.dtype}')")
609
868
 
610
869
  return self
611
870
 
@@ -638,19 +897,19 @@ class LocalGraph:
638
897
 
639
898
  return True
640
899
 
641
- # Check basic dependency:
642
- if not find_spec('graphviz'):
643
- raise ModuleNotFoundError("The 'graphviz' package is required for "
644
- "visualization")
645
- elif not has_graphviz_executables():
900
+ try: # Check basic dependency:
901
+ import graphviz
902
+ except ImportError as e:
903
+ raise ImportError("The 'graphviz' package is required for "
904
+ "visualization") from e
905
+
906
+ if not in_snowflake_notebook() and not has_graphviz_executables():
646
907
  raise RuntimeError("Could not visualize graph as 'graphviz' "
647
908
  "executables are not installed. These "
648
909
  "dependencies are required in addition to the "
649
910
  "'graphviz' Python package. Please install "
650
911
  "them as described at "
651
912
  "https://graphviz.org/download/.")
652
- else:
653
- import graphviz
654
913
 
655
914
  format: Optional[str] = None
656
915
  if isinstance(path, str):
@@ -678,6 +937,11 @@ class LocalGraph:
678
937
  ]
679
938
  if time_column := table.time_column:
680
939
  keys += [f'{time_column.name}: Time ({time_column.dtype})']
940
+ if end_time_column := table.end_time_column:
941
+ keys += [
942
+ f'{end_time_column.name}: '
943
+ f'End Time ({end_time_column.dtype})'
944
+ ]
681
945
  key_repr = left_align(keys)
682
946
 
683
947
  columns = []
@@ -685,9 +949,9 @@ class LocalGraph:
685
949
  columns += [
686
950
  f'{column.name}: {column.stype} ({column.dtype})'
687
951
  for column in table.columns
688
- if column.name not in fkeys_dict[table_name]
689
- and column.name != table._primary_key
690
- and column.name != table._time_column
952
+ if column.name not in fkeys_dict[table_name] and
953
+ column.name != table._primary_key and column.name != table.
954
+ _time_column and column.name != table._end_time_column
691
955
  ]
692
956
  column_repr = left_align(columns)
693
957
 
@@ -729,6 +993,9 @@ class LocalGraph:
729
993
  graph.render(path, cleanup=True)
730
994
  elif isinstance(path, io.BytesIO):
731
995
  path.write(graph.pipe())
996
+ elif in_snowflake_notebook():
997
+ import streamlit as st
998
+ st.graphviz_chart(graph)
732
999
  elif in_notebook():
733
1000
  from IPython.display import display
734
1001
  display(graph)
@@ -754,16 +1021,18 @@ class LocalGraph:
754
1021
  def _to_api_graph_definition(self) -> GraphDefinition:
755
1022
  tables: Dict[str, TableDefinition] = {}
756
1023
  col_groups: List[ColumnKeyGroup] = []
757
- for t_name, table in self.tables.items():
758
- tables[t_name] = table._to_api_table_definition()
1024
+ for table_name, table in self.tables.items():
1025
+ tables[table_name] = table._to_api_table_definition()
759
1026
  if table.primary_key is None:
760
1027
  continue
761
- keys = [ColumnKey(t_name, table.primary_key.name)]
1028
+ keys = [ColumnKey(table_name, table.primary_key.name)]
762
1029
  for edge in self.edges:
763
- if edge.dst_table == t_name:
1030
+ if edge.dst_table == table_name:
764
1031
  keys.append(ColumnKey(edge.src_table, edge.fkey))
765
- keys = sorted(list(set(keys)),
766
- key=lambda x: f'{x.table_name}.{x.col_name}')
1032
+ keys = sorted(
1033
+ list(set(keys)),
1034
+ key=lambda x: f'{x.table_name}.{x.col_name}',
1035
+ )
767
1036
  if len(keys) > 1:
768
1037
  col_groups.append(ColumnKeyGroup(keys))
769
1038
  return GraphDefinition(tables, col_groups)
@@ -776,7 +1045,7 @@ class LocalGraph:
776
1045
  def __contains__(self, name: str) -> bool:
777
1046
  return self.has_table(name)
778
1047
 
779
- def __getitem__(self, name: str) -> LocalTable:
1048
+ def __getitem__(self, name: str) -> Table:
780
1049
  return self.table(name)
781
1050
 
782
1051
  def __delitem__(self, name: str) -> None: