kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +177 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/sampler.py +782 -0
  23. kumoai/experimental/rfm/base/source.py +2 -1
  24. kumoai/experimental/rfm/base/sql_sampler.py +247 -0
  25. kumoai/experimental/rfm/base/table.py +404 -203
  26. kumoai/experimental/rfm/graph.py +374 -172
  27. kumoai/experimental/rfm/infer/__init__.py +6 -4
  28. kumoai/experimental/rfm/infer/dtype.py +7 -4
  29. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  30. kumoai/experimental/rfm/infer/pkey.py +4 -2
  31. kumoai/experimental/rfm/infer/stype.py +35 -0
  32. kumoai/experimental/rfm/infer/time_col.py +1 -2
  33. kumoai/experimental/rfm/pquery/executor.py +27 -27
  34. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  35. kumoai/experimental/rfm/relbench.py +76 -0
  36. kumoai/experimental/rfm/rfm.py +762 -467
  37. kumoai/experimental/rfm/sagemaker.py +4 -4
  38. kumoai/experimental/rfm/task_table.py +292 -0
  39. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  40. kumoai/pquery/predictive_query.py +10 -6
  41. kumoai/pquery/training_table.py +16 -2
  42. kumoai/testing/snow.py +50 -0
  43. kumoai/trainer/distilled_trainer.py +175 -0
  44. kumoai/utils/__init__.py +3 -2
  45. kumoai/utils/display.py +87 -0
  46. kumoai/utils/progress_logger.py +190 -12
  47. kumoai/utils/sql.py +3 -0
  48. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
  49. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
  50. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  51. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  52. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,19 @@
1
1
  from .dtype import infer_dtype
2
- from .pkey import infer_primary_key
3
- from .time_col import infer_time_column
4
2
  from .id import contains_id
5
3
  from .timestamp import contains_timestamp
6
4
  from .categorical import contains_categorical
7
5
  from .multicategorical import contains_multicategorical
6
+ from .stype import infer_stype
7
+ from .pkey import infer_primary_key
8
+ from .time_col import infer_time_column
8
9
 
9
10
  __all__ = [
10
11
  'infer_dtype',
11
- 'infer_primary_key',
12
- 'infer_time_column',
13
12
  'contains_id',
14
13
  'contains_timestamp',
15
14
  'contains_categorical',
16
15
  'contains_multicategorical',
16
+ 'infer_stype',
17
+ 'infer_primary_key',
18
+ 'infer_time_column',
17
19
  ]
@@ -1,17 +1,17 @@
1
- from typing import Dict
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  import pyarrow as pa
6
4
  from kumoapi.typing import Dtype
7
5
 
8
- PANDAS_TO_DTYPE: Dict[str, Dtype] = {
6
+ PANDAS_TO_DTYPE: dict[str, Dtype] = {
9
7
  'bool': Dtype.bool,
10
8
  'boolean': Dtype.bool,
11
9
  'int8': Dtype.int,
12
10
  'int16': Dtype.int,
13
11
  'int32': Dtype.int,
14
12
  'int64': Dtype.int,
13
+ 'float': Dtype.float,
14
+ 'double': Dtype.float,
15
15
  'float16': Dtype.float,
16
16
  'float32': Dtype.float,
17
17
  'float64': Dtype.float,
@@ -20,6 +20,8 @@ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
20
20
  'string[python]': Dtype.string,
21
21
  'string[pyarrow]': Dtype.string,
22
22
  'binary': Dtype.binary,
23
+ 'binary[python]': Dtype.binary,
24
+ 'binary[pyarrow]': Dtype.binary,
23
25
  }
24
26
 
25
27
 
@@ -50,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
50
52
  ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
53
 
52
54
  if isinstance(ser.dtype, pd.ArrowDtype):
53
- if pa.types.is_list(ser.dtype.pyarrow_dtype):
55
+ if (pa.types.is_list(ser.dtype.pyarrow_dtype)
56
+ or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
54
57
  elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
58
  if pa.types.is_integer(elem_dtype):
56
59
  return Dtype.intlist
@@ -40,7 +40,7 @@ def contains_multicategorical(
40
40
  sep = max(candidates, key=candidates.get) # type: ignore
41
41
  ser = ser.str.split(sep)
42
42
 
43
- num_unique_multi = ser.explode().nunique()
43
+ num_unique_multi = ser.astype('object').explode().nunique()
44
44
 
45
45
  if dtype.is_list():
46
46
  return num_unique_multi <= MAX_CAT
@@ -1,6 +1,5 @@
1
1
  import re
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  import pandas as pd
6
5
 
@@ -9,7 +8,7 @@ def infer_primary_key(
9
8
  table_name: str,
10
9
  df: pd.DataFrame,
11
10
  candidates: list[str],
12
- ) -> Optional[str]:
11
+ ) -> str | None:
13
12
  r"""Auto-detect potential primary key column.
14
13
 
15
14
  Args:
@@ -20,6 +19,9 @@ def infer_primary_key(
20
19
  Returns:
21
20
  The name of the detected primary key, or ``None`` if not found.
22
21
  """
22
+ if len(candidates) == 0:
23
+ return None
24
+
23
25
  # A list of (potentially modified) table names that are eligible to match
24
26
  # with a primary key, i.e.:
25
27
  # - UserInfo -> User
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from kumoapi.typing import Dtype, Stype
3
+
4
+ from kumoai.experimental.rfm.infer import (
5
+ contains_categorical,
6
+ contains_id,
7
+ contains_multicategorical,
8
+ contains_timestamp,
9
+ )
10
+
11
+
12
+ def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
13
+ """Infers the :class:`Stype` from a :class:`pandas.Series`.
14
+
15
+ Args:
16
+ ser: A :class:`pandas.Series` to analyze.
17
+ column_name: The column name.
18
+ dtype: The data type.
19
+
20
+ Returns:
21
+ The semantic type.
22
+ """
23
+ if contains_id(ser, column_name, dtype):
24
+ return Stype.ID
25
+
26
+ if contains_timestamp(ser, column_name, dtype):
27
+ return Stype.timestamp
28
+
29
+ if contains_multicategorical(ser, column_name, dtype):
30
+ return Stype.multicategorical
31
+
32
+ if contains_categorical(ser, column_name, dtype):
33
+ return Stype.categorical
34
+
35
+ return dtype.default_stype
@@ -1,6 +1,5 @@
1
1
  import re
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  import pandas as pd
6
5
 
@@ -8,7 +7,7 @@ import pandas as pd
8
7
  def infer_time_column(
9
8
  df: pd.DataFrame,
10
9
  candidates: list[str],
11
- ) -> Optional[str]:
10
+ ) -> str | None:
12
11
  r"""Auto-detect potential time column.
13
12
 
14
13
  Args:
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Generic, Tuple, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  from kumoapi.pquery import ValidatedPredictiveQuery
5
5
  from kumoapi.pquery.AST import (
@@ -21,82 +21,82 @@ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
21
21
  def execute_column(
22
22
  self,
23
23
  column: Column,
24
- feat_dict: Dict[str, TableData],
24
+ feat_dict: dict[str, TableData],
25
25
  filter_na: bool = True,
26
- ) -> Tuple[ColumnData, IndexData]:
26
+ ) -> tuple[ColumnData, IndexData]:
27
27
  pass
28
28
 
29
29
  @abstractmethod
30
30
  def execute_aggregation(
31
31
  self,
32
32
  aggr: Aggregation,
33
- feat_dict: Dict[str, TableData],
34
- time_dict: Dict[str, ColumnData],
35
- batch_dict: Dict[str, IndexData],
33
+ feat_dict: dict[str, TableData],
34
+ time_dict: dict[str, ColumnData],
35
+ batch_dict: dict[str, IndexData],
36
36
  anchor_time: ColumnData,
37
37
  filter_na: bool = True,
38
38
  num_forecasts: int = 1,
39
- ) -> Tuple[ColumnData, IndexData]:
39
+ ) -> tuple[ColumnData, IndexData]:
40
40
  pass
41
41
 
42
42
  @abstractmethod
43
43
  def execute_condition(
44
44
  self,
45
45
  condition: Condition,
46
- feat_dict: Dict[str, TableData],
47
- time_dict: Dict[str, ColumnData],
48
- batch_dict: Dict[str, IndexData],
46
+ feat_dict: dict[str, TableData],
47
+ time_dict: dict[str, ColumnData],
48
+ batch_dict: dict[str, IndexData],
49
49
  anchor_time: ColumnData,
50
50
  filter_na: bool = True,
51
51
  num_forecasts: int = 1,
52
- ) -> Tuple[ColumnData, IndexData]:
52
+ ) -> tuple[ColumnData, IndexData]:
53
53
  pass
54
54
 
55
55
  @abstractmethod
56
56
  def execute_logical_operation(
57
57
  self,
58
58
  logical_operation: LogicalOperation,
59
- feat_dict: Dict[str, TableData],
60
- time_dict: Dict[str, ColumnData],
61
- batch_dict: Dict[str, IndexData],
59
+ feat_dict: dict[str, TableData],
60
+ time_dict: dict[str, ColumnData],
61
+ batch_dict: dict[str, IndexData],
62
62
  anchor_time: ColumnData,
63
63
  filter_na: bool = True,
64
64
  num_forecasts: int = 1,
65
- ) -> Tuple[ColumnData, IndexData]:
65
+ ) -> tuple[ColumnData, IndexData]:
66
66
  pass
67
67
 
68
68
  @abstractmethod
69
69
  def execute_join(
70
70
  self,
71
71
  join: Join,
72
- feat_dict: Dict[str, TableData],
73
- time_dict: Dict[str, ColumnData],
74
- batch_dict: Dict[str, IndexData],
72
+ feat_dict: dict[str, TableData],
73
+ time_dict: dict[str, ColumnData],
74
+ batch_dict: dict[str, IndexData],
75
75
  anchor_time: ColumnData,
76
76
  filter_na: bool = True,
77
77
  num_forecasts: int = 1,
78
- ) -> Tuple[ColumnData, IndexData]:
78
+ ) -> tuple[ColumnData, IndexData]:
79
79
  pass
80
80
 
81
81
  @abstractmethod
82
82
  def execute_filter(
83
83
  self,
84
84
  filter: Filter,
85
- feat_dict: Dict[str, TableData],
86
- time_dict: Dict[str, ColumnData],
87
- batch_dict: Dict[str, IndexData],
85
+ feat_dict: dict[str, TableData],
86
+ time_dict: dict[str, ColumnData],
87
+ batch_dict: dict[str, IndexData],
88
88
  anchor_time: ColumnData,
89
- ) -> Tuple[ColumnData, IndexData]:
89
+ ) -> tuple[ColumnData, IndexData]:
90
90
  pass
91
91
 
92
92
  @abstractmethod
93
93
  def execute(
94
94
  self,
95
95
  query: ValidatedPredictiveQuery,
96
- feat_dict: Dict[str, TableData],
97
- time_dict: Dict[str, ColumnData],
98
- batch_dict: Dict[str, IndexData],
96
+ feat_dict: dict[str, TableData],
97
+ time_dict: dict[str, ColumnData],
98
+ batch_dict: dict[str, IndexData],
99
99
  anchor_time: ColumnData,
100
100
  num_forecasts: int = 1,
101
- ) -> Tuple[ColumnData, IndexData]:
101
+ ) -> tuple[ColumnData, IndexData]:
102
102
  pass
@@ -1,5 +1,3 @@
1
- from typing import Dict, List, Tuple
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  from kumoapi.pquery import ValidatedPredictiveQuery
@@ -22,9 +20,9 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
22
20
  def execute_column(
23
21
  self,
24
22
  column: Column,
25
- feat_dict: Dict[str, pd.DataFrame],
23
+ feat_dict: dict[str, pd.DataFrame],
26
24
  filter_na: bool = True,
27
- ) -> Tuple[pd.Series, np.ndarray]:
25
+ ) -> tuple[pd.Series, np.ndarray]:
28
26
  table_name, column_name = column.fqn.split(".")
29
27
  if column_name == '*':
30
28
  out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
@@ -60,7 +58,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
60
58
  batch: np.ndarray,
61
59
  batch_size: int,
62
60
  filter_na: bool = True,
63
- ) -> Tuple[pd.Series, np.ndarray]:
61
+ ) -> tuple[pd.Series, np.ndarray]:
64
62
 
65
63
  mask = feat.notna()
66
64
  feat, batch = feat[mask], batch[mask]
@@ -104,13 +102,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
104
102
  def execute_aggregation(
105
103
  self,
106
104
  aggr: Aggregation,
107
- feat_dict: Dict[str, pd.DataFrame],
108
- time_dict: Dict[str, pd.Series],
109
- batch_dict: Dict[str, np.ndarray],
105
+ feat_dict: dict[str, pd.DataFrame],
106
+ time_dict: dict[str, pd.Series],
107
+ batch_dict: dict[str, np.ndarray],
110
108
  anchor_time: pd.Series,
111
109
  filter_na: bool = True,
112
110
  num_forecasts: int = 1,
113
- ) -> Tuple[pd.Series, np.ndarray]:
111
+ ) -> tuple[pd.Series, np.ndarray]:
114
112
  target_table = aggr._get_target_column_name().split('.')[0]
115
113
  target_batch = batch_dict[target_table]
116
114
  target_time = time_dict[target_table]
@@ -131,10 +129,10 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
131
129
  filter_na=True,
132
130
  )
133
131
 
134
- outs: List[pd.Series] = []
135
- masks: List[np.ndarray] = []
132
+ outs: list[pd.Series] = []
133
+ masks: list[np.ndarray] = []
136
134
  for _ in range(num_forecasts):
137
- anchor_target_time = anchor_time[target_batch]
135
+ anchor_target_time = anchor_time.iloc[target_batch]
138
136
  anchor_target_time = anchor_target_time.reset_index(drop=True)
139
137
 
140
138
  time_filter_mask = (target_time <= anchor_target_time +
@@ -226,13 +224,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
226
224
  def execute_condition(
227
225
  self,
228
226
  condition: Condition,
229
- feat_dict: Dict[str, pd.DataFrame],
230
- time_dict: Dict[str, pd.Series],
231
- batch_dict: Dict[str, np.ndarray],
227
+ feat_dict: dict[str, pd.DataFrame],
228
+ time_dict: dict[str, pd.Series],
229
+ batch_dict: dict[str, np.ndarray],
232
230
  anchor_time: pd.Series,
233
231
  filter_na: bool = True,
234
232
  num_forecasts: int = 1,
235
- ) -> Tuple[pd.Series, np.ndarray]:
233
+ ) -> tuple[pd.Series, np.ndarray]:
236
234
  if num_forecasts > 1:
237
235
  raise NotImplementedError("Forecasting not yet implemented for "
238
236
  "non-regression tasks")
@@ -306,13 +304,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
306
304
  def execute_logical_operation(
307
305
  self,
308
306
  logical_operation: LogicalOperation,
309
- feat_dict: Dict[str, pd.DataFrame],
310
- time_dict: Dict[str, pd.Series],
311
- batch_dict: Dict[str, np.ndarray],
307
+ feat_dict: dict[str, pd.DataFrame],
308
+ time_dict: dict[str, pd.Series],
309
+ batch_dict: dict[str, np.ndarray],
312
310
  anchor_time: pd.Series,
313
311
  filter_na: bool = True,
314
312
  num_forecasts: int = 1,
315
- ) -> Tuple[pd.Series, np.ndarray]:
313
+ ) -> tuple[pd.Series, np.ndarray]:
316
314
  if num_forecasts > 1:
317
315
  raise NotImplementedError("Forecasting not yet implemented for "
318
316
  "non-regression tasks")
@@ -370,13 +368,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
370
368
  def execute_join(
371
369
  self,
372
370
  join: Join,
373
- feat_dict: Dict[str, pd.DataFrame],
374
- time_dict: Dict[str, pd.Series],
375
- batch_dict: Dict[str, np.ndarray],
371
+ feat_dict: dict[str, pd.DataFrame],
372
+ time_dict: dict[str, pd.Series],
373
+ batch_dict: dict[str, np.ndarray],
376
374
  anchor_time: pd.Series,
377
375
  filter_na: bool = True,
378
376
  num_forecasts: int = 1,
379
- ) -> Tuple[pd.Series, np.ndarray]:
377
+ ) -> tuple[pd.Series, np.ndarray]:
380
378
  if isinstance(join.rhs_target, Aggregation):
381
379
  return self.execute_aggregation(
382
380
  aggr=join.rhs_target,
@@ -393,12 +391,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
393
391
  def execute_filter(
394
392
  self,
395
393
  filter: Filter,
396
- feat_dict: Dict[str, pd.DataFrame],
397
- time_dict: Dict[str, pd.Series],
398
- batch_dict: Dict[str, np.ndarray],
394
+ feat_dict: dict[str, pd.DataFrame],
395
+ time_dict: dict[str, pd.Series],
396
+ batch_dict: dict[str, np.ndarray],
399
397
  anchor_time: pd.Series,
400
398
  filter_na: bool = True,
401
- ) -> Tuple[pd.Series, np.ndarray]:
399
+ ) -> tuple[pd.Series, np.ndarray]:
402
400
  out, mask = self.execute_column(
403
401
  column=filter.target,
404
402
  feat_dict=feat_dict,
@@ -431,12 +429,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
431
429
  def execute(
432
430
  self,
433
431
  query: ValidatedPredictiveQuery,
434
- feat_dict: Dict[str, pd.DataFrame],
435
- time_dict: Dict[str, pd.Series],
436
- batch_dict: Dict[str, np.ndarray],
432
+ feat_dict: dict[str, pd.DataFrame],
433
+ time_dict: dict[str, pd.Series],
434
+ batch_dict: dict[str, np.ndarray],
437
435
  anchor_time: pd.Series,
438
436
  num_forecasts: int = 1,
439
- ) -> Tuple[pd.Series, np.ndarray]:
437
+ ) -> tuple[pd.Series, np.ndarray]:
440
438
  if isinstance(query.entity_ast, Column):
441
439
  out, mask = self.execute_column(
442
440
  column=query.entity_ast,
@@ -0,0 +1,76 @@
1
+ import difflib
2
+ import json
3
+ from functools import lru_cache
4
+ from urllib.request import urlopen
5
+
6
+ import pooch
7
+ import pyarrow as pa
8
+
9
+ from kumoai.experimental.rfm import Graph
10
+ from kumoai.experimental.rfm.backend.local import LocalTable
11
+
12
+ PREFIX = 'rel-'
13
+ CACHE_DIR = pooch.os_cache('relbench')
14
+ HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
15
+ 'relbench/datasets/hashes.json')
16
+
17
+
18
+ @lru_cache
19
+ def get_registry() -> pooch.Pooch:
20
+ with urlopen(HASH_URL) as r:
21
+ hashes = json.load(r)
22
+
23
+ return pooch.create(
24
+ path=CACHE_DIR,
25
+ base_url='https://relbench.stanford.edu/download/',
26
+ registry=hashes,
27
+ )
28
+
29
+
30
+ def from_relbench(dataset: str, verbose: bool = True) -> Graph:
31
+ dataset = dataset.lower()
32
+ if dataset.startswith(PREFIX):
33
+ dataset = dataset[len(PREFIX):]
34
+
35
+ registry = get_registry()
36
+
37
+ datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
38
+ if dataset not in datasets:
39
+ matches = difflib.get_close_matches(dataset, datasets, n=1)
40
+ hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
41
+ raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
42
+ f"datasets are {str(datasets)[1:-1]}.")
43
+
44
+ registry.fetch(
45
+ f'{PREFIX}{dataset}/db.zip',
46
+ processor=pooch.Unzip(extract_dir='.'),
47
+ progressbar=verbose,
48
+ )
49
+
50
+ graph = Graph(tables=[])
51
+ edges: list[tuple[str, str, str]] = []
52
+ for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
53
+ data = pa.parquet.read_table(path)
54
+ metadata = {
55
+ key.decode('utf-8'): json.loads(value.decode('utf-8'))
56
+ for key, value in data.schema.metadata.items()
57
+ if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
58
+ }
59
+
60
+ table = LocalTable(
61
+ df=data.to_pandas(),
62
+ name=path.stem,
63
+ primary_key=metadata['pkey_col'],
64
+ time_column=metadata['time_col'],
65
+ )
66
+ graph.add_table(table)
67
+
68
+ edges.extend([
69
+ (path.stem, fkey, dst_table)
70
+ for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
71
+ ])
72
+
73
+ for edge in edges:
74
+ graph.link(*edge)
75
+
76
+ return graph