kumoai 2.12.0.dev202511111731__cp311-cp311-macosx_11_0_arm64.whl → 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/connector/utils.py +23 -2
  5. kumoai/experimental/rfm/__init__.py +162 -46
  6. kumoai/experimental/rfm/backend/__init__.py +0 -0
  7. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  8. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  9. kumoai/experimental/rfm/backend/local/sampler.py +242 -0
  10. kumoai/experimental/rfm/backend/local/table.py +109 -0
  11. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  14. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  15. kumoai/experimental/rfm/base/__init__.py +14 -0
  16. kumoai/experimental/rfm/base/column.py +66 -0
  17. kumoai/experimental/rfm/base/sampler.py +374 -0
  18. kumoai/experimental/rfm/base/source.py +18 -0
  19. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  20. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  21. kumoai/experimental/rfm/infer/__init__.py +6 -0
  22. kumoai/experimental/rfm/infer/dtype.py +79 -0
  23. kumoai/experimental/rfm/infer/pkey.py +126 -0
  24. kumoai/experimental/rfm/infer/time_col.py +62 -0
  25. kumoai/experimental/rfm/local_graph_sampler.py +43 -4
  26. kumoai/experimental/rfm/local_pquery_driver.py +1 -1
  27. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  28. kumoai/experimental/rfm/rfm.py +28 -27
  29. kumoai/experimental/rfm/sagemaker.py +138 -0
  30. kumoai/spcs.py +1 -3
  31. kumoai/testing/decorators.py +1 -1
  32. {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/METADATA +12 -2
  33. {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/RECORD +36 -21
  34. kumoai/experimental/rfm/utils.py +0 -344
  35. {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/WHEEL +0 -0
  36. {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/licenses/LICENSE +0 -0
  37. {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,79 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
9
+ 'bool': Dtype.bool,
10
+ 'boolean': Dtype.bool,
11
+ 'int8': Dtype.int,
12
+ 'int16': Dtype.int,
13
+ 'int32': Dtype.int,
14
+ 'int64': Dtype.int,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ }
24
+
25
+
26
+ def infer_dtype(ser: pd.Series) -> Dtype:
27
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
28
+
29
+ Args:
30
+ ser: A :class:`pandas.Series` to analyze.
31
+
32
+ Returns:
33
+ The data type.
34
+ """
35
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
36
+ return Dtype.date
37
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
38
+ return Dtype.timedelta
39
+ if isinstance(ser.dtype, pd.CategoricalDtype):
40
+ return Dtype.string
41
+
42
+ if (pd.api.types.is_object_dtype(ser.dtype)
43
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
44
+ index = ser.iloc[:1000].first_valid_index()
45
+ if index is not None and pd.api.types.is_list_like(ser[index]):
46
+ pos = ser.index.get_loc(index)
47
+ assert isinstance(pos, int)
48
+ ser = ser.iloc[pos:pos + 1000].dropna()
49
+ arr = pa.array(ser.tolist())
50
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
+
52
+ if isinstance(ser.dtype, pd.ArrowDtype):
53
+ if pa.types.is_list(ser.dtype.pyarrow_dtype):
54
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
+ if pa.types.is_integer(elem_dtype):
56
+ return Dtype.intlist
57
+ if pa.types.is_floating(elem_dtype):
58
+ return Dtype.floatlist
59
+ if pa.types.is_decimal(elem_dtype):
60
+ return Dtype.floatlist
61
+ if pa.types.is_string(elem_dtype):
62
+ return Dtype.stringlist
63
+ if pa.types.is_null(elem_dtype):
64
+ return Dtype.floatlist
65
+
66
+ if isinstance(ser.dtype, np.dtype):
67
+ dtype_str = str(ser.dtype).lower()
68
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
69
+ dtype_str = ser.dtype.name.lower()
70
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
71
+ elif isinstance(ser.dtype, pa.DataType):
72
+ dtype_str = str(ser.dtype).lower()
73
+ else:
74
+ dtype_str = 'object'
75
+
76
+ if dtype_str not in PANDAS_TO_DTYPE:
77
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
78
+
79
+ return PANDAS_TO_DTYPE[dtype_str]
@@ -0,0 +1,126 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_primary_key(
9
+ table_name: str,
10
+ df: pd.DataFrame,
11
+ candidates: list[str],
12
+ ) -> Optional[str]:
13
+ r"""Auto-detect potential primary key column.
14
+
15
+ Args:
16
+ table_name: The table name.
17
+ df: The pandas DataFrame to analyze.
18
+ candidates: A list of potential candidates.
19
+
20
+ Returns:
21
+ The name of the detected primary key, or ``None`` if not found.
22
+ """
23
+ # A list of (potentially modified) table names that are eligible to match
24
+ # with a primary key, i.e.:
25
+ # - UserInfo -> User
26
+ # - snakecase <-> camelcase
27
+ # - camelcase <-> snakecase
28
+ # - plural <-> singular (users -> user, eligibilities -> eligibility)
29
+ # - verb -> noun (qualifying -> qualify)
30
+ _table_names = {table_name}
31
+ if table_name.lower().endswith('_info'):
32
+ _table_names.add(table_name[:-5])
33
+ elif table_name.lower().endswith('info'):
34
+ _table_names.add(table_name[:-4])
35
+
36
+ table_names = set()
37
+ for _table_name in _table_names:
38
+ table_names.add(_table_name.lower())
39
+ snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
40
+ snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
41
+ table_names.add(snakecase.lower())
42
+ camelcase = _table_name.replace('_', '')
43
+ table_names.add(camelcase.lower())
44
+ if _table_name.lower().endswith('s'):
45
+ table_names.add(_table_name.lower()[:-1])
46
+ table_names.add(snakecase.lower()[:-1])
47
+ table_names.add(camelcase.lower()[:-1])
48
+ else:
49
+ table_names.add(_table_name.lower() + 's')
50
+ table_names.add(snakecase.lower() + 's')
51
+ table_names.add(camelcase.lower() + 's')
52
+ if _table_name.lower().endswith('ies'):
53
+ table_names.add(_table_name.lower()[:-3] + 'y')
54
+ table_names.add(snakecase.lower()[:-3] + 'y')
55
+ table_names.add(camelcase.lower()[:-3] + 'y')
56
+ elif _table_name.lower().endswith('y'):
57
+ table_names.add(_table_name.lower()[:-1] + 'ies')
58
+ table_names.add(snakecase.lower()[:-1] + 'ies')
59
+ table_names.add(camelcase.lower()[:-1] + 'ies')
60
+ if _table_name.lower().endswith('ing'):
61
+ table_names.add(_table_name.lower()[:-3])
62
+ table_names.add(snakecase.lower()[:-3])
63
+ table_names.add(camelcase.lower()[:-3])
64
+
65
+ scores: list[tuple[str, int]] = []
66
+ for col_name in candidates:
67
+ col_name_lower = col_name.lower()
68
+
69
+ score = 0
70
+
71
+ if col_name_lower == 'id':
72
+ score += 4
73
+
74
+ for table_name_lower in table_names:
75
+
76
+ if col_name_lower == table_name_lower:
77
+ score += 4 # USER -> USER
78
+ break
79
+
80
+ for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
81
+ if not col_name_lower.endswith(suffix):
82
+ continue
83
+
84
+ if col_name_lower == f'{table_name_lower}_{suffix}':
85
+ score += 5 # USER -> USER_ID
86
+ break
87
+
88
+ if col_name_lower == f'{table_name_lower}{suffix}':
89
+ score += 5 # User -> UserId
90
+ break
91
+
92
+ if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
93
+ score += 2
94
+
95
+ if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
96
+ score += 2
97
+
98
+ # `rel-bench` hard-coding :(
99
+ if table_name == 'studies' and col_name == 'nct_id':
100
+ score += 1
101
+
102
+ ser = df[col_name].iloc[:1_000_000]
103
+ score += 3 * (ser.nunique() / len(ser))
104
+
105
+ scores.append((col_name, score))
106
+
107
+ scores = [x for x in scores if x[-1] >= 4]
108
+ scores.sort(key=lambda x: x[-1], reverse=True)
109
+
110
+ if len(scores) == 0:
111
+ return None
112
+
113
+ if len(scores) == 1:
114
+ return scores[0][0]
115
+
116
+ # In case of multiple candidates, only return one if its score is unique:
117
+ if scores[0][1] != scores[1][1]:
118
+ return scores[0][0]
119
+
120
+ max_score = max(scores, key=lambda x: x[1])
121
+ candidates = [col_name for col_name, score in scores if score == max_score]
122
+ warnings.warn(f"Found multiple potential primary keys in table "
123
+ f"'{table_name}': {candidates}. Please specify the primary "
124
+ f"key for this table manually.")
125
+
126
+ return None
@@ -0,0 +1,62 @@
1
+ import re
2
+ import warnings
3
+ from typing import Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def infer_time_column(
9
+ df: pd.DataFrame,
10
+ candidates: list[str],
11
+ ) -> Optional[str]:
12
+ r"""Auto-detect potential time column.
13
+
14
+ Args:
15
+ df: The pandas DataFrame to analyze.
16
+ candidates: A list of potential candidates.
17
+
18
+ Returns:
19
+ The name of the detected time column, or ``None`` if not found.
20
+ """
21
+ candidates = [ # Exclude all candidates with `*last*` in column names:
22
+ col_name for col_name in candidates
23
+ if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
24
+ ]
25
+
26
+ if len(candidates) == 0:
27
+ return None
28
+
29
+ if len(candidates) == 1:
30
+ return candidates[0]
31
+
32
+ # If there exists a dedicated `create*` column, use it as time column:
33
+ create_candidates = [
34
+ candidate for candidate in candidates
35
+ if candidate.lower().startswith('create')
36
+ ]
37
+ if len(create_candidates) == 1:
38
+ return create_candidates[0]
39
+ if len(create_candidates) > 1:
40
+ candidates = create_candidates
41
+
42
+ # Find the most optimal time column. Usually, it is the one pointing to
43
+ # the oldest timestamps:
44
+ with warnings.catch_warnings():
45
+ warnings.filterwarnings('ignore', message='Could not infer format')
46
+ min_timestamp_dict = {
47
+ key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
48
+ for key in candidates
49
+ }
50
+ min_timestamp_dict = {
51
+ key: value.min().tz_localize(None)
52
+ for key, value in min_timestamp_dict.items()
53
+ }
54
+ min_timestamp_dict = {
55
+ key: value
56
+ for key, value in min_timestamp_dict.items() if not pd.isna(value)
57
+ }
58
+
59
+ if len(min_timestamp_dict) == 0:
60
+ return None
61
+
62
+ return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
@@ -1,14 +1,54 @@
1
+ import re
1
2
  from typing import Dict, List, Optional, Tuple
2
3
 
3
4
  import numpy as np
4
5
  import pandas as pd
5
- from kumoapi.model_plan import RunMode
6
6
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.utils import normalize_text
10
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
11
+
12
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
13
+ MULTISPACE = re.compile(r"\s+")
14
+
15
+
16
+ def normalize_text(
17
+ ser: pd.Series,
18
+ max_words: Optional[int] = 50,
19
+ ) -> pd.Series:
20
+ r"""Normalizes text into a list of lower-case words.
21
+
22
+ Args:
23
+ ser: The :class:`pandas.Series` to normalize.
24
+ max_words: The maximum number of words to return.
25
+ This will auto-shrink any large text column to avoid blowing up
26
+ context size.
27
+ """
28
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
29
+ return ser
30
+
31
+ def normalize_fn(line: str) -> list[str]:
32
+ line = PUNCTUATION.sub(" ", line)
33
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
34
+ line = MULTISPACE.sub(" ", line)
35
+ words = line.split()
36
+ if max_words is not None:
37
+ words = words[:max_words]
38
+ return words
39
+
40
+ ser = ser.fillna('').astype(str)
41
+
42
+ if max_words is not None:
43
+ # We estimate the number of words as 5 characters + 1 space in an
44
+ # English text on average. We need this pre-filter here, as word
45
+ # splitting on a giant text can be very expensive:
46
+ ser = ser.str[:6 * max_words]
47
+
48
+ ser = ser.str.lower()
49
+ ser = ser.map(normalize_fn)
50
+
51
+ return ser
12
52
 
13
53
 
14
54
  class LocalGraphSampler:
@@ -33,7 +73,6 @@ class LocalGraphSampler:
33
73
  entity_table_names: Tuple[str, ...],
34
74
  node: np.ndarray,
35
75
  time: np.ndarray,
36
- run_mode: RunMode,
37
76
  num_neighbors: List[int],
38
77
  exclude_cols_dict: Dict[str, List[str]],
39
78
  ) -> Subgraph:
@@ -17,7 +17,7 @@ from kumoapi.task import TaskType
17
17
  from kumoapi.typing import AggregationType, DateOffset, Stype
18
18
 
19
19
  import kumoai.kumolib as kumolib
20
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
20
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
21
21
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
22
22
 
23
23
  _coverage_warned = False
@@ -134,7 +134,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
134
134
  outs: List[pd.Series] = []
135
135
  masks: List[np.ndarray] = []
136
136
  for _ in range(num_forecasts):
137
- anchor_target_time = anchor_time[target_batch]
137
+ anchor_target_time = anchor_time.iloc[target_batch]
138
138
  anchor_target_time = anchor_target_time.reset_index(drop=True)
139
139
 
140
140
  time_filter_mask = (target_time <= anchor_target_time +
@@ -30,11 +30,11 @@ from kumoapi.rfm import (
30
30
  )
31
31
  from kumoapi.task import TaskType
32
32
 
33
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
34
34
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
35
+ from kumoai.experimental.rfm import Graph
36
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
36
37
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
38
  from kumoai.experimental.rfm.local_pquery_driver import (
39
39
  LocalPQueryDriver,
40
40
  date_offset_to_seconds,
@@ -123,17 +123,17 @@ class KumoRFM:
123
123
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
124
  relational dataset without training.
125
125
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
126
+ model from a :class:`Graph` object.
127
127
 
128
128
  .. code-block:: python
129
129
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
130
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
131
 
132
132
  df_users = pd.DataFrame(...)
133
133
  df_items = pd.DataFrame(...)
134
134
  df_orders = pd.DataFrame(...)
135
135
 
136
- graph = LocalGraph.from_data({
136
+ graph = Graph.from_data({
137
137
  'users': df_users,
138
138
  'items': df_items,
139
139
  'orders': df_orders,
@@ -141,40 +141,41 @@ class KumoRFM:
141
141
 
142
142
  rfm = KumoRFM(graph)
143
143
 
144
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
145
- "FOR users.user_id=0")
146
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
147
147
 
148
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
149
149
  # 1 0.85
150
150
 
151
151
  Args:
152
152
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
153
  verbose: Whether to print verbose output.
163
154
  """
164
155
  def __init__(
165
156
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
157
+ graph: Graph,
168
158
  verbose: Union[bool, ProgressLogger] = True,
169
159
  ) -> None:
170
160
  graph = graph.validate()
171
161
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
162
+ self._graph_store = LocalGraphStore(graph, verbose)
173
163
  self._graph_sampler = LocalGraphSampler(self._graph_store)
174
164
 
165
+ self._client: Optional[RFMAPI] = None
166
+
175
167
  self._batch_size: Optional[int | Literal['max']] = None
176
168
  self.num_retries: int = 0
177
169
 
170
+ @property
171
+ def _api_client(self) -> RFMAPI:
172
+ if self._client is not None:
173
+ return self._client
174
+
175
+ from kumoai.experimental.rfm import global_state
176
+ self._client = RFMAPI(global_state.client)
177
+ return self._client
178
+
178
179
  def __repr__(self) -> str:
179
180
  return f'{self.__class__.__name__}()'
180
181
 
@@ -420,14 +421,14 @@ class KumoRFM:
420
421
  for attempt in range(self.num_retries + 1):
421
422
  try:
422
423
  if explain_config is not None:
423
- resp = global_state.client.rfm_api.explain(
424
+ resp = self._api_client.explain(
424
425
  request=_bytes,
425
426
  skip_summary=explain_config.skip_summary,
426
427
  )
427
428
  summary = resp.summary
428
429
  details = resp.details
429
430
  else:
430
- resp = global_state.client.rfm_api.predict(_bytes)
431
+ resp = self._api_client.predict(_bytes)
431
432
  df = pd.DataFrame(**resp.prediction)
432
433
 
433
434
  # Cast 'ENTITY' to correct data type:
@@ -630,10 +631,10 @@ class KumoRFM:
630
631
 
631
632
  if len(request_bytes) > _MAX_SIZE:
632
633
  stats_msg = Context.get_memory_stats(request_msg.context)
633
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
634
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
634
635
 
635
636
  try:
636
- resp = global_state.client.rfm_api.evaluate(request_bytes)
637
+ resp = self._api_client.evaluate(request_bytes)
637
638
  except HTTPException as e:
638
639
  try:
639
640
  msg = json.loads(e.detail)['detail']
@@ -731,7 +732,8 @@ class KumoRFM:
731
732
  graph_definition=self._graph_def,
732
733
  )
733
734
 
734
- resp = global_state.client.rfm_api.parse_query(request)
735
+ resp = self._api_client.parse_query(request)
736
+
735
737
  # TODO Expose validation warnings.
736
738
 
737
739
  if len(resp.validation_response.warnings) > 0:
@@ -1035,7 +1037,6 @@ class KumoRFM:
1035
1037
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1036
1038
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1037
1039
  ]),
1038
- run_mode=run_mode,
1039
1040
  num_neighbors=num_neighbors,
1040
1041
  exclude_cols_dict=exclude_cols_dict,
1041
1042
  )
@@ -0,0 +1,138 @@
1
+ import base64
2
+ import json
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ import requests
6
+
7
+ from kumoai.client import KumoClient
8
+ from kumoai.client.endpoints import Endpoint, HTTPMethod
9
+ from kumoai.exceptions import HTTPException
10
+
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
21
+
22
+ class SageMakerResponseAdapter(requests.Response):
23
+ def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
24
+ super().__init__()
25
+ # Read the body bytes
26
+ self._content = sm_response['Body'].read()
27
+ self.status_code = 200
28
+ self.headers['Content-Type'] = sm_response.get('ContentType',
29
+ 'application/json')
30
+ # Optionally, you can store original sm_response for debugging
31
+ self.sm_response = sm_response
32
+
33
+ @property
34
+ def text(self) -> str:
35
+ assert isinstance(self._content, bytes)
36
+ return self._content.decode('utf-8')
37
+
38
+ def json(self, **kwargs) -> dict[str, Any]: # type: ignore
39
+ return json.loads(self.text, **kwargs)
40
+
41
+
42
+ class KumoClient_SageMakerAdapter(KumoClient):
43
+ def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
45
+ self._client: SageMakerRuntimeClient = boto3.client(
46
+ service_name="sagemaker-runtime", region_name=region)
47
+ self._endpoint_name = endpoint_name
48
+
49
+ # Recording buffers.
50
+ self._recording_active = False
51
+ self._recorded_reqs: List[Dict[str, Any]] = []
52
+ self._recorded_resps: List[Dict[str, Any]] = []
53
+
54
+ def authenticate(self) -> None:
55
+ # TODO(siyang): call /ping to verify?
56
+ pass
57
+
58
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
59
+ assert endpoint.method == HTTPMethod.POST
60
+ if 'json' in kwargs:
61
+ payload = json.dumps(kwargs.pop('json'))
62
+ elif 'data' in kwargs:
63
+ raw_payload = kwargs.pop('data')
64
+ assert isinstance(raw_payload, bytes)
65
+ payload = base64.b64encode(raw_payload).decode()
66
+ else:
67
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
68
+
69
+ request = {
70
+ 'method': endpoint.get_path().rsplit('/')[-1],
71
+ 'payload': payload,
72
+ }
73
+ response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
74
+ EndpointName=self._endpoint_name,
75
+ ContentType="application/json",
76
+ Body=json.dumps(request),
77
+ )
78
+
79
+ adapted_response = SageMakerResponseAdapter(response)
80
+
81
+ # If validation is active, store input/output
82
+ if self._recording_active:
83
+ self._recorded_reqs.append(request)
84
+ self._recorded_resps.append(adapted_response.json())
85
+
86
+ return adapted_response
87
+
88
+ def start_recording(self) -> None:
89
+ """Start recording requests/responses to/from sagemaker endpoint."""
90
+ assert not self._recording_active
91
+ self._recording_active = True
92
+ self._recorded_reqs.clear()
93
+ self._recorded_resps.clear()
94
+
95
+ def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
96
+ """Stop recording and return recorded requests/responses."""
97
+ assert self._recording_active
98
+ self._recording_active = False
99
+ recorded = list(zip(self._recorded_reqs, self._recorded_resps))
100
+ self._recorded_reqs.clear()
101
+ self._recorded_resps.clear()
102
+ return recorded
103
+
104
+
105
+ class KumoClient_SageMakerProxy_Local(KumoClient):
106
+ def __init__(self, url: str):
107
+ self._client = KumoClient(url, api_key=None)
108
+ self._client._api_url = self._client._url
109
+ self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
110
+
111
+ def authenticate(self) -> None:
112
+ try:
113
+ self._client._session.get(
114
+ self._url + '/ping',
115
+ verify=self._verify_ssl).raise_for_status()
116
+ except Exception:
117
+ raise ValueError(
118
+ "Client authentication failed. Please check if you "
119
+ "have a valid API key/credentials.")
120
+
121
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
122
+ assert endpoint.method == HTTPMethod.POST
123
+ if 'json' in kwargs:
124
+ payload = json.dumps(kwargs.pop('json'))
125
+ elif 'data' in kwargs:
126
+ raw_payload = kwargs.pop('data')
127
+ assert isinstance(raw_payload, bytes)
128
+ payload = base64.b64encode(raw_payload).decode()
129
+ else:
130
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
131
+ return self._client._request(
132
+ self._endpoint,
133
+ json={
134
+ 'method': endpoint.get_path().rsplit('/')[-1],
135
+ 'payload': payload,
136
+ },
137
+ **kwargs,
138
+ )
kumoai/spcs.py CHANGED
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
54
54
  api_key=global_state._api_key,
55
55
  spcs_token=spcs_token,
56
56
  )
57
- if not client.authenticate():
58
- raise ValueError("Client authentication failed. Please check if you "
59
- "have a valid API key.")
57
+ client.authenticate()
60
58
 
61
59
  # Update state:
62
60
  global_state.set_spcs_token(spcs_token)
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
25
25
  def has_package(package: str) -> bool:
26
26
  r"""Returns ``True`` in case ``package`` is installed."""
27
27
  req = Requirement(package)
28
- if importlib.util.find_spec(req.name) is None:
28
+ if importlib.util.find_spec(req.name) is None: # type: ignore
29
29
  return False
30
30
 
31
31
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.12.0.dev202511111731
3
+ Version: 2.13.0.dev202512091732
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.45.0
26
+ Requires-Dist: kumo-api==0.48.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21
@@ -38,6 +38,16 @@ Provides-Extra: test
38
38
  Requires-Dist: pytest; extra == "test"
39
39
  Requires-Dist: pytest-mock; extra == "test"
40
40
  Requires-Dist: requests-mock; extra == "test"
41
+ Provides-Extra: sqlite
42
+ Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
43
+ Provides-Extra: snowflake
44
+ Requires-Dist: snowflake-connector-python; extra == "snowflake"
45
+ Requires-Dist: pyyaml; extra == "snowflake"
46
+ Provides-Extra: sagemaker
47
+ Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
48
+ Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
49
+ Provides-Extra: test-sagemaker
50
+ Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
41
51
  Dynamic: license-file
42
52
  Dynamic: requires-dist
43
53