kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/experimental/rfm/__init__.py +49 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  11. kumoai/experimental/rfm/backend/local/table.py +32 -14
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +186 -39
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
  18. kumoai/experimental/rfm/base/__init__.py +23 -3
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +761 -0
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  24. kumoai/experimental/rfm/base/table.py +380 -185
  25. kumoai/experimental/rfm/graph.py +404 -144
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +52 -60
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +283 -230
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/pquery/predictive_query.py +10 -6
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/trainer/distilled_trainer.py +175 -0
  40. kumoai/utils/__init__.py +3 -2
  41. kumoai/utils/display.py +51 -0
  42. kumoai/utils/progress_logger.py +178 -12
  43. kumoai/utils/sql.py +3 -0
  44. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
  45. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
  46. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  47. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  48. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  49. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  50. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
kumoai/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import os
2
3
  import sys
3
4
  import threading
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
68
69
  if self._url is None or (self._api_key is None
69
70
  and self._spcs_token is None
70
71
  and self._snowpark_session is None):
71
- raise ValueError(
72
- "Client creation or authentication failed; please re-create "
73
- "your client before proceeding.")
72
+ raise ValueError("Client creation or authentication failed. "
73
+ "Please re-create your client before proceeding.")
74
74
 
75
75
  if hasattr(self.thread_local, '_client'):
76
76
  # Set the spcs token in the client to ensure it has the latest.
@@ -123,10 +123,9 @@ def init(
123
123
  """ # noqa
124
124
  # Avoid mutations to the global state after it is set:
125
125
  if global_state.initialized:
126
- print(
127
- "Client has already been created. To re-initialize Kumo, please "
128
- "start a new interpreter. No changes will be made to the current "
129
- "session.")
126
+ warnings.warn("Kumo SDK already initialized. To re-initialize the "
127
+ "SDK, please start a new interpreter. No changes will "
128
+ "be made to the current session.")
130
129
  return
131
130
 
132
131
  set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
@@ -138,15 +137,15 @@ def init(
138
137
  if snowflake_application:
139
138
  if url is not None:
140
139
  raise ValueError(
141
- "Client creation failed: both snowflake_application and url "
142
- "are specified. If running from a snowflake notebook, specify"
143
- "only snowflake_application.")
140
+ "Kumo SDK initialization failed. Both 'snowflake_application' "
141
+ "and 'url' are specified. If running from a Snowflake "
142
+ "notebook, specify only 'snowflake_application'.")
144
143
  snowpark_session = _get_active_session()
145
144
  if not snowpark_session:
146
145
  raise ValueError(
147
- "Client creation failed: snowflake_application is specified "
148
- "without an active snowpark session. If running outside "
149
- "a snowflake notebook, specify a URL and credentials.")
146
+ "Kumo SDK initialization failed. 'snowflake_application' is "
147
+ "specified without an active Snowpark session. If running "
148
+ "outside a Snowflake notebook, specify a URL and credentials.")
150
149
  description = snowpark_session.sql(
151
150
  f"DESCRIBE SERVICE {snowflake_application}."
152
151
  "USER_SCHEMA.KUMO_SERVICE").collect()[0]
@@ -155,14 +154,14 @@ def init(
155
154
  if api_key is None and not snowflake_application:
156
155
  if snowflake_credentials is None:
157
156
  raise ValueError(
158
- "Client creation failed: Neither API key nor snowflake "
159
- "credentials provided. Please either set the 'KUMO_API_KEY' "
160
- "or explicitly call `kumoai.init(...)`.")
157
+ "Kumo SDK initialization failed. Neither an API key nor "
158
+ "Snowflake credentials provided. Please either set the "
159
+ "'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
161
160
  if (set(snowflake_credentials.keys())
162
161
  != {'user', 'password', 'account'}):
163
162
  raise ValueError(
164
- f"Provided credentials should be a dictionary with keys "
165
- f"'user', 'password', and 'account'. Only "
163
+ f"Provided Snowflake credentials should be a dictionary with "
164
+ f"keys 'user', 'password', and 'account'. Only "
166
165
  f"{set(snowflake_credentials.keys())} were provided.")
167
166
 
168
167
  # Get or infer URL:
@@ -173,10 +172,10 @@ def init(
173
172
  except KeyError:
174
173
  pass
175
174
  if url is None:
176
- raise ValueError(
177
- "Client creation failed: endpoint URL not provided. Please "
178
- "either set the 'KUMO_API_ENDPOINT' environment variable or "
179
- "explicitly call `kumoai.init(...)`.")
175
+ raise ValueError("Kumo SDK initialization failed since no endpoint "
176
+ "URL was provided. Please either set the "
177
+ "'KUMO_API_ENDPOINT' environment variable or "
178
+ "explicitly call `kumoai.init(...)`.")
180
179
 
181
180
  # Assign global state after verification that client can be created and
182
181
  # authenticated successfully:
@@ -198,10 +197,8 @@ def init(
198
197
  logger = logging.getLogger('kumoai')
199
198
  log_level = logging.getLevelName(logger.getEffectiveLevel())
200
199
 
201
- logger.info(
202
- f"Successfully initialized the Kumo SDK (version {__version__}) "
203
- f"against deployment {url}, with "
204
- f"log level {log_level}.")
200
+ logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
201
+ f"'{url}'")
205
202
 
206
203
 
207
204
  def set_log_level(level: str) -> None:
@@ -280,7 +277,19 @@ __all__ = [
280
277
  ]
281
278
 
282
279
 
280
+ def in_snowflake_notebook() -> bool:
281
+ try:
282
+ from snowflake.snowpark.context import get_active_session
283
+ import streamlit # noqa: F401
284
+ get_active_session()
285
+ return True
286
+ except Exception:
287
+ return False
288
+
289
+
283
290
  def in_notebook() -> bool:
291
+ if in_snowflake_notebook():
292
+ return True
284
293
  try:
285
294
  from IPython import get_ipython
286
295
  shell = get_ipython()
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.13.0.dev202512031731'
1
+ __version__ = '2.14.0.dev202512301731'
kumoai/client/client.py CHANGED
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
13
13
  ArtifactExportJobAPI,
14
14
  BaselineJobAPI,
15
15
  BatchPredictionJobAPI,
16
+ DistillationJobAPI,
16
17
  GeneratePredictionTableJobAPI,
17
18
  GenerateTrainTableJobAPI,
18
19
  LLMJobAPI,
@@ -132,6 +133,11 @@ class KumoClient:
132
133
  from kumoai.client.jobs import TrainingJobAPI
133
134
  return TrainingJobAPI(self)
134
135
 
136
+ @property
137
+ def distillation_job_api(self) -> 'DistillationJobAPI':
138
+ from kumoai.client.jobs import DistillationJobAPI
139
+ return DistillationJobAPI(self)
140
+
135
141
  @property
136
142
  def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
137
143
  from kumoai.client.jobs import BatchPredictionJobAPI
kumoai/client/jobs.py CHANGED
@@ -22,6 +22,8 @@ from kumoapi.jobs import (
22
22
  BatchPredictionRequest,
23
23
  CancelBatchPredictionJobResponse,
24
24
  CancelTrainingJobResponse,
25
+ DistillationJobRequest,
26
+ DistillationJobResource,
25
27
  ErrorDetails,
26
28
  GeneratePredictionTableJobResource,
27
29
  GeneratePredictionTableRequest,
@@ -171,6 +173,28 @@ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
171
173
  return resource.config
172
174
 
173
175
 
176
+ class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
177
+ DistillationJobResource]):
178
+ r"""Typed API definition for the distillation job resource."""
179
+ def __init__(self, client: KumoClient) -> None:
180
+ super().__init__(client, '/training_jobs/distilled_training_job',
181
+ DistillationJobResource)
182
+
183
+ def get_config(self, job_id: str) -> DistillationJobRequest:
184
+ raise NotImplementedError(
185
+ "Getting the configuration for a distillation job is "
186
+ "not implemented yet.")
187
+
188
+ def get_progress(self, id: str) -> AutoTrainerProgress:
189
+ raise NotImplementedError(
190
+ "Getting the progress for a distillation job is not "
191
+ "implemented yet.")
192
+
193
+ def cancel(self, id: str) -> CancelTrainingJobResponse:
194
+ raise NotImplementedError(
195
+ "Cancelling a distillation job is not implemented yet.")
196
+
197
+
174
198
  class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
175
199
  BatchPredictionJobResource]):
176
200
  r"""Typed API definition for the prediction job resource."""
kumoai/client/pquery.py CHANGED
@@ -176,8 +176,12 @@ def filter_model_plan(
176
176
  # Undefined
177
177
  pass
178
178
 
179
- new_opt_fields.append((field.name, _type, default))
180
- new_opts.append(getattr(section, field.name))
179
+ # Forward compatibility - Remove any newly introduced arguments not
180
+ # returned yet by the backend:
181
+ value = getattr(section, field.name)
182
+ if value != MissingType.VALUE:
183
+ new_opt_fields.append((field.name, _type, default))
184
+ new_opts.append(value)
181
185
 
182
186
  Section = dataclass(
183
187
  config=dict(validate_assignment=True),
@@ -6,11 +6,11 @@ import socket
6
6
  import threading
7
7
  from dataclasses import dataclass
8
8
  from enum import Enum
9
- from typing import Dict, Optional, Tuple
10
9
  from urllib.parse import urlparse
11
10
 
12
11
  import kumoai
13
12
  from kumoai.client.client import KumoClient
13
+ from kumoai.spcs import _get_active_session
14
14
 
15
15
  from .authenticate import authenticate
16
16
  from .sagemaker import (
@@ -49,7 +49,8 @@ class InferenceBackend(str, Enum):
49
49
 
50
50
 
51
51
  def _detect_backend(
52
- url: str) -> Tuple[InferenceBackend, Optional[str], Optional[str]]:
52
+ url: str, #
53
+ ) -> tuple[InferenceBackend, str | None, str | None]:
53
54
  parsed = urlparse(url)
54
55
 
55
56
  # Remote SageMaker
@@ -73,12 +74,27 @@ def _detect_backend(
73
74
  return InferenceBackend.REST, None, None
74
75
 
75
76
 
77
+ def _get_snowflake_url(snowflake_application: str) -> str:
78
+ snowpark_session = _get_active_session()
79
+ if not snowpark_session:
80
+ raise ValueError(
81
+ "KumoRFM initialization failed. 'snowflake_application' is "
82
+ "specified without an active Snowpark session. If running outside "
83
+ "a Snowflake notebook, specify a URL and credentials.")
84
+ with snowpark_session.connection.cursor() as cur:
85
+ cur.execute(
86
+ f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
87
+ f" ->> SELECT \"dns_name\" from $1")
88
+ dns_name: str = cur.fetchone()[0]
89
+ return f"http://{dns_name}:8000/api"
90
+
91
+
76
92
  @dataclass
77
93
  class RfmGlobalState:
78
94
  _url: str = '__url_not_provided__'
79
95
  _backend: InferenceBackend = InferenceBackend.UNKNOWN
80
- _region: Optional[str] = None
81
- _endpoint_name: Optional[str] = None
96
+ _region: str | None = None
97
+ _endpoint_name: str | None = None
82
98
  _thread_local = threading.local()
83
99
 
84
100
  # Thread-safe init-once.
@@ -87,6 +103,9 @@ class RfmGlobalState:
87
103
 
88
104
  @property
89
105
  def client(self) -> KumoClient:
106
+ if self._backend == InferenceBackend.UNKNOWN:
107
+ raise RuntimeError("KumoRFM is not yet initialized")
108
+
90
109
  if self._backend == InferenceBackend.REST:
91
110
  return kumoai.global_state.client
92
111
 
@@ -121,52 +140,58 @@ global_state = RfmGlobalState()
121
140
 
122
141
 
123
142
  def init(
124
- url: Optional[str] = None,
125
- api_key: Optional[str] = None,
126
- snowflake_credentials: Optional[Dict[str, str]] = None,
127
- snowflake_application: Optional[str] = None,
143
+ url: str | None = None,
144
+ api_key: str | None = None,
145
+ snowflake_credentials: dict[str, str] | None = None,
146
+ snowflake_application: str | None = None,
128
147
  log_level: str = "INFO",
129
148
  ) -> None:
130
149
  with global_state._lock:
131
150
  if global_state._initialized:
132
151
  if url != global_state._url:
133
- raise ValueError(
134
- "Kumo RFM has already been initialized with a different "
135
- "URL. Re-initialization with a different URL is not "
152
+ raise RuntimeError(
153
+ "KumoRFM has already been initialized with a different "
154
+ "API URL. Re-initialization with a different URL is not "
136
155
  "supported.")
137
156
  return
138
157
 
158
+ if snowflake_application:
159
+ if url is not None:
160
+ raise ValueError(
161
+ "KumoRFM initialization failed. Both "
162
+ "'snowflake_application' and 'url' are specified. If "
163
+ "running from a Snowflake notebook, specify only "
164
+ "'snowflake_application'.")
165
+ url = _get_snowflake_url(snowflake_application)
166
+ api_key = "test:DISABLED"
167
+
139
168
  if url is None:
140
169
  url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
141
170
 
142
171
  backend, region, endpoint_name = _detect_backend(url)
143
172
  if backend == InferenceBackend.REST:
144
- # Initialize kumoai.global_state
145
- if (kumoai.global_state.initialized
146
- and kumoai.global_state._url != url):
147
- raise ValueError(
148
- "Kumo AI SDK has already been initialized with different "
149
- "API URL. Please restart Python interpreter and "
150
- "initialize via kumoai.rfm.init()")
151
- kumoai.init(url=url, api_key=api_key,
152
- snowflake_credentials=snowflake_credentials,
153
- snowflake_application=snowflake_application,
154
- log_level=log_level)
173
+ kumoai.init(
174
+ url=url,
175
+ api_key=api_key,
176
+ snowflake_credentials=snowflake_credentials,
177
+ snowflake_application=snowflake_application,
178
+ log_level=log_level,
179
+ )
155
180
  elif backend == InferenceBackend.AWS_SAGEMAKER:
156
181
  assert region
157
182
  assert endpoint_name
158
183
  KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
184
+ logger.info("KumoRFM initialized in AWS SageMaker")
159
185
  else:
160
186
  assert backend == InferenceBackend.LOCAL_SAGEMAKER
161
187
  KumoClient_SageMakerProxy_Local(url).authenticate()
188
+ logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
162
189
 
163
190
  global_state._url = url
164
191
  global_state._backend = backend
165
192
  global_state._region = region
166
193
  global_state._endpoint_name = endpoint_name
167
194
  global_state._initialized = True
168
- logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
169
- url)
170
195
 
171
196
 
172
197
  LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
@@ -2,12 +2,11 @@ import logging
2
2
  import os
3
3
  import platform
4
4
  from datetime import datetime
5
- from typing import Optional
6
5
 
7
6
  from kumoai import in_notebook
8
7
 
9
8
 
10
- def authenticate(api_url: Optional[str] = None) -> None:
9
+ def authenticate(api_url: str | None = None) -> None:
11
10
  """Authenticates the user and sets the Kumo API key for the SDK.
12
11
 
13
12
  This function detects the current environment and launches the appropriate
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
65
64
  import webbrowser
66
65
  from getpass import getpass
67
66
  from socketserver import TCPServer
68
- from typing import Any, Dict
67
+ from typing import Any
69
68
 
70
69
  logger = logging.getLogger('kumoai')
71
70
 
72
- token_status: Dict[str, Any] = {
71
+ token_status: dict[str, Any] = {
73
72
  'token': None,
74
73
  'token_name': None,
75
74
  'failed': False
@@ -32,7 +32,11 @@ Please create a feature request at 'https://github.com/kumo-ai/kumo-rfm'."""
32
32
  raise RuntimeError(_msg) from e
33
33
 
34
34
  from .table import LocalTable
35
+ from .graph_store import LocalGraphStore
36
+ from .sampler import LocalSampler
35
37
 
36
38
  __all__ = [
37
39
  'LocalTable',
40
+ 'LocalGraphStore',
41
+ 'LocalSampler',
38
42
  ]
@@ -1,13 +1,12 @@
1
- import warnings
2
- from typing import Dict, List, Optional, Tuple, Union
1
+ from typing import TYPE_CHECKING
3
2
 
4
3
  import numpy as np
5
4
  import pandas as pd
6
5
  from kumoapi.rfm.context import Subgraph
7
- from kumoapi.typing import Stype
8
6
 
9
- from kumoai.experimental.rfm import Graph, LocalTable
10
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
7
+ from kumoai.experimental.rfm.backend.local import LocalTable
8
+ from kumoai.experimental.rfm.base import Table
9
+ from kumoai.utils import ProgressLogger
11
10
 
12
11
  try:
13
12
  import torch
@@ -15,42 +14,40 @@ try:
15
14
  except ImportError:
16
15
  WITH_TORCH = False
17
16
 
17
+ if TYPE_CHECKING:
18
+ from kumoai.experimental.rfm import Graph
19
+
18
20
 
19
21
  class LocalGraphStore:
20
22
  def __init__(
21
23
  self,
22
- graph: Graph,
23
- verbose: Union[bool, ProgressLogger] = True,
24
+ graph: 'Graph',
25
+ verbose: bool | ProgressLogger = True,
24
26
  ) -> None:
25
27
 
26
28
  if not isinstance(verbose, ProgressLogger):
27
- verbose = InteractiveProgressLogger(
28
- "Materializing graph",
29
+ verbose = ProgressLogger.default(
30
+ msg="Materializing graph",
29
31
  verbose=verbose,
30
32
  )
31
33
 
32
34
  with verbose as logger:
33
35
  self.df_dict, self.mask_dict = self.sanitize(graph)
34
- self.stype_dict = self.get_stype_dict(graph)
35
36
  logger.log("Sanitized input data")
36
37
 
37
- self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
38
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
38
39
  num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
39
40
  if num_pkeys > 1:
40
41
  logger.log(f"Collected primary keys from {num_pkeys} tables")
41
42
  else:
42
43
  logger.log(f"Collected primary key from {num_pkeys} table")
43
44
 
44
- (
45
- self.time_column_dict,
46
- self.end_time_column_dict,
47
- self.time_dict,
48
- self.min_time,
49
- self.max_time,
50
- ) = self.get_time_data(graph)
51
- if self.max_time != pd.Timestamp.min:
45
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
46
+ if len(self.min_max_time_dict) > 0:
47
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
48
+ max_time = max(t for _, t in self.min_max_time_dict.values())
52
49
  logger.log(f"Identified temporal graph from "
53
- f"{self.min_time.date()} to {self.max_time.date()}")
50
+ f"{min_time.date()} to {max_time.date()}")
54
51
  else:
55
52
  logger.log("Identified static graph without timestamps")
56
53
 
@@ -60,14 +57,6 @@ class LocalGraphStore:
60
57
  logger.log(f"Created graph with {num_nodes:,} nodes and "
61
58
  f"{num_edges:,} edges")
62
59
 
63
- @property
64
- def node_types(self) -> List[str]:
65
- return list(self.df_dict.keys())
66
-
67
- @property
68
- def edge_types(self) -> List[Tuple[str, str, str]]:
69
- return list(self.row_dict.keys())
70
-
71
60
  def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
72
61
  r"""Returns the node ID given primary keys.
73
62
 
@@ -103,8 +92,8 @@ class LocalGraphStore:
103
92
 
104
93
  def sanitize(
105
94
  self,
106
- graph: Graph,
107
- ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
95
+ graph: 'Graph',
96
+ ) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
108
97
  r"""Sanitizes raw data according to table schema definition:
109
98
 
110
99
  In particular, it:
@@ -113,30 +102,24 @@ class LocalGraphStore:
113
102
  * drops duplicate primary keys
114
103
  * removes rows with missing primary keys or time values
115
104
  """
116
- df_dict: Dict[str, pd.DataFrame] = {}
105
+ df_dict: dict[str, pd.DataFrame] = {}
117
106
  for table_name, table in graph.tables.items():
118
107
  assert isinstance(table, LocalTable)
119
- df = table._data
120
- df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
108
+ df_dict[table_name] = Table._sanitize(
109
+ df=table._data.copy(deep=False).reset_index(drop=True),
110
+ dtype_dict={
111
+ column.name: column.dtype
112
+ for column in table.columns
113
+ },
114
+ stype_dict={
115
+ column.name: column.stype
116
+ for column in table.columns
117
+ },
118
+ )
121
119
 
122
- mask_dict: Dict[str, np.ndarray] = {}
120
+ mask_dict: dict[str, np.ndarray] = {}
123
121
  for table in graph.tables.values():
124
- for col in table.columns:
125
- if col.stype == Stype.timestamp:
126
- ser = df_dict[table.name][col.name]
127
- if not pd.api.types.is_datetime64_any_dtype(ser):
128
- with warnings.catch_warnings():
129
- warnings.filterwarnings(
130
- 'ignore',
131
- message='Could not infer format',
132
- )
133
- ser = pd.to_datetime(ser, errors='coerce')
134
- df_dict[table.name][col.name] = ser
135
- if isinstance(ser.dtype, pd.DatetimeTZDtype):
136
- ser = ser.dt.tz_localize(None)
137
- df_dict[table.name][col.name] = ser
138
-
139
- mask: Optional[np.ndarray] = None
122
+ mask: np.ndarray | None = None
140
123
  if table._time_column is not None:
141
124
  ser = df_dict[table.name][table._time_column]
142
125
  mask = ser.notna().to_numpy()
@@ -151,34 +134,16 @@ class LocalGraphStore:
151
134
 
152
135
  return df_dict, mask_dict
153
136
 
154
- def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
155
- stype_dict: Dict[str, Dict[str, Stype]] = {}
156
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
157
- for table in graph.tables.values():
158
- stype_dict[table.name] = {}
159
- for column in table.columns:
160
- if column == table.primary_key:
161
- continue
162
- if (table.name, column.name) in foreign_keys:
163
- continue
164
- stype_dict[table.name][column.name] = column.stype
165
- return stype_dict
166
-
167
- def get_pkey_data(
137
+ def get_pkey_map_dict(
168
138
  self,
169
- graph: Graph,
170
- ) -> Tuple[
171
- Dict[str, str],
172
- Dict[str, pd.DataFrame],
173
- ]:
174
- pkey_name_dict: Dict[str, str] = {}
175
- pkey_map_dict: Dict[str, pd.DataFrame] = {}
139
+ graph: 'Graph',
140
+ ) -> dict[str, pd.DataFrame]:
141
+ pkey_map_dict: dict[str, pd.DataFrame] = {}
176
142
 
177
143
  for table in graph.tables.values():
178
144
  if table._primary_key is None:
179
145
  continue
180
146
 
181
- pkey_name_dict[table.name] = table._primary_key
182
147
  pkey = self.df_dict[table.name][table._primary_key]
183
148
  pkey_map = pd.DataFrame(
184
149
  dict(arange=range(len(pkey))),
@@ -200,61 +165,48 @@ class LocalGraphStore:
200
165
 
201
166
  pkey_map_dict[table.name] = pkey_map
202
167
 
203
- return pkey_name_dict, pkey_map_dict
168
+ return pkey_map_dict
204
169
 
205
170
  def get_time_data(
206
171
  self,
207
- graph: Graph,
208
- ) -> Tuple[
209
- Dict[str, str],
210
- Dict[str, str],
211
- Dict[str, np.ndarray],
212
- pd.Timestamp,
213
- pd.Timestamp,
172
+ graph: 'Graph',
173
+ ) -> tuple[
174
+ dict[str, np.ndarray],
175
+ dict[str, tuple[pd.Timestamp, pd.Timestamp]],
214
176
  ]:
215
- time_column_dict: Dict[str, str] = {}
216
- end_time_column_dict: Dict[str, str] = {}
217
- time_dict: Dict[str, np.ndarray] = {}
218
- min_time = pd.Timestamp.max
219
- max_time = pd.Timestamp.min
177
+ time_dict: dict[str, np.ndarray] = {}
178
+ min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
220
179
  for table in graph.tables.values():
221
- if table._end_time_column is not None:
222
- end_time_column_dict[table.name] = table._end_time_column
223
-
224
180
  if table._time_column is None:
225
181
  continue
226
182
 
227
183
  time = self.df_dict[table.name][table._time_column]
228
- time_dict[table.name] = time.astype('datetime64[ns]').astype(
229
- int).to_numpy() // 1000**3
230
- time_column_dict[table.name] = table._time_column
184
+ time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
231
185
 
232
186
  if table.name in self.mask_dict.keys():
233
187
  time = time[self.mask_dict[table.name]]
234
188
  if len(time) > 0:
235
- min_time = min(min_time, time.min())
236
- max_time = max(max_time, time.max())
189
+ min_max_time_dict[table.name] = (time.min(), time.max())
190
+ else:
191
+ min_max_time_dict[table.name] = (
192
+ pd.Timestamp.max,
193
+ pd.Timestamp.min,
194
+ )
237
195
 
238
- return (
239
- time_column_dict,
240
- end_time_column_dict,
241
- time_dict,
242
- min_time,
243
- max_time,
244
- )
196
+ return time_dict, min_max_time_dict
245
197
 
246
198
  def get_csc(
247
199
  self,
248
- graph: Graph,
249
- ) -> Tuple[
250
- Dict[Tuple[str, str, str], np.ndarray],
251
- Dict[Tuple[str, str, str], np.ndarray],
200
+ graph: 'Graph',
201
+ ) -> tuple[
202
+ dict[tuple[str, str, str], np.ndarray],
203
+ dict[tuple[str, str, str], np.ndarray],
252
204
  ]:
253
205
  # A mapping from raw primary keys to node indices (0 to N-1):
254
- map_dict: Dict[str, pd.CategoricalDtype] = {}
206
+ map_dict: dict[str, pd.CategoricalDtype] = {}
255
207
  # A dictionary to manage offsets of node indices for invalid rows:
256
- offset_dict: Dict[str, np.ndarray] = {}
257
- for table_name in set(edge.dst_table for edge in graph.edges):
208
+ offset_dict: dict[str, np.ndarray] = {}
209
+ for table_name in {edge.dst_table for edge in graph.edges}:
258
210
  ser = self.df_dict[table_name][graph[table_name]._primary_key]
259
211
  if table_name in self.mask_dict.keys():
260
212
  mask = self.mask_dict[table_name]
@@ -263,8 +215,8 @@ class LocalGraphStore:
263
215
  map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
264
216
 
265
217
  # Build CSC graph representation:
266
- row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
267
- colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
218
+ row_dict: dict[tuple[str, str, str], np.ndarray] = {}
219
+ colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
268
220
  for src_table, fkey, dst_table in graph.edges:
269
221
  src_df = self.df_dict[src_table]
270
222
  dst_df = self.df_dict[dst_table]
@@ -326,7 +278,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
326
278
  return torch.from_numpy(input).argsort().numpy()
327
279
 
328
280
 
329
- def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
281
+ def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
330
282
  assert len(inputs) >= 1
331
283
 
332
284
  if not WITH_TORCH: