kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,258 @@
1
+ import asyncio
2
+ import concurrent
3
+ import concurrent.futures
4
+ import logging
5
+ from datetime import datetime, timezone
6
+ from typing import Optional, Union
7
+
8
+ from kumoapi.json_serde import to_json_dict
9
+ from kumoapi.online_serving import (
10
+ NodeId,
11
+ OnlinePredictionRequest,
12
+ OnlinePredictionResponse,
13
+ OnlinePredictionResult,
14
+ OnlineServingEndpointRequest,
15
+ OnlineServingEndpointResource,
16
+ OnlineServingStatus,
17
+ OnlineServingStatusCode,
18
+ OnlineServingUpdate,
19
+ RealtimeFeatures,
20
+ TimestampNanos,
21
+ )
22
+ from typing_extensions import override
23
+
24
+ from kumoai import global_state
25
+ from kumoai.client.jobs import TrainingJobID
26
+ from kumoai.client.online import (
27
+ OnlineServingEndpointAPI,
28
+ OnlineServingEndpointID,
29
+ )
30
+ from kumoai.client.utils import parse_response
31
+ from kumoai.futures import KumoFuture, create_future
32
+ from kumoai.graph.graph import Graph
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class OnlineServingEndpoint:
38
+ """Represents a Kumo online serving endpoint that serves online `predict`
39
+ requests.
40
+ """
41
+ def __init__(self, endpoint_url: str):
42
+ self._endpoint_url = endpoint_url
43
+ # Use the same global session with API key in header.
44
+ self._session = global_state.client._session
45
+ self._endpoint_id = self._endpoint_url.split('/')[-1]
46
+ self._predict_url = f'{endpoint_url}/predict'
47
+ logger.info('Initialized OnlineServingEndpoint at: %s', endpoint_url)
48
+
49
+ def predict(
50
+ self,
51
+ fkey: NodeId,
52
+ *,
53
+ time: Union[datetime, TimestampNanos, None] = None,
54
+ realtime_features: Optional[RealtimeFeatures] = None,
55
+ ) -> OnlinePredictionResult:
56
+ """Performs online inference for a single entity key using the
57
+ currently deployed model and feature graph.
58
+
59
+ This method sends a low-latency prediction request to the live
60
+ endpoint. It supports injecting optional
61
+ real-time features and controlling the anchor time used for temporal
62
+ feature lookups.
63
+
64
+ Parameters:
65
+ fkey (NodeId):
66
+ The entity key (e.g., user ID, item ID) to run inference on.
67
+
68
+ time (datetime | TimestampNanos | None, optional):
69
+ The effective timestamp for feature lookup and model
70
+ prediction. If not provided, the current server time will be
71
+ used.
72
+
73
+ realtime_features (Optional[RealtimeFeatures], optional):
74
+ Additional real-time features to inject into the feature graph
75
+ for this prediction request.
76
+ These can complement batch-generated features, useful for
77
+ contextual signals like current session state, real-time data,
78
+ etc.
79
+
80
+ Returns: The prediction result from the deployed model. The return type
81
+ is a union type depending on the model task type.
82
+ """
83
+ timestamp_nanos = time
84
+ if isinstance(time, datetime):
85
+ timestamp_nanos = int(time.timestamp() * 10**9)
86
+ resp = self._session.post(
87
+ self._predict_url, json=to_json_dict(
88
+ OnlinePredictionRequest(fkey, timestamp_nanos,
89
+ realtime_features)))
90
+ resp.raise_for_status()
91
+ return parse_response(OnlinePredictionResponse, resp).result
92
+
93
+ def ping(self) -> str:
94
+ resp = self._session.get(f'{self._endpoint_url}/probe_liveness')
95
+ resp.raise_for_status()
96
+ return resp.text
97
+
98
+ def update(
99
+ self,
100
+ *,
101
+ refresh_graph_data: bool = True,
102
+ graph_override: Optional[Graph] = None,
103
+ new_model_id: Optional[TrainingJobID] = None,
104
+ ) -> 'OnlineServingEndpointUpdateFuture':
105
+ """Triggers an asynchronous update to the online serving endpoint using
106
+ a blue-green deployment strategy.
107
+
108
+ This method allows clients to deploy a new version of the endpoint with
109
+ updated model weights, refreshed feature
110
+ data, or a complete graph override. The update is applied in the
111
+ background without interrupting availability,
112
+ and will swap traffic to the new deployment once it is fully ready.
113
+
114
+ Parameters:
115
+ refresh_graph_data (bool, optional):
116
+ Whether to reload feature data from the latest available
117
+ source. Defaults to True.
118
+
119
+ graph_override (Optional[Graph], optional):
120
+ If provided, overrides the existing feature graph with the
121
+ given one.
122
+ This is useful for testing or dynamic reconfiguration of the
123
+ feature pipeline.
124
+
125
+ new_model_id (Optional[TrainingJobID], optional):
126
+ If specified, deploys a new model with the given training job
127
+ ID. This model will replace the currently serving model after
128
+ update is complete.
129
+
130
+ Returns:
131
+ OnlineServingEndpointUpdateFuture:
132
+ A future-like object that can be used to check the progress and
133
+ result of the update operation.
134
+
135
+ Example: (aysnchronously send email notification when update is done)
136
+ >>> fut = endpoint.update(new_model_id="model_202504")
137
+ >>> fut.future().add_done_callback(send_email_notification)
138
+ """
139
+ res = _endpoint_api().get_if_exists(self._endpoint_id)
140
+ assert res
141
+
142
+ if not refresh_graph_data and not new_model_id:
143
+ raise ValueError(
144
+ 'Expect to update online endpoint by loading a new model '
145
+ 'and/or refreshed graph data.')
146
+
147
+ model_id = new_model_id or res.config.model_training_job_id
148
+ if refresh_graph_data:
149
+ graph_snapshot_id = None
150
+ if graph_override:
151
+ graph_snapshot_id = graph_override.snapshot(
152
+ force_refresh=refresh_graph_data, non_blocking=True)
153
+ else:
154
+ graph_snapshot_id = (None if refresh_graph_data else
155
+ res.config.graph_snapshot_id)
156
+
157
+ updated = _endpoint_api().update(
158
+ self._endpoint_id,
159
+ OnlineServingEndpointRequest(model_id, res.config.predict_options,
160
+ graph_snapshot_id))
161
+
162
+ return OnlineServingEndpointUpdateFuture(self._endpoint_id,
163
+ noop=not updated)
164
+
165
+ def destroy(self) -> None:
166
+ _endpoint_api().delete(self._endpoint_id)
167
+
168
+
169
+ class OnlineServingEndpointFuture(KumoFuture[OnlineServingEndpoint]):
170
+ def __init__(self, id: OnlineServingEndpointID) -> None:
171
+ self._id = id
172
+ self._fut: concurrent.futures.Future[
173
+ OnlineServingEndpoint] = create_future(_poll_endpoint_ready(id))
174
+
175
+ @property
176
+ def id(self) -> OnlineServingEndpointID:
177
+ r"""The unique ID of this batch prediction job."""
178
+ return self._id
179
+
180
+ @override
181
+ def result(self) -> OnlineServingEndpoint:
182
+ return self._fut.result()
183
+
184
+ @override
185
+ def future(self) -> 'concurrent.futures.Future[OnlineServingEndpoint]':
186
+ return self._fut
187
+
188
+
189
+ class OnlineServingEndpointUpdateFuture(KumoFuture[OnlineServingUpdate]):
190
+ def __init__(self, id: OnlineServingEndpointID, noop: bool):
191
+ if noop:
192
+ res = _endpoint_api().get_if_exists(id)
193
+ assert res
194
+ fut = concurrent.futures.Future[OnlineServingEndpoint]()
195
+ fut.set_result(
196
+ OnlineServingUpdate(
197
+ prev_config=res.config, target_config=res.config,
198
+ update_started_at=datetime.now(timezone.utc),
199
+ update_status=OnlineServingStatus(
200
+ OnlineServingStatusCode.READY,
201
+ datetime.now(timezone.utc))))
202
+ else:
203
+ fut = create_future(_poll_update_ready(id))
204
+ self._fut = fut
205
+
206
+ @override
207
+ def result(self) -> OnlineServingUpdate:
208
+ return self._fut.result()
209
+
210
+ @override
211
+ def future(self) -> 'concurrent.futures.Future[OnlineServingUpdate]':
212
+ return self._fut
213
+
214
+
215
+ async def _get_endpoint_resource(
216
+ id: OnlineServingEndpointID) -> OnlineServingEndpointResource:
217
+ api = global_state.client.online_serving_endpoint_api
218
+ # TODO(manan): make asynchronous natively with aiohttp:
219
+ res = await asyncio.get_running_loop().run_in_executor(
220
+ None, api.get_if_exists, id)
221
+ assert res
222
+ return res
223
+
224
+
225
+ async def _poll_endpoint_ready(
226
+ id: OnlineServingEndpointID) -> OnlineServingEndpoint:
227
+ while True:
228
+ res = await _get_endpoint_resource(id)
229
+ status = res.status.status_code
230
+ if status == OnlineServingStatusCode.IN_PROGRESS:
231
+ await asyncio.sleep(10)
232
+ else:
233
+ break
234
+
235
+ if status == OnlineServingStatusCode.FAILED:
236
+ raise ValueError(f"Failed to launch online endpoint id={id}, "
237
+ f"failure message: {res.status.failure_message}")
238
+
239
+ assert status == OnlineServingStatusCode.READY
240
+ endpoint = OnlineServingEndpoint(res.endpoint_url)
241
+ logger.info('OnlineServingEndpoint is ready, ping: %s', endpoint.ping())
242
+ return endpoint
243
+
244
+
245
+ async def _poll_update_ready(
246
+ id: OnlineServingEndpointID) -> OnlineServingUpdate:
247
+ while True:
248
+ res = await _get_endpoint_resource(id)
249
+ if res.update.update_status == OnlineServingStatusCode.IN_PROGRESS:
250
+ await asyncio.sleep(10)
251
+ else:
252
+ break
253
+
254
+ return res.update
255
+
256
+
257
+ def _endpoint_api() -> OnlineServingEndpointAPI:
258
+ return global_state.client.online_serving_endpoint_api