kumoai 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kumoapi.typing import Dtype
4
+
5
+
6
+ @dataclass
7
+ class SourceColumn:
8
+ name: str
9
+ dtype: Dtype
10
+ is_primary_key: bool
11
+ is_unique_key: bool
12
+
13
+
14
+ @dataclass
15
+ class SourceForeignKey:
16
+ name: str
17
+ dst_table: str
18
+ primary_key: str
@@ -1,15 +1,25 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, List, Optional, Sequence, Tuple
2
+ from collections import defaultdict
3
+ from functools import cached_property
4
+ from typing import Dict, List, Optional, Sequence, Set
3
5
 
4
6
  import pandas as pd
5
7
  from kumoapi.source_table import UnavailableSourceTable
6
8
  from kumoapi.table import Column as ColumnDefinition
7
9
  from kumoapi.table import TableDefinition
8
- from kumoapi.typing import Dtype, Stype
10
+ from kumoapi.typing import Stype
9
11
  from typing_extensions import Self
10
12
 
11
- from kumoai import in_notebook
12
- from kumoai.experimental.rfm.base import Column
13
+ from kumoai import in_notebook, in_snowflake_notebook
14
+ from kumoai.experimental.rfm.base import Column, SourceColumn, SourceForeignKey
15
+ from kumoai.experimental.rfm.infer import (
16
+ contains_categorical,
17
+ contains_id,
18
+ contains_multicategorical,
19
+ contains_timestamp,
20
+ infer_primary_key,
21
+ infer_time_column,
22
+ )
13
23
 
14
24
 
15
25
  class Table(ABC):
@@ -39,8 +49,30 @@ class Table(ABC):
39
49
  self._time_column: Optional[str] = None
40
50
  self._end_time_column: Optional[str] = None
41
51
 
52
+ if len(self._source_column_dict) == 0:
53
+ raise ValueError(f"Table '{name}' does not hold any column with "
54
+ f"a supported data type")
55
+
56
+ primary_keys = [
57
+ column.name for column in self._source_column_dict.values()
58
+ if column.is_primary_key
59
+ ]
60
+ if len(primary_keys) == 1: # NOTE No composite keys yet.
61
+ if primary_key is not None and primary_key != primary_keys[0]:
62
+ raise ValueError(f"Found duplicate primary key "
63
+ f"definition '{primary_key}' and "
64
+ f"'{primary_keys[0]}' in table '{name}'")
65
+ primary_key = primary_keys[0]
66
+
67
+ unique_keys = [
68
+ column.name for column in self._source_column_dict.values()
69
+ if column.is_unique_key
70
+ ]
71
+ if primary_key is None and len(unique_keys) == 1:
72
+ primary_key = unique_keys[0]
73
+
42
74
  self._columns: Dict[str, Column] = {}
43
- for column_name in columns or []:
75
+ for column_name in columns or list(self._source_column_dict.keys()):
44
76
  self.add_column(column_name)
45
77
 
46
78
  if primary_key is not None:
@@ -104,12 +136,12 @@ class Table(ABC):
104
136
  raise KeyError(f"Column '{name}' already exists in table "
105
137
  f"'{self.name}'")
106
138
 
107
- if not self._has_source_column(name):
139
+ if name not in self._source_column_dict:
108
140
  raise KeyError(f"Column '{name}' does not exist in the underlying "
109
141
  f"source table")
110
142
 
111
143
  try:
112
- dtype = self._get_source_dtype(name)
144
+ dtype = self._source_column_dict[name].dtype
113
145
  except Exception as e:
114
146
  raise RuntimeError(f"Could not obtain data type for column "
115
147
  f"'{name}' in table '{self.name}'. Change "
@@ -117,7 +149,17 @@ class Table(ABC):
117
149
  f"table or remove it from the table.") from e
118
150
 
119
151
  try:
120
- stype = self._get_source_stype(name, dtype)
152
+ ser = self._sample_df[name]
153
+ if contains_id(ser, name, dtype):
154
+ stype = Stype.ID
155
+ elif contains_timestamp(ser, name, dtype):
156
+ stype = Stype.timestamp
157
+ elif contains_multicategorical(ser, name, dtype):
158
+ stype = Stype.multicategorical
159
+ elif contains_categorical(ser, name, dtype):
160
+ stype = Stype.categorical
161
+ else:
162
+ stype = dtype.default_stype
121
163
  except Exception as e:
122
164
  raise RuntimeError(f"Could not obtain semantic type for column "
123
165
  f"'{name}' in table '{self.name}'. Change "
@@ -338,10 +380,16 @@ class Table(ABC):
338
380
 
339
381
  def print_metadata(self) -> None:
340
382
  r"""Prints the :meth:`~metadata` of this table."""
341
- num_rows = self._num_rows()
342
- num_rows_repr = ' ({num_rows:,} rows)' if num_rows is not None else ''
383
+ num_rows_repr = ''
384
+ if self._num_rows is not None:
385
+ num_rows_repr = ' ({self._num_rows:,} rows)'
343
386
 
344
- if in_notebook():
387
+ if in_snowflake_notebook():
388
+ import streamlit as st
389
+ md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
390
+ st.markdown(md_repr)
391
+ st.dataframe(self.metadata, hide_index=True)
392
+ elif in_notebook():
345
393
  from IPython.display import Markdown, display
346
394
  md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
347
395
  display(Markdown(md_repr))
@@ -384,7 +432,11 @@ class Table(ABC):
384
432
  column.name for column in self.columns if is_candidate(column)
385
433
  ]
386
434
 
387
- if primary_key := self._infer_primary_key(candidates):
435
+ if primary_key := infer_primary_key(
436
+ table_name=self.name,
437
+ df=self._sample_df,
438
+ candidates=candidates,
439
+ ):
388
440
  self.primary_key = primary_key
389
441
  logs.append(f"primary key '{primary_key}'")
390
442
 
@@ -395,7 +447,10 @@ class Table(ABC):
395
447
  if column.stype == Stype.timestamp
396
448
  and column.name != self._end_time_column
397
449
  ]
398
- if time_column := self._infer_time_column(candidates):
450
+ if time_column := infer_time_column(
451
+ df=self._sample_df,
452
+ candidates=candidates,
453
+ ):
399
454
  self.time_column = time_column
400
455
  logs.append(f"time column '{time_column}'")
401
456
 
@@ -446,32 +501,45 @@ class Table(ABC):
446
501
  f' end_time_column={self._end_time_column},\n'
447
502
  f')')
448
503
 
449
- # Abstract method #########################################################
504
+ # Abstract Methods ########################################################
450
505
 
451
- @abstractmethod
452
- def _has_source_column(self, name: str) -> bool:
453
- pass
506
+ @cached_property
507
+ def _source_column_dict(self) -> Dict[str, SourceColumn]:
508
+ return {col.name: col for col in self._get_source_columns()}
454
509
 
455
510
  @abstractmethod
456
- def _get_source_dtype(self, name: str) -> Dtype:
511
+ def _get_source_columns(self) -> List[SourceColumn]:
457
512
  pass
458
513
 
459
- @abstractmethod
460
- def _get_source_stype(self, name: str, dtype: Dtype) -> Stype:
461
- pass
514
+ @cached_property
515
+ def _source_foreign_key_dict(self) -> Dict[str, SourceForeignKey]:
516
+ fkeys = self._get_source_foreign_keys()
517
+ # NOTE Drop all keys that link to different primary keys in the same
518
+ # table since we don't support composite keys yet:
519
+ table_pkeys: Dict[str, Set[str]] = defaultdict(set)
520
+ for fkey in fkeys:
521
+ table_pkeys[fkey.dst_table].add(fkey.primary_key)
522
+ return {
523
+ fkey.name: fkey
524
+ for fkey in fkeys if len(table_pkeys[fkey.dst_table]) == 1
525
+ }
462
526
 
463
527
  @abstractmethod
464
- def _get_source_foreign_keys(self) -> List[Tuple[str, str, str]]:
528
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
465
529
  pass
466
530
 
467
- @abstractmethod
468
- def _infer_primary_key(self, candidates: List[str]) -> Optional[str]:
469
- pass
531
+ @cached_property
532
+ def _sample_df(self) -> pd.DataFrame:
533
+ return self._get_sample_df()
470
534
 
471
535
  @abstractmethod
472
- def _infer_time_column(self, candidates: List[str]) -> Optional[str]:
536
+ def _get_sample_df(self) -> pd.DataFrame:
473
537
  pass
474
538
 
475
- @abstractmethod
539
+ @cached_property
476
540
  def _num_rows(self) -> Optional[int]:
541
+ return self._get_num_rows()
542
+
543
+ @abstractmethod
544
+ def _get_num_rows(self) -> Optional[int]:
477
545
  pass
@@ -2,7 +2,8 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from importlib.util import find_spec
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
8
 
8
9
  import pandas as pd
@@ -11,12 +12,21 @@ from kumoapi.table import TableDefinition
11
12
  from kumoapi.typing import Stype
12
13
  from typing_extensions import Self
13
14
 
14
- from kumoai import in_notebook
15
+ from kumoai import in_notebook, in_snowflake_notebook
15
16
  from kumoai.experimental.rfm import Table
16
17
  from kumoai.graph import Edge
18
+ from kumoai.mixin import CastMixin
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  import graphviz
22
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
23
+ from snowflake.connector import SnowflakeConnection
24
+
25
+
26
+ @dataclass
27
+ class SqliteConnectionConfig(CastMixin):
28
+ uri: Union[str, Path]
29
+ kwargs: Dict[str, Any] = field(default_factory=dict)
20
30
 
21
31
 
22
32
  class Graph:
@@ -86,14 +96,17 @@ class Graph:
86
96
  self.add_table(table)
87
97
 
88
98
  for table in tables:
89
- for fkey, dst_table, pkey in table._get_source_foreign_keys():
90
- if self[dst_table].primary_key is None:
91
- self[dst_table].primary_key = pkey
92
- elif self[dst_table]._primary_key != pkey:
99
+ for fkey in table._source_foreign_key_dict.values():
100
+ if fkey.name not in table or fkey.dst_table not in self:
101
+ continue
102
+ if self[fkey.dst_table].primary_key is None:
103
+ self[fkey.dst_table].primary_key = fkey.primary_key
104
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
93
105
  raise ValueError(f"Found duplicate primary key definition "
94
- f"'{self[dst_table]._primary_key}' and "
95
- f"'{pkey}' in table '{dst_table}'.")
96
- self.link(table.name, fkey, dst_table)
106
+ f"'{self[fkey.dst_table]._primary_key}' "
107
+ f"and '{fkey.primary_key}' in table "
108
+ f"'{fkey.dst_table}'.")
109
+ self.link(table.name, fkey.name, fkey.dst_table)
97
110
 
98
111
  for edge in (edges or []):
99
112
  _edge = Edge._cast(edge)
@@ -132,13 +145,6 @@ class Graph:
132
145
  ... "table3": df3,
133
146
  ... })
134
147
 
135
- >>> # Inspect table metadata:
136
- >>> for table in graph.tables.values():
137
- ... table.print_metadata()
138
-
139
- >>> # Visualize graph (if graphviz is installed):
140
- >>> graph.visualize()
141
-
142
148
  Args:
143
149
  df_dict: A dictionary of data frames, where the keys are the names
144
150
  of the tables and the values hold table data.
@@ -169,12 +175,17 @@ class Graph:
169
175
  @classmethod
170
176
  def from_sqlite(
171
177
  cls,
172
- uri: Any,
178
+ connection: Union[
179
+ 'AdbcSqliteConnection',
180
+ SqliteConnectionConfig,
181
+ str,
182
+ Path,
183
+ Dict[str, Any],
184
+ ],
173
185
  table_names: Optional[Sequence[str]] = None,
174
186
  edges: Optional[Sequence[Edge]] = None,
175
187
  infer_metadata: bool = True,
176
188
  verbose: bool = True,
177
- conn_kwargs: Optional[Dict[str, Any]] = None,
178
189
  ) -> Self:
179
190
  r"""Creates a :class:`Graph` from a :class:`sqlite` database.
180
191
 
@@ -188,16 +199,10 @@ class Graph:
188
199
  >>> # Create a graph from a SQLite database:
189
200
  >>> graph = rfm.Graph.from_sqlite('data.db')
190
201
 
191
- >>> # Inspect table metadata:
192
- >>> for table in graph.tables.values():
193
- ... table.print_metadata()
194
-
195
- >>> # Visualize graph (if graphviz is installed):
196
- >>> graph.visualize()
197
-
198
202
  Args:
199
- uri: The path to the database file or an open connection obtained
200
- from :meth:`~kumoai.experimental.rfm.backend.sqlite.connect`.
203
+ connection: An open connection from
204
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
+ path to the database file.
201
206
  table_names: Set of table names to include. If ``None``, will add
202
207
  all tables present in the database.
203
208
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
@@ -206,8 +211,6 @@ class Graph:
206
211
  infer_metadata: Whether to infer metadata for all tables in the
207
212
  graph.
208
213
  verbose: Whether to print verbose output.
209
- conn_kwargs: Additional connection arguments, following the
210
- :class:`adbc_driver_sqlite` protocol.
211
214
  """
212
215
  from kumoai.experimental.rfm.backend.sqlite import (
213
216
  Connection,
@@ -215,10 +218,11 @@ class Graph:
215
218
  connect,
216
219
  )
217
220
 
218
- if not isinstance(uri, Connection):
219
- connection = connect(uri, **(conn_kwargs or {}))
220
- else:
221
- connection = uri
221
+ if not isinstance(connection, Connection):
222
+ connection = SqliteConnectionConfig._cast(connection)
223
+ assert isinstance(connection, SqliteConnectionConfig)
224
+ connection = connect(connection.uri, **connection.kwargs)
225
+ assert isinstance(connection, Connection)
222
226
 
223
227
  if table_names is None:
224
228
  with connection.cursor() as cursor:
@@ -242,6 +246,154 @@ class Graph:
242
246
 
243
247
  return graph
244
248
 
249
+ @classmethod
250
+ def from_snowflake(
251
+ cls,
252
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
253
+ database: Optional[str] = None,
254
+ schema: Optional[str] = None,
255
+ table_names: Optional[Sequence[str]] = None,
256
+ edges: Optional[Sequence[Edge]] = None,
257
+ infer_metadata: bool = True,
258
+ verbose: bool = True,
259
+ ) -> Self:
260
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
261
+ schema.
262
+
263
+ Automatically infers table metadata and links by default.
264
+
265
+ .. code-block:: python
266
+
267
+ >>> # doctest: +SKIP
268
+ >>> import kumoai.experimental.rfm as rfm
269
+
270
+ >>> # Create a graph directly in a Snowflake notebook:
271
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
272
+
273
+ Args:
274
+ connection: An open connection from
275
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
276
+ :class:`snowflake` connector keyword arguments to open a new
277
+ connection. If ``None``, will re-use an active session in case
278
+ it exists, or create a new connection from credentials stored
279
+ in environment variables.
280
+ database: The database.
281
+ schema: The schema.
282
+ table_names: Set of table names to include. If ``None``, will add
283
+ all tables present in the database.
284
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
285
+ add to the graph. If not provided, edges will be automatically
286
+ inferred from the data in case ``infer_metadata=True``.
287
+ infer_metadata: Whether to infer metadata for all tables in the
288
+ graph.
289
+ verbose: Whether to print verbose output.
290
+ """
291
+ from kumoai.experimental.rfm.backend.snow import (
292
+ Connection,
293
+ SnowTable,
294
+ connect,
295
+ )
296
+
297
+ if not isinstance(connection, Connection):
298
+ connection = connect(**(connection or {}))
299
+ assert isinstance(connection, Connection)
300
+
301
+ if table_names is None:
302
+ with connection.cursor() as cursor:
303
+ if database is None and schema is None:
304
+ cursor.execute("SELECT CURRENT_DATABASE(), "
305
+ "CURRENT_SCHEMA()")
306
+ result = cursor.fetchone()
307
+ database = database or result[0]
308
+ schema = schema or result[1]
309
+ cursor.execute(f"""
310
+ SELECT TABLE_NAME
311
+ FROM {database}.INFORMATION_SCHEMA.TABLES
312
+ WHERE TABLE_SCHEMA = '{schema}'
313
+ """)
314
+ table_names = [row[0] for row in cursor.fetchall()]
315
+
316
+ tables = [
317
+ SnowTable(
318
+ connection,
319
+ name=table_name,
320
+ database=database,
321
+ schema=schema,
322
+ ) for table_name in table_names
323
+ ]
324
+
325
+ graph = cls(tables, edges=edges or [])
326
+
327
+ if infer_metadata:
328
+ graph.infer_metadata(False)
329
+
330
+ if edges is None:
331
+ graph.infer_links(False)
332
+
333
+ if verbose:
334
+ graph.print_metadata()
335
+ graph.print_links()
336
+
337
+ return graph
338
+
339
+ @classmethod
340
+ def from_snowflake_semantic_view(
341
+ cls,
342
+ semantic_view_name: str,
343
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
344
+ verbose: bool = True,
345
+ ) -> Self:
346
+ import yaml
347
+
348
+ from kumoai.experimental.rfm.backend.snow import (
349
+ Connection,
350
+ SnowTable,
351
+ connect,
352
+ )
353
+
354
+ if not isinstance(connection, Connection):
355
+ connection = connect(**(connection or {}))
356
+ assert isinstance(connection, Connection)
357
+
358
+ with connection.cursor() as cursor:
359
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
360
+ f"'{semantic_view_name}')")
361
+ view = yaml.safe_load(cursor.fetchone()[0])
362
+
363
+ graph = cls(tables=[])
364
+
365
+ for table_desc in view['tables']:
366
+ primary_key: Optional[str] = None
367
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
368
+ and len(table_desc['primary_key']['columns']) == 1):
369
+ primary_key = table_desc['primary_key']['columns'][0]
370
+
371
+ table = SnowTable(
372
+ connection,
373
+ name=table_desc['base_table']['table'],
374
+ database=table_desc['base_table']['database'],
375
+ schema=table_desc['base_table']['schema'],
376
+ primary_key=primary_key,
377
+ )
378
+ graph.add_table(table)
379
+
380
+ # TODO Find a solution to register time columns!
381
+
382
+ for relations in view['relationships']:
383
+ if len(relations['relationship_columns']) != 1:
384
+ continue # NOTE No composite keys yet.
385
+ graph.link(
386
+ src_table=relations['left_table'],
387
+ fkey=relations['relationship_columns'][0]['left_column'],
388
+ dst_table=relations['right_table'],
389
+ )
390
+
391
+ if verbose:
392
+ graph.print_metadata()
393
+ graph.print_links()
394
+
395
+ return graph
396
+
245
397
  # Tables ##############################################################
246
398
 
247
399
  def has_table(self, name: str) -> bool:
@@ -349,9 +501,13 @@ class Graph:
349
501
 
350
502
  def print_metadata(self) -> None:
351
503
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
352
- if in_notebook():
504
+ if in_snowflake_notebook():
505
+ import streamlit as st
506
+ st.markdown("### 🗂️ Graph Metadata")
507
+ st.dataframe(self.metadata, hide_index=True)
508
+ elif in_notebook():
353
509
  from IPython.display import Markdown, display
354
- display(Markdown('### 🗂️ Graph Metadata'))
510
+ display(Markdown("### 🗂️ Graph Metadata"))
355
511
  df = self.metadata
356
512
  try:
357
513
  if hasattr(df.style, 'hide'):
@@ -395,26 +551,36 @@ class Graph:
395
551
  edge.src_table, edge.fkey) for edge in self.edges]
396
552
  edges = sorted(edges)
397
553
 
398
- if in_notebook():
554
+ if in_snowflake_notebook():
555
+ import streamlit as st
556
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
557
+ if len(edges) > 0:
558
+ st.markdown('\n'.join([
559
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
560
+ for edge in edges
561
+ ]))
562
+ else:
563
+ st.markdown("*No links registered*")
564
+ elif in_notebook():
399
565
  from IPython.display import Markdown, display
400
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
566
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
401
567
  if len(edges) > 0:
402
568
  display(
403
569
  Markdown('\n'.join([
404
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
570
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
405
571
  for edge in edges
406
572
  ])))
407
573
  else:
408
- display(Markdown('*No links registered*'))
574
+ display(Markdown("*No links registered*"))
409
575
  else:
410
576
  print("🕸️ Graph Links (FK ↔️ PK):")
411
577
  if len(edges) > 0:
412
578
  print('\n'.join([
413
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
579
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
414
580
  for edge in edges
415
581
  ]))
416
582
  else:
417
- print('No links registered')
583
+ print("No links registered")
418
584
 
419
585
  def link(
420
586
  self,
@@ -612,10 +778,9 @@ class Graph:
612
778
  score += 1.0
613
779
 
614
780
  # Cardinality ratio:
615
- src_num_rows = src_table._num_rows()
616
- dst_num_rows = dst_table._num_rows()
617
- if (src_num_rows is not None and dst_num_rows is not None
618
- and src_num_rows > dst_num_rows):
781
+ if (src_table._num_rows is not None
782
+ and dst_table._num_rows is not None
783
+ and src_table._num_rows > dst_table._num_rows):
619
784
  score += 1.0
620
785
 
621
786
  if score < 5.0:
@@ -732,19 +897,19 @@ class Graph:
732
897
 
733
898
  return True
734
899
 
735
- # Check basic dependency:
736
- if not find_spec('graphviz'):
737
- raise ModuleNotFoundError("The 'graphviz' package is required for "
738
- "visualization")
739
- elif not has_graphviz_executables():
900
+ try: # Check basic dependency:
901
+ import graphviz
902
+ except ImportError as e:
903
+ raise ImportError("The 'graphviz' package is required for "
904
+ "visualization") from e
905
+
906
+ if not in_snowflake_notebook() and not has_graphviz_executables():
740
907
  raise RuntimeError("Could not visualize graph as 'graphviz' "
741
908
  "executables are not installed. These "
742
909
  "dependencies are required in addition to the "
743
910
  "'graphviz' Python package. Please install "
744
911
  "them as described at "
745
912
  "https://graphviz.org/download/.")
746
- else:
747
- import graphviz
748
913
 
749
914
  format: Optional[str] = None
750
915
  if isinstance(path, str):
@@ -828,6 +993,9 @@ class Graph:
828
993
  graph.render(path, cleanup=True)
829
994
  elif isinstance(path, io.BytesIO):
830
995
  path.write(graph.pipe())
996
+ elif in_snowflake_notebook():
997
+ import streamlit as st
998
+ st.graphviz_chart(graph)
831
999
  elif in_notebook():
832
1000
  from IPython.display import display
833
1001
  display(graph)
@@ -1,13 +1,17 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
- from .stype import infer_stype
6
8
 
7
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
8
13
  'contains_id',
9
14
  'contains_timestamp',
10
15
  'contains_categorical',
11
16
  'contains_multicategorical',
12
- 'infer_stype',
13
17
  ]