kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from .dtype import infer_dtype
|
|
2
|
+
from .pkey import infer_primary_key
|
|
3
|
+
from .time_col import infer_time_column
|
|
1
4
|
from .id import contains_id
|
|
2
5
|
from .timestamp import contains_timestamp
|
|
3
6
|
from .categorical import contains_categorical
|
|
4
7
|
from .multicategorical import contains_multicategorical
|
|
5
8
|
|
|
6
9
|
__all__ = [
|
|
10
|
+
'infer_dtype',
|
|
11
|
+
'infer_primary_key',
|
|
12
|
+
'infer_time_column',
|
|
7
13
|
'contains_id',
|
|
8
14
|
'contains_timestamp',
|
|
9
15
|
'contains_categorical',
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
PANDAS_TO_DTYPE: Dict[str, Dtype] = {
|
|
9
|
+
'bool': Dtype.bool,
|
|
10
|
+
'boolean': Dtype.bool,
|
|
11
|
+
'int8': Dtype.int,
|
|
12
|
+
'int16': Dtype.int,
|
|
13
|
+
'int32': Dtype.int,
|
|
14
|
+
'int64': Dtype.int,
|
|
15
|
+
'float16': Dtype.float,
|
|
16
|
+
'float32': Dtype.float,
|
|
17
|
+
'float64': Dtype.float,
|
|
18
|
+
'object': Dtype.string,
|
|
19
|
+
'string': Dtype.string,
|
|
20
|
+
'string[python]': Dtype.string,
|
|
21
|
+
'string[pyarrow]': Dtype.string,
|
|
22
|
+
'binary': Dtype.binary,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def infer_dtype(ser: pd.Series) -> Dtype:
|
|
27
|
+
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The data type.
|
|
34
|
+
"""
|
|
35
|
+
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
36
|
+
return Dtype.date
|
|
37
|
+
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
38
|
+
return Dtype.timedelta
|
|
39
|
+
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
40
|
+
return Dtype.string
|
|
41
|
+
|
|
42
|
+
if (pd.api.types.is_object_dtype(ser.dtype)
|
|
43
|
+
and not isinstance(ser.dtype, pd.ArrowDtype)):
|
|
44
|
+
index = ser.iloc[:1000].first_valid_index()
|
|
45
|
+
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
46
|
+
pos = ser.index.get_loc(index)
|
|
47
|
+
assert isinstance(pos, int)
|
|
48
|
+
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
49
|
+
arr = pa.array(ser.tolist())
|
|
50
|
+
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
51
|
+
|
|
52
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
53
|
+
if pa.types.is_list(ser.dtype.pyarrow_dtype):
|
|
54
|
+
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
55
|
+
if pa.types.is_integer(elem_dtype):
|
|
56
|
+
return Dtype.intlist
|
|
57
|
+
if pa.types.is_floating(elem_dtype):
|
|
58
|
+
return Dtype.floatlist
|
|
59
|
+
if pa.types.is_decimal(elem_dtype):
|
|
60
|
+
return Dtype.floatlist
|
|
61
|
+
if pa.types.is_string(elem_dtype):
|
|
62
|
+
return Dtype.stringlist
|
|
63
|
+
if pa.types.is_null(elem_dtype):
|
|
64
|
+
return Dtype.floatlist
|
|
65
|
+
|
|
66
|
+
if isinstance(ser.dtype, np.dtype):
|
|
67
|
+
dtype_str = str(ser.dtype).lower()
|
|
68
|
+
elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
|
|
69
|
+
dtype_str = ser.dtype.name.lower()
|
|
70
|
+
dtype_str = dtype_str.split('[')[0] # Remove backend metadata
|
|
71
|
+
elif isinstance(ser.dtype, pa.DataType):
|
|
72
|
+
dtype_str = str(ser.dtype).lower()
|
|
73
|
+
else:
|
|
74
|
+
dtype_str = 'object'
|
|
75
|
+
|
|
76
|
+
if dtype_str not in PANDAS_TO_DTYPE:
|
|
77
|
+
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
78
|
+
|
|
79
|
+
return PANDAS_TO_DTYPE[dtype_str]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_primary_key(
|
|
9
|
+
table_name: str,
|
|
10
|
+
df: pd.DataFrame,
|
|
11
|
+
candidates: list[str],
|
|
12
|
+
) -> Optional[str]:
|
|
13
|
+
r"""Auto-detect potential primary key column.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
table_name: The table name.
|
|
17
|
+
df: The pandas DataFrame to analyze.
|
|
18
|
+
candidates: A list of potential candidates.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The name of the detected primary key, or ``None`` if not found.
|
|
22
|
+
"""
|
|
23
|
+
# A list of (potentially modified) table names that are eligible to match
|
|
24
|
+
# with a primary key, i.e.:
|
|
25
|
+
# - UserInfo -> User
|
|
26
|
+
# - snakecase <-> camelcase
|
|
27
|
+
# - camelcase <-> snakecase
|
|
28
|
+
# - plural <-> singular (users -> user, eligibilities -> eligibility)
|
|
29
|
+
# - verb -> noun (qualifying -> qualify)
|
|
30
|
+
_table_names = {table_name}
|
|
31
|
+
if table_name.lower().endswith('_info'):
|
|
32
|
+
_table_names.add(table_name[:-5])
|
|
33
|
+
elif table_name.lower().endswith('info'):
|
|
34
|
+
_table_names.add(table_name[:-4])
|
|
35
|
+
|
|
36
|
+
table_names = set()
|
|
37
|
+
for _table_name in _table_names:
|
|
38
|
+
table_names.add(_table_name.lower())
|
|
39
|
+
snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
|
|
40
|
+
snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
|
|
41
|
+
table_names.add(snakecase.lower())
|
|
42
|
+
camelcase = _table_name.replace('_', '')
|
|
43
|
+
table_names.add(camelcase.lower())
|
|
44
|
+
if _table_name.lower().endswith('s'):
|
|
45
|
+
table_names.add(_table_name.lower()[:-1])
|
|
46
|
+
table_names.add(snakecase.lower()[:-1])
|
|
47
|
+
table_names.add(camelcase.lower()[:-1])
|
|
48
|
+
else:
|
|
49
|
+
table_names.add(_table_name.lower() + 's')
|
|
50
|
+
table_names.add(snakecase.lower() + 's')
|
|
51
|
+
table_names.add(camelcase.lower() + 's')
|
|
52
|
+
if _table_name.lower().endswith('ies'):
|
|
53
|
+
table_names.add(_table_name.lower()[:-3] + 'y')
|
|
54
|
+
table_names.add(snakecase.lower()[:-3] + 'y')
|
|
55
|
+
table_names.add(camelcase.lower()[:-3] + 'y')
|
|
56
|
+
elif _table_name.lower().endswith('y'):
|
|
57
|
+
table_names.add(_table_name.lower()[:-1] + 'ies')
|
|
58
|
+
table_names.add(snakecase.lower()[:-1] + 'ies')
|
|
59
|
+
table_names.add(camelcase.lower()[:-1] + 'ies')
|
|
60
|
+
if _table_name.lower().endswith('ing'):
|
|
61
|
+
table_names.add(_table_name.lower()[:-3])
|
|
62
|
+
table_names.add(snakecase.lower()[:-3])
|
|
63
|
+
table_names.add(camelcase.lower()[:-3])
|
|
64
|
+
|
|
65
|
+
scores: list[tuple[str, int]] = []
|
|
66
|
+
for col_name in candidates:
|
|
67
|
+
col_name_lower = col_name.lower()
|
|
68
|
+
|
|
69
|
+
score = 0
|
|
70
|
+
|
|
71
|
+
if col_name_lower == 'id':
|
|
72
|
+
score += 4
|
|
73
|
+
|
|
74
|
+
for table_name_lower in table_names:
|
|
75
|
+
|
|
76
|
+
if col_name_lower == table_name_lower:
|
|
77
|
+
score += 4 # USER -> USER
|
|
78
|
+
break
|
|
79
|
+
|
|
80
|
+
for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
|
|
81
|
+
if not col_name_lower.endswith(suffix):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
if col_name_lower == f'{table_name_lower}_{suffix}':
|
|
85
|
+
score += 5 # USER -> USER_ID
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
if col_name_lower == f'{table_name_lower}{suffix}':
|
|
89
|
+
score += 5 # User -> UserId
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
|
|
93
|
+
score += 2
|
|
94
|
+
|
|
95
|
+
if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
|
|
96
|
+
score += 2
|
|
97
|
+
|
|
98
|
+
# `rel-bench` hard-coding :(
|
|
99
|
+
if table_name == 'studies' and col_name == 'nct_id':
|
|
100
|
+
score += 1
|
|
101
|
+
|
|
102
|
+
ser = df[col_name].iloc[:1_000_000]
|
|
103
|
+
score += 3 * (ser.nunique() / len(ser))
|
|
104
|
+
|
|
105
|
+
scores.append((col_name, score))
|
|
106
|
+
|
|
107
|
+
scores = [x for x in scores if x[-1] >= 4]
|
|
108
|
+
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
109
|
+
|
|
110
|
+
if len(scores) == 0:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
if len(scores) == 1:
|
|
114
|
+
return scores[0][0]
|
|
115
|
+
|
|
116
|
+
# In case of multiple candidates, only return one if its score is unique:
|
|
117
|
+
if scores[0][1] != scores[1][1]:
|
|
118
|
+
return scores[0][0]
|
|
119
|
+
|
|
120
|
+
max_score = max(scores, key=lambda x: x[1])
|
|
121
|
+
candidates = [col_name for col_name, score in scores if score == max_score]
|
|
122
|
+
warnings.warn(f"Found multiple potential primary keys in table "
|
|
123
|
+
f"'{table_name}': {candidates}. Please specify the primary "
|
|
124
|
+
f"key for this table manually.")
|
|
125
|
+
|
|
126
|
+
return None
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_time_column(
|
|
9
|
+
df: pd.DataFrame,
|
|
10
|
+
candidates: list[str],
|
|
11
|
+
) -> Optional[str]:
|
|
12
|
+
r"""Auto-detect potential time column.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
df: The pandas DataFrame to analyze.
|
|
16
|
+
candidates: A list of potential candidates.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
The name of the detected time column, or ``None`` if not found.
|
|
20
|
+
"""
|
|
21
|
+
candidates = [ # Exclude all candidates with `*last*` in column names:
|
|
22
|
+
col_name for col_name in candidates
|
|
23
|
+
if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
if len(candidates) == 0:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
if len(candidates) == 1:
|
|
30
|
+
return candidates[0]
|
|
31
|
+
|
|
32
|
+
# If there exists a dedicated `create*` column, use it as time column:
|
|
33
|
+
create_candidates = [
|
|
34
|
+
candidate for candidate in candidates
|
|
35
|
+
if candidate.lower().startswith('create')
|
|
36
|
+
]
|
|
37
|
+
if len(create_candidates) == 1:
|
|
38
|
+
return create_candidates[0]
|
|
39
|
+
if len(create_candidates) > 1:
|
|
40
|
+
candidates = create_candidates
|
|
41
|
+
|
|
42
|
+
# Find the most optimal time column. Usually, it is the one pointing to
|
|
43
|
+
# the oldest timestamps:
|
|
44
|
+
with warnings.catch_warnings():
|
|
45
|
+
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
46
|
+
min_timestamp_dict = {
|
|
47
|
+
key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
|
|
48
|
+
for key in candidates
|
|
49
|
+
}
|
|
50
|
+
min_timestamp_dict = {
|
|
51
|
+
key: value.min().tz_localize(None)
|
|
52
|
+
for key, value in min_timestamp_dict.items()
|
|
53
|
+
}
|
|
54
|
+
min_timestamp_dict = {
|
|
55
|
+
key: value
|
|
56
|
+
for key, value in min_timestamp_dict.items() if not pd.isna(value)
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
if len(min_timestamp_dict) == 0:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
|
|
@@ -2,6 +2,7 @@ import re
|
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
|
+
from dateutil.parser import UnknownTimezoneWarning
|
|
5
6
|
from kumoapi.typing import Dtype, Stype
|
|
6
7
|
|
|
7
8
|
|
|
@@ -20,9 +21,7 @@ def contains_timestamp(ser: pd.Series, column_name: str, dtype: Dtype) -> bool:
|
|
|
20
21
|
column_name,
|
|
21
22
|
re.IGNORECASE,
|
|
22
23
|
)
|
|
23
|
-
|
|
24
|
-
if match is not None:
|
|
25
|
-
return True
|
|
24
|
+
score = 0.3 if match is not None else 0.0
|
|
26
25
|
|
|
27
26
|
ser = ser.iloc[:100]
|
|
28
27
|
ser = ser.dropna()
|
|
@@ -34,5 +33,9 @@ def contains_timestamp(ser: pd.Series, column_name: str, dtype: Dtype) -> bool:
|
|
|
34
33
|
ser = ser.astype(str) # Avoid parsing numbers as unix timestamps.
|
|
35
34
|
|
|
36
35
|
with warnings.catch_warnings():
|
|
36
|
+
warnings.simplefilter('ignore', UnknownTimezoneWarning)
|
|
37
37
|
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
38
|
-
|
|
38
|
+
mask = pd.to_datetime(ser, errors='coerce').notna()
|
|
39
|
+
score += int(mask.sum()) / len(mask)
|
|
40
|
+
|
|
41
|
+
return score >= 1.0
|
|
@@ -1,14 +1,54 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from typing import Dict, List, Optional, Tuple
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
6
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
7
|
from kumoapi.typing import Stype
|
|
8
8
|
|
|
9
9
|
import kumoai.kumolib as kumolib
|
|
10
10
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
|
|
13
|
+
MULTISPACE = re.compile(r"\s+")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def normalize_text(
|
|
17
|
+
ser: pd.Series,
|
|
18
|
+
max_words: Optional[int] = 50,
|
|
19
|
+
) -> pd.Series:
|
|
20
|
+
r"""Normalizes text into a list of lower-case words.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
ser: The :class:`pandas.Series` to normalize.
|
|
24
|
+
max_words: The maximum number of words to return.
|
|
25
|
+
This will auto-shrink any large text column to avoid blowing up
|
|
26
|
+
context size.
|
|
27
|
+
"""
|
|
28
|
+
if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
|
|
29
|
+
return ser
|
|
30
|
+
|
|
31
|
+
def normalize_fn(line: str) -> list[str]:
|
|
32
|
+
line = PUNCTUATION.sub(" ", line)
|
|
33
|
+
line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
|
|
34
|
+
line = MULTISPACE.sub(" ", line)
|
|
35
|
+
words = line.split()
|
|
36
|
+
if max_words is not None:
|
|
37
|
+
words = words[:max_words]
|
|
38
|
+
return words
|
|
39
|
+
|
|
40
|
+
ser = ser.fillna('').astype(str)
|
|
41
|
+
|
|
42
|
+
if max_words is not None:
|
|
43
|
+
# We estimate the number of words as 5 characters + 1 space in an
|
|
44
|
+
# English text on average. We need this pre-filter here, as word
|
|
45
|
+
# splitting on a giant text can be very expensive:
|
|
46
|
+
ser = ser.str[:6 * max_words]
|
|
47
|
+
|
|
48
|
+
ser = ser.str.lower()
|
|
49
|
+
ser = ser.map(normalize_fn)
|
|
50
|
+
|
|
51
|
+
return ser
|
|
12
52
|
|
|
13
53
|
|
|
14
54
|
class LocalGraphSampler:
|
|
@@ -33,7 +73,6 @@ class LocalGraphSampler:
|
|
|
33
73
|
entity_table_names: Tuple[str, ...],
|
|
34
74
|
node: np.ndarray,
|
|
35
75
|
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
76
|
num_neighbors: List[int],
|
|
38
77
|
exclude_cols_dict: Dict[str, List[str]],
|
|
39
78
|
) -> Subgraph:
|
|
@@ -92,15 +131,23 @@ class LocalGraphSampler:
|
|
|
92
131
|
)
|
|
93
132
|
continue
|
|
94
133
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
df = df.iloc[
|
|
99
|
-
|
|
134
|
+
row: Optional[np.ndarray] = None
|
|
135
|
+
if table_name in self._graph_store.end_time_column_dict:
|
|
136
|
+
# Set end time to NaT for all values greater than anchor time:
|
|
137
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
138
|
+
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
139
|
+
ser = df[col_name]
|
|
140
|
+
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
141
|
+
mask = value > time[batch]
|
|
142
|
+
df.loc[mask, col_name] = pd.NaT
|
|
100
143
|
else:
|
|
101
|
-
df
|
|
102
|
-
|
|
103
|
-
|
|
144
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
145
|
+
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
146
|
+
if len(node) > 1.05 * len(unique_node):
|
|
147
|
+
df = df.iloc[unique_node].reset_index(drop=True)
|
|
148
|
+
row = inverse
|
|
149
|
+
else:
|
|
150
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
104
151
|
|
|
105
152
|
# Filter data frame to minimal set of columns:
|
|
106
153
|
df = df[columns]
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Optional, Tuple
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from kumoapi.rfm.context import Subgraph
|
|
7
7
|
from kumoapi.typing import Stype
|
|
8
8
|
|
|
9
|
-
from kumoai.experimental.rfm import
|
|
10
|
-
from kumoai.
|
|
11
|
-
from kumoai.utils import ProgressLogger
|
|
9
|
+
from kumoai.experimental.rfm import Graph, LocalTable
|
|
10
|
+
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
12
11
|
|
|
13
12
|
try:
|
|
14
13
|
import torch
|
|
@@ -20,13 +19,18 @@ except ImportError:
|
|
|
20
19
|
class LocalGraphStore:
|
|
21
20
|
def __init__(
|
|
22
21
|
self,
|
|
23
|
-
graph:
|
|
24
|
-
|
|
25
|
-
verbose: bool = True,
|
|
22
|
+
graph: Graph,
|
|
23
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
26
24
|
) -> None:
|
|
27
25
|
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
if not isinstance(verbose, ProgressLogger):
|
|
27
|
+
verbose = InteractiveProgressLogger(
|
|
28
|
+
"Materializing graph",
|
|
29
|
+
verbose=verbose,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
with verbose as logger:
|
|
33
|
+
self.df_dict, self.mask_dict = self.sanitize(graph)
|
|
30
34
|
self.stype_dict = self.get_stype_dict(graph)
|
|
31
35
|
logger.log("Sanitized input data")
|
|
32
36
|
|
|
@@ -39,6 +43,7 @@ class LocalGraphStore:
|
|
|
39
43
|
|
|
40
44
|
(
|
|
41
45
|
self.time_column_dict,
|
|
46
|
+
self.end_time_column_dict,
|
|
42
47
|
self.time_dict,
|
|
43
48
|
self.min_time,
|
|
44
49
|
self.max_time,
|
|
@@ -98,8 +103,7 @@ class LocalGraphStore:
|
|
|
98
103
|
|
|
99
104
|
def sanitize(
|
|
100
105
|
self,
|
|
101
|
-
graph:
|
|
102
|
-
preprocess: bool = False,
|
|
106
|
+
graph: Graph,
|
|
103
107
|
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
|
|
104
108
|
r"""Sanitizes raw data according to table schema definition:
|
|
105
109
|
|
|
@@ -108,17 +112,12 @@ class LocalGraphStore:
|
|
|
108
112
|
* drops timezone information from timestamps
|
|
109
113
|
* drops duplicate primary keys
|
|
110
114
|
* removes rows with missing primary keys or time values
|
|
111
|
-
|
|
112
|
-
If ``preprocess`` is set to ``True``, it will additionally pre-process
|
|
113
|
-
data for faster model processing. In particular, it:
|
|
114
|
-
* tokenizes any text column that is not a foreign key
|
|
115
115
|
"""
|
|
116
|
-
df_dict: Dict[str, pd.DataFrame] = {
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
116
|
+
df_dict: Dict[str, pd.DataFrame] = {}
|
|
117
|
+
for table_name, table in graph.tables.items():
|
|
118
|
+
assert isinstance(table, LocalTable)
|
|
119
|
+
df = table._data
|
|
120
|
+
df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
|
|
122
121
|
|
|
123
122
|
mask_dict: Dict[str, np.ndarray] = {}
|
|
124
123
|
for table in graph.tables.values():
|
|
@@ -137,12 +136,6 @@ class LocalGraphStore:
|
|
|
137
136
|
ser = ser.dt.tz_localize(None)
|
|
138
137
|
df_dict[table.name][col.name] = ser
|
|
139
138
|
|
|
140
|
-
# Normalize text in advance (but exclude foreign keys):
|
|
141
|
-
if (preprocess and col.stype == Stype.text
|
|
142
|
-
and (table.name, col.name) not in foreign_keys):
|
|
143
|
-
ser = df_dict[table.name][col.name]
|
|
144
|
-
df_dict[table.name][col.name] = normalize_text(ser)
|
|
145
|
-
|
|
146
139
|
mask: Optional[np.ndarray] = None
|
|
147
140
|
if table._time_column is not None:
|
|
148
141
|
ser = df_dict[table.name][table._time_column]
|
|
@@ -158,7 +151,7 @@ class LocalGraphStore:
|
|
|
158
151
|
|
|
159
152
|
return df_dict, mask_dict
|
|
160
153
|
|
|
161
|
-
def get_stype_dict(self, graph:
|
|
154
|
+
def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
|
|
162
155
|
stype_dict: Dict[str, Dict[str, Stype]] = {}
|
|
163
156
|
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
164
157
|
for table in graph.tables.values():
|
|
@@ -173,7 +166,7 @@ class LocalGraphStore:
|
|
|
173
166
|
|
|
174
167
|
def get_pkey_data(
|
|
175
168
|
self,
|
|
176
|
-
graph:
|
|
169
|
+
graph: Graph,
|
|
177
170
|
) -> Tuple[
|
|
178
171
|
Dict[str, str],
|
|
179
172
|
Dict[str, pd.DataFrame],
|
|
@@ -195,11 +188,15 @@ class LocalGraphStore:
|
|
|
195
188
|
pkey_map = pkey_map[self.mask_dict[table.name]]
|
|
196
189
|
|
|
197
190
|
if len(pkey_map) == 0:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
191
|
+
error_msg = f"Found no valid rows in table '{table.name}'. "
|
|
192
|
+
if table.has_time_column():
|
|
193
|
+
error_msg += ("Please make sure that there exists valid "
|
|
194
|
+
"non-N/A primary key and time column pairs "
|
|
195
|
+
"in this table.")
|
|
196
|
+
else:
|
|
197
|
+
error_msg += ("Please make sure that there exists valid "
|
|
198
|
+
"non-N/A primary keys in this table.")
|
|
199
|
+
raise ValueError(error_msg)
|
|
203
200
|
|
|
204
201
|
pkey_map_dict[table.name] = pkey_map
|
|
205
202
|
|
|
@@ -207,18 +204,23 @@ class LocalGraphStore:
|
|
|
207
204
|
|
|
208
205
|
def get_time_data(
|
|
209
206
|
self,
|
|
210
|
-
graph:
|
|
207
|
+
graph: Graph,
|
|
211
208
|
) -> Tuple[
|
|
209
|
+
Dict[str, str],
|
|
212
210
|
Dict[str, str],
|
|
213
211
|
Dict[str, np.ndarray],
|
|
214
212
|
pd.Timestamp,
|
|
215
213
|
pd.Timestamp,
|
|
216
214
|
]:
|
|
217
215
|
time_column_dict: Dict[str, str] = {}
|
|
216
|
+
end_time_column_dict: Dict[str, str] = {}
|
|
218
217
|
time_dict: Dict[str, np.ndarray] = {}
|
|
219
218
|
min_time = pd.Timestamp.max
|
|
220
219
|
max_time = pd.Timestamp.min
|
|
221
220
|
for table in graph.tables.values():
|
|
221
|
+
if table._end_time_column is not None:
|
|
222
|
+
end_time_column_dict[table.name] = table._end_time_column
|
|
223
|
+
|
|
222
224
|
if table._time_column is None:
|
|
223
225
|
continue
|
|
224
226
|
|
|
@@ -233,11 +235,17 @@ class LocalGraphStore:
|
|
|
233
235
|
min_time = min(min_time, time.min())
|
|
234
236
|
max_time = max(max_time, time.max())
|
|
235
237
|
|
|
236
|
-
return
|
|
238
|
+
return (
|
|
239
|
+
time_column_dict,
|
|
240
|
+
end_time_column_dict,
|
|
241
|
+
time_dict,
|
|
242
|
+
min_time,
|
|
243
|
+
max_time,
|
|
244
|
+
)
|
|
237
245
|
|
|
238
246
|
def get_csc(
|
|
239
247
|
self,
|
|
240
|
-
graph:
|
|
248
|
+
graph: Graph,
|
|
241
249
|
) -> Tuple[
|
|
242
250
|
Dict[Tuple[str, str, str], np.ndarray],
|
|
243
251
|
Dict[Tuple[str, str, str], np.ndarray],
|