kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,138 @@
1
+ import base64
2
+ import json
3
+ from typing import Any
4
+
5
+ import requests
6
+
7
+ from kumoai.client import KumoClient
8
+ from kumoai.client.endpoints import Endpoint, HTTPMethod
9
+ from kumoai.exceptions import HTTPException
10
+
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
21
+
22
+ class SageMakerResponseAdapter(requests.Response):
23
+ def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
24
+ super().__init__()
25
+ # Read the body bytes
26
+ self._content = sm_response['Body'].read()
27
+ self.status_code = 200
28
+ self.headers['Content-Type'] = sm_response.get('ContentType',
29
+ 'application/json')
30
+ # Optionally, you can store original sm_response for debugging
31
+ self.sm_response = sm_response
32
+
33
+ @property
34
+ def text(self) -> str:
35
+ assert isinstance(self._content, bytes)
36
+ return self._content.decode('utf-8')
37
+
38
+ def json(self, **kwargs) -> dict[str, Any]: # type: ignore
39
+ return json.loads(self.text, **kwargs)
40
+
41
+
42
+ class KumoClient_SageMakerAdapter(KumoClient):
43
+ def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
45
+ self._client: SageMakerRuntimeClient = boto3.client(
46
+ service_name="sagemaker-runtime", region_name=region)
47
+ self._endpoint_name = endpoint_name
48
+
49
+ # Recording buffers.
50
+ self._recording_active = False
51
+ self._recorded_reqs: list[dict[str, Any]] = []
52
+ self._recorded_resps: list[dict[str, Any]] = []
53
+
54
+ def authenticate(self) -> None:
55
+ # TODO(siyang): call /ping to verify?
56
+ pass
57
+
58
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
59
+ assert endpoint.method == HTTPMethod.POST
60
+ if 'json' in kwargs:
61
+ payload = json.dumps(kwargs.pop('json'))
62
+ elif 'data' in kwargs:
63
+ raw_payload = kwargs.pop('data')
64
+ assert isinstance(raw_payload, bytes)
65
+ payload = base64.b64encode(raw_payload).decode()
66
+ else:
67
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
68
+
69
+ request = {
70
+ 'method': endpoint.get_path().rsplit('/')[-1],
71
+ 'payload': payload,
72
+ }
73
+ response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
74
+ EndpointName=self._endpoint_name,
75
+ ContentType="application/json",
76
+ Body=json.dumps(request),
77
+ )
78
+
79
+ adapted_response = SageMakerResponseAdapter(response)
80
+
81
+ # If validation is active, store input/output
82
+ if self._recording_active:
83
+ self._recorded_reqs.append(request)
84
+ self._recorded_resps.append(adapted_response.json())
85
+
86
+ return adapted_response
87
+
88
+ def start_recording(self) -> None:
89
+ """Start recording requests/responses to/from sagemaker endpoint."""
90
+ assert not self._recording_active
91
+ self._recording_active = True
92
+ self._recorded_reqs.clear()
93
+ self._recorded_resps.clear()
94
+
95
+ def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
96
+ """Stop recording and return recorded requests/responses."""
97
+ assert self._recording_active
98
+ self._recording_active = False
99
+ recorded = list(zip(self._recorded_reqs, self._recorded_resps))
100
+ self._recorded_reqs.clear()
101
+ self._recorded_resps.clear()
102
+ return recorded
103
+
104
+
105
+ class KumoClient_SageMakerProxy_Local(KumoClient):
106
+ def __init__(self, url: str):
107
+ self._client = KumoClient(url, api_key=None)
108
+ self._client._api_url = self._client._url
109
+ self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
110
+
111
+ def authenticate(self) -> None:
112
+ try:
113
+ self._client._session.get(
114
+ self._url + '/ping',
115
+ verify=self._verify_ssl).raise_for_status()
116
+ except Exception:
117
+ raise ValueError(
118
+ "Client authentication failed. Please check if you "
119
+ "have a valid API key/credentials.")
120
+
121
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
122
+ assert endpoint.method == HTTPMethod.POST
123
+ if 'json' in kwargs:
124
+ payload = json.dumps(kwargs.pop('json'))
125
+ elif 'data' in kwargs:
126
+ raw_payload = kwargs.pop('data')
127
+ assert isinstance(raw_payload, bytes)
128
+ payload = base64.b64encode(raw_payload).decode()
129
+ else:
130
+ raise HTTPException(400, 'Unable to send data to KumoRFM.')
131
+ return self._client._request(
132
+ self._endpoint,
133
+ json={
134
+ 'method': endpoint.get_path().rsplit('/')[-1],
135
+ 'payload': payload,
136
+ },
137
+ **kwargs,
138
+ )
@@ -0,0 +1,231 @@
1
+ from collections.abc import Sequence
2
+
3
+ import pandas as pd
4
+ from kumoapi.task import TaskType
5
+ from kumoapi.typing import Dtype, Stype
6
+
7
+ from kumoai.experimental.rfm.base import Column
8
+ from kumoai.experimental.rfm.infer import contains_timestamp, infer_dtype
9
+
10
+
11
+ class TaskTable:
12
+ r"""A :class:`TaskTable` fully specifies the task, *i.e.* its context and
13
+ prediction examples with entity IDs, targets and timestamps.
14
+
15
+ Args:
16
+ task_type: The task type.
17
+ context_df: The data frame holding context examples.
18
+ pred_df: The data frame holding prediction examples.
19
+ entity_table_name: The entity table to predict for. For link prediction
20
+ tasks, needs to hold both entity and target table names.
21
+ entity_column: The name of the entity column.
22
+ target_column: The name of the target column.
23
+ time_column: The name of the time column, if it exists.
24
+ """
25
+ def __init__(
26
+ self,
27
+ task_type: TaskType,
28
+ context_df: pd.DataFrame,
29
+ pred_df: pd.DataFrame,
30
+ entity_table_name: str | Sequence[str],
31
+ entity_column: str,
32
+ target_column: str,
33
+ time_column: str | None = None,
34
+ ) -> None:
35
+
36
+ task_type = TaskType(task_type)
37
+ if task_type not in { # Currently supported task types:
38
+ TaskType.BINARY_CLASSIFICATION,
39
+ TaskType.MULTICLASS_CLASSIFICATION,
40
+ TaskType.REGRESSION,
41
+ TaskType.TEMPORAL_LINK_PREDICTION,
42
+ }:
43
+ raise ValueError # TODO
44
+ self._task_type = task_type
45
+
46
+ # TODO Check dfs (unify from local table)
47
+ self._context_df = context_df.copy(deep=False)
48
+ self._pred_df = pred_df.copy(deep=False)
49
+
50
+ self._dtype_dict: dict[str, Dtype] = {
51
+ column_name: infer_dtype(context_df[column_name])
52
+ for column_name in context_df.columns
53
+ }
54
+
55
+ self._entity_table_names: tuple[str] | tuple[str, str]
56
+ if isinstance(entity_table_name, str):
57
+ self._entity_table_names = (entity_table_name, )
58
+ elif len(entity_table_name) == 1:
59
+ self._entity_table_names = (entity_table_name[0], )
60
+ elif len(entity_table_name) == 2:
61
+ self._entity_table_names = (
62
+ entity_table_name[0],
63
+ entity_table_name[1],
64
+ )
65
+ else:
66
+ raise ValueError # TODO
67
+
68
+ self._entity_column: str = ''
69
+ self._target_column: str = ''
70
+ self._time_column: str | None = None
71
+
72
+ self.entity_column = entity_column
73
+ self.target_column = target_column
74
+ if time_column is not None:
75
+ self.time_column = time_column
76
+
77
+ @property
78
+ def task_type(self) -> TaskType:
79
+ return self._task_type
80
+
81
+ # Entity column ###########################################################
82
+
83
+ @property
84
+ def entity_table_name(self) -> str:
85
+ return self._entity_table_names[0]
86
+
87
+ @property
88
+ def entity_table_names(self) -> tuple[str] | tuple[str, str]:
89
+ return self._entity_table_names
90
+
91
+ @property
92
+ def entity_column(self) -> Column:
93
+ return Column(
94
+ name=self._entity_column,
95
+ expr=None,
96
+ dtype=self._dtype_dict[self._entity_column],
97
+ stype=Stype.ID,
98
+ )
99
+
100
+ @entity_column.setter
101
+ def entity_column(self, name: str) -> None:
102
+ if name in {self._target_column, self._time_column}:
103
+ raise ValueError # TODO
104
+ if name not in self._context_df:
105
+ raise ValueError # TODO
106
+ if name not in self._pred_df:
107
+ raise ValueError # TODO
108
+ if not Stype.ID.supports_dtype(self._dtype_dict[name]):
109
+ raise ValueError # TODO
110
+
111
+ self._entity_column = name
112
+
113
+ # Target column ###########################################################
114
+
115
+ @property
116
+ def _target_stype(self) -> Stype:
117
+ if self.task_type in {
118
+ TaskType.BINARY_CLASSIFICATION,
119
+ TaskType.MULTICLASS_CLASSIFICATION,
120
+ }:
121
+ return Stype.categorical
122
+ if self.task_type in {TaskType.REGRESSION}:
123
+ return Stype.numerical
124
+ if self.task_type.is_link_pred:
125
+ return Stype.multicategorical
126
+ raise ValueError
127
+
128
+ @property
129
+ def target_column(self) -> Column:
130
+ return Column(
131
+ name=self._target_column,
132
+ expr=None,
133
+ dtype=self._dtype_dict[self._target_column],
134
+ stype=self._target_stype,
135
+ )
136
+
137
+ @target_column.setter
138
+ def target_column(self, name: str) -> None:
139
+ if name in {self._entity_column, self._time_column}:
140
+ raise ValueError # TODO
141
+ if name not in self._context_df:
142
+ raise ValueError # TODO
143
+ if not self._target_stype.supports_dtype(self._dtype_dict[name]):
144
+ raise ValueError # TODO
145
+
146
+ self._target_column = name
147
+
148
+ # Time column #############################################################
149
+
150
+ def has_time_column(self) -> bool:
151
+ r"""Returns ``True`` if this task has a time column; ``False``
152
+ otherwise.
153
+ """
154
+ return self._time_column is not None
155
+
156
+ @property
157
+ def time_column(self) -> Column | None:
158
+ r"""The time column of this task.
159
+
160
+ The getter returns the time column of this task, or ``None`` if no
161
+ such time column is present.
162
+
163
+ The setter sets a column as a time column for this task, and raises a
164
+ :class:`ValueError` if the time column has a non-timestamp compatible
165
+ data type or if the column name does not match a column in the data
166
+ frame.
167
+ """
168
+ if self._time_column is None:
169
+ return None
170
+ return Column(
171
+ name=self._time_column,
172
+ expr=None,
173
+ dtype=self._dtype_dict[self._time_column],
174
+ stype=Stype.timestamp,
175
+ )
176
+
177
+ @time_column.setter
178
+ def time_column(self, name: str | None) -> None:
179
+ if name is None:
180
+ self._time_column = None
181
+ return
182
+
183
+ if name in {self._entity_column, self._target_column}:
184
+ raise ValueError # TODO
185
+ if name not in self._context_df:
186
+ raise ValueError # TODO
187
+ if name not in self._pred_df:
188
+ raise ValueError # TODO
189
+ if not contains_timestamp(
190
+ ser=self._context_df[name],
191
+ column_name=name,
192
+ dtype=self._dtype_dict[name],
193
+ ):
194
+ raise ValueError # TODO
195
+
196
+ self._time_column = name
197
+
198
+ # Metadata ################################################################
199
+
200
+ @property
201
+ def metadata(self) -> pd.DataFrame:
202
+ raise NotImplementedError
203
+
204
+ def print_metadata(self) -> None:
205
+ raise NotImplementedError
206
+
207
+ # Python builtins #########################################################
208
+
209
+ def __hash__(self) -> int:
210
+ return hash((
211
+ self.task_type,
212
+ self.entity_table_names,
213
+ self._entity_column,
214
+ self._target_column,
215
+ self._time_column,
216
+ ))
217
+
218
+ def __repr__(self) -> str:
219
+ if self.task_type.is_link_pred:
220
+ entity_table_repr = f'entity_table_names={self.entity_table_names}'
221
+ else:
222
+ entity_table_repr = f'entity_table_name={self.entity_table_name}'
223
+ return (f'{self.__class__.__name__}(\n'
224
+ f' task_type={self.task_type},\n'
225
+ f' num_context_examples={len(self._context_df)},\n'
226
+ f' num_prediction_examples={len(self._pred_df)},\n'
227
+ f' {entity_table_repr},\n'
228
+ f' entity_column={self._entity_column},\n'
229
+ f' target_column={self._target_column},\n'
230
+ f' time_column={self._time_column},\n'
231
+ f')')
kumoai/formatting.py ADDED
@@ -0,0 +1,30 @@
1
+ from kumoapi.jobs import ErrorDetails
2
+
3
+
4
+ def pretty_print_error_details(error_details: ErrorDetails) -> str:
5
+ """Pretty prints the ErrorDetails combining all the individual items.
6
+ If there are CTAs, they are also displayed after creating
7
+ corresponding hyperlinks.
8
+
9
+ Arguments:
10
+ error_details (ErrorDetails): Standard ErrorDetails response from
11
+ get_errors APIs.
12
+ """
13
+ out = ""
14
+ ctr = None
15
+ if len(error_details.items) != 1:
16
+ out += "Encountered multiple errors:\n"
17
+ ctr = 1
18
+ for error_detail in error_details.items:
19
+ if ctr is not None:
20
+ out += f'{ctr}.'
21
+ ctr += 1
22
+ if error_detail.title is not None:
23
+ out += f'{error_detail.title}: '
24
+ out += error_detail.description
25
+ if error_detail.cta is not None:
26
+ out += 'Follow the link for potential resolution:'
27
+ f' {error_detail.cta.url}'
28
+ out += '\n'
29
+
30
+ return out
kumoai/futures.py ADDED
@@ -0,0 +1,99 @@
1
+ import asyncio
2
+ import concurrent
3
+ import logging
4
+ import threading
5
+ from abc import ABC, abstractmethod
6
+ from asyncio.events import AbstractEventLoop
7
+ from typing import Any, Awaitable, Coroutine, Generic, TypeVar
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ CoroFuncType = Awaitable[Any]
12
+
13
+ # Kumo global event loop (our implementation of green threads for pollers and
14
+ # other interactions with the Kumo backend that require long-running tasks).
15
+ # Since the caller may have their own event loop that we do not want to
16
+ # mess with, _do not_ ever call `set_event_loop` here!! Instead, be extra
17
+ # cautious to pass this loop everywhere.
18
+ _KUMO_EVENT_LOOP: AbstractEventLoop = asyncio.new_event_loop()
19
+
20
+
21
+ def initialize_event_loop() -> None:
22
+ def _run_background_loop(loop: AbstractEventLoop) -> None:
23
+ asyncio.set_event_loop(loop)
24
+ loop.run_forever()
25
+
26
+ t = threading.Thread(target=_run_background_loop,
27
+ args=(_KUMO_EVENT_LOOP, ), daemon=True)
28
+ t.start()
29
+
30
+
31
+ def create_future(coro: Coroutine[Any, Any, Any]) -> concurrent.futures.Future:
32
+ r"""Creates a future to execute in the Kumo event loop."""
33
+ # NOTE this function creates a future, chains it to the output of the
34
+ # execution of `coro` in the Kumo event loop, and handles exceptions
35
+ # before scheduling to run in the loop:
36
+ return asyncio.run_coroutine_threadsafe(coro, _KUMO_EVENT_LOOP)
37
+
38
+
39
+ T = TypeVar("T")
40
+
41
+
42
+ class KumoFuture(ABC, Generic[T]):
43
+ r"""Abstract base class for a Kumo future object."""
44
+
45
+ # We cannot use Python future implementations (`asyncio.Future` or
46
+ # `concurrent.futures.Future`) as they are native to the Python
47
+ # implementation of asyncio and threading, and thus not easily extensible.
48
+ # Python additionally recommends not exposing low-level Future objects in
49
+ # user facing APIs.
50
+ @abstractmethod
51
+ def result(self) -> T:
52
+ r"""Returns the resolved state of the future.
53
+
54
+ Raises:
55
+ Exception:
56
+ If the future is complete but in a failed state due to an
57
+ exception being raised, this method will raise the same
58
+ exception.
59
+ """
60
+ raise NotImplementedError
61
+
62
+ @abstractmethod
63
+ def future(self) -> 'concurrent.futures.Future[T]':
64
+ r"""Returns the :obj:`concurrent.futures.Future` object wrapped by
65
+ this future. It is not recommended to access this object directly.
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def done(self) -> bool:
70
+ r"""Returns :obj:`True` if this future has been resolved with
71
+ ``result()``, or :obj:`False` if this future is still
72
+ in-progress. Note that this method will return :obj:`True` if the
73
+ future is complete, but in a failed state, and that this method will
74
+ return :obj:`False` if the job is complete, but the future has not
75
+ been awaited.
76
+ """
77
+ return self.future().done()
78
+
79
+
80
+ class KumoProgressFuture(KumoFuture[T]):
81
+ @abstractmethod
82
+ def _attach_internal(self, interval_s: float = 4.0) -> T:
83
+ raise NotImplementedError
84
+
85
+ def attach(self, interval_s: float = 4.0) -> T:
86
+ r"""Allows a user to attach to a running job and view its progress.
87
+
88
+ Args:
89
+ interval_s (float): Time interval (seconds) between polls, minimum
90
+ value allowed is 4 seconds.
91
+ """
92
+ try:
93
+ return self._attach_internal(interval_s=interval_s)
94
+ except Exception:
95
+ logger.warning(
96
+ "Detailed job tracking has become temporarily unavailable. "
97
+ "The job is continuing to proceed on the Kumo server, "
98
+ "and this call will complete when the job has finished.")
99
+ return self.result()
@@ -0,0 +1,12 @@
1
+ from .column import Column, TimestampUnit
2
+ from .table import Table
3
+ from .graph import Graph, Edge, GraphHealthStats
4
+
5
+ __all__ = [
6
+ 'TimestampUnit',
7
+ 'Column',
8
+ 'Table',
9
+ 'Graph',
10
+ 'Edge',
11
+ 'GraphHealthStats',
12
+ ]
kumoai/graph/column.py ADDED
@@ -0,0 +1,106 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional, Union
3
+
4
+ from kumoapi.table import TimestampUnit
5
+ from kumoapi.typing import Dtype, Stype
6
+
7
+ from kumoai.mixin import CastMixin
8
+
9
+
10
+ @dataclass(init=False)
11
+ class Column(CastMixin):
12
+ r"""A column represents metadata information for a column in a Kumo
13
+ :class:`~kumoai.graph.Table`. Columns can be created independent of
14
+ a table, or can be fetched from a table with the
15
+ :meth:`~kumoai.graph.Table.column` method.
16
+
17
+ .. code-block:: python
18
+
19
+ import kumoai
20
+
21
+ # Fetch a column from a `kumoai.Table`:
22
+ table = kumoai.Table(...)
23
+
24
+ column = table.column("col_name")
25
+ column = table["col_name"] # equivalent to the above.
26
+
27
+ # Edit a column's data type:
28
+ print("Existing dtype: ", column.dtype)
29
+ column.dtype = "int"
30
+
31
+ # Edit a column's semantic type:
32
+ print("Existing stype: ", column.stype)
33
+ column.stype = "ID"
34
+
35
+ Args:
36
+ name: The name of this column.
37
+ stype: The semantic type of this column. Semantic types can be
38
+ specified as strings: the list of possible semantic types
39
+ is located at :class:`~kumoai.Stype`.
40
+ dtype: The data type of this column. Data types can be specified
41
+ as strings: the list of possible data types is located at
42
+ :class:`~kumoai.Dtype`.
43
+ timestamp_format: If this column represents a timestamp, the format
44
+ that the timestamp should be parsed in. The format can either be
45
+ a :class:`~kumoapi.table.TimestampUnit` for integer columns or a
46
+ string with a format identifier described
47
+ `here <https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html>`__
48
+ for a SaaS Kumo deployment and
49
+ `here <https://docs.snowflake.com/en/sql-reference/date-time-input-output#about-the-elements-used-in-input-and-output-formats>`__
50
+ for a Snowpark Container Services Kumo deployment. If left empty,
51
+ will be intelligently inferred by Kumo.
52
+ """ # noqa: E501
53
+ name: str
54
+ stype: Optional[Stype] = None
55
+ dtype: Optional[Dtype] = None
56
+ timestamp_format: Optional[Union[str, TimestampUnit]] = None
57
+
58
+ def __init__(
59
+ self,
60
+ name: str,
61
+ stype: Optional[Union[Stype, str]] = None,
62
+ dtype: Optional[Union[Dtype, str]] = None,
63
+ timestamp_format: Optional[Union[str, TimestampUnit]] = None,
64
+ ) -> None:
65
+ self.name = name
66
+ self.stype = Stype(stype) if stype is not None else None
67
+ self.dtype = Dtype(dtype) if dtype is not None else None
68
+ try:
69
+ self.timestamp_format = TimestampUnit(timestamp_format)
70
+ except ValueError:
71
+ self.timestamp_format = timestamp_format
72
+
73
+ def __hash__(self) -> int:
74
+ return hash((self.name, self.stype, self.dtype, self.timestamp_format))
75
+
76
+ def __setattr__(self, key: Any, value: Any) -> None:
77
+ if key == 'name' and value != getattr(self, key, value):
78
+ raise AttributeError("Attribute 'name' is read-only")
79
+ elif key == 'stype' and isinstance(value, str):
80
+ value = Stype(value)
81
+ elif key == 'dtype' and isinstance(value, str):
82
+ value = Dtype(value)
83
+ elif key == 'timestamp_format' and isinstance(value, str):
84
+ try:
85
+ value = TimestampUnit(value)
86
+ except ValueError:
87
+ pass
88
+ super().__setattr__(key, value)
89
+
90
+ def update(self, obj: 'Column', override: bool = True) -> 'Column':
91
+ for key in self.__dict__:
92
+ if key[0] == '_': # Skip private attributes:
93
+ continue
94
+ value = getattr(obj, key, None)
95
+ if value is not None:
96
+ if override or getattr(self, key, None) is None:
97
+ setattr(self, key, value)
98
+ return self
99
+
100
+ def __repr__(self) -> str:
101
+ out = (f"Column(name=\"{self.name}\", stype=\"{self.stype}\", "
102
+ f"dtype=\"{self.dtype}\"")
103
+ if self.timestamp_format is not None:
104
+ out += f", timestamp_format=\"{self.timestamp_format}\""
105
+ out += ")"
106
+ return out