kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: list[dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: list[dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from kumoapi.task import TaskType
|
|
5
|
+
from kumoapi.typing import Dtype, Stype
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.base import Column
|
|
8
|
+
from kumoai.experimental.rfm.infer import contains_timestamp, infer_dtype
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TaskTable:
|
|
12
|
+
r"""A :class:`TaskTable` fully specifies the task, *i.e.* its context and
|
|
13
|
+
prediction examples with entity IDs, targets and timestamps.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
task_type: The task type.
|
|
17
|
+
context_df: The data frame holding context examples.
|
|
18
|
+
pred_df: The data frame holding prediction examples.
|
|
19
|
+
entity_table_name: The entity table to predict for. For link prediction
|
|
20
|
+
tasks, needs to hold both entity and target table names.
|
|
21
|
+
entity_column: The name of the entity column.
|
|
22
|
+
target_column: The name of the target column.
|
|
23
|
+
time_column: The name of the time column, if it exists.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
task_type: TaskType,
|
|
28
|
+
context_df: pd.DataFrame,
|
|
29
|
+
pred_df: pd.DataFrame,
|
|
30
|
+
entity_table_name: str | Sequence[str],
|
|
31
|
+
entity_column: str,
|
|
32
|
+
target_column: str,
|
|
33
|
+
time_column: str | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
|
|
36
|
+
task_type = TaskType(task_type)
|
|
37
|
+
if task_type not in { # Currently supported task types:
|
|
38
|
+
TaskType.BINARY_CLASSIFICATION,
|
|
39
|
+
TaskType.MULTICLASS_CLASSIFICATION,
|
|
40
|
+
TaskType.REGRESSION,
|
|
41
|
+
TaskType.TEMPORAL_LINK_PREDICTION,
|
|
42
|
+
}:
|
|
43
|
+
raise ValueError # TODO
|
|
44
|
+
self._task_type = task_type
|
|
45
|
+
|
|
46
|
+
# TODO Check dfs (unify from local table)
|
|
47
|
+
self._context_df = context_df.copy(deep=False)
|
|
48
|
+
self._pred_df = pred_df.copy(deep=False)
|
|
49
|
+
|
|
50
|
+
self._dtype_dict: dict[str, Dtype] = {
|
|
51
|
+
column_name: infer_dtype(context_df[column_name])
|
|
52
|
+
for column_name in context_df.columns
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
self._entity_table_names: tuple[str] | tuple[str, str]
|
|
56
|
+
if isinstance(entity_table_name, str):
|
|
57
|
+
self._entity_table_names = (entity_table_name, )
|
|
58
|
+
elif len(entity_table_name) == 1:
|
|
59
|
+
self._entity_table_names = (entity_table_name[0], )
|
|
60
|
+
elif len(entity_table_name) == 2:
|
|
61
|
+
self._entity_table_names = (
|
|
62
|
+
entity_table_name[0],
|
|
63
|
+
entity_table_name[1],
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError # TODO
|
|
67
|
+
|
|
68
|
+
self._entity_column: str = ''
|
|
69
|
+
self._target_column: str = ''
|
|
70
|
+
self._time_column: str | None = None
|
|
71
|
+
|
|
72
|
+
self.entity_column = entity_column
|
|
73
|
+
self.target_column = target_column
|
|
74
|
+
if time_column is not None:
|
|
75
|
+
self.time_column = time_column
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def task_type(self) -> TaskType:
|
|
79
|
+
return self._task_type
|
|
80
|
+
|
|
81
|
+
# Entity column ###########################################################
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def entity_table_name(self) -> str:
|
|
85
|
+
return self._entity_table_names[0]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def entity_table_names(self) -> tuple[str] | tuple[str, str]:
|
|
89
|
+
return self._entity_table_names
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def entity_column(self) -> Column:
|
|
93
|
+
return Column(
|
|
94
|
+
name=self._entity_column,
|
|
95
|
+
expr=None,
|
|
96
|
+
dtype=self._dtype_dict[self._entity_column],
|
|
97
|
+
stype=Stype.ID,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
@entity_column.setter
|
|
101
|
+
def entity_column(self, name: str) -> None:
|
|
102
|
+
if name in {self._target_column, self._time_column}:
|
|
103
|
+
raise ValueError # TODO
|
|
104
|
+
if name not in self._context_df:
|
|
105
|
+
raise ValueError # TODO
|
|
106
|
+
if name not in self._pred_df:
|
|
107
|
+
raise ValueError # TODO
|
|
108
|
+
if not Stype.ID.supports_dtype(self._dtype_dict[name]):
|
|
109
|
+
raise ValueError # TODO
|
|
110
|
+
|
|
111
|
+
self._entity_column = name
|
|
112
|
+
|
|
113
|
+
# Target column ###########################################################
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def _target_stype(self) -> Stype:
|
|
117
|
+
if self.task_type in {
|
|
118
|
+
TaskType.BINARY_CLASSIFICATION,
|
|
119
|
+
TaskType.MULTICLASS_CLASSIFICATION,
|
|
120
|
+
}:
|
|
121
|
+
return Stype.categorical
|
|
122
|
+
if self.task_type in {TaskType.REGRESSION}:
|
|
123
|
+
return Stype.numerical
|
|
124
|
+
if self.task_type.is_link_pred:
|
|
125
|
+
return Stype.multicategorical
|
|
126
|
+
raise ValueError
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def target_column(self) -> Column:
|
|
130
|
+
return Column(
|
|
131
|
+
name=self._target_column,
|
|
132
|
+
expr=None,
|
|
133
|
+
dtype=self._dtype_dict[self._target_column],
|
|
134
|
+
stype=self._target_stype,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@target_column.setter
|
|
138
|
+
def target_column(self, name: str) -> None:
|
|
139
|
+
if name in {self._entity_column, self._time_column}:
|
|
140
|
+
raise ValueError # TODO
|
|
141
|
+
if name not in self._context_df:
|
|
142
|
+
raise ValueError # TODO
|
|
143
|
+
if not self._target_stype.supports_dtype(self._dtype_dict[name]):
|
|
144
|
+
raise ValueError # TODO
|
|
145
|
+
|
|
146
|
+
self._target_column = name
|
|
147
|
+
|
|
148
|
+
# Time column #############################################################
|
|
149
|
+
|
|
150
|
+
def has_time_column(self) -> bool:
|
|
151
|
+
r"""Returns ``True`` if this task has a time column; ``False``
|
|
152
|
+
otherwise.
|
|
153
|
+
"""
|
|
154
|
+
return self._time_column is not None
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def time_column(self) -> Column | None:
|
|
158
|
+
r"""The time column of this task.
|
|
159
|
+
|
|
160
|
+
The getter returns the time column of this task, or ``None`` if no
|
|
161
|
+
such time column is present.
|
|
162
|
+
|
|
163
|
+
The setter sets a column as a time column for this task, and raises a
|
|
164
|
+
:class:`ValueError` if the time column has a non-timestamp compatible
|
|
165
|
+
data type or if the column name does not match a column in the data
|
|
166
|
+
frame.
|
|
167
|
+
"""
|
|
168
|
+
if self._time_column is None:
|
|
169
|
+
return None
|
|
170
|
+
return Column(
|
|
171
|
+
name=self._time_column,
|
|
172
|
+
expr=None,
|
|
173
|
+
dtype=self._dtype_dict[self._time_column],
|
|
174
|
+
stype=Stype.timestamp,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
@time_column.setter
|
|
178
|
+
def time_column(self, name: str | None) -> None:
|
|
179
|
+
if name is None:
|
|
180
|
+
self._time_column = None
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
if name in {self._entity_column, self._target_column}:
|
|
184
|
+
raise ValueError # TODO
|
|
185
|
+
if name not in self._context_df:
|
|
186
|
+
raise ValueError # TODO
|
|
187
|
+
if name not in self._pred_df:
|
|
188
|
+
raise ValueError # TODO
|
|
189
|
+
if not contains_timestamp(
|
|
190
|
+
ser=self._context_df[name],
|
|
191
|
+
column_name=name,
|
|
192
|
+
dtype=self._dtype_dict[name],
|
|
193
|
+
):
|
|
194
|
+
raise ValueError # TODO
|
|
195
|
+
|
|
196
|
+
self._time_column = name
|
|
197
|
+
|
|
198
|
+
# Metadata ################################################################
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def metadata(self) -> pd.DataFrame:
|
|
202
|
+
raise NotImplementedError
|
|
203
|
+
|
|
204
|
+
def print_metadata(self) -> None:
|
|
205
|
+
raise NotImplementedError
|
|
206
|
+
|
|
207
|
+
# Python builtins #########################################################
|
|
208
|
+
|
|
209
|
+
def __hash__(self) -> int:
|
|
210
|
+
return hash((
|
|
211
|
+
self.task_type,
|
|
212
|
+
self.entity_table_names,
|
|
213
|
+
self._entity_column,
|
|
214
|
+
self._target_column,
|
|
215
|
+
self._time_column,
|
|
216
|
+
))
|
|
217
|
+
|
|
218
|
+
def __repr__(self) -> str:
|
|
219
|
+
if self.task_type.is_link_pred:
|
|
220
|
+
entity_table_repr = f'entity_table_names={self.entity_table_names}'
|
|
221
|
+
else:
|
|
222
|
+
entity_table_repr = f'entity_table_name={self.entity_table_name}'
|
|
223
|
+
return (f'{self.__class__.__name__}(\n'
|
|
224
|
+
f' task_type={self.task_type},\n'
|
|
225
|
+
f' num_context_examples={len(self._context_df)},\n'
|
|
226
|
+
f' num_prediction_examples={len(self._pred_df)},\n'
|
|
227
|
+
f' {entity_table_repr},\n'
|
|
228
|
+
f' entity_column={self._entity_column},\n'
|
|
229
|
+
f' target_column={self._target_column},\n'
|
|
230
|
+
f' time_column={self._time_column},\n'
|
|
231
|
+
f')')
|
kumoai/formatting.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from kumoapi.jobs import ErrorDetails
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def pretty_print_error_details(error_details: ErrorDetails) -> str:
|
|
5
|
+
"""Pretty prints the ErrorDetails combining all the individual items.
|
|
6
|
+
If there are CTAs, they are also displayed after creating
|
|
7
|
+
corresponding hyperlinks.
|
|
8
|
+
|
|
9
|
+
Arguments:
|
|
10
|
+
error_details (ErrorDetails): Standard ErrorDetails response from
|
|
11
|
+
get_errors APIs.
|
|
12
|
+
"""
|
|
13
|
+
out = ""
|
|
14
|
+
ctr = None
|
|
15
|
+
if len(error_details.items) != 1:
|
|
16
|
+
out += "Encountered multiple errors:\n"
|
|
17
|
+
ctr = 1
|
|
18
|
+
for error_detail in error_details.items:
|
|
19
|
+
if ctr is not None:
|
|
20
|
+
out += f'{ctr}.'
|
|
21
|
+
ctr += 1
|
|
22
|
+
if error_detail.title is not None:
|
|
23
|
+
out += f'{error_detail.title}: '
|
|
24
|
+
out += error_detail.description
|
|
25
|
+
if error_detail.cta is not None:
|
|
26
|
+
out += 'Follow the link for potential resolution:'
|
|
27
|
+
f' {error_detail.cta.url}'
|
|
28
|
+
out += '\n'
|
|
29
|
+
|
|
30
|
+
return out
|
kumoai/futures.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent
|
|
3
|
+
import logging
|
|
4
|
+
import threading
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from asyncio.events import AbstractEventLoop
|
|
7
|
+
from typing import Any, Awaitable, Coroutine, Generic, TypeVar
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
CoroFuncType = Awaitable[Any]
|
|
12
|
+
|
|
13
|
+
# Kumo global event loop (our implementation of green threads for pollers and
|
|
14
|
+
# other interactions with the Kumo backend that require long-running tasks).
|
|
15
|
+
# Since the caller may have their own event loop that we do not want to
|
|
16
|
+
# mess with, _do not_ ever call `set_event_loop` here!! Instead, be extra
|
|
17
|
+
# cautious to pass this loop everywhere.
|
|
18
|
+
_KUMO_EVENT_LOOP: AbstractEventLoop = asyncio.new_event_loop()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def initialize_event_loop() -> None:
|
|
22
|
+
def _run_background_loop(loop: AbstractEventLoop) -> None:
|
|
23
|
+
asyncio.set_event_loop(loop)
|
|
24
|
+
loop.run_forever()
|
|
25
|
+
|
|
26
|
+
t = threading.Thread(target=_run_background_loop,
|
|
27
|
+
args=(_KUMO_EVENT_LOOP, ), daemon=True)
|
|
28
|
+
t.start()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_future(coro: Coroutine[Any, Any, Any]) -> concurrent.futures.Future:
|
|
32
|
+
r"""Creates a future to execute in the Kumo event loop."""
|
|
33
|
+
# NOTE this function creates a future, chains it to the output of the
|
|
34
|
+
# execution of `coro` in the Kumo event loop, and handles exceptions
|
|
35
|
+
# before scheduling to run in the loop:
|
|
36
|
+
return asyncio.run_coroutine_threadsafe(coro, _KUMO_EVENT_LOOP)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
T = TypeVar("T")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoFuture(ABC, Generic[T]):
|
|
43
|
+
r"""Abstract base class for a Kumo future object."""
|
|
44
|
+
|
|
45
|
+
# We cannot use Python future implementations (`asyncio.Future` or
|
|
46
|
+
# `concurrent.futures.Future`) as they are native to the Python
|
|
47
|
+
# implementation of asyncio and threading, and thus not easily extensible.
|
|
48
|
+
# Python additionally recommends not exposing low-level Future objects in
|
|
49
|
+
# user facing APIs.
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def result(self) -> T:
|
|
52
|
+
r"""Returns the resolved state of the future.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
Exception:
|
|
56
|
+
If the future is complete but in a failed state due to an
|
|
57
|
+
exception being raised, this method will raise the same
|
|
58
|
+
exception.
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def future(self) -> 'concurrent.futures.Future[T]':
|
|
64
|
+
r"""Returns the :obj:`concurrent.futures.Future` object wrapped by
|
|
65
|
+
this future. It is not recommended to access this object directly.
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
def done(self) -> bool:
|
|
70
|
+
r"""Returns :obj:`True` if this future has been resolved with
|
|
71
|
+
``result()``, or :obj:`False` if this future is still
|
|
72
|
+
in-progress. Note that this method will return :obj:`True` if the
|
|
73
|
+
future is complete, but in a failed state, and that this method will
|
|
74
|
+
return :obj:`False` if the job is complete, but the future has not
|
|
75
|
+
been awaited.
|
|
76
|
+
"""
|
|
77
|
+
return self.future().done()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class KumoProgressFuture(KumoFuture[T]):
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def _attach_internal(self, interval_s: float = 4.0) -> T:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def attach(self, interval_s: float = 4.0) -> T:
|
|
86
|
+
r"""Allows a user to attach to a running job and view its progress.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
interval_s (float): Time interval (seconds) between polls, minimum
|
|
90
|
+
value allowed is 4 seconds.
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
return self._attach_internal(interval_s=interval_s)
|
|
94
|
+
except Exception:
|
|
95
|
+
logger.warning(
|
|
96
|
+
"Detailed job tracking has become temporarily unavailable. "
|
|
97
|
+
"The job is continuing to proceed on the Kumo server, "
|
|
98
|
+
"and this call will complete when the job has finished.")
|
|
99
|
+
return self.result()
|
kumoai/graph/__init__.py
ADDED
kumoai/graph/column.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
from kumoapi.table import TimestampUnit
|
|
5
|
+
from kumoapi.typing import Dtype, Stype
|
|
6
|
+
|
|
7
|
+
from kumoai.mixin import CastMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(init=False)
|
|
11
|
+
class Column(CastMixin):
|
|
12
|
+
r"""A column represents metadata information for a column in a Kumo
|
|
13
|
+
:class:`~kumoai.graph.Table`. Columns can be created independent of
|
|
14
|
+
a table, or can be fetched from a table with the
|
|
15
|
+
:meth:`~kumoai.graph.Table.column` method.
|
|
16
|
+
|
|
17
|
+
.. code-block:: python
|
|
18
|
+
|
|
19
|
+
import kumoai
|
|
20
|
+
|
|
21
|
+
# Fetch a column from a `kumoai.Table`:
|
|
22
|
+
table = kumoai.Table(...)
|
|
23
|
+
|
|
24
|
+
column = table.column("col_name")
|
|
25
|
+
column = table["col_name"] # equivalent to the above.
|
|
26
|
+
|
|
27
|
+
# Edit a column's data type:
|
|
28
|
+
print("Existing dtype: ", column.dtype)
|
|
29
|
+
column.dtype = "int"
|
|
30
|
+
|
|
31
|
+
# Edit a column's semantic type:
|
|
32
|
+
print("Existing stype: ", column.stype)
|
|
33
|
+
column.stype = "ID"
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name: The name of this column.
|
|
37
|
+
stype: The semantic type of this column. Semantic types can be
|
|
38
|
+
specified as strings: the list of possible semantic types
|
|
39
|
+
is located at :class:`~kumoai.Stype`.
|
|
40
|
+
dtype: The data type of this column. Data types can be specified
|
|
41
|
+
as strings: the list of possible data types is located at
|
|
42
|
+
:class:`~kumoai.Dtype`.
|
|
43
|
+
timestamp_format: If this column represents a timestamp, the format
|
|
44
|
+
that the timestamp should be parsed in. The format can either be
|
|
45
|
+
a :class:`~kumoapi.table.TimestampUnit` for integer columns or a
|
|
46
|
+
string with a format identifier described
|
|
47
|
+
`here <https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html>`__
|
|
48
|
+
for a SaaS Kumo deployment and
|
|
49
|
+
`here <https://docs.snowflake.com/en/sql-reference/date-time-input-output#about-the-elements-used-in-input-and-output-formats>`__
|
|
50
|
+
for a Snowpark Container Services Kumo deployment. If left empty,
|
|
51
|
+
will be intelligently inferred by Kumo.
|
|
52
|
+
""" # noqa: E501
|
|
53
|
+
name: str
|
|
54
|
+
stype: Optional[Stype] = None
|
|
55
|
+
dtype: Optional[Dtype] = None
|
|
56
|
+
timestamp_format: Optional[Union[str, TimestampUnit]] = None
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
name: str,
|
|
61
|
+
stype: Optional[Union[Stype, str]] = None,
|
|
62
|
+
dtype: Optional[Union[Dtype, str]] = None,
|
|
63
|
+
timestamp_format: Optional[Union[str, TimestampUnit]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
self.name = name
|
|
66
|
+
self.stype = Stype(stype) if stype is not None else None
|
|
67
|
+
self.dtype = Dtype(dtype) if dtype is not None else None
|
|
68
|
+
try:
|
|
69
|
+
self.timestamp_format = TimestampUnit(timestamp_format)
|
|
70
|
+
except ValueError:
|
|
71
|
+
self.timestamp_format = timestamp_format
|
|
72
|
+
|
|
73
|
+
def __hash__(self) -> int:
|
|
74
|
+
return hash((self.name, self.stype, self.dtype, self.timestamp_format))
|
|
75
|
+
|
|
76
|
+
def __setattr__(self, key: Any, value: Any) -> None:
|
|
77
|
+
if key == 'name' and value != getattr(self, key, value):
|
|
78
|
+
raise AttributeError("Attribute 'name' is read-only")
|
|
79
|
+
elif key == 'stype' and isinstance(value, str):
|
|
80
|
+
value = Stype(value)
|
|
81
|
+
elif key == 'dtype' and isinstance(value, str):
|
|
82
|
+
value = Dtype(value)
|
|
83
|
+
elif key == 'timestamp_format' and isinstance(value, str):
|
|
84
|
+
try:
|
|
85
|
+
value = TimestampUnit(value)
|
|
86
|
+
except ValueError:
|
|
87
|
+
pass
|
|
88
|
+
super().__setattr__(key, value)
|
|
89
|
+
|
|
90
|
+
def update(self, obj: 'Column', override: bool = True) -> 'Column':
|
|
91
|
+
for key in self.__dict__:
|
|
92
|
+
if key[0] == '_': # Skip private attributes:
|
|
93
|
+
continue
|
|
94
|
+
value = getattr(obj, key, None)
|
|
95
|
+
if value is not None:
|
|
96
|
+
if override or getattr(self, key, None) is None:
|
|
97
|
+
setattr(self, key, value)
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def __repr__(self) -> str:
|
|
101
|
+
out = (f"Column(name=\"{self.name}\", stype=\"{self.stype}\", "
|
|
102
|
+
f"dtype=\"{self.dtype}\"")
|
|
103
|
+
if self.timestamp_format is not None:
|
|
104
|
+
out += f", timestamp_format=\"{self.timestamp_format}\""
|
|
105
|
+
out += ")"
|
|
106
|
+
return out
|