kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,10 @@
1
1
  from abc import abstractmethod
2
+ from collections import defaultdict
2
3
  from typing import TYPE_CHECKING, Literal
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from kumoapi.rfm.context import Subgraph
6
8
  from kumoapi.typing import Dtype
7
9
 
8
10
  from kumoai.experimental.rfm.base import (
@@ -11,11 +13,14 @@ from kumoai.experimental.rfm.base import (
11
13
  SamplerOutput,
12
14
  SourceColumn,
13
15
  )
16
+ from kumoai.experimental.rfm.base.mapper import Mapper
14
17
  from kumoai.utils import ProgressLogger, quote_ident
15
18
 
16
19
  if TYPE_CHECKING:
17
20
  from kumoai.experimental.rfm import Graph
18
21
 
22
+ EdgeType = tuple[str, str, str]
23
+
19
24
 
20
25
  class SQLSampler(Sampler):
21
26
  def __init__(
@@ -100,11 +105,28 @@ class SQLSampler(Sampler):
100
105
  num_neighbors: list[int],
101
106
  ) -> SamplerOutput:
102
107
 
103
- # TODO Add entity time column to `columns_dict`.
108
+ # Make sure to always include primary key, foreign key and time columns
109
+ # during data fetching since these are needed for graph traversal:
110
+ sample_columns_dict: dict[str, set[str]] = {}
111
+ for table, columns in columns_dict.items():
112
+ sample_columns = columns | {
113
+ foreign_key
114
+ for foreign_key, _ in self.foreign_key_dict[table]
115
+ }
116
+ if primary_key := self.primary_key_dict.get(table):
117
+ sample_columns |= {primary_key}
118
+ sample_columns_dict[table] = sample_columns
119
+ if not isinstance(anchor_time, pd.Series):
120
+ sample_columns_dict[entity_table_name] |= {
121
+ self.time_column_dict[entity_table_name]
122
+ }
123
+
124
+ # Sample Entity Table #################################################
125
+
104
126
  df, batch = self._by_pkey(
105
127
  table_name=entity_table_name,
106
- pkey=entity_pkey,
107
- columns=columns_dict[entity_table_name],
128
+ index=entity_pkey,
129
+ columns=sample_columns_dict[entity_table_name],
108
130
  )
109
131
  if len(batch) != len(entity_pkey):
110
132
  mask = np.ones(len(entity_pkey), dtype=bool)
@@ -113,23 +135,230 @@ class SQLSampler(Sampler):
113
135
  f"{entity_pkey.iloc[mask].tolist()} do not exist "
114
136
  f"in the '{entity_table_name}' table")
115
137
 
138
+ # Make sure that entities are returned in expected order:
116
139
  perm = batch.argsort()
117
140
  batch = batch[perm]
118
141
  df = df.iloc[perm].reset_index(drop=True)
119
142
 
143
+ # Fill 'entity' anchor times with actual values:
120
144
  if not isinstance(anchor_time, pd.Series):
121
145
  time_column = self.time_column_dict[entity_table_name]
122
146
  anchor_time = df[time_column]
147
+ assert isinstance(anchor_time, pd.Series)
148
+
149
+ # Recursive Neighbor Sampling #########################################
150
+
151
+ mapper_dict: dict[str, Mapper] = defaultdict(
152
+ lambda: Mapper(num_examples=len(entity_pkey)))
153
+ mapper_dict[entity_table_name].add(
154
+ pkey=df[self.primary_key_dict[entity_table_name]],
155
+ batch=batch,
156
+ )
157
+
158
+ dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
159
+ dfs_dict[entity_table_name].append(df)
160
+ batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
161
+ batches_dict[entity_table_name].append(batch)
162
+ num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
163
+ lambda: [0] * (len(num_neighbors) + 1))
164
+ num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
165
+
166
+ rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
167
+ cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
168
+ num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
169
+ lambda: [0] * len(num_neighbors))
170
+
171
+ # The start index of data frame slices of the previous hop:
172
+ offset_dict: dict[str, int] = defaultdict(int)
173
+
174
+ for hop, neighbors in enumerate(num_neighbors):
175
+ if neighbors == 0:
176
+ break # Abort early.
177
+
178
+ for table in list(num_sampled_nodes_dict.keys()):
179
+ # Only sample from tables that have been visited in the
180
+ # previous hop:
181
+ if num_sampled_nodes_dict[table][hop] == 0:
182
+ continue
183
+
184
+ # Collect the slices of data sampled in the previous hop
185
+ # (but maintain only required key information):
186
+ cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
187
+ if table in self.primary_key_dict:
188
+ cols.append(self.primary_key_dict[table])
189
+ dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
190
+ df = pd.concat(
191
+ dfs,
192
+ axis=0,
193
+ ignore_index=True,
194
+ ) if len(dfs) > 1 else dfs[0]
195
+ batches = batches_dict[table][offset_dict[table]:]
196
+ batch = (np.concatenate(batches)
197
+ if len(batches) > 1 else batches[0])
198
+ offset_dict[table] = len(batches_dict[table]) # Increase.
199
+
200
+ pkey: pd.Series | None = None
201
+ index: pd.ndarray | None = None
202
+ if table in self.primary_key_dict:
203
+ pkey = df[self.primary_key_dict[table]]
204
+ index = mapper_dict[table].get(pkey, batch)
205
+
206
+ # Iterate over foreign keys in the current table:
207
+ for fkey, dst_table in self.foreign_key_dict[table]:
208
+ row = mapper_dict[dst_table].get(df[fkey], batch)
209
+ mask = row == -1
210
+ if mask.any():
211
+ key_df = pd.DataFrame({
212
+ 'fkey': df[fkey],
213
+ 'batch': batch,
214
+ }).iloc[mask]
215
+ # Only maintain unique keys per example:
216
+ unique_key_df = key_df.drop_duplicates()
217
+ # Fully de-duplicate keys across examples:
218
+ code, fkey_index = pd.factorize(unique_key_df['fkey'])
219
+
220
+ _df, _batch = self._by_pkey(
221
+ table_name=dst_table,
222
+ index=fkey_index,
223
+ columns=sample_columns_dict[dst_table],
224
+ ) # Ensure result is sorted according to input order:
225
+ _df = _df.iloc[_batch.argsort()]
226
+
227
+ # Compute valid entries (without dangling foreign keys)
228
+ # in `unique_fkey_df`:
229
+ _mask = np.full(len(fkey_index), fill_value=False)
230
+ _mask[_batch] = True
231
+ _mask = _mask[code]
232
+
233
+ # Recontruct unique (key, batch) pairs:
234
+ code, _ = pd.factorize(unique_key_df['fkey'][_mask])
235
+ _df = _df.iloc[code].reset_index(drop=True)
236
+ _batch = unique_key_df['batch'].to_numpy()[_mask]
237
+
238
+ # Register node IDs:
239
+ mapper_dict[dst_table].add(
240
+ pkey=_df[self.primary_key_dict[dst_table]],
241
+ batch=_batch,
242
+ )
243
+ row[mask] = mapper_dict[dst_table].get(
244
+ pkey=key_df['fkey'],
245
+ batch=key_df['batch'].to_numpy(),
246
+ ) # NOTE `row` may still hold `-1` for dangling fkeys.
247
+
248
+ dfs_dict[dst_table].append(_df)
249
+ batches_dict[dst_table].append(_batch)
250
+ num_sampled_nodes_dict[dst_table][hop + 1] += ( #
251
+ len(_batch))
252
+
253
+ mask = row != -1
254
+
255
+ col = index
256
+ if col is None:
257
+ start = sum(num_sampled_nodes_dict[table][:hop])
258
+ end = sum(num_sampled_nodes_dict[table][:hop + 1])
259
+ col = np.arange(start, end)
260
+
261
+ row = row[mask]
262
+ col = col[mask]
263
+
264
+ edge_type = (table, fkey, dst_table)
265
+ edge_type = Subgraph.rev_edge_type(edge_type)
266
+ rows_dict[edge_type].append(row)
267
+ cols_dict[edge_type].append(col)
268
+ num_sampled_edges_dict[edge_type][hop] = len(col)
269
+
270
+ # Iterate over foreign keys that reference the current table:
271
+ for src_table, fkey in self.rev_foreign_key_dict[table]:
272
+ assert pkey is not None and index is not None
273
+ _df, _batch = self._by_fkey(
274
+ table_name=src_table,
275
+ foreign_key=fkey,
276
+ index=pkey,
277
+ num_neighbors=neighbors,
278
+ anchor_time=anchor_time.iloc[batch],
279
+ columns=sample_columns_dict[src_table],
280
+ )
281
+
282
+ edge_type = (src_table, fkey, table)
283
+ cols_dict[edge_type].append(index[_batch])
284
+ num_sampled_edges_dict[edge_type][hop] = len(_batch)
285
+
286
+ _batch = batch[_batch]
287
+ num_nodes = sum(num_sampled_nodes_dict[src_table])
288
+ if src_table in self.primary_key_dict:
289
+ _pkey = _df[self.primary_key_dict]
290
+ mapper_dict[src_table].add(_pkey, _batch)
291
+ row = mapper_dict[src_table].get(_pkey, _batch)
292
+
293
+ # Only preserve unknown rows:
294
+ mask = row >= num_nodes # type: ignore
295
+ mask[pd.duplicated(row)] = False
296
+ _df = _df.iloc[mask]
297
+ _batch = _batch[mask]
298
+ else:
299
+ row = np.arange(num_nodes, num_nodes + len(_batch))
300
+
301
+ rows_dict[edge_type].append(row)
302
+ num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
303
+
304
+ dfs_dict[src_table].append(_df)
305
+ batches_dict[src_table].append(_batch)
306
+
307
+ # Post-Processing #####################################################
308
+
309
+ df_dict = {
310
+ table:
311
+ pd.concat(dfs, axis=0, ignore_index=True)
312
+ if len(dfs) > 1 else dfs[0]
313
+ for table, dfs in dfs_dict.items()
314
+ }
315
+
316
+ # Only store unique rows in `df` above a certain threshold:
317
+ inverse_dict: dict[str, np.ndarray] = {}
318
+ for table, df in df_dict.items():
319
+ if table not in self.primary_key_dict:
320
+ continue
321
+ unique, index, inverse = np.unique(
322
+ df_dict[table][self.primary_key_dict[table]],
323
+ return_index=True,
324
+ return_inverse=True,
325
+ )
326
+ if len(df) > 1.05 * len(unique):
327
+ df_dict[table] = df.iloc[index].reset_index(drop=True)
328
+ inverse_dict[table] = inverse
329
+
330
+ df_dict = { # Post-filter column set:
331
+ table: df[list(columns_dict[table])]
332
+ for table, df in df_dict.items()
333
+ }
334
+ batch_dict = {
335
+ table: np.concatenate(batches) if len(batches) > 1 else batches[0]
336
+ for table, batches in batches_dict.items()
337
+ }
338
+ row_dict = {
339
+ edge_type: np.concatenate(rows)
340
+ for edge_type, rows in rows_dict.items()
341
+ }
342
+ col_dict = {
343
+ edge_type: np.concatenate(cols)
344
+ for edge_type, cols in cols_dict.items()
345
+ }
346
+
347
+ if len(num_sampled_edges_dict) == 0: # Single table:
348
+ num_sampled_nodes_dict = {
349
+ key: value[:1]
350
+ for key, value in num_sampled_nodes_dict.items()
351
+ }
123
352
 
124
353
  return SamplerOutput(
125
354
  anchor_time=anchor_time.astype(int).to_numpy(),
126
- df_dict={entity_table_name: df},
127
- inverse_dict={},
128
- batch_dict={entity_table_name: batch},
129
- num_sampled_nodes_dict={entity_table_name: [len(batch)]},
130
- row_dict={},
131
- col_dict={},
132
- num_sampled_edges_dict={},
355
+ df_dict=df_dict,
356
+ inverse_dict=inverse_dict,
357
+ batch_dict=batch_dict,
358
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
359
+ row_dict=row_dict,
360
+ col_dict=col_dict,
361
+ num_sampled_edges_dict=num_sampled_edges_dict,
133
362
  )
134
363
 
135
364
  # Abstract Methods ########################################################
@@ -138,7 +367,19 @@ class SQLSampler(Sampler):
138
367
  def _by_pkey(
139
368
  self,
140
369
  table_name: str,
141
- pkey: pd.Series,
370
+ index: pd.Series,
371
+ columns: set[str],
372
+ ) -> tuple[pd.DataFrame, np.ndarray]:
373
+ pass
374
+
375
+ @abstractmethod
376
+ def _by_fkey(
377
+ self,
378
+ table_name: str,
379
+ foreign_key: str,
380
+ index: pd.Series,
381
+ num_neighbors: int,
382
+ anchor_time: pd.Series | None,
142
383
  columns: set[str],
143
384
  ) -> tuple[pd.DataFrame, np.ndarray]:
144
385
  pass
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from abc import ABC, abstractmethod
3
2
  from collections.abc import Sequence
4
3
  from functools import cached_property
@@ -20,6 +19,7 @@ from kumoai.experimental.rfm.base import (
20
19
  SourceColumn,
21
20
  SourceForeignKey,
22
21
  )
22
+ from kumoai.experimental.rfm.base.utils import to_datetime
23
23
  from kumoai.experimental.rfm.infer import (
24
24
  infer_dtype,
25
25
  infer_primary_key,
@@ -399,39 +399,39 @@ class Table(ABC):
399
399
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
400
400
  information about the columns in this table.
401
401
 
402
- The returned dataframe has columns ``name``, ``dtype``, ``stype``,
403
- ``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
404
- which provide an aggregate view of the properties of the columns of
405
- this table.
402
+ The returned dataframe has columns ``"Name"``, ``"Data Type"``,
403
+ ``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
404
+ ``"End Time Column"``, which provide an aggregated view of the
405
+ properties of the columns of this table.
406
406
 
407
407
  Example:
408
408
  >>> # doctest: +SKIP
409
409
  >>> import kumoai.experimental.rfm as rfm
410
410
  >>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
411
411
  >>> table.metadata
412
- name dtype stype is_primary_key is_time_column is_end_time_column
413
- 0 CustomerID float64 ID True False False
412
+ Name Data Type Semantic Type Primary Key Time Column End Time Column
413
+ 0 CustomerID float64 ID True False False
414
414
  """ # noqa: E501
415
415
  cols = self.columns
416
416
 
417
417
  return pd.DataFrame({
418
- 'name':
418
+ 'Name':
419
419
  pd.Series(dtype=str, data=[c.name for c in cols]),
420
- 'dtype':
420
+ 'Data Type':
421
421
  pd.Series(dtype=str, data=[c.dtype for c in cols]),
422
- 'stype':
422
+ 'Semantic Type':
423
423
  pd.Series(dtype=str, data=[c.stype for c in cols]),
424
- 'is_primary_key':
424
+ 'Primary Key':
425
425
  pd.Series(
426
426
  dtype=bool,
427
427
  data=[self._primary_key == c.name for c in cols],
428
428
  ),
429
- 'is_time_column':
429
+ 'Time Column':
430
430
  pd.Series(
431
431
  dtype=bool,
432
432
  data=[self._time_column == c.name for c in cols],
433
433
  ),
434
- 'is_end_time_column':
434
+ 'End Time Column':
435
435
  pd.Series(
436
436
  dtype=bool,
437
437
  data=[self._end_time_column == c.name for c in cols],
@@ -623,20 +623,6 @@ class Table(ABC):
623
623
  r"""Sanitzes a :class:`pandas.DataFrame` in-place such that its data
624
624
  types match table data and semantic type specification.
625
625
  """
626
- def _to_datetime(ser: pd.Series) -> pd.Series:
627
- if not pd.api.types.is_datetime64_any_dtype(ser):
628
- with warnings.catch_warnings():
629
- warnings.filterwarnings(
630
- 'ignore',
631
- message='Could not infer format',
632
- )
633
- ser = pd.to_datetime(ser, errors='coerce')
634
- if isinstance(ser.dtype, pd.DatetimeTZDtype):
635
- ser = ser.dt.tz_localize(None)
636
- if ser.dtype != 'datetime64[ns]':
637
- ser = ser.astype('datetime64[ns]')
638
- return ser
639
-
640
626
  def _to_list(ser: pd.Series, dtype: Dtype | None) -> pd.Series:
641
627
  if (pd.api.types.is_string_dtype(ser)
642
628
  and dtype in {Dtype.intlist, Dtype.floatlist}):
@@ -667,9 +653,9 @@ class Table(ABC):
667
653
  stype = (stype_dict or {}).get(column_name)
668
654
 
669
655
  if dtype == Dtype.time:
670
- df[column_name] = _to_datetime(df[column_name])
656
+ df[column_name] = to_datetime(df[column_name])
671
657
  elif stype == Stype.timestamp:
672
- df[column_name] = _to_datetime(df[column_name])
658
+ df[column_name] = to_datetime(df[column_name])
673
659
  elif dtype is not None and dtype.is_list():
674
660
  df[column_name] = _to_list(df[column_name], dtype)
675
661
  elif stype == Stype.sequence:
@@ -0,0 +1,36 @@
1
+ import warnings
2
+
3
+ import pandas as pd
4
+ import pyarrow as pa
5
+
6
+
7
+ def is_datetime(ser: pd.Series) -> bool:
8
+ r"""Check whether a :class:`pandas.Series` holds datetime values."""
9
+ if isinstance(ser.dtype, pd.ArrowDtype):
10
+ dtype = ser.dtype.pyarrow_dtype
11
+ return (pa.types.is_timestamp(dtype) or pa.types.is_date(dtype)
12
+ or pa.types.is_time(dtype))
13
+
14
+ return pd.api.types.is_datetime64_any_dtype(ser)
15
+
16
+
17
+ def to_datetime(ser: pd.Series) -> pd.Series:
18
+ """Converts a :class:`pandas.Series` to ``datetime64[ns]`` format."""
19
+ if isinstance(ser.dtype, pd.ArrowDtype):
20
+ ser = pd.Series(ser.to_numpy(), index=ser.index, name=ser.name)
21
+
22
+ if not pd.api.types.is_datetime64_any_dtype(ser):
23
+ with warnings.catch_warnings():
24
+ warnings.filterwarnings(
25
+ 'ignore',
26
+ message='Could not infer format',
27
+ )
28
+ ser = pd.to_datetime(ser, errors='coerce')
29
+
30
+ if isinstance(ser.dtype, pd.DatetimeTZDtype):
31
+ ser = ser.dt.tz_localize(None)
32
+
33
+ if ser.dtype != 'datetime64[ns]':
34
+ ser = ser.astype('datetime64[ns]')
35
+
36
+ return ser
@@ -649,28 +649,28 @@ class Graph:
649
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
650
650
  information about the tables in this graph.
651
651
 
652
- The returned dataframe has columns ``name``, ``primary_key``,
653
- ``time_column``, and ``end_time_column``, which provide an aggregate
654
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
655
655
 
656
656
  Example:
657
657
  >>> # doctest: +SKIP
658
658
  >>> import kumoai.experimental.rfm as rfm
659
659
  >>> graph = rfm.Graph(tables=...).infer_metadata()
660
660
  >>> graph.metadata # doctest: +SKIP
661
- name primary_key time_column end_time_column
662
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
663
663
  """
664
664
  tables = list(self.tables.values())
665
665
 
666
666
  return pd.DataFrame({
667
- 'name':
667
+ 'Name':
668
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
669
- 'primary_key':
669
+ 'Primary Key':
670
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
671
- 'time_column':
671
+ 'Time Column':
672
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
673
- 'end_time_column':
673
+ 'End Time Column':
674
674
  pd.Series(
675
675
  dtype=str,
676
676
  data=[t._end_time_column or '-' for t in tables],
@@ -3,6 +3,8 @@ import pandas as pd
3
3
  import pyarrow as pa
4
4
  from kumoapi.typing import Dtype
5
5
 
6
+ from kumoai.experimental.rfm.base.utils import is_datetime
7
+
6
8
  PANDAS_TO_DTYPE: dict[str, Dtype] = {
7
9
  'bool': Dtype.bool,
8
10
  'boolean': Dtype.bool,
@@ -34,7 +36,7 @@ def infer_dtype(ser: pd.Series) -> Dtype:
34
36
  Returns:
35
37
  The data type.
36
38
  """
37
- if pd.api.types.is_datetime64_any_dtype(ser.dtype):
39
+ if is_datetime(ser):
38
40
  return Dtype.date
39
41
  if pd.api.types.is_timedelta64_dtype(ser.dtype):
40
42
  return Dtype.timedelta
@@ -3,6 +3,8 @@ import warnings
3
3
 
4
4
  import pandas as pd
5
5
 
6
+ from kumoai.experimental.rfm.base.utils import to_datetime
7
+
6
8
 
7
9
  def infer_time_column(
8
10
  df: pd.DataFrame,
@@ -43,11 +45,11 @@ def infer_time_column(
43
45
  with warnings.catch_warnings():
44
46
  warnings.filterwarnings('ignore', message='Could not infer format')
45
47
  min_timestamp_dict = {
46
- key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
48
+ key: to_datetime(df[key].iloc[:10_000])
47
49
  for key in candidates
48
50
  }
49
51
  min_timestamp_dict = {
50
- key: value.min().tz_localize(None)
52
+ key: value.min()
51
53
  for key, value in min_timestamp_dict.items()
52
54
  }
53
55
  min_timestamp_dict = {