kumoai 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
  4. kumoai/experimental/rfm/backend/local/sampler.py +229 -45
  5. kumoai/experimental/rfm/backend/local/table.py +12 -2
  6. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  7. kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +35 -17
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
  10. kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
  11. kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
  12. kumoai/experimental/rfm/base/__init__.py +16 -5
  13. kumoai/experimental/rfm/base/sampler.py +538 -52
  14. kumoai/experimental/rfm/base/source.py +1 -0
  15. kumoai/experimental/rfm/base/sql_sampler.py +56 -0
  16. kumoai/experimental/rfm/base/table.py +12 -1
  17. kumoai/experimental/rfm/graph.py +26 -9
  18. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  19. kumoai/experimental/rfm/rfm.py +214 -151
  20. kumoai/pquery/predictive_query.py +10 -6
  21. kumoai/testing/snow.py +50 -0
  22. kumoai/utils/__init__.py +2 -0
  23. kumoai/utils/sql.py +3 -0
  24. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
  25. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
  26. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  27. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  28. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
  29. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ class SourceColumn:
9
9
  dtype: Dtype
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
+ is_nullable: bool
12
13
 
13
14
 
14
15
  @dataclass
@@ -0,0 +1,56 @@
1
+ from abc import abstractmethod
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
8
+
9
+
10
+ class SQLSampler(Sampler):
11
+ def _sample_subgraph(
12
+ self,
13
+ entity_table_name: str,
14
+ entity_pkey: pd.Series,
15
+ anchor_time: pd.Series | Literal['entity'],
16
+ columns_dict: dict[str, set[str]],
17
+ num_neighbors: list[int],
18
+ ) -> SamplerOutput:
19
+
20
+ df, batch = self._by_pkey(
21
+ table_name=entity_table_name,
22
+ pkey=entity_pkey,
23
+ columns=columns_dict[entity_table_name],
24
+ )
25
+ if len(batch) != len(entity_pkey):
26
+ raise KeyError("Invalid primary keys") # TODO
27
+
28
+ perm = batch.argsort()
29
+ batch = batch[perm]
30
+ df = df.iloc[perm].reset_index(drop=True)
31
+
32
+ if not isinstance(anchor_time, pd.Series):
33
+ time_column = self.time_column_dict[entity_table_name]
34
+ anchor_time = df[time_column]
35
+
36
+ return SamplerOutput(
37
+ anchor_time=anchor_time.astype(int).to_numpy(),
38
+ df_dict={entity_table_name: df},
39
+ inverse_dict={},
40
+ batch_dict={entity_table_name: batch},
41
+ num_sampled_nodes_dict={entity_table_name: [len(batch)]},
42
+ row_dict={},
43
+ col_dict={},
44
+ num_sampled_edges_dict={},
45
+ )
46
+
47
+ # Abstract Methods ########################################################
48
+
49
+ @abstractmethod
50
+ def _by_pkey(
51
+ self,
52
+ table_name: str,
53
+ pkey: pd.Series,
54
+ columns: set[str],
55
+ ) -> tuple[pd.DataFrame, np.ndarray]:
56
+ pass
@@ -11,7 +11,12 @@ from kumoapi.typing import Stype
11
11
  from typing_extensions import Self
12
12
 
13
13
  from kumoai import in_notebook, in_snowflake_notebook
14
- from kumoai.experimental.rfm.base import Column, SourceColumn, SourceForeignKey
14
+ from kumoai.experimental.rfm.base import (
15
+ Column,
16
+ DataBackend,
17
+ SourceColumn,
18
+ SourceForeignKey,
19
+ )
15
20
  from kumoai.experimental.rfm.infer import (
16
21
  contains_categorical,
17
22
  contains_id,
@@ -503,6 +508,12 @@ class Table(ABC):
503
508
 
504
509
  # Abstract Methods ########################################################
505
510
 
511
+ @property
512
+ @abstractmethod
513
+ def backend(self) -> DataBackend:
514
+ r"""The data backend of this table."""
515
+ pass
516
+
506
517
  @cached_property
507
518
  def _source_column_dict(self) -> Dict[str, SourceColumn]:
508
519
  return {col.name: col for col in self._get_source_columns()}
@@ -13,7 +13,7 @@ from kumoapi.typing import Stype
13
13
  from typing_extensions import Self
14
14
 
15
15
  from kumoai import in_notebook, in_snowflake_notebook
16
- from kumoai.experimental.rfm import Table
16
+ from kumoai.experimental.rfm.base import DataBackend, Table
17
17
  from kumoai.graph import Edge
18
18
  from kumoai.mixin import CastMixin
19
19
 
@@ -218,10 +218,12 @@ class Graph:
218
218
  connect,
219
219
  )
220
220
 
221
+ internal_connection = False
221
222
  if not isinstance(connection, Connection):
222
223
  connection = SqliteConnectionConfig._cast(connection)
223
224
  assert isinstance(connection, SqliteConnectionConfig)
224
225
  connection = connect(connection.uri, **connection.kwargs)
226
+ internal_connection = True
225
227
  assert isinstance(connection, Connection)
226
228
 
227
229
  if table_names is None:
@@ -234,6 +236,9 @@ class Graph:
234
236
 
235
237
  graph = cls(tables, edges=edges or [])
236
238
 
239
+ if internal_connection:
240
+ graph._connection = connection # type: ignore
241
+
237
242
  if infer_metadata:
238
243
  graph.infer_metadata(False)
239
244
 
@@ -394,7 +399,14 @@ class Graph:
394
399
 
395
400
  return graph
396
401
 
397
- # Tables ##############################################################
402
+ # Backend #################################################################
403
+
404
+ @property
405
+ def backend(self) -> DataBackend | None:
406
+ backends = [table.backend for table in self._tables.values()]
407
+ return backends[0] if len(backends) > 0 else None
408
+
409
+ # Tables ##################################################################
398
410
 
399
411
  def has_table(self, name: str) -> bool:
400
412
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -433,13 +445,10 @@ class Graph:
433
445
  raise KeyError(f"Cannot add table with name '{table.name}' to "
434
446
  f"this graph; table names must be globally unique.")
435
447
 
436
- if len(self._tables) > 0:
437
- cls = next(iter(self._tables.values())).__class__
438
- if table.__class__ != cls:
439
- raise ValueError(f"Cannot register a "
440
- f"'{table.__class__.__name__}' to this "
441
- f"graph since other tables are of type "
442
- f"'{cls.__name__}'.")
448
+ if self.backend is not None and table.backend != self.backend:
449
+ raise ValueError(f"Cannot register a table with backend "
450
+ f"'{table.backend}' to this graph since other "
451
+ f"tables have backend '{self.backend}'.")
443
452
 
444
453
  self._tables[table.name] = table
445
454
 
@@ -826,6 +835,10 @@ class Graph:
826
835
  raise ValueError("At least one table needs to be added to the "
827
836
  "graph")
828
837
 
838
+ backends = {table.backend for table in self._tables.values()}
839
+ if len(backends) != 1:
840
+ raise ValueError("Found multiple table backends in the graph")
841
+
829
842
  for edge in self.edges:
830
843
  src_table, fkey, dst_table = edge
831
844
 
@@ -1063,3 +1076,7 @@ class Graph:
1063
1076
  f' tables={tables},\n'
1064
1077
  f' edges={edges},\n'
1065
1078
  f')')
1079
+
1080
+ def __del__(self) -> None:
1081
+ if hasattr(self, '_connection'):
1082
+ self._connection.close()
@@ -134,7 +134,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
134
134
  outs: List[pd.Series] = []
135
135
  masks: List[np.ndarray] = []
136
136
  for _ in range(num_forecasts):
137
- anchor_target_time = anchor_time[target_batch]
137
+ anchor_target_time = anchor_time.iloc[target_batch]
138
138
  anchor_target_time = anchor_target_time.reset_index(drop=True)
139
139
 
140
140
  time_filter_mask = (target_time <= anchor_target_time +