kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/jobs.py +24 -0
  6. kumoai/client/pquery.py +6 -2
  7. kumoai/client/rfm.py +35 -7
  8. kumoai/connector/utils.py +23 -2
  9. kumoai/experimental/rfm/__init__.py +191 -48
  10. kumoai/experimental/rfm/authenticate.py +3 -4
  11. kumoai/experimental/rfm/backend/__init__.py +0 -0
  12. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  13. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  14. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  15. kumoai/experimental/rfm/backend/local/table.py +113 -0
  16. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  17. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  18. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  19. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  20. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  21. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  22. kumoai/experimental/rfm/base/__init__.py +30 -0
  23. kumoai/experimental/rfm/base/column.py +152 -0
  24. kumoai/experimental/rfm/base/expression.py +44 -0
  25. kumoai/experimental/rfm/base/sampler.py +761 -0
  26. kumoai/experimental/rfm/base/source.py +19 -0
  27. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  28. kumoai/experimental/rfm/base/table.py +735 -0
  29. kumoai/experimental/rfm/graph.py +1237 -0
  30. kumoai/experimental/rfm/infer/__init__.py +8 -0
  31. kumoai/experimental/rfm/infer/dtype.py +82 -0
  32. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  33. kumoai/experimental/rfm/infer/pkey.py +128 -0
  34. kumoai/experimental/rfm/infer/stype.py +35 -0
  35. kumoai/experimental/rfm/infer/time_col.py +61 -0
  36. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  37. kumoai/experimental/rfm/pquery/executor.py +27 -27
  38. kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
  39. kumoai/experimental/rfm/relbench.py +76 -0
  40. kumoai/experimental/rfm/rfm.py +386 -276
  41. kumoai/experimental/rfm/sagemaker.py +138 -0
  42. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  43. kumoai/pquery/predictive_query.py +10 -6
  44. kumoai/spcs.py +1 -3
  45. kumoai/testing/decorators.py +1 -1
  46. kumoai/testing/snow.py +50 -0
  47. kumoai/trainer/distilled_trainer.py +175 -0
  48. kumoai/trainer/trainer.py +9 -10
  49. kumoai/utils/__init__.py +3 -2
  50. kumoai/utils/display.py +51 -0
  51. kumoai/utils/progress_logger.py +188 -16
  52. kumoai/utils/sql.py +3 -0
  53. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  54. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
  55. kumoai/experimental/rfm/local_graph.py +0 -810
  56. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  57. kumoai/experimental/rfm/local_pquery_driver.py +0 -494
  58. kumoai/experimental/rfm/local_table.py +0 -545
  59. kumoai/experimental/rfm/pquery/backend.py +0 -136
  60. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  61. kumoai/experimental/rfm/utils.py +0 -344
  62. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  63. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  64. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -1,810 +0,0 @@
1
- import contextlib
2
- import io
3
- import warnings
4
- from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
-
8
- import pandas as pd
9
- from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
10
- from kumoapi.table import TableDefinition
11
- from kumoapi.typing import Stype
12
- from typing_extensions import Self
13
-
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
16
- from kumoai.graph import Edge
17
-
18
- if TYPE_CHECKING:
19
- import graphviz
20
-
21
-
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
24
- tables in a relational database.
25
-
26
- Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
28
- Kumo Relational Foundation Model (:class:`KumoRFM`).
29
-
30
- .. code-block:: python
31
-
32
- >>> # doctest: +SKIP
33
- >>> import pandas as pd
34
- >>> import kumoai.experimental.rfm as rfm
35
-
36
- >>> # Load data frames into memory:
37
- >>> df1 = pd.DataFrame(...)
38
- >>> df2 = pd.DataFrame(...)
39
- >>> df3 = pd.DataFrame(...)
40
-
41
- >>> # Define tables from data frames:
42
- >>> table1 = rfm.LocalTable(name="table1", data=df1)
43
- >>> table2 = rfm.LocalTable(name="table2", data=df2)
44
- >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
-
46
- >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
48
- ... "table1": table1,
49
- ... "table2": table2,
50
- ... "table3": table3,
51
- ... })
52
-
53
- >>> # Infer table metadata:
54
- >>> graph.infer_metadata()
55
-
56
- >>> # Infer links/edges:
57
- >>> graph.infer_links()
58
-
59
- >>> # Inspect table metadata:
60
- >>> for table in graph.tables.values():
61
- ... table.print_metadata()
62
-
63
- >>> # Visualize graph (if graphviz is installed):
64
- >>> graph.visualize()
65
-
66
- >>> # Add/Remove edges between tables:
67
- >>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
68
- >>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
69
-
70
- >>> # Validate graph:
71
- >>> graph.validate()
72
- """
73
-
74
- # Constructors ############################################################
75
-
76
- def __init__(
77
- self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
80
- ) -> None:
81
-
82
- self._tables: Dict[str, LocalTable] = {}
83
- self._edges: List[Edge] = []
84
-
85
- for table in tables:
86
- self.add_table(table)
87
-
88
- for edge in (edges or []):
89
- _edge = Edge._cast(edge)
90
- assert _edge is not None
91
- self.link(*_edge)
92
-
93
- @classmethod
94
- def from_data(
95
- cls,
96
- df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
98
- infer_metadata: bool = True,
99
- verbose: bool = True,
100
- ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
102
- :class:`pandas.DataFrame` objects.
103
-
104
- Automatically infers table metadata and links.
105
-
106
- .. code-block:: python
107
-
108
- >>> # doctest: +SKIP
109
- >>> import pandas as pd
110
- >>> import kumoai.experimental.rfm as rfm
111
-
112
- >>> # Load data frames into memory:
113
- >>> df1 = pd.DataFrame(...)
114
- >>> df2 = pd.DataFrame(...)
115
- >>> df3 = pd.DataFrame(...)
116
-
117
- >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
119
- ... "table1": df1,
120
- ... "table2": df2,
121
- ... "table3": df3,
122
- ... })
123
-
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
- Args:
132
- df_dict: A dictionary of data frames, where the keys are the names
133
- of the tables and the values hold table data.
134
- infer_metadata: Whether to infer metadata for all tables in the
135
- graph.
136
- edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
- add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
139
- verbose: Whether to print verbose output.
140
-
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
144
-
145
- Example:
146
- >>> # doctest: +SKIP
147
- >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
157
- """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
159
-
160
- graph = cls(tables, edges=edges or [])
161
-
162
- if infer_metadata:
163
- graph.infer_metadata(verbose)
164
-
165
- if edges is None:
166
- graph.infer_links(verbose)
167
-
168
- return graph
169
-
170
- # Tables ##############################################################
171
-
172
- def has_table(self, name: str) -> bool:
173
- r"""Returns ``True`` if the graph has a table with name ``name``;
174
- ``False`` otherwise.
175
- """
176
- return name in self.tables
177
-
178
- def table(self, name: str) -> LocalTable:
179
- r"""Returns the table with name ``name`` in the graph.
180
-
181
- Raises:
182
- KeyError: If ``name`` is not present in the graph.
183
- """
184
- if not self.has_table(name):
185
- raise KeyError(f"Table '{name}' not found in graph")
186
- return self.tables[name]
187
-
188
- @property
189
- def tables(self) -> Dict[str, LocalTable]:
190
- r"""Returns the dictionary of table objects."""
191
- return self._tables
192
-
193
- def add_table(self, table: LocalTable) -> Self:
194
- r"""Adds a table to the graph.
195
-
196
- Args:
197
- table: The table to add.
198
-
199
- Raises:
200
- KeyError: If a table with the same name already exists in the
201
- graph.
202
- """
203
- if table.name in self._tables:
204
- raise KeyError(f"Cannot add table with name '{table.name}' to "
205
- f"this graph; table names must be globally unique.")
206
-
207
- self._tables[table.name] = table
208
-
209
- return self
210
-
211
- def remove_table(self, name: str) -> Self:
212
- r"""Removes a table with ``name`` from the graph.
213
-
214
- Args:
215
- name: The table to remove.
216
-
217
- Raises:
218
- KeyError: If no such table is present in the graph.
219
- """
220
- if not self.has_table(name):
221
- raise KeyError(f"Table '{name}' not found in the graph")
222
-
223
- del self._tables[name]
224
-
225
- self._edges = [
226
- edge for edge in self._edges
227
- if edge.src_table != name and edge.dst_table != name
228
- ]
229
-
230
- return self
231
-
232
- @property
233
- def metadata(self) -> pd.DataFrame:
234
- r"""Returns a :class:`pandas.DataFrame` object containing metadata
235
- information about the tables in this graph.
236
-
237
- The returned dataframe has columns ``name``, ``primary_key``,
238
- ``time_column``, and ``end_time_column``, which provide an aggregate
239
- view of the properties of the tables of this graph.
240
-
241
- Example:
242
- >>> # doctest: +SKIP
243
- >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
245
- >>> graph.metadata # doctest: +SKIP
246
- name primary_key time_column end_time_column
247
- 0 users user_id - -
248
- """
249
- tables = list(self.tables.values())
250
-
251
- return pd.DataFrame({
252
- 'name':
253
- pd.Series(dtype=str, data=[t.name for t in tables]),
254
- 'primary_key':
255
- pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
256
- 'time_column':
257
- pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
258
- 'end_time_column':
259
- pd.Series(
260
- dtype=str,
261
- data=[t._end_time_column or '-' for t in tables],
262
- ),
263
- })
264
-
265
- def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
268
- from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
270
- df = self.metadata
271
- try:
272
- if hasattr(df.style, 'hide'):
273
- display(df.style.hide(axis='index')) # pandas=2
274
- else:
275
- display(df.style.hide_index()) # pandas<1.3
276
- except ImportError:
277
- print(df.to_string(index=False)) # missing jinja2
278
- else:
279
- print("🗂️ Graph Metadata:")
280
- print(self.metadata.to_string(index=False))
281
-
282
- def infer_metadata(self, verbose: bool = True) -> Self:
283
- r"""Infers metadata for all tables in the graph.
284
-
285
- Args:
286
- verbose: Whether to print verbose output.
287
-
288
- Note:
289
- For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
291
- """
292
- for table in self.tables.values():
293
- table.infer_metadata(verbose=False)
294
-
295
- if verbose:
296
- self.print_metadata()
297
-
298
- return self
299
-
300
- # Edges ###################################################################
301
-
302
- @property
303
- def edges(self) -> List[Edge]:
304
- r"""Returns the edges of the graph."""
305
- return self._edges
306
-
307
- def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
309
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
- edge.src_table, edge.fkey) for edge in self.edges]
311
- edges = sorted(edges)
312
-
313
- if in_notebook():
314
- from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
316
- if len(edges) > 0:
317
- display(
318
- Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
320
- for edge in edges
321
- ])))
322
- else:
323
- display(Markdown('*No links registered*'))
324
- else:
325
- print("🕸️ Graph Links (FK ↔️ PK):")
326
- if len(edges) > 0:
327
- print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
329
- for edge in edges
330
- ]))
331
- else:
332
- print('No links registered')
333
-
334
- def link(
335
- self,
336
- src_table: Union[str, LocalTable],
337
- fkey: str,
338
- dst_table: Union[str, LocalTable],
339
- ) -> Self:
340
- r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
- key ``fkey`` in the source table to the primary key in the destination
342
- table.
343
-
344
- The link is treated as bidirectional.
345
-
346
- Args:
347
- src_table: The name of the source table of the edge. This table
348
- must have a foreign key with name :obj:`fkey` that links to the
349
- primary key in the destination table.
350
- fkey: The name of the foreign key in the source table.
351
- dst_table: The name of the destination table of the edge. This
352
- table must have a primary key that links to the source table's
353
- foreign key.
354
-
355
- Raises:
356
- ValueError: if the edge is already present in the graph, if the
357
- source table does not exist in the graph, if the destination
358
- table does not exist in the graph, if the source key does not
359
- exist in the source table.
360
- """
361
- if isinstance(src_table, LocalTable):
362
- src_table = src_table.name
363
- assert isinstance(src_table, str)
364
-
365
- if isinstance(dst_table, LocalTable):
366
- dst_table = dst_table.name
367
- assert isinstance(dst_table, str)
368
-
369
- edge = Edge(src_table, fkey, dst_table)
370
-
371
- if edge in self.edges:
372
- raise ValueError(f"{edge} already exists in the graph")
373
-
374
- if not self.has_table(src_table):
375
- raise ValueError(f"Source table '{src_table}' does not exist in "
376
- f"the graph")
377
-
378
- if not self.has_table(dst_table):
379
- raise ValueError(f"Destination table '{dst_table}' does not exist "
380
- f"in the graph")
381
-
382
- if not self[src_table].has_column(fkey):
383
- raise ValueError(f"Source key '{fkey}' does not exist as a column "
384
- f"in source table '{src_table}'")
385
-
386
- if not Stype.ID.supports_dtype(self[src_table][fkey].dtype):
387
- raise ValueError(f"Cannot use '{fkey}' in source table "
388
- f"'{src_table}' as a foreign key due to its "
389
- f"incompatible data type. Foreign keys must have "
390
- f"data type 'int', 'float' or 'string' "
391
- f"(got '{self[src_table][fkey].dtype}')")
392
-
393
- self._edges.append(edge)
394
-
395
- return self
396
-
397
- def unlink(
398
- self,
399
- src_table: Union[str, LocalTable],
400
- fkey: str,
401
- dst_table: Union[str, LocalTable],
402
- ) -> Self:
403
- r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
-
405
- Args:
406
- src_table: The name of the source table of the edge.
407
- fkey: The name of the foreign key in the source table.
408
- dst_table: The name of the destination table of the edge.
409
-
410
- Raises:
411
- ValueError: if the edge is not present in the graph.
412
- """
413
- if isinstance(src_table, LocalTable):
414
- src_table = src_table.name
415
- assert isinstance(src_table, str)
416
-
417
- if isinstance(dst_table, LocalTable):
418
- dst_table = dst_table.name
419
- assert isinstance(dst_table, str)
420
-
421
- edge = Edge(src_table, fkey, dst_table)
422
-
423
- if edge not in self.edges:
424
- raise ValueError(f"{edge} is not present in the graph")
425
-
426
- self._edges.remove(edge)
427
-
428
- return self
429
-
430
- def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
432
-
433
- Args:
434
- verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
- """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
442
-
443
- # A list of primary key candidates (+score) for every column:
444
- candidate_dict: dict[
445
- tuple[str, str],
446
- list[tuple[str, float]],
447
- ] = defaultdict(list)
448
-
449
- for dst_table in self.tables.values():
450
- dst_key = dst_table.primary_key
451
-
452
- if dst_key is None:
453
- continue
454
-
455
- assert dst_key.dtype is not None
456
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
457
- dst_string = dst_key.dtype.is_string()
458
-
459
- dst_table_name = dst_table.name.lower()
460
- dst_key_name = dst_key.name.lower()
461
-
462
- for src_table in self.tables.values():
463
- src_table_name = src_table.name.lower()
464
-
465
- for src_key in src_table.columns:
466
- if src_key == src_table.primary_key:
467
- continue # Cannot link to primary key.
468
-
469
- src_number = (src_key.dtype.is_int()
470
- or src_key.dtype.is_float())
471
- src_string = src_key.dtype.is_string()
472
-
473
- if src_number != dst_number or src_string != dst_string:
474
- continue # Non-compatible data types.
475
-
476
- src_key_name = src_key.name.lower()
477
-
478
- score = 0.0
479
-
480
- # Name similarity:
481
- if src_key_name == dst_key_name:
482
- score += 7.0
483
- elif (dst_key_name != 'id'
484
- and src_key_name.endswith(dst_key_name)):
485
- score += 4.0
486
- elif src_key_name.endswith( # e.g., user.id -> user_id
487
- f'{dst_table_name}_{dst_key_name}'):
488
- score += 4.0
489
- elif src_key_name.endswith( # e.g., user.id -> userid
490
- f'{dst_table_name}{dst_key_name}'):
491
- score += 4.0
492
- elif (dst_table_name.endswith('s') and
493
- src_key_name.endswith( # e.g., users.id -> user_id
494
- f'{dst_table_name[:-1]}_{dst_key_name}')):
495
- score += 4.0
496
- elif (dst_table_name.endswith('s') and
497
- src_key_name.endswith( # e.g., users.id -> userid
498
- f'{dst_table_name[:-1]}{dst_key_name}')):
499
- score += 4.0
500
- elif src_key_name.endswith(dst_table_name):
501
- score += 4.0 # e.g., users -> users
502
- elif (dst_table_name.endswith('s') # e.g., users -> user
503
- and src_key_name.endswith(dst_table_name[:-1])):
504
- score += 4.0
505
- elif ((src_key_name == 'parentid'
506
- or src_key_name == 'parent_id')
507
- and src_table_name == dst_table_name):
508
- score += 2.0
509
-
510
- # `rel-bench` hard-coding :(
511
- elif (src_table.name == 'posts'
512
- and src_key.name == 'AcceptedAnswerId'
513
- and dst_table.name == 'posts'):
514
- score += 2.0
515
- elif (src_table.name == 'user_friends'
516
- and src_key.name == 'friend'
517
- and dst_table.name == 'users'):
518
- score += 3.0
519
-
520
- # For non-exact matching, at least one additional
521
- # requirement needs to be met.
522
-
523
- # Exact data type compatibility:
524
- if src_key.stype == Stype.ID:
525
- score += 2.0
526
-
527
- if src_key.dtype == dst_key.dtype:
528
- score += 1.0
529
-
530
- # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
532
- score += 1.0
533
-
534
- if score < 5.0:
535
- continue
536
-
537
- candidate_dict[(
538
- src_table.name,
539
- src_key.name,
540
- )].append((
541
- dst_table.name,
542
- score,
543
- ))
544
-
545
- for (src_table_name, src_key_name), scores in candidate_dict.items():
546
- scores.sort(key=lambda x: x[-1], reverse=True)
547
-
548
- if len(scores) > 1 and scores[0][1] == scores[1][1]:
549
- continue # Cannot uniquely infer link.
550
-
551
- dst_table_name = scores[0][0]
552
- self.link(src_table_name, src_key_name, dst_table_name)
553
-
554
- if verbose:
555
- self.print_links()
556
-
557
- return self
558
-
559
- # Metadata ################################################################
560
-
561
- def validate(self) -> Self:
562
- r"""Validates the graph to ensure that all relevant metadata is
563
- specified for its tables and edges.
564
-
565
- Concretely, validation ensures that edges properly link foreign keys to
566
- primary keys between valid tables.
567
- It additionally ensures that primary and foreign keys between tables
568
- in an :class:`~kumoai.graph.Edge` are of the same data type.
569
-
570
- Raises:
571
- ValueError: if validation fails.
572
- """
573
- if len(self.tables) == 0:
574
- raise ValueError("At least one table needs to be added to the "
575
- "graph")
576
-
577
- for edge in self.edges:
578
- src_table, fkey, dst_table = edge
579
-
580
- src_key = self[src_table][fkey]
581
- dst_key = self[dst_table].primary_key
582
-
583
- # Check that the destination table defines a primary key:
584
- if dst_key is None:
585
- raise ValueError(f"Edge {edge} is invalid since table "
586
- f"'{dst_table}' does not have a primary key. "
587
- f"Add either a primary key or remove the "
588
- f"link before proceeding.")
589
-
590
- # Ensure that foreign key is not a primary key:
591
- src_pkey = self[src_table].primary_key
592
- if src_pkey is not None and src_pkey.name == fkey:
593
- raise ValueError(f"Cannot treat the primary key of table "
594
- f"'{src_table}' as a foreign key. Remove "
595
- f"either the primary key or the link before "
596
- f"before proceeding.")
597
-
598
- # Check that fkey/pkey have valid and consistent data types:
599
- assert src_key.dtype is not None
600
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
601
- src_string = src_key.dtype.is_string()
602
- assert dst_key.dtype is not None
603
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
604
- dst_string = dst_key.dtype.is_string()
605
-
606
- if not src_number and not src_string:
607
- raise ValueError(f"{edge} is invalid as foreign key must be a "
608
- f"number or string (got '{src_key.dtype}'")
609
-
610
- if src_number != dst_number or src_string != dst_string:
611
- raise ValueError(f"{edge} is invalid as foreign key "
612
- f"'{fkey}' and primary key '{dst_key.name}' "
613
- f"have incompatible data types (got "
614
- f"fkey.dtype '{src_key.dtype}' and "
615
- f"pkey.dtype '{dst_key.dtype}')")
616
-
617
- return self
618
-
619
- # Visualization ###########################################################
620
-
621
- def visualize(
622
- self,
623
- path: Optional[Union[str, io.BytesIO]] = None,
624
- show_columns: bool = True,
625
- ) -> 'graphviz.Graph':
626
- r"""Visualizes the tables and edges in this graph using the
627
- :class:`graphviz` library.
628
-
629
- Args:
630
- path: A path to write the produced image to. If ``None``, the image
631
- will not be written to disk.
632
- show_columns: Whether to show all columns of every table in the
633
- graph. If ``False``, will only show the primary key, foreign
634
- key(s), and time column of each table.
635
-
636
- Returns:
637
- A ``graphviz.Graph`` instance representing the visualized graph.
638
- """
639
- def has_graphviz_executables() -> bool:
640
- import graphviz
641
- try:
642
- graphviz.Digraph().pipe()
643
- except graphviz.backend.ExecutableNotFound:
644
- return False
645
-
646
- return True
647
-
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
653
- raise RuntimeError("Could not visualize graph as 'graphviz' "
654
- "executables are not installed. These "
655
- "dependencies are required in addition to the "
656
- "'graphviz' Python package. Please install "
657
- "them as described at "
658
- "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
-
662
- format: Optional[str] = None
663
- if isinstance(path, str):
664
- format = path.split('.')[-1]
665
- elif isinstance(path, io.BytesIO):
666
- format = 'svg'
667
- graph = graphviz.Graph(format=format)
668
-
669
- def left_align(keys: List[str]) -> str:
670
- if len(keys) == 0:
671
- return ""
672
- return '\\l'.join(keys) + '\\l'
673
-
674
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
675
- for src_table_name, fkey_name, _ in self.edges:
676
- fkeys_dict[src_table_name].append(fkey_name)
677
-
678
- for table_name, table in self.tables.items():
679
- keys = []
680
- if primary_key := table.primary_key:
681
- keys += [f'{primary_key.name}: PK ({primary_key.dtype})']
682
- keys += [
683
- f'{fkey_name}: FK ({self[table_name][fkey_name].dtype})'
684
- for fkey_name in fkeys_dict[table_name]
685
- ]
686
- if time_column := table.time_column:
687
- keys += [f'{time_column.name}: Time ({time_column.dtype})']
688
- if end_time_column := table.end_time_column:
689
- keys += [
690
- f'{end_time_column.name}: '
691
- f'End Time ({end_time_column.dtype})'
692
- ]
693
- key_repr = left_align(keys)
694
-
695
- columns = []
696
- if show_columns:
697
- columns += [
698
- f'{column.name}: {column.stype} ({column.dtype})'
699
- for column in table.columns
700
- if column.name not in fkeys_dict[table_name] and
701
- column.name != table._primary_key and column.name != table.
702
- _time_column and column.name != table._end_time_column
703
- ]
704
- column_repr = left_align(columns)
705
-
706
- if len(keys) > 0 and len(columns) > 0:
707
- label = f'{{{table_name}|{key_repr}|{column_repr}}}'
708
- elif len(keys) > 0:
709
- label = f'{{{table_name}|{key_repr}}}'
710
- elif len(columns) > 0:
711
- label = f'{{{table_name}|{column_repr}}}'
712
- else:
713
- label = f'{{{table_name}}}'
714
-
715
- graph.node(table_name, shape='record', label=label)
716
-
717
- for src_table_name, fkey_name, dst_table_name in self.edges:
718
- if self[dst_table_name]._primary_key is None:
719
- continue # Invalid edge.
720
-
721
- pkey_name = self[dst_table_name]._primary_key
722
-
723
- if fkey_name != pkey_name:
724
- label = f' {fkey_name}\n< >\n{pkey_name} '
725
- else:
726
- label = f' {fkey_name} '
727
-
728
- graph.edge(
729
- src_table_name,
730
- dst_table_name,
731
- label=label,
732
- headlabel='1',
733
- taillabel='*',
734
- minlen='2',
735
- fontsize='11pt',
736
- labeldistance='1.5',
737
- )
738
-
739
- if isinstance(path, str):
740
- path = '.'.join(path.split('.')[:-1])
741
- graph.render(path, cleanup=True)
742
- elif isinstance(path, io.BytesIO):
743
- path.write(graph.pipe())
744
- elif in_notebook():
745
- from IPython.display import display
746
- display(graph)
747
- else:
748
- try:
749
- stderr_buffer = io.StringIO()
750
- with contextlib.redirect_stderr(stderr_buffer):
751
- graph.view(cleanup=True)
752
- if stderr_buffer.getvalue():
753
- warnings.warn("Could not visualize graph since your "
754
- "system does not know how to open or "
755
- "display PDF files from the command line. "
756
- "Please specify 'visualize(path=...)' and "
757
- "open the generated file yourself.")
758
- except Exception as e:
759
- warnings.warn(f"Could not visualize graph due to an "
760
- f"unexpected error in 'graphviz'. Error: {e}")
761
-
762
- return graph
763
-
764
- # Helpers #################################################################
765
-
766
- def _to_api_graph_definition(self) -> GraphDefinition:
767
- tables: Dict[str, TableDefinition] = {}
768
- col_groups: List[ColumnKeyGroup] = []
769
- for table_name, table in self.tables.items():
770
- tables[table_name] = table._to_api_table_definition()
771
- if table.primary_key is None:
772
- continue
773
- keys = [ColumnKey(table_name, table.primary_key.name)]
774
- for edge in self.edges:
775
- if edge.dst_table == table_name:
776
- keys.append(ColumnKey(edge.src_table, edge.fkey))
777
- keys = sorted(
778
- list(set(keys)),
779
- key=lambda x: f'{x.table_name}.{x.col_name}',
780
- )
781
- if len(keys) > 1:
782
- col_groups.append(ColumnKeyGroup(keys))
783
- return GraphDefinition(tables, col_groups)
784
-
785
- # Class properties ########################################################
786
-
787
- def __hash__(self) -> int:
788
- return hash((tuple(self.edges), tuple(sorted(self.tables.keys()))))
789
-
790
- def __contains__(self, name: str) -> bool:
791
- return self.has_table(name)
792
-
793
- def __getitem__(self, name: str) -> LocalTable:
794
- return self.table(name)
795
-
796
- def __delitem__(self, name: str) -> None:
797
- self.remove_table(name)
798
-
799
- def __repr__(self) -> str:
800
- tables = '\n'.join(f' {table},' for table in self.tables)
801
- tables = f'[\n{tables}\n ]' if len(tables) > 0 else '[]'
802
- edges = '\n'.join(
803
- f' {edge.src_table}.{edge.fkey}'
804
- f' ⇔ {edge.dst_table}.{self[edge.dst_table]._primary_key},'
805
- for edge in self.edges)
806
- edges = f'[\n{edges}\n ]' if len(edges) > 0 else '[]'
807
- return (f'{self.__class__.__name__}(\n'
808
- f' tables={tables},\n'
809
- f' edges={edges},\n'
810
- f')')