kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,223 @@
1
+ from typing import TYPE_CHECKING, Any, Optional
2
+
3
+ import requests
4
+ from requests.adapters import HTTPAdapter
5
+ from urllib3.util import Retry
6
+
7
+ from kumoai.client.endpoints import Endpoint, HTTPMethod
8
+
9
+ if TYPE_CHECKING:
10
+ from kumoai.client.connector import ConnectorAPI
11
+ from kumoai.client.graph import GraphAPI
12
+ from kumoai.client.jobs import (
13
+ ArtifactExportJobAPI,
14
+ BaselineJobAPI,
15
+ BatchPredictionJobAPI,
16
+ DistillationJobAPI,
17
+ GeneratePredictionTableJobAPI,
18
+ GenerateTrainTableJobAPI,
19
+ LLMJobAPI,
20
+ TrainingJobAPI,
21
+ )
22
+ from kumoai.client.online import OnlineServingEndpointAPI
23
+ from kumoai.client.pquery import PQueryAPI
24
+ from kumoai.client.source_table import SourceTableAPI
25
+ from kumoai.client.table import TableAPI
26
+
27
+ API_VERSION = 'v1'
28
+
29
+
30
+ class KumoClient:
31
+ def __init__(
32
+ self,
33
+ url: str,
34
+ api_key: Optional[str],
35
+ spcs_token: Optional[str] = None,
36
+ verify_ssl: bool = True,
37
+ ) -> None:
38
+ r"""Creates a client against the Kumo public API, provided a URL of
39
+ the endpoint and an authentication token.
40
+
41
+ Args:
42
+ url: the public API endpoint URL.
43
+ api_key: the public API authentication token.
44
+ spcs_token: the SPCS token used for authentication to access the
45
+ Kumo API endpoint.
46
+ verify_ssl: whether to verify SSL certificates. Set to False to
47
+ skip SSL certificate verification (equivalent to curl -k).
48
+ """
49
+ self._url = url
50
+ self._api_url = f"{url}/{API_VERSION}"
51
+ self._api_key = api_key
52
+ self._spcs_token = spcs_token
53
+ self._verify_ssl = verify_ssl
54
+
55
+ retry_strategy = Retry(
56
+ total=10, # Maximum number of retries
57
+ connect=3, # How many connection-related errors to retry on
58
+ read=3, # How many times to retry on read errors
59
+ status=5, # How many times to retry on bad status codes (below)
60
+ # Status codes to retry on.
61
+ status_forcelist=[408, 429, 500, 502, 503, 504],
62
+ # Exponential backoff factor: 2, 4, 8, 16 seconds delay)
63
+ backoff_factor=2.0,
64
+ )
65
+ http_adapter = HTTPAdapter(max_retries=retry_strategy)
66
+ session = requests.Session()
67
+ session.mount('http://', http_adapter)
68
+ session.mount('https://', http_adapter)
69
+ self._session = session
70
+ if self._api_key:
71
+ self._session.headers.update({"X-API-Key": self._api_key})
72
+ elif self._spcs_token:
73
+ self._session.headers.update(
74
+ {'Authorization': f'Snowflake Token={self._spcs_token}'})
75
+
76
+ def authenticate(self) -> None:
77
+ """Raises an exception if authentication fails."""
78
+ try:
79
+ self._session.get(self._url + '/v1/connectors',
80
+ verify=self._verify_ssl).raise_for_status()
81
+ except Exception:
82
+ raise ValueError(
83
+ "Client authentication failed. Please check if you "
84
+ "have a valid API key/credentials.")
85
+
86
+ def set_spcs_token(self, spcs_token: str) -> None:
87
+ r"""Sets the SPCS token for the client and updates the session
88
+ headers.
89
+ """
90
+ self._spcs_token = spcs_token
91
+ self._session.headers.update(
92
+ {'Authorization': f'Snowflake Token={self._spcs_token}'})
93
+
94
+ @property
95
+ def artifact_export_api(self) -> 'ArtifactExportJobAPI':
96
+ r"""Returns the artifact export API."""
97
+ from kumoai.client.jobs import ArtifactExportJobAPI
98
+ return ArtifactExportJobAPI(self)
99
+
100
+ @property
101
+ def connector_api(self) -> 'ConnectorAPI':
102
+ r"""Returns the typed connector API."""
103
+ from kumoai.client.connector import ConnectorAPI
104
+ return ConnectorAPI(self)
105
+
106
+ @property
107
+ def source_table_api(self) -> 'SourceTableAPI':
108
+ r"""Returns the typed source table API."""
109
+ from kumoai.client.source_table import SourceTableAPI
110
+ return SourceTableAPI(self)
111
+
112
+ @property
113
+ def table_api(self) -> 'TableAPI':
114
+ r"""Returns the typed Kumo Table (snapshot) API."""
115
+ from kumoai.client.table import TableAPI
116
+ return TableAPI(self)
117
+
118
+ @property
119
+ def graph_api(self) -> 'GraphAPI':
120
+ r"""Returns the typed Graph (metadata and snapshot) API."""
121
+ from kumoai.client.graph import GraphAPI
122
+ return GraphAPI(self)
123
+
124
+ @property
125
+ def pquery_api(self) -> 'PQueryAPI':
126
+ r"""Returns the typed Predictive Query API."""
127
+ from kumoai.client.pquery import PQueryAPI
128
+ return PQueryAPI(self)
129
+
130
+ @property
131
+ def training_job_api(self) -> 'TrainingJobAPI':
132
+ r"""Returns the typed Training Job API."""
133
+ from kumoai.client.jobs import TrainingJobAPI
134
+ return TrainingJobAPI(self)
135
+
136
+ @property
137
+ def distillation_job_api(self) -> 'DistillationJobAPI':
138
+ from kumoai.client.jobs import DistillationJobAPI
139
+ return DistillationJobAPI(self)
140
+
141
+ @property
142
+ def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
143
+ from kumoai.client.jobs import BatchPredictionJobAPI
144
+ return BatchPredictionJobAPI(self)
145
+
146
+ @property
147
+ def generate_train_table_job_api(self) -> 'GenerateTrainTableJobAPI':
148
+ r"""Returns the typed Generate-Train-Table Job API."""
149
+ from kumoai.client.jobs import GenerateTrainTableJobAPI
150
+ return GenerateTrainTableJobAPI(self)
151
+
152
+ @property
153
+ def generate_prediction_table_job_api(
154
+ self) -> 'GeneratePredictionTableJobAPI':
155
+ from kumoai.client.jobs import GeneratePredictionTableJobAPI
156
+ return GeneratePredictionTableJobAPI(self)
157
+
158
+ @property
159
+ def llm_job_api(self) -> 'LLMJobAPI':
160
+ from kumoai.client.jobs import LLMJobAPI
161
+ return LLMJobAPI(self)
162
+
163
+ @property
164
+ def baseline_job_api(self) -> 'BaselineJobAPI':
165
+ r"""Returns the typed Training Job API."""
166
+ from kumoai.client.jobs import BaselineJobAPI
167
+ return BaselineJobAPI(self)
168
+
169
+ @property
170
+ def online_serving_endpoint_api(self) -> 'OnlineServingEndpointAPI':
171
+ from kumoai.client.online import OnlineServingEndpointAPI
172
+ return OnlineServingEndpointAPI(self)
173
+
174
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
175
+ r"""Send a HTTP request to the specified endpoint."""
176
+ endpoint_str = endpoint.get_path()
177
+ if endpoint.method == HTTPMethod.GET:
178
+ return self._get(endpoint_str, **kwargs)
179
+ elif endpoint.method == HTTPMethod.POST:
180
+ return self._post(endpoint_str, **kwargs)
181
+ elif endpoint.method == HTTPMethod.PATCH:
182
+ return self._patch(endpoint_str, **kwargs)
183
+ elif endpoint.method == HTTPMethod.DELETE:
184
+ return self._delete(endpoint_str, **kwargs)
185
+ else:
186
+ raise ValueError(f"Unsupported HTTP method: {endpoint.method}")
187
+
188
+ def _get(self, endpoint: str, **kwargs: Any) -> requests.Response:
189
+ r"""Send a GET request to the specified endpoint, with keyword
190
+ arguments, returned objects, and exceptions raised corresponding to
191
+ :meth:`requests.Session.get`.
192
+ """
193
+ url = self._format_endpoint_url(endpoint)
194
+ return self._session.get(url=url, verify=self._verify_ssl, **kwargs)
195
+
196
+ def _post(self, endpoint: str, **kwargs: Any) -> requests.Response:
197
+ r"""Send a POST request to the specified endpoint, with keyword
198
+ arguments, returned objects, and exceptions raised corresponding to
199
+ :meth:`requests.Session.post`.
200
+ """
201
+ url = self._format_endpoint_url(endpoint)
202
+ return self._session.post(url=url, verify=self._verify_ssl, **kwargs)
203
+
204
+ def _patch(self, endpoint: str, **kwargs: Any) -> requests.Response:
205
+ r"""Send a PATCH request to the specified endpoint, with keyword
206
+ arguments, returned objects, and exceptions raised corresponding to
207
+ :meth:`requests.Session.patch`.
208
+ """
209
+ url = self._format_endpoint_url(endpoint)
210
+ return self._session.patch(url=url, verify=self._verify_ssl, **kwargs)
211
+
212
+ def _delete(self, endpoint: str, **kwargs: Any) -> requests.Response:
213
+ r"""Send a DELETE request to the specified endpoint, with keyword
214
+ arguments, returned objects, and exceptions raised corresponding to
215
+ :meth:`requests.Session.delete`.
216
+ """
217
+ url = self._format_endpoint_url(endpoint)
218
+ return self._session.delete(url=url, verify=self._verify_ssl, **kwargs)
219
+
220
+ def _format_endpoint_url(self, endpoint: str) -> str:
221
+ if endpoint[0] == "/":
222
+ endpoint = endpoint[1:]
223
+ return f"{self._api_url}/{endpoint}"
@@ -0,0 +1,110 @@
1
+ from http import HTTPStatus
2
+ from typing import List, Optional
3
+
4
+ from kumoapi.data_source import (
5
+ CompleteFileUploadRequest,
6
+ ConnectorResponse,
7
+ CreateConnectorArgs,
8
+ DataSourceType,
9
+ DeleteUploadedFileRequest,
10
+ StartFileUploadRequest,
11
+ StartFileUploadResponse,
12
+ )
13
+ from kumoapi.json_serde import to_json_dict
14
+
15
+ from kumoai.client import KumoClient
16
+ from kumoai.client.endpoints import ConnectorEndpoints
17
+ from kumoai.client.utils import parse_response, raise_on_error
18
+ from kumoai.exceptions import HTTPException
19
+
20
+
21
+ class ConnectorAPI:
22
+ r"""Typed API definition for Kumo connectors."""
23
+ def __init__(self, client: KumoClient) -> None:
24
+ self._client = client
25
+
26
+ def start_file_upload(
27
+ self, req: StartFileUploadRequest) -> StartFileUploadResponse:
28
+ res = self._client._request(
29
+ ConnectorEndpoints.start_file_upload,
30
+ json=to_json_dict(req, insecure=True),
31
+ )
32
+ raise_on_error(res)
33
+ return parse_response(StartFileUploadResponse, res)
34
+
35
+ def delete_file_upload(self, req: DeleteUploadedFileRequest) -> None:
36
+ res = self._client._request(
37
+ ConnectorEndpoints.delete_uploaded_file,
38
+ json=to_json_dict(req, insecure=True),
39
+ )
40
+ raise_on_error(res)
41
+
42
+ def complete_file_upload(self, req: CompleteFileUploadRequest) -> None:
43
+ res = self._client._request(
44
+ ConnectorEndpoints.complete_file_upload,
45
+ json=to_json_dict(req, insecure=True),
46
+ )
47
+ raise_on_error(res)
48
+
49
+ def create(self, create_connector_args: CreateConnectorArgs) -> None:
50
+ r"""Creates a connector in Kumo."""
51
+ resp = self._client._request(
52
+ ConnectorEndpoints.create,
53
+ json=to_json_dict(create_connector_args, insecure=True),
54
+ )
55
+ raise_on_error(resp)
56
+
57
+ def create_if_not_exist(
58
+ self,
59
+ create_connector_args: CreateConnectorArgs,
60
+ ) -> bool:
61
+ r"""Creates a connector in Kumo if the connector does not exist.
62
+
63
+ Returns:
64
+ :obj:`True` if the connector is newly created, :obj:`False`
65
+ otherwise.
66
+ """
67
+ _id = create_connector_args.config.name
68
+ existing_connector = self.get(_id)
69
+ if existing_connector:
70
+ if existing_connector.config == create_connector_args.config:
71
+ return False
72
+ raise HTTPException(
73
+ HTTPStatus.UNPROCESSABLE_ENTITY,
74
+ f"Connector {_id} already exists, but has a differing "
75
+ f"configuration. Input: {create_connector_args.config}, "
76
+ f"existing: {existing_connector.config}",
77
+ )
78
+ self.create(create_connector_args)
79
+ return existing_connector is None
80
+
81
+ def get(self, connector_id: str) -> Optional[ConnectorResponse]:
82
+ r"""Fetches a connector given its ID."""
83
+ resp = self._client._request(
84
+ ConnectorEndpoints.get.with_id(connector_id))
85
+ if resp.status_code == HTTPStatus.NOT_FOUND:
86
+ return None
87
+
88
+ raise_on_error(resp)
89
+ return parse_response(ConnectorResponse, resp)
90
+
91
+ def list(
92
+ self, data_source_type: Optional[DataSourceType] = None
93
+ ) -> List[ConnectorResponse]:
94
+ r"""Lists connectors for a given data source type."""
95
+ params = {
96
+ 'data_source_type': data_source_type
97
+ } if data_source_type else {}
98
+
99
+ resp = self._client._request(ConnectorEndpoints.list, params=params)
100
+ raise_on_error(resp)
101
+ return parse_response(List[ConnectorResponse], resp)
102
+
103
+ def delete_if_exists(self, connector_id: str) -> bool:
104
+ r"""Deletes a connector if it exists."""
105
+ resp = self._client._request(
106
+ ConnectorEndpoints.delete.with_id(connector_id))
107
+ if resp.status_code == HTTPStatus.NOT_FOUND:
108
+ return False
109
+ raise_on_error(resp)
110
+ return True
@@ -0,0 +1,150 @@
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import Final, Optional
4
+
5
+
6
+ class HTTPMethod(Enum):
7
+ r"""HTTP methods supported by the API."""
8
+ GET = "GET"
9
+ POST = "POST"
10
+ DELETE = "DELETE"
11
+ PATCH = "PATCH"
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class Endpoint:
16
+ r"""Represents an API endpoint with its path and HTTP method."""
17
+ path: Optional[str] = field(default=None)
18
+ method: HTTPMethod = HTTPMethod.GET
19
+
20
+ def validate(self) -> None:
21
+ pass
22
+
23
+ def get_path(self) -> str:
24
+ if self.path is None:
25
+ # This should be in validate but is here for type checker
26
+ raise ValueError("Endpoint requires a path")
27
+ return self.path
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class IDEndpoint(Endpoint):
32
+ r"""Represents an API endpoint with an additional id/name in its path."""
33
+ template_path: Optional[str] = field(default=None)
34
+
35
+ def __post_init__(self) -> None:
36
+ if self.template_path is None:
37
+ raise ValueError("template_path must be set explicitly")
38
+ if "{id}" not in self.template_path:
39
+ raise ValueError("IDEndpoint path must contain '{{id}}': "
40
+ f"got '{self.template_path}'")
41
+
42
+ def with_id(self, resource_id: str) -> 'IDEndpoint':
43
+ assert self.template_path is not None
44
+ return IDEndpoint(template_path=self.template_path,
45
+ path=self.template_path.format(id=resource_id),
46
+ method=self.method)
47
+
48
+ def validate(self) -> None:
49
+ if self.path is None:
50
+ raise ValueError(
51
+ "IDEndpoint requires with_id() to be called with a resource id"
52
+ )
53
+
54
+ @classmethod
55
+ def from_base(cls, base: str, method: HTTPMethod) -> 'IDEndpoint':
56
+ r"""Alternate constructor to build a template_path by appending /{id}
57
+ to base.
58
+ """
59
+ template = f"{base}/{{id}}"
60
+ return cls(template_path=template, method=method)
61
+
62
+
63
+ class ConnectorEndpoints:
64
+ BASE: Final[str] = "/connectors"
65
+
66
+ create = Endpoint(BASE, HTTPMethod.POST)
67
+ list = Endpoint(BASE, HTTPMethod.GET)
68
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
69
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
70
+
71
+ start_file_upload = Endpoint(f"{BASE}/utils/start_file_upload",
72
+ HTTPMethod.POST)
73
+ delete_uploaded_file = Endpoint(f"{BASE}/utils/delete_uploaded_file",
74
+ HTTPMethod.POST)
75
+ complete_file_upload = Endpoint(f"{BASE}/utils/complete_file_upload",
76
+ HTTPMethod.POST)
77
+
78
+
79
+ class PQueryEndpoints:
80
+ BASE: Final[str] = '/predictive_queries'
81
+
82
+ create = Endpoint(BASE, HTTPMethod.POST)
83
+ list = Endpoint(BASE, HTTPMethod.GET)
84
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
85
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
86
+
87
+ infer_task_type = Endpoint(f"{BASE}/infer_task_type", HTTPMethod.POST)
88
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
89
+ suggest_training_table_plan = Endpoint(
90
+ f"{BASE}/train_table_generation_plan", HTTPMethod.POST)
91
+ suggest_model_plan = Endpoint(f"{BASE}/model_plan", HTTPMethod.POST)
92
+
93
+
94
+ class SourceTableEndpoints:
95
+ BASE: Final[str] = "/source_tables"
96
+
97
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
98
+
99
+ list_tables = Endpoint(f"{BASE}/list_tables", HTTPMethod.POST)
100
+ validate_table = Endpoint(f"{BASE}/validate_table", HTTPMethod.POST)
101
+ get_table_data = Endpoint(f"{BASE}/get_table_data", HTTPMethod.POST)
102
+ get_table_config = Endpoint(f"{BASE}/get_table_config", HTTPMethod.POST)
103
+
104
+
105
+ class GraphEndpoints:
106
+ BASE: Final[str] = "/graphs"
107
+ SNAPSHOTS: Final[str] = "/graphsnapshots"
108
+
109
+ create = Endpoint(BASE, HTTPMethod.POST)
110
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
111
+
112
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
113
+ infer_links = Endpoint(f"{BASE}/infer_links", HTTPMethod.POST)
114
+ create_snapshot = Endpoint(SNAPSHOTS, HTTPMethod.POST)
115
+ get_snapshot = IDEndpoint.from_base(SNAPSHOTS, HTTPMethod.GET)
116
+ get_edge_stats = IDEndpoint.from_base(f"{SNAPSHOTS}/edge_health",
117
+ method=HTTPMethod.GET)
118
+
119
+
120
+ class OnlineServingEndpoints:
121
+ BASE: Final[str] = "/online_serving_endpoints"
122
+
123
+ create = Endpoint(BASE, HTTPMethod.POST)
124
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
125
+ list = Endpoint(BASE, HTTPMethod.GET)
126
+ update = IDEndpoint.from_base(BASE, HTTPMethod.PATCH)
127
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
128
+
129
+
130
+ class TableEndpoints:
131
+ BASE: Final[str] = "/tables"
132
+ SNAPSHOTS: Final[str] = "/tablesnapshots"
133
+
134
+ create = Endpoint(BASE, HTTPMethod.POST)
135
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
136
+
137
+ create_snapshot = Endpoint(SNAPSHOTS, HTTPMethod.POST)
138
+ get_snapshot = IDEndpoint.from_base(SNAPSHOTS, HTTPMethod.GET)
139
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
140
+ infer_metadata = Endpoint(f"{BASE}/infer_metadata", HTTPMethod.POST)
141
+
142
+
143
+ class RFMEndpoints:
144
+ BASE: Final[str] = "/rfm"
145
+
146
+ predict = Endpoint(f"{BASE}/predict", HTTPMethod.POST)
147
+ explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
148
+ evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
149
+ validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
150
+ parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
kumoai/client/graph.py ADDED
@@ -0,0 +1,120 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Dict, Optional
3
+
4
+ from kumoapi.common import ValidationResponse
5
+ from kumoapi.data_snapshot import GraphSnapshotID, GraphSnapshotResource
6
+ from kumoapi.graph import (
7
+ EdgeHealthResponse,
8
+ GraphDefinition,
9
+ GraphResource,
10
+ GraphValidationRequest,
11
+ )
12
+ from kumoapi.json_serde import to_json_dict
13
+
14
+ from kumoai.client import KumoClient
15
+ from kumoai.client.endpoints import GraphEndpoints
16
+ from kumoai.client.utils import (
17
+ parse_id_response,
18
+ parse_response,
19
+ raise_on_error,
20
+ )
21
+
22
+ GraphID = str
23
+
24
+
25
+ class GraphAPI:
26
+ r"""Typed API definition for Kumo graph definition."""
27
+ def __init__(self, client: KumoClient) -> None:
28
+ self._client = client
29
+
30
+ def create_graph(
31
+ self,
32
+ graph_def: GraphDefinition,
33
+ *,
34
+ name_alias: Optional[str] = None,
35
+ force_rename: bool = False,
36
+ ) -> GraphID:
37
+ r"""Creates a Graph (metadata definition) resource object in Kumo."""
38
+ params: Dict[str, Any] = {'force_rename': force_rename}
39
+ if name_alias:
40
+ params['name_alias'] = name_alias
41
+ resp = self._client._request(
42
+ GraphEndpoints.create,
43
+ params=params,
44
+ json=to_json_dict(graph_def),
45
+ )
46
+ raise_on_error(resp)
47
+ return parse_id_response(resp)
48
+
49
+ def get_graph_if_exists(
50
+ self,
51
+ graph_id_or_name: str,
52
+ ) -> Optional[GraphResource]:
53
+ resp = self._client._request(
54
+ GraphEndpoints.get.with_id(graph_id_or_name))
55
+ if resp.status_code == HTTPStatus.NOT_FOUND:
56
+ return None
57
+
58
+ raise_on_error(resp)
59
+ return parse_response(GraphResource, resp)
60
+
61
+ def create_snapshot(
62
+ self,
63
+ graph_id: GraphID,
64
+ *,
65
+ refresh_source: bool = False,
66
+ ) -> GraphSnapshotID:
67
+ params: Dict[str, Any] = {
68
+ 'graph_id': graph_id,
69
+ 'refresh_source': refresh_source
70
+ }
71
+ resp = self._client._request(GraphEndpoints.create_snapshot,
72
+ params=params)
73
+ raise_on_error(resp)
74
+ return GraphSnapshotID(parse_id_response(resp))
75
+
76
+ def get_snapshot(
77
+ self,
78
+ snapshot_id: GraphSnapshotID,
79
+ ) -> GraphSnapshotResource:
80
+ resp = self._client._request(
81
+ GraphEndpoints.get_snapshot.with_id(snapshot_id))
82
+ raise_on_error(resp)
83
+ return parse_response(GraphSnapshotResource, resp)
84
+
85
+ def get_snapshot_if_exists(
86
+ self,
87
+ snapshot_id: GraphSnapshotID,
88
+ ) -> Optional[GraphSnapshotResource]:
89
+ resp = self._client._request(
90
+ GraphEndpoints.get_snapshot.with_id(snapshot_id))
91
+ if resp.status_code == HTTPStatus.NOT_FOUND:
92
+ return None
93
+
94
+ raise_on_error(resp)
95
+ return parse_response(GraphSnapshotResource, resp)
96
+
97
+ def get_edge_stats(
98
+ self,
99
+ graph_snapshot_id: GraphSnapshotID,
100
+ ) -> EdgeHealthResponse:
101
+ r"""Fetches edge statistics given a snapshot id"""
102
+ resp = self._client._request(
103
+ GraphEndpoints.get_edge_stats.with_id(graph_snapshot_id))
104
+ raise_on_error(resp)
105
+ return parse_response(EdgeHealthResponse, resp)
106
+
107
+ def validate_graph(
108
+ self,
109
+ request: GraphValidationRequest,
110
+ ) -> ValidationResponse:
111
+ response = self._client._request(GraphEndpoints.validate,
112
+ json=to_json_dict(request))
113
+ raise_on_error(response)
114
+ return parse_response(ValidationResponse, response)
115
+
116
+ def infer_links(self, graph: GraphDefinition) -> GraphDefinition:
117
+ resp = self._client._request(GraphEndpoints.infer_links,
118
+ json=to_json_dict(graph))
119
+ raise_on_error(resp)
120
+ return parse_response(GraphDefinition, resp)