kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,221 @@
1
+ from typing import TYPE_CHECKING, Any, Optional
2
+
3
+ import requests
4
+ from requests.adapters import HTTPAdapter
5
+ from urllib3.util import Retry
6
+
7
+ from kumoai.client.endpoints import Endpoint, HTTPMethod
8
+
9
+ if TYPE_CHECKING:
10
+ from kumoai.client.connector import ConnectorAPI
11
+ from kumoai.client.graph import GraphAPI
12
+ from kumoai.client.jobs import (
13
+ ArtifactExportJobAPI,
14
+ BaselineJobAPI,
15
+ BatchPredictionJobAPI,
16
+ GeneratePredictionTableJobAPI,
17
+ GenerateTrainTableJobAPI,
18
+ LLMJobAPI,
19
+ TrainingJobAPI,
20
+ )
21
+ from kumoai.client.online import OnlineServingEndpointAPI
22
+ from kumoai.client.pquery import PQueryAPI
23
+ from kumoai.client.rfm import RFMAPI
24
+ from kumoai.client.source_table import SourceTableAPI
25
+ from kumoai.client.table import TableAPI
26
+
27
+ API_VERSION = 'v1'
28
+
29
+
30
+ class KumoClient:
31
+ def __init__(
32
+ self,
33
+ url: str,
34
+ api_key: Optional[str],
35
+ spcs_token: Optional[str] = None,
36
+ verify_ssl: bool = True,
37
+ ) -> None:
38
+ r"""Creates a client against the Kumo public API, provided a URL of
39
+ the endpoint and an authentication token.
40
+
41
+ Args:
42
+ url: the public API endpoint URL.
43
+ api_key: the public API authentication token.
44
+ spcs_token: the SPCS token used for authentication to access the
45
+ Kumo API endpoint.
46
+ verify_ssl: whether to verify SSL certificates. Set to False to
47
+ skip SSL certificate verification (equivalent to curl -k).
48
+ """
49
+ self._url = url
50
+ self._api_url = f"{url}/{API_VERSION}"
51
+ self._api_key = api_key
52
+ self._spcs_token = spcs_token
53
+ self._verify_ssl = verify_ssl
54
+
55
+ retry_strategy = Retry(
56
+ total=10, # Maximum number of retries
57
+ connect=3, # How many connection-related errors to retry on
58
+ read=3, # How many times to retry on read errors
59
+ status=5, # How many times to retry on bad status codes (below)
60
+ # Status codes to retry on.
61
+ status_forcelist=[408, 429, 500, 502, 503, 504],
62
+ # Exponential backoff factor: 2, 4, 8, 16 seconds delay)
63
+ backoff_factor=2.0,
64
+ )
65
+ http_adapter = HTTPAdapter(max_retries=retry_strategy)
66
+ session = requests.Session()
67
+ session.mount('http://', http_adapter)
68
+ session.mount('https://', http_adapter)
69
+ self._session = session
70
+ if self._api_key:
71
+ self._session.headers.update({"X-API-Key": self._api_key})
72
+ elif self._spcs_token:
73
+ self._session.headers.update(
74
+ {'Authorization': f'Snowflake Token={self._spcs_token}'})
75
+
76
+ def authenticate(self) -> bool:
77
+ r"""Raises an exception if authentication fails. Succeeds if the
78
+ client is properly formed.
79
+ """
80
+ return self._session.get(f"{self._url}/v1/connectors",
81
+ verify=self._verify_ssl).ok
82
+
83
+ def set_spcs_token(self, spcs_token: str) -> None:
84
+ r"""Sets the SPCS token for the client and updates the session
85
+ headers.
86
+ """
87
+ self._spcs_token = spcs_token
88
+ self._session.headers.update(
89
+ {'Authorization': f'Snowflake Token={self._spcs_token}'})
90
+
91
+ @property
92
+ def artifact_export_api(self) -> 'ArtifactExportJobAPI':
93
+ r"""Returns the artifact export API."""
94
+ from kumoai.client.jobs import ArtifactExportJobAPI
95
+ return ArtifactExportJobAPI(self)
96
+
97
+ @property
98
+ def connector_api(self) -> 'ConnectorAPI':
99
+ r"""Returns the typed connector API."""
100
+ from kumoai.client.connector import ConnectorAPI
101
+ return ConnectorAPI(self)
102
+
103
+ @property
104
+ def source_table_api(self) -> 'SourceTableAPI':
105
+ r"""Returns the typed source table API."""
106
+ from kumoai.client.source_table import SourceTableAPI
107
+ return SourceTableAPI(self)
108
+
109
+ @property
110
+ def table_api(self) -> 'TableAPI':
111
+ r"""Returns the typed Kumo Table (snapshot) API."""
112
+ from kumoai.client.table import TableAPI
113
+ return TableAPI(self)
114
+
115
+ @property
116
+ def graph_api(self) -> 'GraphAPI':
117
+ r"""Returns the typed Graph (metadata and snapshot) API."""
118
+ from kumoai.client.graph import GraphAPI
119
+ return GraphAPI(self)
120
+
121
+ @property
122
+ def pquery_api(self) -> 'PQueryAPI':
123
+ r"""Returns the typed Predictive Query API."""
124
+ from kumoai.client.pquery import PQueryAPI
125
+ return PQueryAPI(self)
126
+
127
+ @property
128
+ def training_job_api(self) -> 'TrainingJobAPI':
129
+ r"""Returns the typed Training Job API."""
130
+ from kumoai.client.jobs import TrainingJobAPI
131
+ return TrainingJobAPI(self)
132
+
133
+ @property
134
+ def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
135
+ from kumoai.client.jobs import BatchPredictionJobAPI
136
+ return BatchPredictionJobAPI(self)
137
+
138
+ @property
139
+ def generate_train_table_job_api(self) -> 'GenerateTrainTableJobAPI':
140
+ r"""Returns the typed Generate-Train-Table Job API."""
141
+ from kumoai.client.jobs import GenerateTrainTableJobAPI
142
+ return GenerateTrainTableJobAPI(self)
143
+
144
+ @property
145
+ def generate_prediction_table_job_api(
146
+ self) -> 'GeneratePredictionTableJobAPI':
147
+ from kumoai.client.jobs import GeneratePredictionTableJobAPI
148
+ return GeneratePredictionTableJobAPI(self)
149
+
150
+ @property
151
+ def llm_job_api(self) -> 'LLMJobAPI':
152
+ from kumoai.client.jobs import LLMJobAPI
153
+ return LLMJobAPI(self)
154
+
155
+ @property
156
+ def baseline_job_api(self) -> 'BaselineJobAPI':
157
+ r"""Returns the typed Training Job API."""
158
+ from kumoai.client.jobs import BaselineJobAPI
159
+ return BaselineJobAPI(self)
160
+
161
+ @property
162
+ def online_serving_endpoint_api(self) -> 'OnlineServingEndpointAPI':
163
+ from kumoai.client.online import OnlineServingEndpointAPI
164
+ return OnlineServingEndpointAPI(self)
165
+
166
+ @property
167
+ def rfm_api(self) -> 'RFMAPI':
168
+ r"""Returns the typed RFM API."""
169
+ from kumoai.client.rfm import RFMAPI
170
+ return RFMAPI(self)
171
+
172
+ def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
173
+ r"""Send a HTTP request to the specified endpoint."""
174
+ endpoint_str = endpoint.get_path()
175
+ if endpoint.method == HTTPMethod.GET:
176
+ return self._get(endpoint_str, **kwargs)
177
+ elif endpoint.method == HTTPMethod.POST:
178
+ return self._post(endpoint_str, **kwargs)
179
+ elif endpoint.method == HTTPMethod.PATCH:
180
+ return self._patch(endpoint_str, **kwargs)
181
+ elif endpoint.method == HTTPMethod.DELETE:
182
+ return self._delete(endpoint_str, **kwargs)
183
+ else:
184
+ raise ValueError(f"Unsupported HTTP method: {endpoint.method}")
185
+
186
+ def _get(self, endpoint: str, **kwargs: Any) -> requests.Response:
187
+ r"""Send a GET request to the specified endpoint, with keyword
188
+ arguments, returned objects, and exceptions raised corresponding to
189
+ :meth:`requests.Session.get`.
190
+ """
191
+ url = self._format_endpoint_url(endpoint)
192
+ return self._session.get(url=url, verify=self._verify_ssl, **kwargs)
193
+
194
+ def _post(self, endpoint: str, **kwargs: Any) -> requests.Response:
195
+ r"""Send a POST request to the specified endpoint, with keyword
196
+ arguments, returned objects, and exceptions raised corresponding to
197
+ :meth:`requests.Session.post`.
198
+ """
199
+ url = self._format_endpoint_url(endpoint)
200
+ return self._session.post(url=url, verify=self._verify_ssl, **kwargs)
201
+
202
+ def _patch(self, endpoint: str, **kwargs: Any) -> requests.Response:
203
+ r"""Send a PATCH request to the specified endpoint, with keyword
204
+ arguments, returned objects, and exceptions raised corresponding to
205
+ :meth:`requests.Session.patch`.
206
+ """
207
+ url = self._format_endpoint_url(endpoint)
208
+ return self._session.patch(url=url, verify=self._verify_ssl, **kwargs)
209
+
210
+ def _delete(self, endpoint: str, **kwargs: Any) -> requests.Response:
211
+ r"""Send a DELETE request to the specified endpoint, with keyword
212
+ arguments, returned objects, and exceptions raised corresponding to
213
+ :meth:`requests.Session.delete`.
214
+ """
215
+ url = self._format_endpoint_url(endpoint)
216
+ return self._session.delete(url=url, verify=self._verify_ssl, **kwargs)
217
+
218
+ def _format_endpoint_url(self, endpoint: str) -> str:
219
+ if endpoint[0] == "/":
220
+ endpoint = endpoint[1:]
221
+ return f"{self._api_url}/{endpoint}"
@@ -0,0 +1,110 @@
1
+ from http import HTTPStatus
2
+ from typing import List, Optional
3
+
4
+ from kumoapi.data_source import (
5
+ CompleteFileUploadRequest,
6
+ ConnectorResponse,
7
+ CreateConnectorArgs,
8
+ DataSourceType,
9
+ DeleteUploadedFileRequest,
10
+ StartFileUploadRequest,
11
+ StartFileUploadResponse,
12
+ )
13
+ from kumoapi.json_serde import to_json_dict
14
+
15
+ from kumoai.client import KumoClient
16
+ from kumoai.client.endpoints import ConnectorEndpoints
17
+ from kumoai.client.utils import parse_response, raise_on_error
18
+ from kumoai.exceptions import HTTPException
19
+
20
+
21
+ class ConnectorAPI:
22
+ r"""Typed API definition for Kumo connectors."""
23
+ def __init__(self, client: KumoClient) -> None:
24
+ self._client = client
25
+
26
+ def start_file_upload(
27
+ self, req: StartFileUploadRequest) -> StartFileUploadResponse:
28
+ res = self._client._request(
29
+ ConnectorEndpoints.start_file_upload,
30
+ json=to_json_dict(req, insecure=True),
31
+ )
32
+ raise_on_error(res)
33
+ return parse_response(StartFileUploadResponse, res)
34
+
35
+ def delete_file_upload(self, req: DeleteUploadedFileRequest) -> None:
36
+ res = self._client._request(
37
+ ConnectorEndpoints.delete_uploaded_file,
38
+ json=to_json_dict(req, insecure=True),
39
+ )
40
+ raise_on_error(res)
41
+
42
+ def complete_file_upload(self, req: CompleteFileUploadRequest) -> None:
43
+ res = self._client._request(
44
+ ConnectorEndpoints.complete_file_upload,
45
+ json=to_json_dict(req, insecure=True),
46
+ )
47
+ raise_on_error(res)
48
+
49
+ def create(self, create_connector_args: CreateConnectorArgs) -> None:
50
+ r"""Creates a connector in Kumo."""
51
+ resp = self._client._request(
52
+ ConnectorEndpoints.create,
53
+ json=to_json_dict(create_connector_args, insecure=True),
54
+ )
55
+ raise_on_error(resp)
56
+
57
+ def create_if_not_exist(
58
+ self,
59
+ create_connector_args: CreateConnectorArgs,
60
+ ) -> bool:
61
+ r"""Creates a connector in Kumo if the connector does not exist.
62
+
63
+ Returns:
64
+ :obj:`True` if the connector is newly created, :obj:`False`
65
+ otherwise.
66
+ """
67
+ _id = create_connector_args.config.name
68
+ existing_connector = self.get(_id)
69
+ if existing_connector:
70
+ if existing_connector.config == create_connector_args.config:
71
+ return False
72
+ raise HTTPException(
73
+ HTTPStatus.UNPROCESSABLE_ENTITY,
74
+ f"Connector {_id} already exists, but has a differing "
75
+ f"configuration. Input: {create_connector_args.config}, "
76
+ f"existing: {existing_connector.config}",
77
+ )
78
+ self.create(create_connector_args)
79
+ return existing_connector is None
80
+
81
+ def get(self, connector_id: str) -> Optional[ConnectorResponse]:
82
+ r"""Fetches a connector given its ID."""
83
+ resp = self._client._request(
84
+ ConnectorEndpoints.get.with_id(connector_id))
85
+ if resp.status_code == HTTPStatus.NOT_FOUND:
86
+ return None
87
+
88
+ raise_on_error(resp)
89
+ return parse_response(ConnectorResponse, resp)
90
+
91
+ def list(
92
+ self, data_source_type: Optional[DataSourceType] = None
93
+ ) -> List[ConnectorResponse]:
94
+ r"""Lists connectors for a given data source type."""
95
+ params = {
96
+ 'data_source_type': data_source_type
97
+ } if data_source_type else {}
98
+
99
+ resp = self._client._request(ConnectorEndpoints.list, params=params)
100
+ raise_on_error(resp)
101
+ return parse_response(List[ConnectorResponse], resp)
102
+
103
+ def delete_if_exists(self, connector_id: str) -> bool:
104
+ r"""Deletes a connector if it exists."""
105
+ resp = self._client._request(
106
+ ConnectorEndpoints.delete.with_id(connector_id))
107
+ if resp.status_code == HTTPStatus.NOT_FOUND:
108
+ return False
109
+ raise_on_error(resp)
110
+ return True
@@ -0,0 +1,150 @@
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import Final, Optional
4
+
5
+
6
+ class HTTPMethod(Enum):
7
+ r"""HTTP methods supported by the API."""
8
+ GET = "GET"
9
+ POST = "POST"
10
+ DELETE = "DELETE"
11
+ PATCH = "PATCH"
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class Endpoint:
16
+ r"""Represents an API endpoint with its path and HTTP method."""
17
+ path: Optional[str] = field(default=None)
18
+ method: HTTPMethod = HTTPMethod.GET
19
+
20
+ def validate(self) -> None:
21
+ pass
22
+
23
+ def get_path(self) -> str:
24
+ if self.path is None:
25
+ # This should be in validate but is here for type checker
26
+ raise ValueError("Endpoint requires a path")
27
+ return self.path
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class IDEndpoint(Endpoint):
32
+ r"""Represents an API endpoint with an additional id/name in its path."""
33
+ template_path: Optional[str] = field(default=None)
34
+
35
+ def __post_init__(self) -> None:
36
+ if self.template_path is None:
37
+ raise ValueError("template_path must be set explicitly")
38
+ if "{id}" not in self.template_path:
39
+ raise ValueError("IDEndpoint path must contain '{{id}}': "
40
+ f"got '{self.template_path}'")
41
+
42
+ def with_id(self, resource_id: str) -> 'IDEndpoint':
43
+ assert self.template_path is not None
44
+ return IDEndpoint(template_path=self.template_path,
45
+ path=self.template_path.format(id=resource_id),
46
+ method=self.method)
47
+
48
+ def validate(self) -> None:
49
+ if self.path is None:
50
+ raise ValueError(
51
+ "IDEndpoint requires with_id() to be called with a resource id"
52
+ )
53
+
54
+ @classmethod
55
+ def from_base(cls, base: str, method: HTTPMethod) -> 'IDEndpoint':
56
+ r"""Alternate constructor to build a template_path by appending /{id}
57
+ to base.
58
+ """
59
+ template = f"{base}/{{id}}"
60
+ return cls(template_path=template, method=method)
61
+
62
+
63
+ class ConnectorEndpoints:
64
+ BASE: Final[str] = "/connectors"
65
+
66
+ create = Endpoint(BASE, HTTPMethod.POST)
67
+ list = Endpoint(BASE, HTTPMethod.GET)
68
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
69
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
70
+
71
+ start_file_upload = Endpoint(f"{BASE}/utils/start_file_upload",
72
+ HTTPMethod.POST)
73
+ delete_uploaded_file = Endpoint(f"{BASE}/utils/delete_uploaded_file",
74
+ HTTPMethod.POST)
75
+ complete_file_upload = Endpoint(f"{BASE}/utils/complete_file_upload",
76
+ HTTPMethod.POST)
77
+
78
+
79
+ class PQueryEndpoints:
80
+ BASE: Final[str] = '/predictive_queries'
81
+
82
+ create = Endpoint(BASE, HTTPMethod.POST)
83
+ list = Endpoint(BASE, HTTPMethod.GET)
84
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
85
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
86
+
87
+ infer_task_type = Endpoint(f"{BASE}/infer_task_type", HTTPMethod.POST)
88
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
89
+ suggest_training_table_plan = Endpoint(
90
+ f"{BASE}/train_table_generation_plan", HTTPMethod.POST)
91
+ suggest_model_plan = Endpoint(f"{BASE}/model_plan", HTTPMethod.POST)
92
+
93
+
94
+ class SourceTableEndpoints:
95
+ BASE: Final[str] = "/source_tables"
96
+
97
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
98
+
99
+ list_tables = Endpoint(f"{BASE}/list_tables", HTTPMethod.POST)
100
+ validate_table = Endpoint(f"{BASE}/validate_table", HTTPMethod.POST)
101
+ get_table_data = Endpoint(f"{BASE}/get_table_data", HTTPMethod.POST)
102
+ get_table_config = Endpoint(f"{BASE}/get_table_config", HTTPMethod.POST)
103
+
104
+
105
+ class GraphEndpoints:
106
+ BASE: Final[str] = "/graphs"
107
+ SNAPSHOTS: Final[str] = "/graphsnapshots"
108
+
109
+ create = Endpoint(BASE, HTTPMethod.POST)
110
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
111
+
112
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
113
+ infer_links = Endpoint(f"{BASE}/infer_links", HTTPMethod.POST)
114
+ create_snapshot = Endpoint(SNAPSHOTS, HTTPMethod.POST)
115
+ get_snapshot = IDEndpoint.from_base(SNAPSHOTS, HTTPMethod.GET)
116
+ get_edge_stats = IDEndpoint.from_base(f"{SNAPSHOTS}/edge_health",
117
+ method=HTTPMethod.GET)
118
+
119
+
120
+ class OnlineServingEndpoints:
121
+ BASE: Final[str] = "/online_serving_endpoints"
122
+
123
+ create = Endpoint(BASE, HTTPMethod.POST)
124
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
125
+ list = Endpoint(BASE, HTTPMethod.GET)
126
+ update = IDEndpoint.from_base(BASE, HTTPMethod.PATCH)
127
+ delete = IDEndpoint.from_base(BASE, HTTPMethod.DELETE)
128
+
129
+
130
+ class TableEndpoints:
131
+ BASE: Final[str] = "/tables"
132
+ SNAPSHOTS: Final[str] = "/tablesnapshots"
133
+
134
+ create = Endpoint(BASE, HTTPMethod.POST)
135
+ get = IDEndpoint.from_base(BASE, HTTPMethod.GET)
136
+
137
+ create_snapshot = Endpoint(SNAPSHOTS, HTTPMethod.POST)
138
+ get_snapshot = IDEndpoint.from_base(SNAPSHOTS, HTTPMethod.GET)
139
+ validate = Endpoint(f"{BASE}/validate", HTTPMethod.POST)
140
+ infer_metadata = Endpoint(f"{BASE}/infer_metadata", HTTPMethod.POST)
141
+
142
+
143
+ class RFMEndpoints:
144
+ BASE: Final[str] = "/rfm"
145
+
146
+ predict = Endpoint(f"{BASE}/predict", HTTPMethod.POST)
147
+ explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
148
+ evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
149
+ validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
150
+ parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
kumoai/client/graph.py ADDED
@@ -0,0 +1,120 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Dict, Optional
3
+
4
+ from kumoapi.common import ValidationResponse
5
+ from kumoapi.data_snapshot import GraphSnapshotID, GraphSnapshotResource
6
+ from kumoapi.graph import (
7
+ EdgeHealthResponse,
8
+ GraphDefinition,
9
+ GraphResource,
10
+ GraphValidationRequest,
11
+ )
12
+ from kumoapi.json_serde import to_json_dict
13
+
14
+ from kumoai.client import KumoClient
15
+ from kumoai.client.endpoints import GraphEndpoints
16
+ from kumoai.client.utils import (
17
+ parse_id_response,
18
+ parse_response,
19
+ raise_on_error,
20
+ )
21
+
22
+ GraphID = str
23
+
24
+
25
+ class GraphAPI:
26
+ r"""Typed API definition for Kumo graph definition."""
27
+ def __init__(self, client: KumoClient) -> None:
28
+ self._client = client
29
+
30
+ def create_graph(
31
+ self,
32
+ graph_def: GraphDefinition,
33
+ *,
34
+ name_alias: Optional[str] = None,
35
+ force_rename: bool = False,
36
+ ) -> GraphID:
37
+ r"""Creates a Graph (metadata definition) resource object in Kumo."""
38
+ params: Dict[str, Any] = {'force_rename': force_rename}
39
+ if name_alias:
40
+ params['name_alias'] = name_alias
41
+ resp = self._client._request(
42
+ GraphEndpoints.create,
43
+ params=params,
44
+ json=to_json_dict(graph_def),
45
+ )
46
+ raise_on_error(resp)
47
+ return parse_id_response(resp)
48
+
49
+ def get_graph_if_exists(
50
+ self,
51
+ graph_id_or_name: str,
52
+ ) -> Optional[GraphResource]:
53
+ resp = self._client._request(
54
+ GraphEndpoints.get.with_id(graph_id_or_name))
55
+ if resp.status_code == HTTPStatus.NOT_FOUND:
56
+ return None
57
+
58
+ raise_on_error(resp)
59
+ return parse_response(GraphResource, resp)
60
+
61
+ def create_snapshot(
62
+ self,
63
+ graph_id: GraphID,
64
+ *,
65
+ refresh_source: bool = False,
66
+ ) -> GraphSnapshotID:
67
+ params: Dict[str, Any] = {
68
+ 'graph_id': graph_id,
69
+ 'refresh_source': refresh_source
70
+ }
71
+ resp = self._client._request(GraphEndpoints.create_snapshot,
72
+ params=params)
73
+ raise_on_error(resp)
74
+ return GraphSnapshotID(parse_id_response(resp))
75
+
76
+ def get_snapshot(
77
+ self,
78
+ snapshot_id: GraphSnapshotID,
79
+ ) -> GraphSnapshotResource:
80
+ resp = self._client._request(
81
+ GraphEndpoints.get_snapshot.with_id(snapshot_id))
82
+ raise_on_error(resp)
83
+ return parse_response(GraphSnapshotResource, resp)
84
+
85
+ def get_snapshot_if_exists(
86
+ self,
87
+ snapshot_id: GraphSnapshotID,
88
+ ) -> Optional[GraphSnapshotResource]:
89
+ resp = self._client._request(
90
+ GraphEndpoints.get_snapshot.with_id(snapshot_id))
91
+ if resp.status_code == HTTPStatus.NOT_FOUND:
92
+ return None
93
+
94
+ raise_on_error(resp)
95
+ return parse_response(GraphSnapshotResource, resp)
96
+
97
+ def get_edge_stats(
98
+ self,
99
+ graph_snapshot_id: GraphSnapshotID,
100
+ ) -> EdgeHealthResponse:
101
+ r"""Fetches edge statistics given a snapshot id"""
102
+ resp = self._client._request(
103
+ GraphEndpoints.get_edge_stats.with_id(graph_snapshot_id))
104
+ raise_on_error(resp)
105
+ return parse_response(EdgeHealthResponse, resp)
106
+
107
+ def validate_graph(
108
+ self,
109
+ request: GraphValidationRequest,
110
+ ) -> ValidationResponse:
111
+ response = self._client._request(GraphEndpoints.validate,
112
+ json=to_json_dict(request))
113
+ raise_on_error(response)
114
+ return parse_response(ValidationResponse, response)
115
+
116
+ def infer_links(self, graph: GraphDefinition) -> GraphDefinition:
117
+ resp = self._client._request(GraphEndpoints.infer_links,
118
+ json=to_json_dict(graph))
119
+ raise_on_error(resp)
120
+ return parse_response(GraphDefinition, resp)