kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (52) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/utils.py +1399 -210
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +10 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/source.py +18 -0
  20. kumoai/experimental/rfm/base/table.py +545 -0
  21. kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
  22. kumoai/experimental/rfm/infer/__init__.py +6 -0
  23. kumoai/experimental/rfm/infer/dtype.py +79 -0
  24. kumoai/experimental/rfm/infer/pkey.py +126 -0
  25. kumoai/experimental/rfm/infer/time_col.py +62 -0
  26. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  27. kumoai/experimental/rfm/local_graph_sampler.py +58 -11
  28. kumoai/experimental/rfm/local_graph_store.py +45 -37
  29. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  30. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  31. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  32. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  33. kumoai/experimental/rfm/rfm.py +559 -148
  34. kumoai/experimental/rfm/sagemaker.py +138 -0
  35. kumoai/jobs.py +27 -1
  36. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  37. kumoai/pquery/prediction_table.py +5 -3
  38. kumoai/pquery/training_table.py +5 -3
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/trainer/job.py +9 -30
  42. kumoai/trainer/trainer.py +19 -10
  43. kumoai/utils/__init__.py +2 -1
  44. kumoai/utils/progress_logger.py +96 -16
  45. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
  46. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
  47. kumoai/experimental/rfm/local_table.py +0 -448
  48. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  49. kumoai/experimental/rfm/utils.py +0 -347
  50. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
  51. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
  52. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,15 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
8
 
6
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
7
13
  'contains_id',
8
14
  'contains_timestamp',
9
15
  'contains_categorical',
@@ -0,0 +1,79 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
9
+ 'bool': Dtype.bool,
10
+ 'boolean': Dtype.bool,
11
+ 'int8': Dtype.int,
12
+ 'int16': Dtype.int,
13
+ 'int32': Dtype.int,
14
+ 'int64': Dtype.int,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ }
24
+
25
+
26
+ def infer_dtype(ser: pd.Series) -> Dtype:
27
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
28
+
29
+ Args:
30
+ ser: A :class:`pandas.Series` to analyze.
31
+
32
+ Returns:
33
+ The data type.
34
+ """
35
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
36
+ return Dtype.date
37
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
38
+ return Dtype.timedelta
39
+ if isinstance(ser.dtype, pd.CategoricalDtype):
40
+ return Dtype.string
41
+
42
+ if (pd.api.types.is_object_dtype(ser.dtype)
43
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
44
+ index = ser.iloc[:1000].first_valid_index()
45
+ if index is not None and pd.api.types.is_list_like(ser[index]):
46
+ pos = ser.index.get_loc(index)
47
+ assert isinstance(pos, int)
48
+ ser = ser.iloc[pos:pos + 1000].dropna()
49
+ arr = pa.array(ser.tolist())
50
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
+
52
+ if isinstance(ser.dtype, pd.ArrowDtype):
53
+ if pa.types.is_list(ser.dtype.pyarrow_dtype):
54
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
+ if pa.types.is_integer(elem_dtype):
56
+ return Dtype.intlist
57
+ if pa.types.is_floating(elem_dtype):
58
+ return Dtype.floatlist
59
+ if pa.types.is_decimal(elem_dtype):
60
+ return Dtype.floatlist
61
+ if pa.types.is_string(elem_dtype):
62
+ return Dtype.stringlist
63
+ if pa.types.is_null(elem_dtype):
64
+ return Dtype.floatlist
65
+
66
+ if isinstance(ser.dtype, np.dtype):
67
+ dtype_str = str(ser.dtype).lower()
68
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
69
+ dtype_str = ser.dtype.name.lower()
70
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
71
+ elif isinstance(ser.dtype, pa.DataType):
72
+ dtype_str = str(ser.dtype).lower()
73
+ else:
74
+ dtype_str = 'object'
75
+
76
+ if dtype_str not in PANDAS_TO_DTYPE:
77
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
78
+
79
+ return PANDAS_TO_DTYPE[dtype_str]
@@ -0,0 +1,126 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_primary_key(
9
+ table_name: str,
10
+ df: pd.DataFrame,
11
+ candidates: list[str],
12
+ ) -> Optional[str]:
13
+ r"""Auto-detect potential primary key column.
14
+
15
+ Args:
16
+ table_name: The table name.
17
+ df: The pandas DataFrame to analyze.
18
+ candidates: A list of potential candidates.
19
+
20
+ Returns:
21
+ The name of the detected primary key, or ``None`` if not found.
22
+ """
23
+ # A list of (potentially modified) table names that are eligible to match
24
+ # with a primary key, i.e.:
25
+ # - UserInfo -> User
26
+ # - snakecase <-> camelcase
27
+ # - camelcase <-> snakecase
28
+ # - plural <-> singular (users -> user, eligibilities -> eligibility)
29
+ # - verb -> noun (qualifying -> qualify)
30
+ _table_names = {table_name}
31
+ if table_name.lower().endswith('_info'):
32
+ _table_names.add(table_name[:-5])
33
+ elif table_name.lower().endswith('info'):
34
+ _table_names.add(table_name[:-4])
35
+
36
+ table_names = set()
37
+ for _table_name in _table_names:
38
+ table_names.add(_table_name.lower())
39
+ snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
40
+ snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
41
+ table_names.add(snakecase.lower())
42
+ camelcase = _table_name.replace('_', '')
43
+ table_names.add(camelcase.lower())
44
+ if _table_name.lower().endswith('s'):
45
+ table_names.add(_table_name.lower()[:-1])
46
+ table_names.add(snakecase.lower()[:-1])
47
+ table_names.add(camelcase.lower()[:-1])
48
+ else:
49
+ table_names.add(_table_name.lower() + 's')
50
+ table_names.add(snakecase.lower() + 's')
51
+ table_names.add(camelcase.lower() + 's')
52
+ if _table_name.lower().endswith('ies'):
53
+ table_names.add(_table_name.lower()[:-3] + 'y')
54
+ table_names.add(snakecase.lower()[:-3] + 'y')
55
+ table_names.add(camelcase.lower()[:-3] + 'y')
56
+ elif _table_name.lower().endswith('y'):
57
+ table_names.add(_table_name.lower()[:-1] + 'ies')
58
+ table_names.add(snakecase.lower()[:-1] + 'ies')
59
+ table_names.add(camelcase.lower()[:-1] + 'ies')
60
+ if _table_name.lower().endswith('ing'):
61
+ table_names.add(_table_name.lower()[:-3])
62
+ table_names.add(snakecase.lower()[:-3])
63
+ table_names.add(camelcase.lower()[:-3])
64
+
65
+ scores: list[tuple[str, int]] = []
66
+ for col_name in candidates:
67
+ col_name_lower = col_name.lower()
68
+
69
+ score = 0
70
+
71
+ if col_name_lower == 'id':
72
+ score += 4
73
+
74
+ for table_name_lower in table_names:
75
+
76
+ if col_name_lower == table_name_lower:
77
+ score += 4 # USER -> USER
78
+ break
79
+
80
+ for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
81
+ if not col_name_lower.endswith(suffix):
82
+ continue
83
+
84
+ if col_name_lower == f'{table_name_lower}_{suffix}':
85
+ score += 5 # USER -> USER_ID
86
+ break
87
+
88
+ if col_name_lower == f'{table_name_lower}{suffix}':
89
+ score += 5 # User -> UserId
90
+ break
91
+
92
+ if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
93
+ score += 2
94
+
95
+ if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
96
+ score += 2
97
+
98
+ # `rel-bench` hard-coding :(
99
+ if table_name == 'studies' and col_name == 'nct_id':
100
+ score += 1
101
+
102
+ ser = df[col_name].iloc[:1_000_000]
103
+ score += 3 * (ser.nunique() / len(ser))
104
+
105
+ scores.append((col_name, score))
106
+
107
+ scores = [x for x in scores if x[-1] >= 4]
108
+ scores.sort(key=lambda x: x[-1], reverse=True)
109
+
110
+ if len(scores) == 0:
111
+ return None
112
+
113
+ if len(scores) == 1:
114
+ return scores[0][0]
115
+
116
+ # In case of multiple candidates, only return one if its score is unique:
117
+ if scores[0][1] != scores[1][1]:
118
+ return scores[0][0]
119
+
120
+ max_score = max(scores, key=lambda x: x[1])
121
+ candidates = [col_name for col_name, score in scores if score == max_score]
122
+ warnings.warn(f"Found multiple potential primary keys in table "
123
+ f"'{table_name}': {candidates}. Please specify the primary "
124
+ f"key for this table manually.")
125
+
126
+ return None
@@ -0,0 +1,62 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_time_column(
9
+ df: pd.DataFrame,
10
+ candidates: list[str],
11
+ ) -> Optional[str]:
12
+ r"""Auto-detect potential time column.
13
+
14
+ Args:
15
+ df: The pandas DataFrame to analyze.
16
+ candidates: A list of potential candidates.
17
+
18
+ Returns:
19
+ The name of the detected time column, or ``None`` if not found.
20
+ """
21
+ candidates = [ # Exclude all candidates with `*last*` in column names:
22
+ col_name for col_name in candidates
23
+ if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
24
+ ]
25
+
26
+ if len(candidates) == 0:
27
+ return None
28
+
29
+ if len(candidates) == 1:
30
+ return candidates[0]
31
+
32
+ # If there exists a dedicated `create*` column, use it as time column:
33
+ create_candidates = [
34
+ candidate for candidate in candidates
35
+ if candidate.lower().startswith('create')
36
+ ]
37
+ if len(create_candidates) == 1:
38
+ return create_candidates[0]
39
+ if len(create_candidates) > 1:
40
+ candidates = create_candidates
41
+
42
+ # Find the most optimal time column. Usually, it is the one pointing to
43
+ # the oldest timestamps:
44
+ with warnings.catch_warnings():
45
+ warnings.filterwarnings('ignore', message='Could not infer format')
46
+ min_timestamp_dict = {
47
+ key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
48
+ for key in candidates
49
+ }
50
+ min_timestamp_dict = {
51
+ key: value.min().tz_localize(None)
52
+ for key, value in min_timestamp_dict.items()
53
+ }
54
+ min_timestamp_dict = {
55
+ key: value
56
+ for key, value in min_timestamp_dict.items() if not pd.isna(value)
57
+ }
58
+
59
+ if len(min_timestamp_dict) == 0:
60
+ return None
61
+
62
+ return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
@@ -2,6 +2,7 @@ import re
2
2
  import warnings
3
3
 
4
4
  import pandas as pd
5
+ from dateutil.parser import UnknownTimezoneWarning
5
6
  from kumoapi.typing import Dtype, Stype
6
7
 
7
8
 
@@ -20,9 +21,7 @@ def contains_timestamp(ser: pd.Series, column_name: str, dtype: Dtype) -> bool:
20
21
  column_name,
21
22
  re.IGNORECASE,
22
23
  )
23
-
24
- if match is not None:
25
- return True
24
+ score = 0.3 if match is not None else 0.0
26
25
 
27
26
  ser = ser.iloc[:100]
28
27
  ser = ser.dropna()
@@ -34,5 +33,9 @@ def contains_timestamp(ser: pd.Series, column_name: str, dtype: Dtype) -> bool:
34
33
  ser = ser.astype(str) # Avoid parsing numbers as unix timestamps.
35
34
 
36
35
  with warnings.catch_warnings():
36
+ warnings.simplefilter('ignore', UnknownTimezoneWarning)
37
37
  warnings.filterwarnings('ignore', message='Could not infer format')
38
- return pd.to_datetime(ser, errors='coerce').notna().all()
38
+ mask = pd.to_datetime(ser, errors='coerce').notna()
39
+ score += int(mask.sum()) / len(mask)
40
+
41
+ return score >= 1.0
@@ -1,14 +1,54 @@
1
+ import re
1
2
  from typing import Dict, List, Optional, Tuple
2
3
 
3
4
  import numpy as np
4
5
  import pandas as pd
5
- from kumoapi.model_plan import RunMode
6
6
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  import kumoai.kumolib as kumolib
10
10
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.utils import normalize_text
11
+
12
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
13
+ MULTISPACE = re.compile(r"\s+")
14
+
15
+
16
+ def normalize_text(
17
+ ser: pd.Series,
18
+ max_words: Optional[int] = 50,
19
+ ) -> pd.Series:
20
+ r"""Normalizes text into a list of lower-case words.
21
+
22
+ Args:
23
+ ser: The :class:`pandas.Series` to normalize.
24
+ max_words: The maximum number of words to return.
25
+ This will auto-shrink any large text column to avoid blowing up
26
+ context size.
27
+ """
28
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
29
+ return ser
30
+
31
+ def normalize_fn(line: str) -> list[str]:
32
+ line = PUNCTUATION.sub(" ", line)
33
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
34
+ line = MULTISPACE.sub(" ", line)
35
+ words = line.split()
36
+ if max_words is not None:
37
+ words = words[:max_words]
38
+ return words
39
+
40
+ ser = ser.fillna('').astype(str)
41
+
42
+ if max_words is not None:
43
+ # We estimate the number of words as 5 characters + 1 space in an
44
+ # English text on average. We need this pre-filter here, as word
45
+ # splitting on a giant text can be very expensive:
46
+ ser = ser.str[:6 * max_words]
47
+
48
+ ser = ser.str.lower()
49
+ ser = ser.map(normalize_fn)
50
+
51
+ return ser
12
52
 
13
53
 
14
54
  class LocalGraphSampler:
@@ -33,7 +73,6 @@ class LocalGraphSampler:
33
73
  entity_table_names: Tuple[str, ...],
34
74
  node: np.ndarray,
35
75
  time: np.ndarray,
36
- run_mode: RunMode,
37
76
  num_neighbors: List[int],
38
77
  exclude_cols_dict: Dict[str, List[str]],
39
78
  ) -> Subgraph:
@@ -92,15 +131,23 @@ class LocalGraphSampler:
92
131
  )
93
132
  continue
94
133
 
95
- # Only store unique rows in `df` above a certain threshold:
96
- unique_node, inverse_node = np.unique(node, return_inverse=True)
97
- if len(node) > 1.05 * len(unique_node):
98
- df = df.iloc[unique_node]
99
- row = inverse_node
134
+ row: Optional[np.ndarray] = None
135
+ if table_name in self._graph_store.end_time_column_dict:
136
+ # Set end time to NaT for all values greater than anchor time:
137
+ df = df.iloc[node].reset_index(drop=True)
138
+ col_name = self._graph_store.end_time_column_dict[table_name]
139
+ ser = df[col_name]
140
+ value = ser.astype('datetime64[ns]').astype(int).to_numpy()
141
+ mask = value > time[batch]
142
+ df.loc[mask, col_name] = pd.NaT
100
143
  else:
101
- df = df.iloc[node]
102
- row = None
103
- df = df.reset_index(drop=True)
144
+ # Only store unique rows in `df` above a certain threshold:
145
+ unique_node, inverse = np.unique(node, return_inverse=True)
146
+ if len(node) > 1.05 * len(unique_node):
147
+ df = df.iloc[unique_node].reset_index(drop=True)
148
+ row = inverse
149
+ else:
150
+ df = df.iloc[node].reset_index(drop=True)
104
151
 
105
152
  # Filter data frame to minimal set of columns:
106
153
  df = df[columns]
@@ -1,14 +1,13 @@
1
1
  import warnings
2
- from typing import Dict, List, Optional, Tuple
2
+ from typing import Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import LocalGraph
10
- from kumoai.experimental.rfm.utils import normalize_text
11
- from kumoai.utils import ProgressLogger
9
+ from kumoai.experimental.rfm import Graph, LocalTable
10
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger
12
11
 
13
12
  try:
14
13
  import torch
@@ -20,13 +19,18 @@ except ImportError:
20
19
  class LocalGraphStore:
21
20
  def __init__(
22
21
  self,
23
- graph: LocalGraph,
24
- preprocess: bool = False,
25
- verbose: bool = True,
22
+ graph: Graph,
23
+ verbose: Union[bool, ProgressLogger] = True,
26
24
  ) -> None:
27
25
 
28
- with ProgressLogger("Materializing graph", verbose=verbose) as logger:
29
- self.df_dict, self.mask_dict = self.sanitize(graph, preprocess)
26
+ if not isinstance(verbose, ProgressLogger):
27
+ verbose = InteractiveProgressLogger(
28
+ "Materializing graph",
29
+ verbose=verbose,
30
+ )
31
+
32
+ with verbose as logger:
33
+ self.df_dict, self.mask_dict = self.sanitize(graph)
30
34
  self.stype_dict = self.get_stype_dict(graph)
31
35
  logger.log("Sanitized input data")
32
36
 
@@ -39,6 +43,7 @@ class LocalGraphStore:
39
43
 
40
44
  (
41
45
  self.time_column_dict,
46
+ self.end_time_column_dict,
42
47
  self.time_dict,
43
48
  self.min_time,
44
49
  self.max_time,
@@ -98,8 +103,7 @@ class LocalGraphStore:
98
103
 
99
104
  def sanitize(
100
105
  self,
101
- graph: LocalGraph,
102
- preprocess: bool = False,
106
+ graph: Graph,
103
107
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
104
108
  r"""Sanitizes raw data according to table schema definition:
105
109
 
@@ -108,17 +112,12 @@ class LocalGraphStore:
108
112
  * drops timezone information from timestamps
109
113
  * drops duplicate primary keys
110
114
  * removes rows with missing primary keys or time values
111
-
112
- If ``preprocess`` is set to ``True``, it will additionally pre-process
113
- data for faster model processing. In particular, it:
114
- * tokenizes any text column that is not a foreign key
115
115
  """
116
- df_dict: Dict[str, pd.DataFrame] = {
117
- table_name: table._data.copy(deep=False).reset_index(drop=True)
118
- for table_name, table in graph.tables.items()
119
- }
120
-
121
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
116
+ df_dict: Dict[str, pd.DataFrame] = {}
117
+ for table_name, table in graph.tables.items():
118
+ assert isinstance(table, LocalTable)
119
+ df = table._data
120
+ df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
122
121
 
123
122
  mask_dict: Dict[str, np.ndarray] = {}
124
123
  for table in graph.tables.values():
@@ -137,12 +136,6 @@ class LocalGraphStore:
137
136
  ser = ser.dt.tz_localize(None)
138
137
  df_dict[table.name][col.name] = ser
139
138
 
140
- # Normalize text in advance (but exclude foreign keys):
141
- if (preprocess and col.stype == Stype.text
142
- and (table.name, col.name) not in foreign_keys):
143
- ser = df_dict[table.name][col.name]
144
- df_dict[table.name][col.name] = normalize_text(ser)
145
-
146
139
  mask: Optional[np.ndarray] = None
147
140
  if table._time_column is not None:
148
141
  ser = df_dict[table.name][table._time_column]
@@ -158,7 +151,7 @@ class LocalGraphStore:
158
151
 
159
152
  return df_dict, mask_dict
160
153
 
161
- def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
154
+ def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
162
155
  stype_dict: Dict[str, Dict[str, Stype]] = {}
163
156
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
164
157
  for table in graph.tables.values():
@@ -173,7 +166,7 @@ class LocalGraphStore:
173
166
 
174
167
  def get_pkey_data(
175
168
  self,
176
- graph: LocalGraph,
169
+ graph: Graph,
177
170
  ) -> Tuple[
178
171
  Dict[str, str],
179
172
  Dict[str, pd.DataFrame],
@@ -195,11 +188,15 @@ class LocalGraphStore:
195
188
  pkey_map = pkey_map[self.mask_dict[table.name]]
196
189
 
197
190
  if len(pkey_map) == 0:
198
- raise ValueError(
199
- f"Found no valid rows in table '{table.name}' since there "
200
- f"exists not a single row with a non-N/A primary key."
201
- f"Consider fixing your underlying data or removing this "
202
- f"table from the graph.")
191
+ error_msg = f"Found no valid rows in table '{table.name}'. "
192
+ if table.has_time_column():
193
+ error_msg += ("Please make sure that there exists valid "
194
+ "non-N/A primary key and time column pairs "
195
+ "in this table.")
196
+ else:
197
+ error_msg += ("Please make sure that there exists valid "
198
+ "non-N/A primary keys in this table.")
199
+ raise ValueError(error_msg)
203
200
 
204
201
  pkey_map_dict[table.name] = pkey_map
205
202
 
@@ -207,18 +204,23 @@ class LocalGraphStore:
207
204
 
208
205
  def get_time_data(
209
206
  self,
210
- graph: LocalGraph,
207
+ graph: Graph,
211
208
  ) -> Tuple[
209
+ Dict[str, str],
212
210
  Dict[str, str],
213
211
  Dict[str, np.ndarray],
214
212
  pd.Timestamp,
215
213
  pd.Timestamp,
216
214
  ]:
217
215
  time_column_dict: Dict[str, str] = {}
216
+ end_time_column_dict: Dict[str, str] = {}
218
217
  time_dict: Dict[str, np.ndarray] = {}
219
218
  min_time = pd.Timestamp.max
220
219
  max_time = pd.Timestamp.min
221
220
  for table in graph.tables.values():
221
+ if table._end_time_column is not None:
222
+ end_time_column_dict[table.name] = table._end_time_column
223
+
222
224
  if table._time_column is None:
223
225
  continue
224
226
 
@@ -233,11 +235,17 @@ class LocalGraphStore:
233
235
  min_time = min(min_time, time.min())
234
236
  max_time = max(max_time, time.max())
235
237
 
236
- return time_column_dict, time_dict, min_time, max_time
238
+ return (
239
+ time_column_dict,
240
+ end_time_column_dict,
241
+ time_dict,
242
+ min_time,
243
+ max_time,
244
+ )
237
245
 
238
246
  def get_csc(
239
247
  self,
240
- graph: LocalGraph,
248
+ graph: Graph,
241
249
  ) -> Tuple[
242
250
  Dict[Tuple[str, str, str], np.ndarray],
243
251
  Dict[Tuple[str, str, str], np.ndarray],