kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,6 +59,17 @@ class Sampler(ABC):
59
59
  self._edge_types.append(edge_type)
60
60
  self._edge_types.append(Subgraph.rev_edge_type(edge_type))
61
61
 
62
+ # Source Table -> [(Foreign Key, Destination Table)]
63
+ self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
64
+ # Destination Table -> [(Source Table, Foreign Key)]
65
+ self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
66
+ for table in graph.tables.values():
67
+ self._foreign_key_dict[table.name] = []
68
+ self._rev_foreign_key_dict[table.name] = []
69
+ for src_table, fkey, dst_table in graph.edges:
70
+ self._foreign_key_dict[src_table].append((fkey, dst_table))
71
+ self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
72
+
62
73
  self._primary_key_dict: dict[str, str] = {
63
74
  table.name: table._primary_key
64
75
  for table in graph.tables.values()
@@ -98,6 +109,16 @@ class Sampler(ABC):
98
109
  r"""All available edge types in the graph."""
99
110
  return self._edge_types
100
111
 
112
+ @property
113
+ def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
114
+ r"""The foreign keys for all tables in the graph."""
115
+ return self._foreign_key_dict
116
+
117
+ @property
118
+ def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
119
+ r"""The foreign key back references for all tables in the graph."""
120
+ return self._rev_foreign_key_dict
121
+
101
122
  @property
102
123
  def primary_key_dict(self) -> dict[str, str]:
103
124
  r"""All available primary keys in the graph."""
@@ -1,8 +1,10 @@
1
1
  from abc import abstractmethod
2
+ from collections import defaultdict
2
3
  from typing import TYPE_CHECKING, Literal
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from kumoapi.rfm.context import Subgraph
6
8
  from kumoapi.typing import Dtype
7
9
 
8
10
  from kumoai.experimental.rfm.base import (
@@ -11,11 +13,14 @@ from kumoai.experimental.rfm.base import (
11
13
  SamplerOutput,
12
14
  SourceColumn,
13
15
  )
16
+ from kumoai.experimental.rfm.base.mapper import Mapper
14
17
  from kumoai.utils import ProgressLogger, quote_ident
15
18
 
16
19
  if TYPE_CHECKING:
17
20
  from kumoai.experimental.rfm import Graph
18
21
 
22
+ EdgeType = tuple[str, str, str]
23
+
19
24
 
20
25
  class SQLSampler(Sampler):
21
26
  def __init__(
@@ -100,10 +105,28 @@ class SQLSampler(Sampler):
100
105
  num_neighbors: list[int],
101
106
  ) -> SamplerOutput:
102
107
 
108
+ # Make sure to always include primary key, foreign key and time columns
109
+ # during data fetching since these are needed for graph traversal:
110
+ sample_columns_dict: dict[str, set[str]] = {}
111
+ for table, columns in columns_dict.items():
112
+ sample_columns = columns | {
113
+ foreign_key
114
+ for foreign_key, _ in self.foreign_key_dict[table]
115
+ }
116
+ if primary_key := self.primary_key_dict.get(table):
117
+ sample_columns |= {primary_key}
118
+ sample_columns_dict[table] = sample_columns
119
+ if not isinstance(anchor_time, pd.Series):
120
+ sample_columns_dict[entity_table_name] |= {
121
+ self.time_column_dict[entity_table_name]
122
+ }
123
+
124
+ # Sample Entity Table #################################################
125
+
103
126
  df, batch = self._by_pkey(
104
127
  table_name=entity_table_name,
105
- pkey=entity_pkey,
106
- columns=columns_dict[entity_table_name],
128
+ index=entity_pkey,
129
+ columns=sample_columns_dict[entity_table_name],
107
130
  )
108
131
  if len(batch) != len(entity_pkey):
109
132
  mask = np.ones(len(entity_pkey), dtype=bool)
@@ -112,23 +135,211 @@ class SQLSampler(Sampler):
112
135
  f"{entity_pkey.iloc[mask].tolist()} do not exist "
113
136
  f"in the '{entity_table_name}' table")
114
137
 
138
+ # Make sure that entities are returned in expected order:
115
139
  perm = batch.argsort()
116
140
  batch = batch[perm]
117
141
  df = df.iloc[perm].reset_index(drop=True)
118
142
 
143
+ # Fill 'entity' anchor times with actual values:
119
144
  if not isinstance(anchor_time, pd.Series):
120
145
  time_column = self.time_column_dict[entity_table_name]
121
146
  anchor_time = df[time_column]
147
+ assert isinstance(anchor_time, pd.Series)
148
+
149
+ # Recursive Neighbor Sampling #########################################
150
+
151
+ mapper_dict: dict[str, Mapper] = defaultdict(
152
+ lambda: Mapper(num_examples=len(entity_pkey)))
153
+ mapper_dict[entity_table_name].add(
154
+ pkey=df[self.primary_key_dict[entity_table_name]],
155
+ batch=batch,
156
+ )
157
+
158
+ dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
159
+ dfs_dict[entity_table_name].append(df)
160
+ batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
161
+ batches_dict[entity_table_name].append(batch)
162
+ num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
163
+ lambda: [0] * (len(num_neighbors) + 1))
164
+ num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
165
+
166
+ rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
167
+ cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
168
+ num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
169
+ lambda: [0] * len(num_neighbors))
170
+
171
+ # The start index of data frame slices of the previous hop:
172
+ offset_dict: dict[str, int] = defaultdict(int)
173
+
174
+ for hop, neighbors in enumerate(num_neighbors):
175
+ if neighbors == 0:
176
+ break # Abort early.
177
+
178
+ for table in list(num_sampled_nodes_dict.keys()):
179
+ # Only sample from tables that have been visited in the
180
+ # previous hop:
181
+ if num_sampled_nodes_dict[table][hop] == 0:
182
+ continue
183
+
184
+ # Collect the slices of data sampled in the previous hop
185
+ # (but maintain only required key information):
186
+ cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
187
+ if table in self.primary_key_dict:
188
+ cols.append(self.primary_key_dict[table])
189
+ dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
190
+ df = pd.concat(
191
+ dfs,
192
+ axis=0,
193
+ ignore_index=True,
194
+ ) if len(dfs) > 1 else dfs[0]
195
+ batches = batches_dict[table][offset_dict[table]:]
196
+ batch = (np.concatenate(batches)
197
+ if len(batches) > 1 else batches[0])
198
+ offset_dict[table] = len(batches_dict[table]) # Increase.
199
+
200
+ pkey: pd.Series | None = None
201
+ index: pd.ndarray | None = None
202
+ if table in self.primary_key_dict:
203
+ pkey = df[self.primary_key_dict[table]]
204
+ index = mapper_dict[table].get(pkey, batch)
205
+
206
+ # Iterate over foreign keys in the current table:
207
+ for fkey, dst_table in self.foreign_key_dict[table]:
208
+ row = mapper_dict[dst_table].get(df[fkey], batch)
209
+ mask = row == -1
210
+ if mask.any():
211
+ key_df = pd.DataFrame({
212
+ 'fkey': df[fkey],
213
+ 'batch': batch,
214
+ }).iloc[mask]
215
+ # Only maintain unique keys per example:
216
+ unique_key_df = key_df.drop_duplicates()
217
+ # Fully de-duplicate keys across examples:
218
+ code, fkey_index = pd.factorize(unique_key_df['fkey'])
219
+
220
+ _df, _batch = self._by_pkey(
221
+ table_name=dst_table,
222
+ index=fkey_index,
223
+ columns=sample_columns_dict[dst_table],
224
+ ) # Ensure result is sorted according to input order:
225
+ _df = _df.iloc[_batch.argsort()]
226
+
227
+ # Compute valid entries (without dangling foreign keys)
228
+ # in `unique_fkey_df`:
229
+ _mask = np.full(len(fkey_index), fill_value=False)
230
+ _mask[_batch] = True
231
+ _mask = _mask[code]
232
+
233
+ # Recontruct unique (key, batch) pairs:
234
+ code, _ = pd.factorize(unique_key_df['fkey'][_mask])
235
+ _df = _df.iloc[code].reset_index(drop=True)
236
+ _batch = unique_key_df['batch'].to_numpy()[_mask]
237
+
238
+ # Register node IDs:
239
+ mapper_dict[dst_table].add(_df[fkey], _batch)
240
+ row[mask] = mapper_dict[dst_table].get(
241
+ pkey=key_df['fkey'],
242
+ batch=key_df['batch'].to_numpy(),
243
+ ) # NOTE `row` may still hold `-1` for dangling fkeys.
244
+
245
+ dfs_dict[dst_table].append(_df)
246
+ batches_dict[dst_table].append(_batch)
247
+ num_sampled_nodes_dict[dst_table][hop + 1] += ( #
248
+ len(_batch))
249
+
250
+ mask = row != -1
251
+
252
+ col = index
253
+ if col is None:
254
+ num_nodes = num_sampled_nodes_dict[table][hop]
255
+ col = np.arange(num_nodes, num_nodes + len(row))
256
+
257
+ row = row[mask]
258
+ col = col[mask]
259
+
260
+ edge_type = (table, fkey, dst_table)
261
+ edge_type = Subgraph.rev_edge_type(edge_type)
262
+ rows_dict[edge_type].append(row)
263
+ cols_dict[edge_type].append(col)
264
+ num_sampled_edges_dict[edge_type][hop] = len(col)
265
+
266
+ # Iterate over foreign keys that reference the current table:
267
+ for src_table, fkey in self.rev_foreign_key_dict[table]:
268
+ assert pkey is not None and index is not None
269
+ _df, _batch = self._by_fkey(
270
+ table_name=src_table,
271
+ foreign_key=fkey,
272
+ index=pkey,
273
+ num_neighbors=neighbors,
274
+ anchor_time=anchor_time.iloc[batch],
275
+ columns=sample_columns_dict[src_table],
276
+ )
277
+
278
+ edge_type = (src_table, fkey, table)
279
+ cols_dict[edge_type].append(index[_batch])
280
+ num_sampled_edges_dict[edge_type][hop] = len(_batch)
281
+
282
+ _batch = batch[_batch]
283
+ num_nodes = sum(num_sampled_nodes_dict[src_table])
284
+ if src_table in self.primary_key_dict:
285
+ _pkey = _df[self.primary_key_dict]
286
+ mapper_dict[src_table].add(_pkey, _batch)
287
+ row = mapper_dict[src_table].get(_pkey, _batch)
288
+
289
+ # Only preserve unknown rows:
290
+ mask = row >= num_nodes # type: ignore
291
+ mask[pd.duplicated(row)] = False
292
+ _df = _df.iloc[mask]
293
+ _batch = _batch[mask]
294
+ else:
295
+ row = np.arange(num_nodes, num_nodes + len(_batch))
296
+
297
+ rows_dict[edge_type].append(row)
298
+ num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
299
+
300
+ dfs_dict[src_table].append(_df)
301
+ batches_dict[src_table].append(_batch)
302
+
303
+ # Post-Processing #####################################################
304
+
305
+ df_dict = {
306
+ table:
307
+ pd.concat(dfs, axis=0, ignore_index=True)
308
+ if len(dfs) > 1 else dfs[0]
309
+ for table, dfs in dfs_dict.items()
310
+ }
311
+ df_dict = { # Post-filter column set:
312
+ table: df[list(columns_dict[table])]
313
+ for table, df in df_dict.items()
314
+ }
315
+ batch_dict = {
316
+ table: np.concatenate(batches) if len(batches) > 1 else batches[0]
317
+ for table, batches in batches_dict.items()
318
+ }
319
+ row_dict = {
320
+ edge_type: np.concatenate(rows)
321
+ for edge_type, rows in rows_dict.items()
322
+ }
323
+ col_dict = {
324
+ edge_type: np.concatenate(cols)
325
+ for edge_type, cols in cols_dict.items()
326
+ }
327
+
328
+ if len(num_sampled_edges_dict) == 0: # Single table:
329
+ num_sampled_nodes_dict = {
330
+ key: value[:1]
331
+ for key, value in num_sampled_nodes_dict.items()
332
+ }
122
333
 
123
334
  return SamplerOutput(
124
335
  anchor_time=anchor_time.astype(int).to_numpy(),
125
- df_dict={entity_table_name: df},
126
- inverse_dict={},
127
- batch_dict={entity_table_name: batch},
128
- num_sampled_nodes_dict={entity_table_name: [len(batch)]},
129
- row_dict={},
130
- col_dict={},
131
- num_sampled_edges_dict={},
336
+ df_dict=df_dict,
337
+ inverse_dict={}, # TODO
338
+ batch_dict=batch_dict,
339
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
340
+ row_dict=row_dict,
341
+ col_dict=col_dict,
342
+ num_sampled_edges_dict=num_sampled_edges_dict,
132
343
  )
133
344
 
134
345
  # Abstract Methods ########################################################
@@ -137,7 +348,19 @@ class SQLSampler(Sampler):
137
348
  def _by_pkey(
138
349
  self,
139
350
  table_name: str,
140
- pkey: pd.Series,
351
+ index: pd.Series,
352
+ columns: set[str],
353
+ ) -> tuple[pd.DataFrame, np.ndarray]:
354
+ pass
355
+
356
+ @abstractmethod
357
+ def _by_fkey(
358
+ self,
359
+ table_name: str,
360
+ foreign_key: str,
361
+ index: pd.Series,
362
+ num_neighbors: int,
363
+ anchor_time: pd.Series | None,
141
364
  columns: set[str],
142
365
  ) -> tuple[pd.DataFrame, np.ndarray]:
143
366
  pass
@@ -5,6 +5,7 @@ from functools import cached_property
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ import pyarrow as pa
8
9
  from kumoapi.model_plan import MissingType
9
10
  from kumoapi.source_table import UnavailableSourceTable
10
11
  from kumoapi.table import Column as ColumnDefinition
@@ -12,7 +13,6 @@ from kumoapi.table import TableDefinition
12
13
  from kumoapi.typing import Dtype, Stype
13
14
  from typing_extensions import Self
14
15
 
15
- from kumoai import in_notebook, in_snowflake_notebook
16
16
  from kumoai.experimental.rfm.base import (
17
17
  Column,
18
18
  ColumnSpec,
@@ -27,7 +27,7 @@ from kumoai.experimental.rfm.infer import (
27
27
  infer_stype,
28
28
  infer_time_column,
29
29
  )
30
- from kumoai.utils import quote_ident
30
+ from kumoai.utils import display, quote_ident
31
31
 
32
32
 
33
33
  class Table(ABC):
@@ -196,7 +196,7 @@ class Table(ABC):
196
196
  raise RuntimeError(
197
197
  f"Encountered unsupported data type '{ser.dtype}' for "
198
198
  f"column '{column_spec.name}' in table '{self.name}'. "
199
- f"Please either manually specify the columns's data "
199
+ f"Please either manually override the columns's data "
200
200
  f"type or remove the column from this table.") from e
201
201
 
202
202
  if stype is None:
@@ -272,8 +272,8 @@ class Table(ABC):
272
272
  no such primary key is present.
273
273
 
274
274
  The setter sets a column as a primary key on this table, and raises a
275
- :class:`ValueError` if the primary key has a non-ID semantic type or
276
- if the column name does not match a column in the data frame.
275
+ :class:`ValueError` if the primary key has a non-ID compatible data
276
+ type or if the column name does not match a column in the data frame.
277
277
  """
278
278
  if self._primary_key is None:
279
279
  return None
@@ -317,8 +317,9 @@ class Table(ABC):
317
317
  such time column is present.
318
318
 
319
319
  The setter sets a column as a time column on this table, and raises a
320
- :class:`ValueError` if the time column has a non-timestamp semantic
321
- type or if the column name does not match a column in the data frame.
320
+ :class:`ValueError` if the time column has a non-timestamp compatible
321
+ data type or if the column name does not match a column in the data
322
+ frame.
322
323
  """
323
324
  if self._time_column is None:
324
325
  return None
@@ -363,8 +364,8 @@ class Table(ABC):
363
364
 
364
365
  The setter sets a column as an end time column on this table, and
365
366
  raises a :class:`ValueError` if the end time column has a non-timestamp
366
- semantic type or if the column name does not match a column in the data
367
- frame.
367
+ compatible data type or if the column name does not match a column in
368
+ the data frame.
368
369
  """
369
370
  if self._end_time_column is None:
370
371
  return None
@@ -399,39 +400,39 @@ class Table(ABC):
399
400
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
400
401
  information about the columns in this table.
401
402
 
402
- The returned dataframe has columns ``name``, ``dtype``, ``stype``,
403
- ``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
404
- which provide an aggregate view of the properties of the columns of
405
- this table.
403
+ The returned dataframe has columns ``"Name"``, ``"Data Type"``,
404
+ ``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
405
+ ``"End Time Column"``, which provide an aggregated view of the
406
+ properties of the columns of this table.
406
407
 
407
408
  Example:
408
409
  >>> # doctest: +SKIP
409
410
  >>> import kumoai.experimental.rfm as rfm
410
411
  >>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
411
412
  >>> table.metadata
412
- name dtype stype is_primary_key is_time_column is_end_time_column
413
- 0 CustomerID float64 ID True False False
413
+ Name Data Type Semantic Type Primary Key Time Column End Time Column
414
+ 0 CustomerID float64 ID True False False
414
415
  """ # noqa: E501
415
416
  cols = self.columns
416
417
 
417
418
  return pd.DataFrame({
418
- 'name':
419
+ 'Name':
419
420
  pd.Series(dtype=str, data=[c.name for c in cols]),
420
- 'dtype':
421
+ 'Data Type':
421
422
  pd.Series(dtype=str, data=[c.dtype for c in cols]),
422
- 'stype':
423
+ 'Semantic Type':
423
424
  pd.Series(dtype=str, data=[c.stype for c in cols]),
424
- 'is_primary_key':
425
+ 'Primary Key':
425
426
  pd.Series(
426
427
  dtype=bool,
427
428
  data=[self._primary_key == c.name for c in cols],
428
429
  ),
429
- 'is_time_column':
430
+ 'Time Column':
430
431
  pd.Series(
431
432
  dtype=bool,
432
433
  data=[self._time_column == c.name for c in cols],
433
434
  ),
434
- 'is_end_time_column':
435
+ 'End Time Column':
435
436
  pd.Series(
436
437
  dtype=bool,
437
438
  data=[self._end_time_column == c.name for c in cols],
@@ -440,30 +441,12 @@ class Table(ABC):
440
441
 
441
442
  def print_metadata(self) -> None:
442
443
  r"""Prints the :meth:`~metadata` of this table."""
443
- num_rows_repr = ''
444
+ msg = f"🏷️ Metadata of Table `{self.name}`"
444
445
  if num := self._num_rows:
445
- num_rows_repr = f' ({num} row)' if num == 1 else f' ({num:,} rows)'
446
-
447
- if in_snowflake_notebook():
448
- import streamlit as st
449
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
450
- st.markdown(md_repr)
451
- st.dataframe(self.metadata, hide_index=True)
452
- elif in_notebook():
453
- from IPython.display import Markdown, display
454
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
455
- display(Markdown(md_repr))
456
- df = self.metadata
457
- try:
458
- if hasattr(df.style, 'hide'):
459
- display(df.style.hide(axis='index')) # pandas=2
460
- else:
461
- display(df.style.hide_index()) # pandas<1.3
462
- except ImportError:
463
- print(df.to_string(index=False)) # missing jinja2
464
- else:
465
- print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
466
- print(self.metadata.to_string(index=False))
446
+ msg += " (1 row)" if num == 1 else f" ({num:,} rows)"
447
+
448
+ display.title(msg)
449
+ display.dataframe(self.metadata)
467
450
 
468
451
  def infer_primary_key(self, verbose: bool = True) -> Self:
469
452
  r"""Infers the primary key in this table.
@@ -477,8 +460,8 @@ class Table(ABC):
477
460
  def _set_primary_key(primary_key: str) -> None:
478
461
  self.primary_key = primary_key
479
462
  if verbose:
480
- print(f"Inferred primary key '{primary_key}' for table "
481
- f"'{self.name}'")
463
+ display.message(f"Inferred primary key `{primary_key}` for "
464
+ f"table `{self.name}`")
482
465
 
483
466
  # Inference from source column metadata:
484
467
  if any(column.is_source for column in self.columns):
@@ -543,8 +526,8 @@ class Table(ABC):
543
526
  self.time_column = time_column
544
527
 
545
528
  if verbose:
546
- print(f"Inferred time column '{time_column}' for table "
547
- f"'{self.name}'")
529
+ display.message(f"Inferred time column `{time_column}` for "
530
+ f"table `{self.name}`")
548
531
 
549
532
  return self
550
533
 
@@ -560,15 +543,16 @@ class Table(ABC):
560
543
  if not self.has_primary_key():
561
544
  self.infer_primary_key(verbose=False)
562
545
  if self.has_primary_key():
563
- logs.append(f"primary key '{self._primary_key}'")
546
+ logs.append(f"primary key `{self._primary_key}`")
564
547
 
565
548
  if not self.has_time_column():
566
549
  self.infer_time_column(verbose=False)
567
550
  if self.has_time_column():
568
- logs.append(f"time column '{self._time_column}'")
551
+ logs.append(f"time column `{self._time_column}`")
569
552
 
570
553
  if verbose and len(logs) > 0:
571
- print(f"Inferred {' and '.join(logs)} for table '{self.name}'")
554
+ display.message(f"Inferred {' and '.join(logs)} for table "
555
+ f"`{self.name}`")
572
556
 
573
557
  return self
574
558
 
@@ -641,14 +625,18 @@ class Table(ABC):
641
625
  types match table data and semantic type specification.
642
626
  """
643
627
  def _to_datetime(ser: pd.Series) -> pd.Series:
644
- if not pd.api.types.is_datetime64_any_dtype(ser):
628
+ if (not pd.api.types.is_datetime64_any_dtype(ser)
629
+ and not (isinstance(ser.dtype, pd.ArrowDtype) and
630
+ pa.types.is_timestamp(ser.dtype.pyarrow_dtype))):
645
631
  with warnings.catch_warnings():
646
632
  warnings.filterwarnings(
647
633
  'ignore',
648
634
  message='Could not infer format',
649
635
  )
650
636
  ser = pd.to_datetime(ser, errors='coerce')
651
- if isinstance(ser.dtype, pd.DatetimeTZDtype):
637
+ if (isinstance(ser.dtype, pd.DatetimeTZDtype)
638
+ or (isinstance(ser.dtype, pd.ArrowDtype)
639
+ and ser.dtype.pyarrow_dtype.tz is not None)):
652
640
  ser = ser.dt.tz_localize(None)
653
641
  if ser.dtype != 'datetime64[ns]':
654
642
  ser = ser.astype('datetime64[ns]')