kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +35 -7
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +164 -46
  8. kumoai/experimental/rfm/backend/__init__.py +0 -0
  9. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  11. kumoai/experimental/rfm/backend/local/sampler.py +131 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +14 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/sampler.py +287 -0
  20. kumoai/experimental/rfm/base/source.py +18 -0
  21. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  22. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  23. kumoai/experimental/rfm/infer/__init__.py +6 -0
  24. kumoai/experimental/rfm/infer/dtype.py +79 -0
  25. kumoai/experimental/rfm/infer/pkey.py +126 -0
  26. kumoai/experimental/rfm/infer/time_col.py +62 -0
  27. kumoai/experimental/rfm/local_graph_sampler.py +43 -4
  28. kumoai/experimental/rfm/local_pquery_driver.py +222 -27
  29. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
  31. kumoai/experimental/rfm/rfm.py +153 -96
  32. kumoai/experimental/rfm/sagemaker.py +138 -0
  33. kumoai/spcs.py +1 -3
  34. kumoai/testing/decorators.py +1 -1
  35. kumoai/utils/progress_logger.py +10 -4
  36. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/RECORD +40 -27
  38. kumoai/experimental/rfm/pquery/backend.py +0 -136
  39. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,79 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
9
+ 'bool': Dtype.bool,
10
+ 'boolean': Dtype.bool,
11
+ 'int8': Dtype.int,
12
+ 'int16': Dtype.int,
13
+ 'int32': Dtype.int,
14
+ 'int64': Dtype.int,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ }
24
+
25
+
26
+ def infer_dtype(ser: pd.Series) -> Dtype:
27
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
28
+
29
+ Args:
30
+ ser: A :class:`pandas.Series` to analyze.
31
+
32
+ Returns:
33
+ The data type.
34
+ """
35
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
36
+ return Dtype.date
37
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
38
+ return Dtype.timedelta
39
+ if isinstance(ser.dtype, pd.CategoricalDtype):
40
+ return Dtype.string
41
+
42
+ if (pd.api.types.is_object_dtype(ser.dtype)
43
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
44
+ index = ser.iloc[:1000].first_valid_index()
45
+ if index is not None and pd.api.types.is_list_like(ser[index]):
46
+ pos = ser.index.get_loc(index)
47
+ assert isinstance(pos, int)
48
+ ser = ser.iloc[pos:pos + 1000].dropna()
49
+ arr = pa.array(ser.tolist())
50
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
+
52
+ if isinstance(ser.dtype, pd.ArrowDtype):
53
+ if pa.types.is_list(ser.dtype.pyarrow_dtype):
54
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
+ if pa.types.is_integer(elem_dtype):
56
+ return Dtype.intlist
57
+ if pa.types.is_floating(elem_dtype):
58
+ return Dtype.floatlist
59
+ if pa.types.is_decimal(elem_dtype):
60
+ return Dtype.floatlist
61
+ if pa.types.is_string(elem_dtype):
62
+ return Dtype.stringlist
63
+ if pa.types.is_null(elem_dtype):
64
+ return Dtype.floatlist
65
+
66
+ if isinstance(ser.dtype, np.dtype):
67
+ dtype_str = str(ser.dtype).lower()
68
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
69
+ dtype_str = ser.dtype.name.lower()
70
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
71
+ elif isinstance(ser.dtype, pa.DataType):
72
+ dtype_str = str(ser.dtype).lower()
73
+ else:
74
+ dtype_str = 'object'
75
+
76
+ if dtype_str not in PANDAS_TO_DTYPE:
77
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
78
+
79
+ return PANDAS_TO_DTYPE[dtype_str]
@@ -0,0 +1,126 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_primary_key(
9
+ table_name: str,
10
+ df: pd.DataFrame,
11
+ candidates: list[str],
12
+ ) -> Optional[str]:
13
+ r"""Auto-detect potential primary key column.
14
+
15
+ Args:
16
+ table_name: The table name.
17
+ df: The pandas DataFrame to analyze.
18
+ candidates: A list of potential candidates.
19
+
20
+ Returns:
21
+ The name of the detected primary key, or ``None`` if not found.
22
+ """
23
+ # A list of (potentially modified) table names that are eligible to match
24
+ # with a primary key, i.e.:
25
+ # - UserInfo -> User
26
+ # - snakecase <-> camelcase
27
+ # - camelcase <-> snakecase
28
+ # - plural <-> singular (users -> user, eligibilities -> eligibility)
29
+ # - verb -> noun (qualifying -> qualify)
30
+ _table_names = {table_name}
31
+ if table_name.lower().endswith('_info'):
32
+ _table_names.add(table_name[:-5])
33
+ elif table_name.lower().endswith('info'):
34
+ _table_names.add(table_name[:-4])
35
+
36
+ table_names = set()
37
+ for _table_name in _table_names:
38
+ table_names.add(_table_name.lower())
39
+ snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
40
+ snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
41
+ table_names.add(snakecase.lower())
42
+ camelcase = _table_name.replace('_', '')
43
+ table_names.add(camelcase.lower())
44
+ if _table_name.lower().endswith('s'):
45
+ table_names.add(_table_name.lower()[:-1])
46
+ table_names.add(snakecase.lower()[:-1])
47
+ table_names.add(camelcase.lower()[:-1])
48
+ else:
49
+ table_names.add(_table_name.lower() + 's')
50
+ table_names.add(snakecase.lower() + 's')
51
+ table_names.add(camelcase.lower() + 's')
52
+ if _table_name.lower().endswith('ies'):
53
+ table_names.add(_table_name.lower()[:-3] + 'y')
54
+ table_names.add(snakecase.lower()[:-3] + 'y')
55
+ table_names.add(camelcase.lower()[:-3] + 'y')
56
+ elif _table_name.lower().endswith('y'):
57
+ table_names.add(_table_name.lower()[:-1] + 'ies')
58
+ table_names.add(snakecase.lower()[:-1] + 'ies')
59
+ table_names.add(camelcase.lower()[:-1] + 'ies')
60
+ if _table_name.lower().endswith('ing'):
61
+ table_names.add(_table_name.lower()[:-3])
62
+ table_names.add(snakecase.lower()[:-3])
63
+ table_names.add(camelcase.lower()[:-3])
64
+
65
+ scores: list[tuple[str, int]] = []
66
+ for col_name in candidates:
67
+ col_name_lower = col_name.lower()
68
+
69
+ score = 0
70
+
71
+ if col_name_lower == 'id':
72
+ score += 4
73
+
74
+ for table_name_lower in table_names:
75
+
76
+ if col_name_lower == table_name_lower:
77
+ score += 4 # USER -> USER
78
+ break
79
+
80
+ for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
81
+ if not col_name_lower.endswith(suffix):
82
+ continue
83
+
84
+ if col_name_lower == f'{table_name_lower}_{suffix}':
85
+ score += 5 # USER -> USER_ID
86
+ break
87
+
88
+ if col_name_lower == f'{table_name_lower}{suffix}':
89
+ score += 5 # User -> UserId
90
+ break
91
+
92
+ if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
93
+ score += 2
94
+
95
+ if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
96
+ score += 2
97
+
98
+ # `rel-bench` hard-coding :(
99
+ if table_name == 'studies' and col_name == 'nct_id':
100
+ score += 1
101
+
102
+ ser = df[col_name].iloc[:1_000_000]
103
+ score += 3 * (ser.nunique() / len(ser))
104
+
105
+ scores.append((col_name, score))
106
+
107
+ scores = [x for x in scores if x[-1] >= 4]
108
+ scores.sort(key=lambda x: x[-1], reverse=True)
109
+
110
+ if len(scores) == 0:
111
+ return None
112
+
113
+ if len(scores) == 1:
114
+ return scores[0][0]
115
+
116
+ # In case of multiple candidates, only return one if its score is unique:
117
+ if scores[0][1] != scores[1][1]:
118
+ return scores[0][0]
119
+
120
+ max_score = max(scores, key=lambda x: x[1])
121
+ candidates = [col_name for col_name, score in scores if score == max_score]
122
+ warnings.warn(f"Found multiple potential primary keys in table "
123
+ f"'{table_name}': {candidates}. Please specify the primary "
124
+ f"key for this table manually.")
125
+
126
+ return None
@@ -0,0 +1,62 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_time_column(
9
+ df: pd.DataFrame,
10
+ candidates: list[str],
11
+ ) -> Optional[str]:
12
+ r"""Auto-detect potential time column.
13
+
14
+ Args:
15
+ df: The pandas DataFrame to analyze.
16
+ candidates: A list of potential candidates.
17
+
18
+ Returns:
19
+ The name of the detected time column, or ``None`` if not found.
20
+ """
21
+ candidates = [ # Exclude all candidates with `*last*` in column names:
22
+ col_name for col_name in candidates
23
+ if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
24
+ ]
25
+
26
+ if len(candidates) == 0:
27
+ return None
28
+
29
+ if len(candidates) == 1:
30
+ return candidates[0]
31
+
32
+ # If there exists a dedicated `create*` column, use it as time column:
33
+ create_candidates = [
34
+ candidate for candidate in candidates
35
+ if candidate.lower().startswith('create')
36
+ ]
37
+ if len(create_candidates) == 1:
38
+ return create_candidates[0]
39
+ if len(create_candidates) > 1:
40
+ candidates = create_candidates
41
+
42
+ # Find the most optimal time column. Usually, it is the one pointing to
43
+ # the oldest timestamps:
44
+ with warnings.catch_warnings():
45
+ warnings.filterwarnings('ignore', message='Could not infer format')
46
+ min_timestamp_dict = {
47
+ key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
48
+ for key in candidates
49
+ }
50
+ min_timestamp_dict = {
51
+ key: value.min().tz_localize(None)
52
+ for key, value in min_timestamp_dict.items()
53
+ }
54
+ min_timestamp_dict = {
55
+ key: value
56
+ for key, value in min_timestamp_dict.items() if not pd.isna(value)
57
+ }
58
+
59
+ if len(min_timestamp_dict) == 0:
60
+ return None
61
+
62
+ return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
@@ -1,14 +1,54 @@
1
+ import re
1
2
  from typing import Dict, List, Optional, Tuple
2
3
 
3
4
  import numpy as np
4
5
  import pandas as pd
5
- from kumoapi.model_plan import RunMode
6
6
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.utils import normalize_text
10
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
11
+
12
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
13
+ MULTISPACE = re.compile(r"\s+")
14
+
15
+
16
+ def normalize_text(
17
+ ser: pd.Series,
18
+ max_words: Optional[int] = 50,
19
+ ) -> pd.Series:
20
+ r"""Normalizes text into a list of lower-case words.
21
+
22
+ Args:
23
+ ser: The :class:`pandas.Series` to normalize.
24
+ max_words: The maximum number of words to return.
25
+ This will auto-shrink any large text column to avoid blowing up
26
+ context size.
27
+ """
28
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
29
+ return ser
30
+
31
+ def normalize_fn(line: str) -> list[str]:
32
+ line = PUNCTUATION.sub(" ", line)
33
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
34
+ line = MULTISPACE.sub(" ", line)
35
+ words = line.split()
36
+ if max_words is not None:
37
+ words = words[:max_words]
38
+ return words
39
+
40
+ ser = ser.fillna('').astype(str)
41
+
42
+ if max_words is not None:
43
+ # We estimate the number of words as 5 characters + 1 space in an
44
+ # English text on average. We need this pre-filter here, as word
45
+ # splitting on a giant text can be very expensive:
46
+ ser = ser.str[:6 * max_words]
47
+
48
+ ser = ser.str.lower()
49
+ ser = ser.map(normalize_fn)
50
+
51
+ return ser
12
52
 
13
53
 
14
54
  class LocalGraphSampler:
@@ -33,7 +73,6 @@ class LocalGraphSampler:
33
73
  entity_table_names: Tuple[str, ...],
34
74
  node: np.ndarray,
35
75
  time: np.ndarray,
36
- run_mode: RunMode,
37
76
  num_neighbors: List[int],
38
77
  exclude_cols_dict: Dict[str, List[str]],
39
78
  ) -> Subgraph: