kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/databricks.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
DB_SEP = '__kumo__'
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def to_db_table_name(
|
|
7
|
+
table_name: Optional[Union[str, Tuple]] = None) -> Optional[str]:
|
|
8
|
+
r"""For Databricks connectors, return the table name whichs is
|
|
9
|
+
a Tuple as a string with the format `f"{schema}__kumo__{table}"`.
|
|
10
|
+
"""
|
|
11
|
+
if table_name and isinstance(table_name, tuple):
|
|
12
|
+
return (f"{table_name[0]}{DB_SEP}"
|
|
13
|
+
f"{table_name[1]}")
|
|
14
|
+
return table_name # type: ignore
|
kumoai/exceptions.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import http
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class HTTPException(Exception):
|
|
6
|
+
r"""An HTTP exception, with detailed information and headers."""
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
status_code: int,
|
|
10
|
+
detail: Optional[str] = None,
|
|
11
|
+
headers: Optional[Dict[str, str]] = None,
|
|
12
|
+
) -> None:
|
|
13
|
+
# Derived from starlette/blob/master/starlette/exceptions.py
|
|
14
|
+
if detail is None:
|
|
15
|
+
detail = http.HTTPStatus(status_code).phrase
|
|
16
|
+
self.status_code = status_code
|
|
17
|
+
self.detail = detail
|
|
18
|
+
self.headers = headers
|
|
19
|
+
|
|
20
|
+
def __str__(self) -> str:
|
|
21
|
+
return f"{self.status_code}: {self.detail}"
|
|
22
|
+
|
|
23
|
+
def __repr__(self) -> str:
|
|
24
|
+
class_name = self.__class__.__name__
|
|
25
|
+
return (f"{class_name}(status_code={self.status_code!r}, "
|
|
26
|
+
f"detail={self.detail!r})")
|
|
File without changes
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import ipaddress
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import socket
|
|
6
|
+
import threading
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
import kumoai
|
|
12
|
+
from kumoai.client.client import KumoClient
|
|
13
|
+
from kumoai.spcs import _get_active_session
|
|
14
|
+
|
|
15
|
+
from .authenticate import authenticate
|
|
16
|
+
from .sagemaker import (
|
|
17
|
+
KumoClient_SageMakerAdapter,
|
|
18
|
+
KumoClient_SageMakerProxy_Local,
|
|
19
|
+
)
|
|
20
|
+
from .base import Table
|
|
21
|
+
from .backend.local import LocalTable
|
|
22
|
+
from .graph import Graph
|
|
23
|
+
from .task_table import TaskTable
|
|
24
|
+
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger('kumoai_rfm')
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_local_address(host: str | None) -> bool:
|
|
30
|
+
"""Return True if the hostname/IP refers to the local machine."""
|
|
31
|
+
if not host:
|
|
32
|
+
return False
|
|
33
|
+
try:
|
|
34
|
+
infos = socket.getaddrinfo(host, None)
|
|
35
|
+
for _, _, _, _, sockaddr in infos:
|
|
36
|
+
ip = sockaddr[0]
|
|
37
|
+
ip_obj = ipaddress.ip_address(ip)
|
|
38
|
+
if ip_obj.is_loopback or ip_obj.is_unspecified:
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
41
|
+
except Exception:
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InferenceBackend(str, Enum):
|
|
46
|
+
REST = "REST"
|
|
47
|
+
LOCAL_SAGEMAKER = "LOCAL_SAGEMAKER"
|
|
48
|
+
AWS_SAGEMAKER = "AWS_SAGEMAKER"
|
|
49
|
+
UNKNOWN = "UNKNOWN"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _detect_backend(
|
|
53
|
+
url: str, #
|
|
54
|
+
) -> tuple[InferenceBackend, str | None, str | None]:
|
|
55
|
+
parsed = urlparse(url)
|
|
56
|
+
|
|
57
|
+
# Remote SageMaker
|
|
58
|
+
if ("runtime.sagemaker" in parsed.netloc
|
|
59
|
+
and parsed.path.endswith("/invocations")):
|
|
60
|
+
# Example: https://runtime.sagemaker.us-west-2.amazonaws.com/
|
|
61
|
+
# endpoints/Name/invocations
|
|
62
|
+
match = re.search(r"runtime\.sagemaker\.([a-z0-9-]+)\.amazonaws\.com",
|
|
63
|
+
parsed.netloc)
|
|
64
|
+
region = match.group(1) if match else None
|
|
65
|
+
m = re.search(r"/endpoints/([^/]+)/invocations", parsed.path)
|
|
66
|
+
endpoint_name = m.group(1) if m else None
|
|
67
|
+
return InferenceBackend.AWS_SAGEMAKER, region, endpoint_name
|
|
68
|
+
|
|
69
|
+
# Local SageMaker
|
|
70
|
+
if parsed.port == 8080 and parsed.path.endswith(
|
|
71
|
+
"/invocations") and _is_local_address(parsed.hostname):
|
|
72
|
+
return InferenceBackend.LOCAL_SAGEMAKER, None, None
|
|
73
|
+
|
|
74
|
+
# Default: regular REST
|
|
75
|
+
return InferenceBackend.REST, None, None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_snowflake_url(snowflake_application: str) -> str:
|
|
79
|
+
snowpark_session = _get_active_session()
|
|
80
|
+
if not snowpark_session:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"KumoRFM initialization failed. 'snowflake_application' is "
|
|
83
|
+
"specified without an active Snowpark session. If running outside "
|
|
84
|
+
"a Snowflake notebook, specify a URL and credentials.")
|
|
85
|
+
with snowpark_session.connection.cursor() as cur:
|
|
86
|
+
cur.execute(
|
|
87
|
+
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
88
|
+
f" ->> SELECT \"dns_name\" from $1")
|
|
89
|
+
dns_name: str = cur.fetchone()[0]
|
|
90
|
+
return f"http://{dns_name}:8000/api"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class RfmGlobalState:
|
|
95
|
+
_url: str = '__url_not_provided__'
|
|
96
|
+
_backend: InferenceBackend = InferenceBackend.UNKNOWN
|
|
97
|
+
_region: str | None = None
|
|
98
|
+
_endpoint_name: str | None = None
|
|
99
|
+
_thread_local = threading.local()
|
|
100
|
+
|
|
101
|
+
# Thread-safe init-once.
|
|
102
|
+
_initialized: bool = False
|
|
103
|
+
_lock: threading.Lock = threading.Lock()
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def client(self) -> KumoClient:
|
|
107
|
+
if self._backend == InferenceBackend.UNKNOWN:
|
|
108
|
+
raise RuntimeError("KumoRFM is not yet initialized")
|
|
109
|
+
|
|
110
|
+
if self._backend == InferenceBackend.REST:
|
|
111
|
+
return kumoai.global_state.client
|
|
112
|
+
|
|
113
|
+
if hasattr(self._thread_local, '_sagemaker'):
|
|
114
|
+
# Set the spcs token in the client to ensure it has the latest.
|
|
115
|
+
return self._thread_local._sagemaker
|
|
116
|
+
|
|
117
|
+
sagemaker_client: KumoClient
|
|
118
|
+
if self._backend == InferenceBackend.LOCAL_SAGEMAKER:
|
|
119
|
+
sagemaker_client = KumoClient_SageMakerProxy_Local(self._url)
|
|
120
|
+
else:
|
|
121
|
+
assert self._backend == InferenceBackend.AWS_SAGEMAKER
|
|
122
|
+
assert self._region
|
|
123
|
+
assert self._endpoint_name
|
|
124
|
+
sagemaker_client = KumoClient_SageMakerAdapter(
|
|
125
|
+
self._region, self._endpoint_name)
|
|
126
|
+
|
|
127
|
+
self._thread_local._sagemaker = sagemaker_client
|
|
128
|
+
return sagemaker_client
|
|
129
|
+
|
|
130
|
+
def reset(self) -> None: # For testing only.
|
|
131
|
+
with self._lock:
|
|
132
|
+
self._initialized = False
|
|
133
|
+
self._url = '__url_not_provided__'
|
|
134
|
+
self._backend = InferenceBackend.UNKNOWN
|
|
135
|
+
self._region = None
|
|
136
|
+
self._endpoint_name = None
|
|
137
|
+
self._thread_local = threading.local()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
global_state = RfmGlobalState()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def init(
|
|
144
|
+
url: str | None = None,
|
|
145
|
+
api_key: str | None = None,
|
|
146
|
+
snowflake_credentials: dict[str, str] | None = None,
|
|
147
|
+
snowflake_application: str | None = None,
|
|
148
|
+
log_level: str = "INFO",
|
|
149
|
+
) -> None:
|
|
150
|
+
with global_state._lock:
|
|
151
|
+
if global_state._initialized:
|
|
152
|
+
if url != global_state._url:
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"KumoRFM has already been initialized with a different "
|
|
155
|
+
"API URL. Re-initialization with a different URL is not "
|
|
156
|
+
"supported.")
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
if snowflake_application:
|
|
160
|
+
if url is not None:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"KumoRFM initialization failed. Both "
|
|
163
|
+
"'snowflake_application' and 'url' are specified. If "
|
|
164
|
+
"running from a Snowflake notebook, specify only "
|
|
165
|
+
"'snowflake_application'.")
|
|
166
|
+
url = _get_snowflake_url(snowflake_application)
|
|
167
|
+
api_key = "test:DISABLED"
|
|
168
|
+
|
|
169
|
+
if url is None:
|
|
170
|
+
url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
|
|
171
|
+
|
|
172
|
+
backend, region, endpoint_name = _detect_backend(url)
|
|
173
|
+
if backend == InferenceBackend.REST:
|
|
174
|
+
kumoai.init(
|
|
175
|
+
url=url,
|
|
176
|
+
api_key=api_key,
|
|
177
|
+
snowflake_credentials=snowflake_credentials,
|
|
178
|
+
snowflake_application=snowflake_application,
|
|
179
|
+
log_level=log_level,
|
|
180
|
+
)
|
|
181
|
+
elif backend == InferenceBackend.AWS_SAGEMAKER:
|
|
182
|
+
assert region
|
|
183
|
+
assert endpoint_name
|
|
184
|
+
KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
|
|
185
|
+
logger.info("KumoRFM initialized in AWS SageMaker")
|
|
186
|
+
else:
|
|
187
|
+
assert backend == InferenceBackend.LOCAL_SAGEMAKER
|
|
188
|
+
KumoClient_SageMakerProxy_Local(url).authenticate()
|
|
189
|
+
logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
|
|
190
|
+
|
|
191
|
+
global_state._url = url
|
|
192
|
+
global_state._backend = backend
|
|
193
|
+
global_state._region = region
|
|
194
|
+
global_state._endpoint_name = endpoint_name
|
|
195
|
+
global_state._initialized = True
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
|
|
199
|
+
|
|
200
|
+
__all__ = [
|
|
201
|
+
'authenticate',
|
|
202
|
+
'init',
|
|
203
|
+
'Table',
|
|
204
|
+
'LocalTable',
|
|
205
|
+
'Graph',
|
|
206
|
+
'TaskTable',
|
|
207
|
+
'KumoRFM',
|
|
208
|
+
'ExplainConfig',
|
|
209
|
+
'Explanation',
|
|
210
|
+
]
|
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import platform
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from kumoai import in_notebook
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def authenticate(api_url: str | None = None) -> None:
|
|
10
|
+
"""Authenticates the user and sets the Kumo API key for the SDK.
|
|
11
|
+
|
|
12
|
+
This function detects the current environment and launches the appropriate
|
|
13
|
+
authentication flow:
|
|
14
|
+
- In Google Colab: displays an interactive widget to generate and set the
|
|
15
|
+
API key.
|
|
16
|
+
- In all other environments: opens a browser for OAuth2 login, or allows
|
|
17
|
+
manual API key entry if browser login fails.
|
|
18
|
+
|
|
19
|
+
After successful authentication, the API key is set in the "KUMO_API_KEY"
|
|
20
|
+
environment variable for use by the SDK.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
api_url (str, optional): The base URL for the Kumo API
|
|
24
|
+
(e.g., 'https://kumorfm.ai'). If not provided, uses the
|
|
25
|
+
'KUMO_API_URL' environment variable.
|
|
26
|
+
"""
|
|
27
|
+
import re
|
|
28
|
+
|
|
29
|
+
if api_url is None:
|
|
30
|
+
api_url = os.getenv("KUMO_API_URL", "https://kumorfm.ai")
|
|
31
|
+
|
|
32
|
+
# Remove everything after the domain (keep protocol and domain)
|
|
33
|
+
# e.g. https://kumorfm.ai/api/xyz -> https://kumorfm.ai
|
|
34
|
+
# This is needed to create API keys using the UI popup flow
|
|
35
|
+
api_url = re.sub(
|
|
36
|
+
r"(https?://[^/]+).*", r"\1",
|
|
37
|
+
api_url.rstrip('/')) if '://' in api_url else api_url.split('/')[0]
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from google.colab import output # noqa: F401
|
|
41
|
+
except Exception:
|
|
42
|
+
_authenticate_local(api_url)
|
|
43
|
+
else:
|
|
44
|
+
_authenticate_colab(api_url)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
|
|
48
|
+
"""Starts an HTTP server on the user's local machine to handle OAuth2
|
|
49
|
+
or similar login flow, opens the browser for user login, and sets the
|
|
50
|
+
API key via the "KUMO_API_KEY" environment variable.
|
|
51
|
+
|
|
52
|
+
If browser-based authentication fails or is not possible, allows the
|
|
53
|
+
user to manually paste an API key.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
api_url (str): The base URL for authentication (login page).
|
|
57
|
+
redirect_port (int, optional): The port for the local callback
|
|
58
|
+
server (default: 8765).
|
|
59
|
+
"""
|
|
60
|
+
import http.server
|
|
61
|
+
import threading
|
|
62
|
+
import time
|
|
63
|
+
import urllib.parse
|
|
64
|
+
import webbrowser
|
|
65
|
+
from getpass import getpass
|
|
66
|
+
from socketserver import TCPServer
|
|
67
|
+
from typing import Any
|
|
68
|
+
|
|
69
|
+
logger = logging.getLogger('kumoai')
|
|
70
|
+
|
|
71
|
+
token_status: dict[str, Any] = {
|
|
72
|
+
'token': None,
|
|
73
|
+
'token_name': None,
|
|
74
|
+
'failed': False
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
token_name = (f"sdk-{platform.node().lower()}-" +
|
|
78
|
+
datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + '-Z')
|
|
79
|
+
|
|
80
|
+
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
|
81
|
+
def do_GET(self) -> None:
|
|
82
|
+
parsed_path = urllib.parse.urlparse(self.path)
|
|
83
|
+
params = urllib.parse.parse_qs(parsed_path.query)
|
|
84
|
+
token = params.get('token', [None])[0]
|
|
85
|
+
received_token_name = params.get('token_name', [None])[0]
|
|
86
|
+
|
|
87
|
+
if token:
|
|
88
|
+
token_status['token'] = token
|
|
89
|
+
token_status['token_name'] = received_token_name
|
|
90
|
+
self.send_response(200)
|
|
91
|
+
self.send_header('Content-type', 'text/html')
|
|
92
|
+
self.end_headers()
|
|
93
|
+
else:
|
|
94
|
+
token_status['failed'] = True
|
|
95
|
+
self.send_response(400)
|
|
96
|
+
self.end_headers()
|
|
97
|
+
|
|
98
|
+
html = f'''
|
|
99
|
+
<!DOCTYPE html>
|
|
100
|
+
<html>
|
|
101
|
+
<head>
|
|
102
|
+
<title>Authenticate SDK</title>
|
|
103
|
+
<style>
|
|
104
|
+
body {{
|
|
105
|
+
margin: 0;
|
|
106
|
+
padding: 0;
|
|
107
|
+
display: flex;
|
|
108
|
+
justify-content: center;
|
|
109
|
+
align-items: center;
|
|
110
|
+
min-height: 100vh;
|
|
111
|
+
font-family:
|
|
112
|
+
-apple-system,
|
|
113
|
+
BlinkMacSystemFont,
|
|
114
|
+
'Segoe UI', Roboto, sans-serif;
|
|
115
|
+
}}
|
|
116
|
+
.container {{
|
|
117
|
+
text-align: center;
|
|
118
|
+
padding: 40px;
|
|
119
|
+
}}
|
|
120
|
+
svg {{
|
|
121
|
+
margin-bottom: 20px;
|
|
122
|
+
}}
|
|
123
|
+
p {{
|
|
124
|
+
font-size: 18px;
|
|
125
|
+
color: #333;
|
|
126
|
+
}}
|
|
127
|
+
</style>
|
|
128
|
+
</head>
|
|
129
|
+
<body>
|
|
130
|
+
<div class="container">
|
|
131
|
+
<?xml version="1.0" encoding="UTF-8"?>
|
|
132
|
+
<svg xmlns="http://www.w3.org/2000/svg"
|
|
133
|
+
id="kumo-logo" width="183.908" height="91.586"
|
|
134
|
+
viewBox="0 0 183.908 91.586">
|
|
135
|
+
<g id="c">
|
|
136
|
+
<g id="Group_9893" data-name="Group 9893">
|
|
137
|
+
<path id="Path_4831" data-name="Path 4831"
|
|
138
|
+
d="M67.159,67.919V46.238L53.494,59.491,
|
|
139
|
+
68.862,82.3H61.567L49.1,63.74l-7.011,6.8V82.3h-6.02V29.605h6.02V62.182l16.642-16.36H73.109v22.1c0,5.453,3.611,9.419,9.277,9.419,5.547,0,9.14-3.9,9.2-9.282V0H0V91.586H91.586V80.317a15.7,15.7,0,0,1-9.2,2.828c-8.569,0-15.226-6.02-15.226-15.226Z"
|
|
140
|
+
fill="#d40e8c">
|
|
141
|
+
</path>
|
|
142
|
+
<path id="Path_4832" data-name="Path 4832"
|
|
143
|
+
d="M233.452,121.881h-6.019V98.3c0-4.745-3.117-8.286-7.932-8.286s-7.932,3.541-7.932,8.286v23.583h-6.02V98.3c0-4.745-3.116-8.286-7.932-8.286s-7.932,3.541-7.932,8.286v23.583h-6.02V98.51c0-7.932,5.736-14.023,13.952-14.023a12.106,12.106,0,0,1,10.906,6.02,12.3,12.3,0,0,1,10.978-6.02c8.285,0,13.951,6.091,13.951,14.023v23.37Z"
|
|
144
|
+
transform="translate(-86.054 -39.585)"
|
|
145
|
+
fill="#d40e8c">
|
|
146
|
+
</path>
|
|
147
|
+
<path id="Path_4833" data-name="Path 4833"
|
|
148
|
+
d="M313.7,103.751c0,10.481-7.932,
|
|
149
|
+
19.051-18.342,19.051-10.341,
|
|
150
|
+
0-18.343-8.569-18.343-19.051,0-10.623,
|
|
151
|
+
8-19.263,18.343-19.263C305.767,84.488,
|
|
152
|
+
313.7,93.128,313.7,103.751Zm-6.02,
|
|
153
|
+
0c0-7.436-5.523-13.527-12.322-13.527-6.728
|
|
154
|
+
,0-12.252,6.091-12.252,13.527,0,7.295,
|
|
155
|
+
5.524,13.244,12.252,13.244,6.8,0,
|
|
156
|
+
12.322-5.949,12.322-13.244Z"
|
|
157
|
+
transform="translate(-129.791 -39.585)"
|
|
158
|
+
fill="#d40e8c">
|
|
159
|
+
</path>
|
|
160
|
+
</g>
|
|
161
|
+
</g>
|
|
162
|
+
</svg>
|
|
163
|
+
|
|
164
|
+
<div id="success-div"
|
|
165
|
+
style="background: #f2f8f0;
|
|
166
|
+
border: 1px solid #1d8102;
|
|
167
|
+
border-radius: 1px;
|
|
168
|
+
padding: 24px 32px;
|
|
169
|
+
margin: 24px auto 0 auto;
|
|
170
|
+
max-width: 400px;
|
|
171
|
+
text-align: left;
|
|
172
|
+
display: none;"
|
|
173
|
+
>
|
|
174
|
+
<div style="font-size: 1.1em;
|
|
175
|
+
font-weight: bold;
|
|
176
|
+
margin-bottom: 10px;
|
|
177
|
+
text-align: left;"
|
|
178
|
+
>
|
|
179
|
+
Request successful
|
|
180
|
+
</div>
|
|
181
|
+
<div style="font-size: 1.1em;">
|
|
182
|
+
Kumo SDK has been granted a token.
|
|
183
|
+
You may now close this window.
|
|
184
|
+
</div>
|
|
185
|
+
</div>
|
|
186
|
+
|
|
187
|
+
<div id="failure-div"
|
|
188
|
+
style="background: #ffebeb;
|
|
189
|
+
border: 1px solid #ff837a;
|
|
190
|
+
border-radius: 1px;
|
|
191
|
+
padding: 24px 32px;
|
|
192
|
+
margin: 24px auto 0 auto;
|
|
193
|
+
max-width: 400px;
|
|
194
|
+
text-align: left;
|
|
195
|
+
display: none;"
|
|
196
|
+
>
|
|
197
|
+
<div style="font-size: 1.1em;
|
|
198
|
+
font-weight: bold;
|
|
199
|
+
margin-bottom: 10px;
|
|
200
|
+
text-align: left;"
|
|
201
|
+
>
|
|
202
|
+
Request failed
|
|
203
|
+
</div>
|
|
204
|
+
<div style="font-size: 1.1em;">
|
|
205
|
+
Failed to generate a token.
|
|
206
|
+
Please try manually creating a token at
|
|
207
|
+
<a href="{api_url}/api-keys" target="_blank">
|
|
208
|
+
{api_url}/api-keys
|
|
209
|
+
</a>
|
|
210
|
+
or contact Kumo for further assistance.
|
|
211
|
+
</div>
|
|
212
|
+
</div>
|
|
213
|
+
|
|
214
|
+
<script>
|
|
215
|
+
// Show only the appropriate div based on the result
|
|
216
|
+
const search = window.location.search;
|
|
217
|
+
const urlParams = new URLSearchParams(search);
|
|
218
|
+
const hasToken = urlParams.has('token');
|
|
219
|
+
if (hasToken) {{
|
|
220
|
+
document
|
|
221
|
+
.getElementById('success-div')
|
|
222
|
+
.style.display = 'block';
|
|
223
|
+
}} else {{
|
|
224
|
+
document
|
|
225
|
+
.getElementById('failure-div')
|
|
226
|
+
.style.display = 'block';
|
|
227
|
+
}}
|
|
228
|
+
</script>
|
|
229
|
+
</div>
|
|
230
|
+
</body>
|
|
231
|
+
</html>
|
|
232
|
+
'''
|
|
233
|
+
self.wfile.write(html.encode('utf-8'))
|
|
234
|
+
|
|
235
|
+
def log_message(self, format: str, *args: object) -> None:
|
|
236
|
+
return # Suppress logging
|
|
237
|
+
|
|
238
|
+
# Find a free port if needed
|
|
239
|
+
port = redirect_port
|
|
240
|
+
for _ in range(10):
|
|
241
|
+
try:
|
|
242
|
+
with TCPServer(("", port), CallbackHandler) as _:
|
|
243
|
+
break
|
|
244
|
+
except OSError:
|
|
245
|
+
port += 1
|
|
246
|
+
else:
|
|
247
|
+
raise RuntimeError(
|
|
248
|
+
"Could not find a free port for the callback server.")
|
|
249
|
+
|
|
250
|
+
# Start the server in a thread
|
|
251
|
+
def serve() -> None:
|
|
252
|
+
with TCPServer(("", port), CallbackHandler) as httpd:
|
|
253
|
+
httpd.timeout = 60
|
|
254
|
+
while token_status['token'] is None:
|
|
255
|
+
httpd.handle_request()
|
|
256
|
+
|
|
257
|
+
server_thread = threading.Thread(target=serve, daemon=True)
|
|
258
|
+
server_thread.start()
|
|
259
|
+
|
|
260
|
+
# Construct the login URL with callback_url and token_name
|
|
261
|
+
callback_url = f"http://127.0.0.1:{port}/"
|
|
262
|
+
login_url = (f"{api_url}/authenticate-sdk/" +
|
|
263
|
+
f"?callback_url={urllib.parse.quote(callback_url)}" +
|
|
264
|
+
f"&token_name={urllib.parse.quote(token_name)}")
|
|
265
|
+
|
|
266
|
+
print(
|
|
267
|
+
"Opening browser page to automatically generate an API key...\n" +
|
|
268
|
+
"If the page does not open, manually create a new API key at " +
|
|
269
|
+
f"{api_url}/api-keys and set it using os.environ[\"KUMO_API_KEY\"] " +
|
|
270
|
+
"= \"YOUR_API_KEY\"")
|
|
271
|
+
|
|
272
|
+
webbrowser.open(login_url)
|
|
273
|
+
|
|
274
|
+
def get_user_input() -> None:
|
|
275
|
+
token_entered = getpass(
|
|
276
|
+
"or paste the API key here and press enter: ").strip()
|
|
277
|
+
|
|
278
|
+
while (len(token_entered) == 0):
|
|
279
|
+
token_entered = getpass(
|
|
280
|
+
"API Key (type then press enter): ").strip()
|
|
281
|
+
|
|
282
|
+
token_status['token'] = token_entered
|
|
283
|
+
|
|
284
|
+
if not in_notebook():
|
|
285
|
+
user_input_thread = threading.Thread(target=get_user_input,
|
|
286
|
+
daemon=True)
|
|
287
|
+
user_input_thread.start()
|
|
288
|
+
|
|
289
|
+
# Wait for the token (timeout after 120s)
|
|
290
|
+
start = time.time()
|
|
291
|
+
while token_status['token'] is None and time.time() - start < 120:
|
|
292
|
+
time.sleep(1)
|
|
293
|
+
|
|
294
|
+
if not isinstance(token_status['token'], str) or not token_status['token']:
|
|
295
|
+
raise TimeoutError(
|
|
296
|
+
"Timed out waiting for authentication or API key input.")
|
|
297
|
+
|
|
298
|
+
os.environ['KUMO_API_KEY'] = token_status['token']
|
|
299
|
+
|
|
300
|
+
logger.info(
|
|
301
|
+
f"Generated token \"{token_status['token_name'] or token_name}\" " +
|
|
302
|
+
"and saved to KUMO_API_KEY env variable")
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _authenticate_colab(api_url: str) -> None:
|
|
306
|
+
"""Displays an interactive widget in Google Colab to authenticate the user
|
|
307
|
+
and generate a Kumo API key.
|
|
308
|
+
|
|
309
|
+
This method is intended to be used within a Google Colab notebook. It
|
|
310
|
+
presents a button that, when clicked, opens a popup for the user to
|
|
311
|
+
authenticate with KumoRFM and generate an API key. Upon successful
|
|
312
|
+
authentication, the API key is set in the notebook's environment using the
|
|
313
|
+
"KUMO_API_KEY" variable. Note that Jupyter Notebook support unavailable
|
|
314
|
+
at this time.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
api_url (str): The base URL for the Kumo API
|
|
318
|
+
(e.g., 'https://kumorfm.ai').
|
|
319
|
+
|
|
320
|
+
Raises:
|
|
321
|
+
ImportError: If not running in a Google Colab environment or
|
|
322
|
+
required modules are missing.
|
|
323
|
+
"""
|
|
324
|
+
try:
|
|
325
|
+
from google.colab import output
|
|
326
|
+
from IPython.display import HTML, display
|
|
327
|
+
except Exception:
|
|
328
|
+
raise ImportError(
|
|
329
|
+
'This method is meant to be used in Google Colab.\n If your' +
|
|
330
|
+
'python code is running on your local machine, use ' +
|
|
331
|
+
'kumo.authenticate_local().\n Otherwise, visit ' +
|
|
332
|
+
f'{api_url}/api-keys to generate an API key.')
|
|
333
|
+
else:
|
|
334
|
+
import uuid
|
|
335
|
+
from datetime import datetime
|
|
336
|
+
|
|
337
|
+
token_name = "sdk-colab-" + datetime.now().strftime(
|
|
338
|
+
'%Y-%m-%d-%H-%M-%S') + '-Z'
|
|
339
|
+
|
|
340
|
+
def handle_api_key(api_key: str) -> None:
|
|
341
|
+
os.environ['KUMO_API_KEY'] = api_key
|
|
342
|
+
|
|
343
|
+
callback_id = 'api-key-button-' + str(uuid.uuid4())
|
|
344
|
+
|
|
345
|
+
output.register_callback(callback_id, handle_api_key)
|
|
346
|
+
|
|
347
|
+
display(
|
|
348
|
+
HTML(f"""
|
|
349
|
+
<div style="padding: 10px;">
|
|
350
|
+
<!-- <script src="https://cdn.tailwindcss.com"></script> -->
|
|
351
|
+
<svg width="100" height="50" viewBox="0 0 184 92" fill="none"
|
|
352
|
+
xmlns="http://www.w3.org/2000/svg">
|
|
353
|
+
<g clip-path="url(#clip0_749_1962)">
|
|
354
|
+
<path d="M67.159 67.919V46.238L53.494 59.491L68.862 82.3H61.567L49.1 63.74L42.089 70.54V82.3H36.069V29.605H42.089V62.182L58.731 45.822H73.109V67.922C73.109 73.375 76.72 77.341 82.386 77.341C87.933 77.341 91.526 73.441 91.586 68.059V0H0V91.586H91.586V80.317C88.891 82.1996 85.6731 83.1888 82.386 83.145C73.817 83.145 67.16 77.125 67.16 67.919H67.159Z" # noqa: E501
|
|
355
|
+
fill="#FC1373"/>
|
|
356
|
+
<path d="M147.398 82.296H141.379V58.715C141.379 53.97 138.262 50.429 133.447 50.429C128.632 50.429 125.515 53.97 125.515 58.715V82.298H119.495V58.715C119.495 53.97 116.379 50.429 111.563 50.429C106.747 50.429 103.631 53.97 103.631 58.715V82.298H97.611V58.925C97.611 50.993 103.347 44.902 111.563 44.902C113.756 44.8229 115.929 45.3412 117.85 46.4016C119.771 47.4619 121.367 49.0244 122.469 50.922C123.592 49.0276 125.204 47.4696 127.135 46.4107C129.066 45.3517 131.246 44.8307 133.447 44.902C141.732 44.902 147.398 50.993 147.398 58.925V82.296Z"
|
|
357
|
+
fill="#FC1373"/>
|
|
358
|
+
<path d="M183.909 64.166C183.909 74.647 175.977 83.217 165.567 83.217C155.226 83.217 147.224 74.648 147.224 64.166C147.224 53.543 155.224 44.903 165.567 44.903C175.976 44.903 183.909 53.543 183.909 64.166ZM177.889 64.166C177.889 56.73 172.366 50.639 165.567 50.639C158.839 50.639 153.315 56.73 153.315 64.166C153.315 71.461 158.839 77.41 165.567 77.41C172.367 77.41 177.889 71.461 177.889 64.166Z"
|
|
359
|
+
fill="#FC1373"/>
|
|
360
|
+
</g>
|
|
361
|
+
<defs>
|
|
362
|
+
<clipPath id="clip0_749_1962">
|
|
363
|
+
<rect width="183.908" height="91.586" fill="white"/>
|
|
364
|
+
</clipPath>
|
|
365
|
+
</defs>
|
|
366
|
+
</svg>
|
|
367
|
+
<div id="prompt">
|
|
368
|
+
<p>
|
|
369
|
+
Click the button below to connect to KumoRFM and
|
|
370
|
+
generate your API key.
|
|
371
|
+
</p>
|
|
372
|
+
<button id="{callback_id}">
|
|
373
|
+
Generate API Key
|
|
374
|
+
</button>
|
|
375
|
+
</div>
|
|
376
|
+
<div id="success" style="display: none;">
|
|
377
|
+
<p>
|
|
378
|
+
✓ Your API key has been created and configured in your
|
|
379
|
+
colab notebook.
|
|
380
|
+
</p>
|
|
381
|
+
To manage all your API keys, visit the
|
|
382
|
+
<a href="{api_url}/api-keys" target="_blank">
|
|
383
|
+
KumoRFM website.
|
|
384
|
+
</a>
|
|
385
|
+
</div>
|
|
386
|
+
<div id="failed" style="display: none; color: red;">
|
|
387
|
+
<p>
|
|
388
|
+
API key creation failed with error:
|
|
389
|
+
<span id="error-message"></span>
|
|
390
|
+
</p>
|
|
391
|
+
</div>
|
|
392
|
+
<script>
|
|
393
|
+
// Listen for messages from the popup
|
|
394
|
+
window.addEventListener('message', function(event) {{
|
|
395
|
+
if (event.data.type === 'API_KEY_GENERATED') {{
|
|
396
|
+
// Call the Python callback with the API key
|
|
397
|
+
google.colab.kernel.invokeFunction(
|
|
398
|
+
'{callback_id}', [event.data.apiKey], {{}}
|
|
399
|
+
);
|
|
400
|
+
document.getElementById('prompt')
|
|
401
|
+
.style.display = "none";
|
|
402
|
+
document.getElementById('success')
|
|
403
|
+
.style.display = "block";
|
|
404
|
+
document.getElementById('failed')
|
|
405
|
+
.style.display = "none";
|
|
406
|
+
}} else if (
|
|
407
|
+
event.data.type === 'API_KEY_GENERATION_FAILED'
|
|
408
|
+
) {{
|
|
409
|
+
document.getElementById('failed')
|
|
410
|
+
.style.display = "block";
|
|
411
|
+
document.getElementById('error-message')
|
|
412
|
+
.innerHTML = event.data.errorMessage;
|
|
413
|
+
}}
|
|
414
|
+
}});
|
|
415
|
+
|
|
416
|
+
document.getElementById('{callback_id}')
|
|
417
|
+
.onclick = function() {{
|
|
418
|
+
// Open the popup
|
|
419
|
+
const popup = window.open(
|
|
420
|
+
'{api_url}/authenticate-sdk?opener=colab&token_name={token_name}',
|
|
421
|
+
'apiKeyPopup',
|
|
422
|
+
'width=600,height=700,scrollbars=yes,resizable=yes'
|
|
423
|
+
);
|
|
424
|
+
|
|
425
|
+
// Focus the popup
|
|
426
|
+
if (popup) {{
|
|
427
|
+
popup.focus();
|
|
428
|
+
}}
|
|
429
|
+
}};
|
|
430
|
+
</script>
|
|
431
|
+
</div>
|
|
432
|
+
"""))
|