kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/graph/graph.py ADDED
@@ -0,0 +1,948 @@
1
+ import copy
2
+ import io
3
+ import logging
4
+ import time
5
+ from dataclasses import dataclass
6
+ from importlib.util import find_spec
7
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
8
+
9
+ if TYPE_CHECKING:
10
+ import graphviz
11
+
12
+ import kumoapi.data_snapshot as snapshot_api
13
+ import kumoapi.graph as api
14
+ from kumoapi.common import JobStatus
15
+ from kumoapi.data_snapshot import GraphSnapshotID
16
+ from tqdm.auto import tqdm
17
+ from typing_extensions import Self
18
+
19
+ from kumoai import global_state
20
+ from kumoai.client.graph import GraphID
21
+ from kumoai.graph.table import Table
22
+ from kumoai.mixin import CastMixin
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _DEFAULT_INTERVAL_S = 20
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class Edge(CastMixin, api.Edge):
31
+ r"""An edge represents a relationship between two tables in a
32
+ :class:`~kumoai.graph.Graph`. Note that edges are **always** bidirectional
33
+ within the Kumo platform.
34
+
35
+ Args:
36
+ src_table: The name of the source table of the edge. This table must
37
+ have a foreign key with name :obj:`fkey` that links to the primary
38
+ key in the destination table.
39
+ fkey: The name of the foreign key in the source table.
40
+ dst_table: The name of the destination table in the graph. This table
41
+ must have a primary key that links to the
42
+ source table's foreign key.
43
+
44
+ Example:
45
+ >>> import kumoai
46
+ >>> edge = kumoai.Edge("table_with_fkey", "fkey", "table_with_pkey")
47
+ """
48
+ def __iter__(self) -> Iterator[str]:
49
+ # Allows unwrapping an edge via `src_table, fkey, dst_table = edge`
50
+ return iter((self.src_table, self.fkey, self.dst_table))
51
+
52
+ def __hash__(self) -> int:
53
+ return hash((self.src_table, self.fkey, self.dst_table))
54
+
55
+ @property
56
+ def _fully_qualified_name(self) -> str:
57
+ return f"{self.src_table}.{self.fkey}.{self.dst_table}"
58
+
59
+
60
+ @dataclass
61
+ class GraphHealthStats:
62
+ r"""Graph health statistics contain important statistics that represent the
63
+ healthiness of each defined edge in a graph. These statistics are computed
64
+ as part of a :class:`~kumoai.graph.Graph` snapshot, and can be fetched by
65
+ indexing with an :class:`~kumoai.graph.graph.Edge` object.
66
+ """
67
+ _stats: Dict[str, api.EdgeHealthStatistics]
68
+
69
+ def __init__(self, stats: Dict[str, api.EdgeHealthStatistics]):
70
+ self._stats = stats
71
+
72
+ def __getitem__(self, key: Edge) -> api.EdgeHealthStatistics:
73
+ return self._stats[key._fully_qualified_name]
74
+
75
+ def __repr__(self) -> str:
76
+ representation = "GraphHealthStats\n"
77
+ for key, stats in self._stats.items():
78
+ src, fkey, dst = key.split('.')
79
+ representation += (f" - Edge({src} ({fkey})-> {dst}) \n")
80
+ representation += (f" - {stats.total_num_edges} total edges\n")
81
+ representation += (
82
+ f" - {int(stats.absolute_match_stats.src_in_dst)} "
83
+ f"({round(stats.percent_match_stats.src_in_dst, 2)}%) rows "
84
+ f"in {src} have valid edges to {dst}\n")
85
+ representation += (
86
+ f" - {int(stats.absolute_match_stats.dst_in_src)} "
87
+ f"({round(stats.percent_match_stats.dst_in_src, 2)}%) rows "
88
+ f"in {dst} have valid edges to {src}\n")
89
+ return representation
90
+
91
+
92
+ class Graph:
93
+ r"""A graph defines the relationships between a set of Kumo tables, akin
94
+ to relationships between tables in a relational database. Creating a graph
95
+ is the final step of data definition in Kumo; after a graph is created, you
96
+ are ready to write a :class:`~kumoai.pquery.PredictiveQuery` and train a
97
+ predictive model.
98
+
99
+
100
+ .. code-block:: python
101
+
102
+ import kumoai
103
+
104
+ # Define connector to source data:
105
+ connector = kumoai.S3Connector('s3://...')
106
+
107
+ # Create Kumo Tables. See Table documentation for more information:
108
+ customer = kumoai.Table(...)
109
+ article = kumoai.Table(...)
110
+ transaction = kumoai.Table(...)
111
+
112
+ # Create a graph:
113
+ graph = kumo.Graph(
114
+ # These are the tables that participate in the graph: the keys of this
115
+ # dictionary are the names of the tables, and the values are the Table
116
+ # objects that correspond to these names:
117
+ tables={
118
+ 'customer': customer,
119
+ 'stock': stock,
120
+ 'transaction': transaction,
121
+ },
122
+
123
+ # These are the edges that define the primary key / foreign key
124
+ # relationships between the tables defined above. Here, `src_table`
125
+ # is the table that has the foreign key `fkey`, which maps to the
126
+ # table `dst_table`'s primary key:`
127
+ edges=[
128
+ dict(src_table='transaction', fkey='StockCode', dst_table='stock'),
129
+ dict(src_table='transaction', fkey='CustomerID', dst_table='customer'),
130
+ ],
131
+ )
132
+
133
+ # Validate the graph configuration, for use in Kumo downstream models:
134
+ graph.validate(verbose=True)
135
+
136
+ # Visualize the graph:
137
+ graph.visualize()
138
+
139
+ # Fetch the statistics of the tables in this graph (this method will
140
+ # take a graph snapshot, and as a result may have high latency):
141
+ graph.get_table_stats(wait_for="minimal")
142
+
143
+ # Fetch link health statistics (this method will
144
+ # take a graph snapshot, and as a result may have high latency):
145
+ graph.get_edge_stats(non_blocking=Falsej)
146
+
147
+ Args:
148
+ tables: The tables in the graph, represented as a dictionary mapping
149
+ unique table names (within the context of this graph) to the
150
+ :class:`~kumoai.graph.Table` definition for the table.
151
+ edges: The edges (relationships) between the :obj:`tables` in the
152
+ graph. Edges must specify the source table, foreign key, and
153
+ destination table that they link.
154
+
155
+ .. # noqa: E501
156
+ """
157
+ def __init__(
158
+ self,
159
+ tables: Optional[Dict[str, Table]] = None,
160
+ edges: Optional[List[Edge]] = None,
161
+ ) -> None:
162
+ self._tables: Dict[str, Table] = {}
163
+ self._edges: List[Edge] = []
164
+
165
+ for name, table in (tables or {}).items():
166
+ self.add_table(name, table)
167
+
168
+ for edge in (edges or []):
169
+ self.link(Edge._cast(edge))
170
+
171
+ # Cached from backend:
172
+ self._graph_snapshot_id: Optional[GraphSnapshotID] = None
173
+
174
+ def print_definition(self) -> None:
175
+ r"""Prints the full definition for this graph; the definition uses
176
+ placeholder names in place of `kumoai.graph.Table` variables. Copy and
177
+ paste this definition, modify the table variable names to re-create
178
+ the original graph.
179
+
180
+ Example:
181
+ >>> import kumoai
182
+ >>> graph = kumoai.Graph(...) # doctest: +SKIP
183
+ >>> graph.print_definition() # doctest: +SKIP
184
+ Graph(
185
+ tables={
186
+ 'table-1' : <table-1>,
187
+ 'table-2' : <table-2>,
188
+ ...
189
+ 'table-N' : <table-N>,
190
+ },
191
+ edges=[
192
+ Edge(src_table='table-A', fkey='fkey-AD', dst_table='table-D'),
193
+ Edge(src_table='table-B', fkey='fkey-BE', dst_table='table-E'),
194
+ ...
195
+ Edge(src_table='table-C', fkey='fkey-CF', dst_table='table-F'),
196
+ ],
197
+ )
198
+ """
199
+ definition = f"{self.__class__.__name__}(\n"
200
+ definition += " tables={"
201
+ for table in self._tables.keys():
202
+ definition += f"\n '{table}' : <{table}>,"
203
+ definition += "\n },\n"
204
+ definition += " edges=["
205
+ for edge in self._edges:
206
+ src_table, fkey, dst_table = edge
207
+ definition += (
208
+ f"\n {edge.__class__.__name__}(src_table='{src_table}', "
209
+ f"fkey='{fkey}', dst_table='{dst_table}'),")
210
+ definition += "\n ],\n)"
211
+ print(definition)
212
+
213
+ # Properties ##############################################################
214
+
215
+ @property
216
+ def id(self) -> str:
217
+ r"""Returns the unique ID for this graph, determined from its
218
+ schema and the schemas of the tables and columns that it contains. Two
219
+ graphs with any differences in their constituent tables or columns are
220
+ guaranteed to have unique identifiers.
221
+ """
222
+ return self.save()
223
+
224
+ # Save / load #############################################################
225
+
226
+ def _to_api_graph_definition(self) -> api.GraphDefinition:
227
+ col_groups_by_dst_table: Dict[str, List[api.ColumnKey]] = dict()
228
+ for edge in self.edges:
229
+ dst_pkey = self[edge.dst_table].primary_key
230
+ if dst_pkey is None:
231
+ raise ValueError(
232
+ f"The destination table {edge.dst_table} of edge "
233
+ f"{edge} does not have a primary key.")
234
+ if edge.dst_table not in col_groups_by_dst_table:
235
+ col_groups_by_dst_table[edge.dst_table] = [
236
+ api.ColumnKey(edge.dst_table, dst_pkey.name)
237
+ ]
238
+ col_groups_by_dst_table[edge.dst_table].append(
239
+ api.ColumnKey(edge.src_table, edge.fkey))
240
+
241
+ return api.GraphDefinition(
242
+ tables={
243
+ table_name: table._to_api_table_definition()
244
+ for table_name, table in self.tables.items()
245
+ },
246
+ col_groups=[
247
+ api.ColumnKeyGroup(columns=tuple(col_keys))
248
+ for col_keys in col_groups_by_dst_table.values()
249
+ ],
250
+ )
251
+
252
+ @staticmethod
253
+ def _edges_from_api_graph_definition(
254
+ graph_definition: api.GraphDefinition) -> List[Edge]:
255
+ edges: List[Edge] = []
256
+ for col_group in graph_definition.col_groups:
257
+ pkey_col = None
258
+ for col in col_group.columns:
259
+ table_def = graph_definition.tables[col.table_name]
260
+ if col.col_name == table_def.pkey:
261
+ pkey_col = col
262
+ break
263
+ assert pkey_col is not None
264
+ for col in col_group.columns:
265
+ if col != pkey_col:
266
+ edges.append(
267
+ Edge(src_table=col.table_name, fkey=col.col_name,
268
+ dst_table=pkey_col.table_name))
269
+
270
+ return edges
271
+
272
+ @staticmethod
273
+ def _from_api_graph_definition(
274
+ graph_definition: api.GraphDefinition) -> 'Graph':
275
+ tables = {
276
+ k: Table._from_api_table_definition(v)
277
+ for k, v in graph_definition.tables.items()
278
+ }
279
+ edges = Graph._edges_from_api_graph_definition(graph_definition)
280
+ return Graph(tables, edges)
281
+
282
+ def save(
283
+ self,
284
+ name: Optional[str] = None,
285
+ skip_validation: bool = False,
286
+ ) -> Union[GraphID, str]:
287
+ r"""Associates this graph with a unique name, that can later be
288
+ used to fetch the graph either in the Kumo UI or in the Kumo SDK
289
+ with method :meth:`~kumoai.Graph.load`.
290
+
291
+ Args:
292
+ name: The name to associate with this table definition. If the
293
+ name is already associated with another table, that table will
294
+ be overridden.
295
+ skip_validation: Whether to skip validation of the graph. If
296
+ :obj:`True`, validation will be skipped, but saving an invalid
297
+ graph may result in undefined behavior.
298
+ If :obj:`False`, the graph will be validated before saving.
299
+
300
+ Example:
301
+ >>> import kumoai
302
+ >>> graph = kumoai.Graph(...) # doctest: +SKIP
303
+ >>> graph.save() # doctest: +SKIP
304
+ graph-xxx
305
+ >>> graph.save("template_name") # doctest: +SKIP
306
+ >>> loaded = kumoai.Graph.load("template_name") # doctest: +SKIP
307
+ """
308
+ if not skip_validation:
309
+ self.validate(verbose=False)
310
+
311
+ template_resource = (global_state.client.graph_api.get_graph_if_exists(
312
+ graph_id_or_name=name)) if name else None
313
+ if template_resource is not None:
314
+ config = self._from_api_graph_definition(template_resource.graph)
315
+ logger.warning(
316
+ ("Graph template %s already exists, with configuration %s. "
317
+ "This template will be overridden with configuration %s."),
318
+ name, str(config), str(self))
319
+
320
+ # Save as named template
321
+ return global_state.client.graph_api.create_graph(
322
+ graph_def=self._to_api_graph_definition(),
323
+ force_rename=True if name else False,
324
+ name_alias=name,
325
+ )
326
+
327
+ @classmethod
328
+ def load(cls, graph_id_or_template: str) -> 'Graph':
329
+ r"""Loads a graph from either a graph ID or a named template. Returns a
330
+ :class:`Graph` object that contains the loaded graph along with its
331
+ associated tables, columns, etc.
332
+ """
333
+ api = global_state.client.graph_api
334
+ res = api.get_graph_if_exists(graph_id_or_template)
335
+ if not res:
336
+ raise ValueError(f"Graph {graph_id_or_template} was not found.")
337
+ out = cls._from_api_graph_definition(res.graph)
338
+ return out
339
+
340
+ # Snapshot ################################################################
341
+
342
+ @property
343
+ def snapshot_id(self) -> Optional[snapshot_api.GraphSnapshotID]:
344
+ r"""Returns the snapshot ID of this graph's snapshot, if a snapshot
345
+ has been taken. Returns `None` otherwise.
346
+
347
+ .. warning::
348
+ This function currently only returns a snapshot ID if a snapshot
349
+ has been taken *in this session.*
350
+ """
351
+ return self._graph_snapshot_id
352
+
353
+ def snapshot(
354
+ self,
355
+ *,
356
+ force_refresh: bool = False,
357
+ non_blocking: bool = False,
358
+ ) -> snapshot_api.GraphSnapshotID:
359
+ r"""Takes a *snapshot* of this graph's underlying data, and returns a
360
+ unique identifier for this snapshot.
361
+
362
+ This is equivalent to taking a snapshot for each constituent table in
363
+ the graph. For more information, please see the documentation for
364
+ :meth:`~kumoai.graph.Table.snapshot`.
365
+
366
+ .. warning::
367
+ Please familiarize yourself with the warnings for this method in
368
+ :class:`~kumoai.graph.Table` before proceeding.
369
+
370
+ Args:
371
+ force_refresh: Indicates whether a snapshot should be taken, if one
372
+ already exists in Kumo. If :obj:`False`, a previously existing
373
+ snapshot may be re-used. If :obj:`True`, a new snapshot is
374
+ always taken.
375
+ non_blocking: Whether this operation should return immediately
376
+ after creating the snapshot, or await completion of the
377
+ snapshot. If :obj:`True`, the snapshot will proceed in the
378
+ background, and will be used for any downstream job.
379
+
380
+ Raises:
381
+ RuntimeError: if ``non_blocking`` is set to :obj:`False` and the
382
+ graph snapshot fails.
383
+ """
384
+ if self._graph_snapshot_id is None or force_refresh:
385
+ self.save()
386
+ if not force_refresh:
387
+ snapshotted_table_names: List[str] = []
388
+ for table_name, table in self.tables.items():
389
+ if table.snapshot_id is not None:
390
+ snapshotted_table_names.append(table_name)
391
+ if len(snapshotted_table_names) > 0:
392
+ logger.warning(
393
+ "Tables %s have already been snapshot, and will not "
394
+ "be refreshed. If you would like to refresh all "
395
+ "tables, please set 'force_refresh=True'.",
396
+ snapshotted_table_names)
397
+
398
+ self._graph_snapshot_id = (
399
+ global_state.client.graph_api.create_snapshot(
400
+ graph_id=self.id,
401
+ refresh_source=True,
402
+ ))
403
+ logger.info("Graph snapshot with identifier %s created.",
404
+ self._graph_snapshot_id)
405
+
406
+ # Perform initial GET to update table snapshot IDs:
407
+ graph_resource: snapshot_api.GraphSnapshotResource = (
408
+ global_state.client.graph_api.get_snapshot(
409
+ snapshot_id=self._graph_snapshot_id))
410
+ for table_name, table_id in graph_resource.table_ids.items():
411
+ self[table_name]._table_snapshot_id = table_id
412
+
413
+ # NOTE we do not use a `KumoFuture` here as we do not want to treat
414
+ # a graph refresh as having its own state; since we only ever
415
+ # operate on the latest graph version (and do not let users to time
416
+ # travel), there is no need for a separate Future object:
417
+ if not non_blocking:
418
+ stage = snapshot_api.GraphSnapshotStage.INGEST
419
+ table_status: Dict[str, JobStatus] = {
420
+ table_name: JobStatus.NOT_STARTED
421
+ for table_name in self.tables
422
+ }
423
+
424
+ # Increment progress bar with table refresh stages:
425
+ done = [status.is_terminal for status in table_status.values()]
426
+ graph_done = False
427
+ if logger.isEnabledFor(logging.INFO):
428
+ pbar = tqdm(total=len(done), unit="table",
429
+ desc="Ingesting")
430
+ while not (all(done) and graph_done):
431
+ graph_resource = (
432
+ global_state.client.graph_api.get_snapshot(
433
+ snapshot_id=self._graph_snapshot_id))
434
+ for table_name, table_id in graph_resource.table_ids.items(
435
+ ):
436
+ resource = (global_state.client.table_api.get_snapshot(
437
+ snapshot_id=table_id))
438
+ table_status[table_name] = resource.stages[
439
+ stage].status
440
+ done = [
441
+ status.is_terminal for status in table_status.values()
442
+ ]
443
+ graph_done = graph_resource.stages[
444
+ stage].status.is_terminal
445
+ if logger.isEnabledFor(logging.INFO):
446
+ pbar.update(sum(done) - pbar.n)
447
+ time.sleep(_DEFAULT_INTERVAL_S)
448
+ if logger.isEnabledFor(logging.INFO):
449
+ pbar.update(len(done) - pbar.n)
450
+ pbar.close()
451
+
452
+ state = graph_resource.stages[stage]
453
+ status = state.status
454
+ warnings = "\n".join([
455
+ f"{i}. {message}"
456
+ for i, message in enumerate(state.warnings)
457
+ ])
458
+ error = state.error
459
+ if status == JobStatus.FAILED:
460
+ raise RuntimeError(
461
+ f"Graph snapshot with identifier "
462
+ f"{self._graph_snapshot_id} failed, with error "
463
+ f"{error} and warnings {warnings}")
464
+ if len(state.warnings) > 0:
465
+ logger.warning(
466
+ "Graph snapshot completed with the following "
467
+ "warnings: %s", warnings)
468
+ else:
469
+ logger.warning(
470
+ "Graph snapshot with identifier %s already exists, and will "
471
+ "not be refreshed.", self._graph_snapshot_id)
472
+
473
+ # <prefix>@<data_version>:
474
+ assert self._graph_snapshot_id is not None
475
+ return self._graph_snapshot_id
476
+
477
+ # Statistics ##############################################################
478
+
479
+ def get_table_stats(
480
+ self,
481
+ wait_for: Optional[str] = None,
482
+ ) -> Dict[str, Dict[str, Any]]:
483
+ r"""Returns all currently computed statistics on the latest snapshot of
484
+ this graph. If a snapshot on this graph has not been taken, this method
485
+ will take a snapshot.
486
+
487
+ .. note::
488
+ Graph statistics are computed in multiple stages after ingestion is
489
+ complete. These stages are called *minimal* and *full*; minimal
490
+ statistics are always computed before full statistics.
491
+
492
+ Args:
493
+ wait_for: Whether this operation should block on the existence of
494
+ statistics availability. This argument can take one of three
495
+ values: :obj:`None`, which indicates that the method should
496
+ return immediately with whatever statistics are present,
497
+ :obj:`"minimal"`, which indicates that the method should return
498
+ the when the minimum, maximum, and fraction of NA values
499
+ statistics are present, or :obj:`"full"`, which indicates that
500
+ the method should return when all computed statistics are
501
+ present.
502
+ """
503
+ assert wait_for is None or wait_for in {"minimal", "full"}
504
+
505
+ # Wait for graph ingestion to be done:
506
+ if not self._graph_snapshot_id:
507
+ self.snapshot(force_refresh=False, non_blocking=False)
508
+ assert self._graph_snapshot_id is not None
509
+
510
+ # Wait for all table snapshots to match the `wait_for` stage, if
511
+ # we support that:
512
+ if wait_for:
513
+ if wait_for == "minimal":
514
+ stage = snapshot_api.TableSnapshotStage.MIN_COL_STATS
515
+ else:
516
+ stage = snapshot_api.TableSnapshotStage.FULL_COL_STATS
517
+
518
+ table_status: Dict[str, JobStatus] = {
519
+ table_name: JobStatus.NOT_STARTED
520
+ for table_name in self.tables
521
+ }
522
+ done = [status.is_terminal for status in table_status.values()]
523
+ if logger.isEnabledFor(logging.INFO):
524
+ pbar = tqdm(total=len(done), unit="table",
525
+ desc="Computing Statistics")
526
+ while not all(done):
527
+ for table_name, table in self.tables.items():
528
+ resource = (global_state.client.table_api.get_snapshot(
529
+ snapshot_id=table._table_snapshot_id))
530
+ table_status[table_name] = resource.stages[stage].status
531
+ done = [status.is_terminal for status in table_status.values()]
532
+ if logger.isEnabledFor(logging.INFO):
533
+ pbar.update(sum(done) - pbar.n)
534
+ time.sleep(_DEFAULT_INTERVAL_S)
535
+ if logger.isEnabledFor(logging.INFO):
536
+ pbar.update(len(done) - pbar.n)
537
+ pbar.close()
538
+
539
+ # Write out statistics:
540
+ out = {}
541
+ for table_name, table in self.tables.items():
542
+ resource = (global_state.client.table_api.get_snapshot(
543
+ snapshot_id=table._table_snapshot_id))
544
+ out[table_name] = {
545
+ stat.column_name: stat.stats
546
+ for stat in resource.column_stats
547
+ }
548
+ return out
549
+
550
+ def get_edge_stats(
551
+ self,
552
+ *,
553
+ non_blocking: bool = False,
554
+ ) -> Optional[GraphHealthStats]:
555
+ """Retrieves edge health statistics for the edges in a graph, if these
556
+ statistics have been computed by a graph snapshot.
557
+
558
+ Edge health statistics are returned in a
559
+ :class:`~kumoai.graph.GraphHealthStats` object, and contain information
560
+ about the match rate between primary key / foreign key relationships
561
+ between the tables in the graph.
562
+
563
+ Args:
564
+ non_blocking: Whether this operation should return immediately
565
+ after querying edge statistics (returning `None` if statistics
566
+ are not available), or await completion of statistics
567
+ computation.
568
+ """
569
+ if self._graph_snapshot_id is None:
570
+ raise ValueError('In order to calculate edge health statistics, '
571
+ 'you must first create a snapshot of the graph '
572
+ 'on which to calculate match statistics for each '
573
+ 'edge. Please call Graph.snapshot() and then '
574
+ 'this function.')
575
+
576
+ edge_health_response = global_state.client.graph_api.get_edge_stats(
577
+ graph_snapshot_id=self._graph_snapshot_id)
578
+
579
+ if non_blocking:
580
+ if not edge_health_response.is_ready:
581
+ return None
582
+ else:
583
+ while not edge_health_response.is_ready:
584
+ edge_health_response = (
585
+ global_state.client.graph_api.get_edge_stats(
586
+ graph_snapshot_id=self._graph_snapshot_id))
587
+
588
+ return GraphHealthStats(edge_health_response.statistics)
589
+
590
+ # Tables ##################################################################
591
+
592
+ def has_table(self, name: str) -> bool:
593
+ r"""Returns True if a table by `name` is present in this Graph."""
594
+ return name in self._tables
595
+
596
+ def table(self, name: str) -> Table:
597
+ r"""Returns a table in this Kumo Graph.
598
+
599
+ Raises:
600
+ KeyError: if no such table is present.
601
+ """
602
+ if name not in self._tables:
603
+ raise KeyError(f"Table '{name}' not found in this graph.")
604
+ return self._tables[name]
605
+
606
+ def add_table(self, name: str, table: Table) -> 'Graph':
607
+ r"""Adds a table to this Kumo Graph.
608
+
609
+ Raises:
610
+ KeyError: if a table with the same name already exists in this
611
+ graph.
612
+ """
613
+ if name in self._tables:
614
+ raise KeyError(
615
+ f"Cannot add table with name '{name}' to this graph; names "
616
+ f"must be globally unique within a graph.")
617
+ self._tables[name] = table
618
+ return self
619
+
620
+ def remove_table(self, name: str) -> Self:
621
+ r"""Removes a table from this graph.
622
+
623
+ Raises:
624
+ KeyError: if no such table is present.
625
+ """
626
+ if not self.has_table(name):
627
+ raise KeyError(f"Table '{name}' not found in this graph.'")
628
+
629
+ del self._tables[name]
630
+ self._edges = [
631
+ edge for edge in self._edges
632
+ if edge.src_table != name and edge.dst_table != name
633
+ ]
634
+ return self
635
+
636
+ @property
637
+ def tables(self) -> Dict[str, Table]:
638
+ r"""Returns a list of all :class:`~kumoai.graph.Table` objects that
639
+ are contained in this graph.
640
+ """
641
+ return self._tables
642
+
643
+ def infer_metadata(self, inplace: bool = True) -> 'Graph':
644
+ r"""Infers metadata for the tables in this Graph, by inferring the
645
+ metadata of each table in the graph. For more information, please
646
+ see the documentation for
647
+ :meth:`~kumoai.table.Table.infer_metadata`.
648
+ """
649
+ out = self
650
+ if not inplace:
651
+ out = copy.deepcopy(self)
652
+
653
+ for table in out.tables.values():
654
+ table.infer_metadata(inplace=True)
655
+ return out
656
+
657
+ # Edges ###################################################################
658
+
659
+ def infer_links(self) -> 'Graph':
660
+ r"""Infers edges for the tables in this Graph. It adds edges to the
661
+ graph.
662
+
663
+ Note that the function only works if the graph edges are empty.
664
+ """
665
+ if self._edges is not None and len(self._edges) > 0:
666
+ raise ValueError(
667
+ "Cannot infer links if graph edges are not empty.")
668
+
669
+ graph_def_with_col_groups = global_state.client.graph_api.infer_links(
670
+ graph=self._to_api_graph_definition())
671
+
672
+ edges = Graph._edges_from_api_graph_definition(
673
+ graph_def_with_col_groups)
674
+
675
+ for edge in (edges or []):
676
+ logger.info("Inferring edge: %s", edge)
677
+ self.link(Edge._cast(edge))
678
+ return self
679
+
680
+ def link(self, *args: Optional[Union[str, Edge]],
681
+ **kwargs: str) -> 'Graph':
682
+ r"""Links two tables (:obj:`src_table` and :obj:`dst_table`) from the
683
+ foreign key :obj:`fkey` in the source table to the primary key in the
684
+ destination table. These edges are treated bidirectionally in Kumo.
685
+
686
+ Args:
687
+ *args: Any arguments to construct a
688
+ :class:`kumoai.graph.Edge`, or a :class:`kumoai.graph.Edge`
689
+ itself.
690
+ **kwargs: Any keyword arguments to construct a
691
+ :class:`kumoai.graph.Edge`.
692
+
693
+ Raises:
694
+ ValueError: if the edge is already present in the graph, if the
695
+ source table does not exist in the graph, if the destination
696
+ table does not exist in the graph, if the source key does not
697
+ exist in the source table, or if the primary key of the source
698
+ table is being treated as a foreign key.
699
+ """
700
+ edge = Edge._cast(*args, **kwargs)
701
+ if edge is None:
702
+ raise ValueError("Cannot add a 'None' edge to a graph.")
703
+
704
+ (src_table, fkey, dst_table) = edge
705
+
706
+ if edge in self._edges:
707
+ raise ValueError(f"Cannot add edge {edge} to graph; edge is "
708
+ f"already present.")
709
+
710
+ if src_table not in self._tables:
711
+ raise ValueError(
712
+ f"Source table '{src_table}' does not exist in the graph. "
713
+ f"Please add it via `Graph.add_table(...)` before proceeding.")
714
+
715
+ if dst_table not in self._tables:
716
+ raise ValueError(
717
+ f"Destination table '{dst_table}' does not exist in the "
718
+ f"graph. Please add it via `Graph.add_table(...)` before "
719
+ f"proceeding.")
720
+
721
+ if fkey not in self._tables[src_table]:
722
+ raise ValueError(
723
+ f"Source key '{fkey}' does not exist in source table "
724
+ f"'{src_table}'; please check that you have added it as a "
725
+ f"column.")
726
+
727
+ # Backend limitations: ensure the source is not its primary key:
728
+ src_pkey = self.table(src_table).primary_key
729
+ src_is_pkey = src_pkey is not None and src_pkey.name == fkey
730
+ if src_is_pkey:
731
+ raise ValueError(f"Cannot treat the primary key of table "
732
+ f"'{src_table}' as a foreign key; please "
733
+ f"select a different key.")
734
+
735
+ self._edges.append(edge)
736
+ return self
737
+
738
+ def unlink(self, *args: Optional[Union[str, Edge]],
739
+ **kwargs: str) -> 'Graph':
740
+ r"""Removes an edge added to a Kumo Graph.
741
+
742
+ Args:
743
+ *args: Any arguments to construct a
744
+ :class:`~kumoai.graph.Edge`, or a :class:`~kumoai.graph.Edge`
745
+ itself.
746
+ **kwargs: Any keyword arguments to construct a
747
+ :class:`~kumoai.graph.Edge`.
748
+
749
+ Raises:
750
+ ValueError: if the edge is not present in the graph.
751
+ """
752
+ edge = Edge._cast(*args, **kwargs)
753
+ if edge not in self._edges:
754
+ raise ValueError(f"Edge {edge} is not present in {self._edges}")
755
+ self._edges.remove(edge)
756
+ return self
757
+
758
+ @property
759
+ def edges(self) -> List[Edge]:
760
+ r"""Returns a list of all :class:`~kumoai.graph.Edge` objects that
761
+ represent links in this graph.
762
+ """
763
+ return self._edges
764
+
765
+ def validate(self, verbose: bool = True) -> Self:
766
+ r"""Validates a Graph to ensure that all relevant metadata is specified
767
+ for its Tables and Edges.
768
+
769
+ Concretely, validation ensures that all tables are valid (see
770
+ :meth:`~kumoai.graph.table.validate` for more information), and that
771
+ edges properly link primary keys and foreign keys between valid tables.
772
+ It additionally ensures that primary and foreign keys between tables
773
+ in an edge are of the same data type, so that unexpected mismatches do
774
+ not occur within the Kumo platform.
775
+
776
+ Example:
777
+ >>> import kumoai
778
+ >>> graph = kumoai.Graph(...) # doctest: +SKIP
779
+ >>> graph.validate() # doctest: +SKIP
780
+ ValidationResponse(warnings=[], errors=[])
781
+
782
+ Args:
783
+ verbose: Whether to log non-error output of this validation.
784
+
785
+ Raises:
786
+ ValueError:
787
+ if validation fails.
788
+ """
789
+ # Validate table definitions, so we can properly create a graph
790
+ # definition:
791
+ for table_name, table in self.tables.items():
792
+ try:
793
+ table.validate(verbose=verbose)
794
+ except ValueError as e:
795
+ raise ValueError(
796
+ f"Validation of table {table_name} failed. {e}") from e
797
+
798
+ resp = global_state.client.graph_api.validate_graph(
799
+ api.GraphValidationRequest(self._to_api_graph_definition()))
800
+ if not resp.ok:
801
+ raise ValueError(resp.error_message())
802
+ if verbose:
803
+ if resp.empty():
804
+ logger.info("Graph is configured correctly.")
805
+ else:
806
+ logger.warning(resp.message())
807
+ return self
808
+
809
+ def visualize(
810
+ self,
811
+ path: Optional[Union[str, io.BytesIO]] = None,
812
+ show_cols: bool = True,
813
+ ) -> 'graphviz.Graph':
814
+ r"""Visualizes the tables and edges in this graph using the
815
+ ``graphviz`` library.
816
+
817
+ Args:
818
+ path: An optional local path to write the produced image to. If
819
+ None, the image will not be written to disk.
820
+ show_cols: Whether to show all columns of every table in the graph.
821
+ If False, will only show the primary key, foreign key(s),
822
+ time column, and end time column of each table.
823
+
824
+ Returns:
825
+ A ``graphviz.Graph`` instance representing the visualized graph.
826
+ """
827
+ def has_graphviz_executables() -> bool:
828
+ import graphviz
829
+ try:
830
+ graphviz.Digraph().pipe()
831
+ except graphviz.backend.ExecutableNotFound:
832
+ return False
833
+
834
+ return True
835
+
836
+ # Check basic dependency:
837
+ if not find_spec('graphviz'):
838
+ raise ModuleNotFoundError(
839
+ "The `graphviz` Python package is required for visualization.")
840
+ elif not has_graphviz_executables():
841
+ raise RuntimeError(
842
+ "Could not visualize graph as `graphviz` executables have not "
843
+ "been installed. These dependencies are required in addition "
844
+ "to the `graphviz` Python package. Please install them to "
845
+ "continue. Instructions at https://graphviz.org/download/.")
846
+ else:
847
+ import graphviz
848
+
849
+ fmt = None
850
+ if isinstance(path, str):
851
+ fmt = path.split('.')[-1]
852
+ elif isinstance(path, io.BytesIO):
853
+ fmt = 'svg'
854
+ graph = graphviz.Graph(format=fmt)
855
+
856
+ def left_align(list_of_text: List[str]) -> str:
857
+ return '\\l'.join(list_of_text) + '\\l'
858
+
859
+ table_to_fkey: Dict[str, List[str]] = {}
860
+ for edge in self.edges:
861
+ src, fkey, dst = edge
862
+ if src not in table_to_fkey:
863
+ table_to_fkey[src] = []
864
+ table_to_fkey[src].append(fkey)
865
+
866
+ for table_name, table in self.tables.items():
867
+ keys = []
868
+ if table.has_primary_key():
869
+ assert table.primary_key is not None
870
+ keys += [f'{table.primary_key.name} (PK)']
871
+ if table_name in table_to_fkey:
872
+ keys += [f'{fkey} (FK)' for fkey in table_to_fkey[table_name]]
873
+ if table.has_time_column():
874
+ assert table.time_column is not None
875
+ keys += [f'{table.time_column.name} (Time)']
876
+ if table.has_end_time_column():
877
+ assert table.end_time_column is not None
878
+ keys += [f'{table.end_time_column.name} (End Time)']
879
+
880
+ keys_aligned = left_align(keys)
881
+
882
+ cols = []
883
+ cols_aligned = ""
884
+ if show_cols and len(table.columns) > 0:
885
+ cols += [
886
+ f'{col.name}: {col.stype or "???"} ({col.dtype or "???"})'
887
+ for col in table.columns
888
+ ]
889
+ cols_aligned = left_align(cols)
890
+
891
+ if cols:
892
+ label = f'{{{table_name}|{keys_aligned}|{cols_aligned}}}'
893
+ else:
894
+ label = f'{{{table_name}|{keys_aligned}}}'
895
+
896
+ graph.node(table_name, shape='record', label=label)
897
+
898
+ for edge in self.edges:
899
+ src, fkey, dst = edge
900
+ pkey_obj = self[dst].primary_key
901
+ assert pkey_obj is not None
902
+ pkey = pkey_obj.name
903
+ # Print both key names if different:
904
+ if fkey != pkey:
905
+ label = f' {fkey}\n< >\n{pkey} '
906
+ else:
907
+ label = f' {fkey} '
908
+ headlabel, taillabel = '1', '*'
909
+ graph.edge(src, dst, label=label, headlabel=headlabel,
910
+ taillabel=taillabel, minlen='2', fontsize='11pt',
911
+ labeldistance='1.5')
912
+
913
+ if isinstance(path, str):
914
+ path = '.'.join(path.split('.')[:-1])
915
+ graph.render(path, cleanup=True)
916
+ elif isinstance(path, io.BytesIO):
917
+ path.write(graph.pipe())
918
+ else:
919
+ try:
920
+ graph.view()
921
+ except Exception as e:
922
+ logger.warning(
923
+ "Could not visualize graph due to an unexpected error in "
924
+ "`graphviz`. If you are in a notebook environment, "
925
+ "consider calling `display()` on the returned object "
926
+ "from `visualize()`. Error: %s", e)
927
+ return graph
928
+
929
+ # Class properties ########################################################
930
+
931
+ def __hash__(self) -> int:
932
+ return hash((tuple(self.edges), self.tables.values()))
933
+
934
+ def __contains__(self, name: str) -> bool:
935
+ return self.has_table(name)
936
+
937
+ def __getitem__(self, name: str) -> Table:
938
+ return self.table(name)
939
+
940
+ def __delitem__(self, name: str) -> None:
941
+ self.remove_table(name)
942
+
943
+ def __repr__(self) -> str:
944
+ table_names = str(list(self._tables.keys())).replace("'", "")
945
+ return (f'{self.__class__.__name__}(\n'
946
+ f' tables={table_names},\n'
947
+ f' edges={self._edges},\n'
948
+ f')')