kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import logging
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
|
|
8
|
+
from kumoapi.json_serde import to_json_dict
|
|
9
|
+
from kumoapi.online_serving import (
|
|
10
|
+
NodeId,
|
|
11
|
+
OnlinePredictionRequest,
|
|
12
|
+
OnlinePredictionResponse,
|
|
13
|
+
OnlinePredictionResult,
|
|
14
|
+
OnlineServingEndpointRequest,
|
|
15
|
+
OnlineServingEndpointResource,
|
|
16
|
+
OnlineServingStatus,
|
|
17
|
+
OnlineServingStatusCode,
|
|
18
|
+
OnlineServingUpdate,
|
|
19
|
+
RealtimeFeatures,
|
|
20
|
+
TimestampNanos,
|
|
21
|
+
)
|
|
22
|
+
from typing_extensions import override
|
|
23
|
+
|
|
24
|
+
from kumoai import global_state
|
|
25
|
+
from kumoai.client.jobs import TrainingJobID
|
|
26
|
+
from kumoai.client.online import (
|
|
27
|
+
OnlineServingEndpointAPI,
|
|
28
|
+
OnlineServingEndpointID,
|
|
29
|
+
)
|
|
30
|
+
from kumoai.client.utils import parse_response
|
|
31
|
+
from kumoai.futures import KumoFuture, create_future
|
|
32
|
+
from kumoai.graph.graph import Graph
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OnlineServingEndpoint:
|
|
38
|
+
"""Represents a Kumo online serving endpoint that serves online `predict`
|
|
39
|
+
requests.
|
|
40
|
+
"""
|
|
41
|
+
def __init__(self, endpoint_url: str):
|
|
42
|
+
self._endpoint_url = endpoint_url
|
|
43
|
+
# Use the same global session with API key in header.
|
|
44
|
+
self._session = global_state.client._session
|
|
45
|
+
self._endpoint_id = self._endpoint_url.split('/')[-1]
|
|
46
|
+
self._predict_url = f'{endpoint_url}/predict'
|
|
47
|
+
logger.info('Initialized OnlineServingEndpoint at: %s', endpoint_url)
|
|
48
|
+
|
|
49
|
+
def predict(
|
|
50
|
+
self,
|
|
51
|
+
fkey: NodeId,
|
|
52
|
+
*,
|
|
53
|
+
time: Union[datetime, TimestampNanos, None] = None,
|
|
54
|
+
realtime_features: Optional[RealtimeFeatures] = None,
|
|
55
|
+
) -> OnlinePredictionResult:
|
|
56
|
+
"""Performs online inference for a single entity key using the
|
|
57
|
+
currently deployed model and feature graph.
|
|
58
|
+
|
|
59
|
+
This method sends a low-latency prediction request to the live
|
|
60
|
+
endpoint. It supports injecting optional
|
|
61
|
+
real-time features and controlling the anchor time used for temporal
|
|
62
|
+
feature lookups.
|
|
63
|
+
|
|
64
|
+
Parameters:
|
|
65
|
+
fkey (NodeId):
|
|
66
|
+
The entity key (e.g., user ID, item ID) to run inference on.
|
|
67
|
+
|
|
68
|
+
time (datetime | TimestampNanos | None, optional):
|
|
69
|
+
The effective timestamp for feature lookup and model
|
|
70
|
+
prediction. If not provided, the current server time will be
|
|
71
|
+
used.
|
|
72
|
+
|
|
73
|
+
realtime_features (Optional[RealtimeFeatures], optional):
|
|
74
|
+
Additional real-time features to inject into the feature graph
|
|
75
|
+
for this prediction request.
|
|
76
|
+
These can complement batch-generated features, useful for
|
|
77
|
+
contextual signals like current session state, real-time data,
|
|
78
|
+
etc.
|
|
79
|
+
|
|
80
|
+
Returns: The prediction result from the deployed model. The return type
|
|
81
|
+
is a union type depending on the model task type.
|
|
82
|
+
"""
|
|
83
|
+
timestamp_nanos = time
|
|
84
|
+
if isinstance(time, datetime):
|
|
85
|
+
timestamp_nanos = int(time.timestamp() * 10**9)
|
|
86
|
+
resp = self._session.post(
|
|
87
|
+
self._predict_url, json=to_json_dict(
|
|
88
|
+
OnlinePredictionRequest(fkey, timestamp_nanos,
|
|
89
|
+
realtime_features)))
|
|
90
|
+
resp.raise_for_status()
|
|
91
|
+
return parse_response(OnlinePredictionResponse, resp).result
|
|
92
|
+
|
|
93
|
+
def ping(self) -> str:
|
|
94
|
+
resp = self._session.get(f'{self._endpoint_url}/probe_liveness')
|
|
95
|
+
resp.raise_for_status()
|
|
96
|
+
return resp.text
|
|
97
|
+
|
|
98
|
+
def update(
|
|
99
|
+
self,
|
|
100
|
+
*,
|
|
101
|
+
refresh_graph_data: bool = True,
|
|
102
|
+
graph_override: Optional[Graph] = None,
|
|
103
|
+
new_model_id: Optional[TrainingJobID] = None,
|
|
104
|
+
) -> 'OnlineServingEndpointUpdateFuture':
|
|
105
|
+
"""Triggers an asynchronous update to the online serving endpoint using
|
|
106
|
+
a blue-green deployment strategy.
|
|
107
|
+
|
|
108
|
+
This method allows clients to deploy a new version of the endpoint with
|
|
109
|
+
updated model weights, refreshed feature
|
|
110
|
+
data, or a complete graph override. The update is applied in the
|
|
111
|
+
background without interrupting availability,
|
|
112
|
+
and will swap traffic to the new deployment once it is fully ready.
|
|
113
|
+
|
|
114
|
+
Parameters:
|
|
115
|
+
refresh_graph_data (bool, optional):
|
|
116
|
+
Whether to reload feature data from the latest available
|
|
117
|
+
source. Defaults to True.
|
|
118
|
+
|
|
119
|
+
graph_override (Optional[Graph], optional):
|
|
120
|
+
If provided, overrides the existing feature graph with the
|
|
121
|
+
given one.
|
|
122
|
+
This is useful for testing or dynamic reconfiguration of the
|
|
123
|
+
feature pipeline.
|
|
124
|
+
|
|
125
|
+
new_model_id (Optional[TrainingJobID], optional):
|
|
126
|
+
If specified, deploys a new model with the given training job
|
|
127
|
+
ID. This model will replace the currently serving model after
|
|
128
|
+
update is complete.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
OnlineServingEndpointUpdateFuture:
|
|
132
|
+
A future-like object that can be used to check the progress and
|
|
133
|
+
result of the update operation.
|
|
134
|
+
|
|
135
|
+
Example: (aysnchronously send email notification when update is done)
|
|
136
|
+
>>> fut = endpoint.update(new_model_id="model_202504")
|
|
137
|
+
>>> fut.future().add_done_callback(send_email_notification)
|
|
138
|
+
"""
|
|
139
|
+
res = _endpoint_api().get_if_exists(self._endpoint_id)
|
|
140
|
+
assert res
|
|
141
|
+
|
|
142
|
+
if not refresh_graph_data and not new_model_id:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
'Expect to update online endpoint by loading a new model '
|
|
145
|
+
'and/or refreshed graph data.')
|
|
146
|
+
|
|
147
|
+
model_id = new_model_id or res.config.model_training_job_id
|
|
148
|
+
if refresh_graph_data:
|
|
149
|
+
graph_snapshot_id = None
|
|
150
|
+
if graph_override:
|
|
151
|
+
graph_snapshot_id = graph_override.snapshot(
|
|
152
|
+
force_refresh=refresh_graph_data, non_blocking=True)
|
|
153
|
+
else:
|
|
154
|
+
graph_snapshot_id = (None if refresh_graph_data else
|
|
155
|
+
res.config.graph_snapshot_id)
|
|
156
|
+
|
|
157
|
+
updated = _endpoint_api().update(
|
|
158
|
+
self._endpoint_id,
|
|
159
|
+
OnlineServingEndpointRequest(model_id, res.config.predict_options,
|
|
160
|
+
graph_snapshot_id))
|
|
161
|
+
|
|
162
|
+
return OnlineServingEndpointUpdateFuture(self._endpoint_id,
|
|
163
|
+
noop=not updated)
|
|
164
|
+
|
|
165
|
+
def destroy(self) -> None:
|
|
166
|
+
_endpoint_api().delete(self._endpoint_id)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class OnlineServingEndpointFuture(KumoFuture[OnlineServingEndpoint]):
|
|
170
|
+
def __init__(self, id: OnlineServingEndpointID) -> None:
|
|
171
|
+
self._id = id
|
|
172
|
+
self._fut: concurrent.futures.Future[
|
|
173
|
+
OnlineServingEndpoint] = create_future(_poll_endpoint_ready(id))
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def id(self) -> OnlineServingEndpointID:
|
|
177
|
+
r"""The unique ID of this batch prediction job."""
|
|
178
|
+
return self._id
|
|
179
|
+
|
|
180
|
+
@override
|
|
181
|
+
def result(self) -> OnlineServingEndpoint:
|
|
182
|
+
return self._fut.result()
|
|
183
|
+
|
|
184
|
+
@override
|
|
185
|
+
def future(self) -> 'concurrent.futures.Future[OnlineServingEndpoint]':
|
|
186
|
+
return self._fut
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class OnlineServingEndpointUpdateFuture(KumoFuture[OnlineServingUpdate]):
|
|
190
|
+
def __init__(self, id: OnlineServingEndpointID, noop: bool):
|
|
191
|
+
if noop:
|
|
192
|
+
res = _endpoint_api().get_if_exists(id)
|
|
193
|
+
assert res
|
|
194
|
+
fut = concurrent.futures.Future[OnlineServingEndpoint]()
|
|
195
|
+
fut.set_result(
|
|
196
|
+
OnlineServingUpdate(
|
|
197
|
+
prev_config=res.config, target_config=res.config,
|
|
198
|
+
update_started_at=datetime.now(timezone.utc),
|
|
199
|
+
update_status=OnlineServingStatus(
|
|
200
|
+
OnlineServingStatusCode.READY,
|
|
201
|
+
datetime.now(timezone.utc))))
|
|
202
|
+
else:
|
|
203
|
+
fut = create_future(_poll_update_ready(id))
|
|
204
|
+
self._fut = fut
|
|
205
|
+
|
|
206
|
+
@override
|
|
207
|
+
def result(self) -> OnlineServingUpdate:
|
|
208
|
+
return self._fut.result()
|
|
209
|
+
|
|
210
|
+
@override
|
|
211
|
+
def future(self) -> 'concurrent.futures.Future[OnlineServingUpdate]':
|
|
212
|
+
return self._fut
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
async def _get_endpoint_resource(
|
|
216
|
+
id: OnlineServingEndpointID) -> OnlineServingEndpointResource:
|
|
217
|
+
api = global_state.client.online_serving_endpoint_api
|
|
218
|
+
# TODO(manan): make asynchronous natively with aiohttp:
|
|
219
|
+
res = await asyncio.get_running_loop().run_in_executor(
|
|
220
|
+
None, api.get_if_exists, id)
|
|
221
|
+
assert res
|
|
222
|
+
return res
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
async def _poll_endpoint_ready(
|
|
226
|
+
id: OnlineServingEndpointID) -> OnlineServingEndpoint:
|
|
227
|
+
while True:
|
|
228
|
+
res = await _get_endpoint_resource(id)
|
|
229
|
+
status = res.status.status_code
|
|
230
|
+
if status == OnlineServingStatusCode.IN_PROGRESS:
|
|
231
|
+
await asyncio.sleep(10)
|
|
232
|
+
else:
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
if status == OnlineServingStatusCode.FAILED:
|
|
236
|
+
raise ValueError(f"Failed to launch online endpoint id={id}, "
|
|
237
|
+
f"failure message: {res.status.failure_message}")
|
|
238
|
+
|
|
239
|
+
assert status == OnlineServingStatusCode.READY
|
|
240
|
+
endpoint = OnlineServingEndpoint(res.endpoint_url)
|
|
241
|
+
logger.info('OnlineServingEndpoint is ready, ping: %s', endpoint.ping())
|
|
242
|
+
return endpoint
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
async def _poll_update_ready(
|
|
246
|
+
id: OnlineServingEndpointID) -> OnlineServingUpdate:
|
|
247
|
+
while True:
|
|
248
|
+
res = await _get_endpoint_resource(id)
|
|
249
|
+
if res.update.update_status == OnlineServingStatusCode.IN_PROGRESS:
|
|
250
|
+
await asyncio.sleep(10)
|
|
251
|
+
else:
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
return res.update
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _endpoint_api() -> OnlineServingEndpointAPI:
|
|
258
|
+
return global_state.client.online_serving_endpoint_api
|