hiddenlayer-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hiddenlayer/__init__.py +109 -0
- hiddenlayer/sdk/__init__.py +0 -0
- hiddenlayer/sdk/constants.py +14 -0
- hiddenlayer/sdk/enterprise/__init__.py +0 -0
- hiddenlayer/sdk/enterprise/enterprise_model_scan_api.py +55 -0
- hiddenlayer/sdk/exceptions.py +12 -0
- hiddenlayer/sdk/models.py +22 -0
- hiddenlayer/sdk/rest/__init__.py +49 -0
- hiddenlayer/sdk/rest/api/__init__.py +7 -0
- hiddenlayer/sdk/rest/api/aidr_predictive_api.py +308 -0
- hiddenlayer/sdk/rest/api/model_scan_api.py +591 -0
- hiddenlayer/sdk/rest/api/sensor_api.py +1966 -0
- hiddenlayer/sdk/rest/api_client.py +770 -0
- hiddenlayer/sdk/rest/api_response.py +21 -0
- hiddenlayer/sdk/rest/configuration.py +445 -0
- hiddenlayer/sdk/rest/exceptions.py +199 -0
- hiddenlayer/sdk/rest/models/__init__.py +30 -0
- hiddenlayer/sdk/rest/models/create_sensor_request.py +95 -0
- hiddenlayer/sdk/rest/models/file_info.py +110 -0
- hiddenlayer/sdk/rest/models/get_multipart_upload_response.py +97 -0
- hiddenlayer/sdk/rest/models/model.py +100 -0
- hiddenlayer/sdk/rest/models/model_query_response.py +101 -0
- hiddenlayer/sdk/rest/models/multipart_upload_part.py +93 -0
- hiddenlayer/sdk/rest/models/scan_model_request.py +87 -0
- hiddenlayer/sdk/rest/models/scan_results_v2.py +108 -0
- hiddenlayer/sdk/rest/models/sensor_sor_query_filter.py +108 -0
- hiddenlayer/sdk/rest/models/sensor_sor_query_request.py +109 -0
- hiddenlayer/sdk/rest/models/submission_response.py +95 -0
- hiddenlayer/sdk/rest/models/submission_v2.py +109 -0
- hiddenlayer/sdk/rest/models/validation_error_model.py +99 -0
- hiddenlayer/sdk/rest/models/validation_error_model_loc_inner.py +138 -0
- hiddenlayer/sdk/rest/rest.py +257 -0
- hiddenlayer/sdk/services/__init__.py +0 -0
- hiddenlayer/sdk/services/aidr_predictive.py +76 -0
- hiddenlayer/sdk/services/model.py +101 -0
- hiddenlayer/sdk/services/model_scan.py +414 -0
- hiddenlayer/sdk/utils.py +92 -0
- hiddenlayer/sdk/version.py +1 -0
- hiddenlayer_sdk-0.1.0.dist-info/LICENSE +201 -0
- hiddenlayer_sdk-0.1.0.dist-info/METADATA +320 -0
- hiddenlayer_sdk-0.1.0.dist-info/RECORD +43 -0
- hiddenlayer_sdk-0.1.0.dist-info/WHEEL +5 -0
- hiddenlayer_sdk-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
"""
|
4
|
+
HiddenLayer ModelScan
|
5
|
+
|
6
|
+
HiddenLayer ModelScan API for scanning of models
|
7
|
+
|
8
|
+
The version of the OpenAPI document: 1
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
10
|
+
|
11
|
+
Do not edit the class manually.
|
12
|
+
""" # noqa: E501
|
13
|
+
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
from inspect import getfullargspec
|
17
|
+
import json
|
18
|
+
import pprint
|
19
|
+
import re # noqa: F401
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator
|
21
|
+
from typing import Optional
|
22
|
+
from typing import Union, Any, List, Set, TYPE_CHECKING, Optional, Dict
|
23
|
+
from typing_extensions import Literal, Self
|
24
|
+
from pydantic import Field
|
25
|
+
|
26
|
+
VALIDATIONERRORMODELLOCINNER_ANY_OF_SCHEMAS = ["int", "str"]
|
27
|
+
|
28
|
+
class ValidationErrorModelLocInner(BaseModel):
|
29
|
+
"""
|
30
|
+
ValidationErrorModelLocInner
|
31
|
+
"""
|
32
|
+
|
33
|
+
# data type: str
|
34
|
+
anyof_schema_1_validator: Optional[StrictStr] = None
|
35
|
+
# data type: int
|
36
|
+
anyof_schema_2_validator: Optional[StrictInt] = None
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
actual_instance: Optional[Union[int, str]] = None
|
39
|
+
else:
|
40
|
+
actual_instance: Any = None
|
41
|
+
any_of_schemas: Set[str] = { "int", "str" }
|
42
|
+
|
43
|
+
model_config = {
|
44
|
+
"validate_assignment": True,
|
45
|
+
"protected_namespaces": (),
|
46
|
+
}
|
47
|
+
|
48
|
+
def __init__(self, *args, **kwargs) -> None:
|
49
|
+
if args:
|
50
|
+
if len(args) > 1:
|
51
|
+
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
|
52
|
+
if kwargs:
|
53
|
+
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
|
54
|
+
super().__init__(actual_instance=args[0])
|
55
|
+
else:
|
56
|
+
super().__init__(**kwargs)
|
57
|
+
|
58
|
+
@field_validator('actual_instance')
|
59
|
+
def actual_instance_must_validate_anyof(cls, v):
|
60
|
+
instance = ValidationErrorModelLocInner.model_construct()
|
61
|
+
error_messages = []
|
62
|
+
# validate data type: str
|
63
|
+
try:
|
64
|
+
instance.anyof_schema_1_validator = v
|
65
|
+
return v
|
66
|
+
except (ValidationError, ValueError) as e:
|
67
|
+
error_messages.append(str(e))
|
68
|
+
# validate data type: int
|
69
|
+
try:
|
70
|
+
instance.anyof_schema_2_validator = v
|
71
|
+
return v
|
72
|
+
except (ValidationError, ValueError) as e:
|
73
|
+
error_messages.append(str(e))
|
74
|
+
if error_messages:
|
75
|
+
# no match
|
76
|
+
raise ValueError("No match found when setting the actual_instance in ValidationErrorModelLocInner with anyOf schemas: int, str. Details: " + ", ".join(error_messages))
|
77
|
+
else:
|
78
|
+
return v
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def from_dict(cls, obj: Dict[str, Any]) -> Self:
|
82
|
+
return cls.from_json(json.dumps(obj))
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def from_json(cls, json_str: str) -> Self:
|
86
|
+
"""Returns the object represented by the json string"""
|
87
|
+
instance = cls.model_construct()
|
88
|
+
error_messages = []
|
89
|
+
# deserialize data into str
|
90
|
+
try:
|
91
|
+
# validation
|
92
|
+
instance.anyof_schema_1_validator = json.loads(json_str)
|
93
|
+
# assign value to actual_instance
|
94
|
+
instance.actual_instance = instance.anyof_schema_1_validator
|
95
|
+
return instance
|
96
|
+
except (ValidationError, ValueError) as e:
|
97
|
+
error_messages.append(str(e))
|
98
|
+
# deserialize data into int
|
99
|
+
try:
|
100
|
+
# validation
|
101
|
+
instance.anyof_schema_2_validator = json.loads(json_str)
|
102
|
+
# assign value to actual_instance
|
103
|
+
instance.actual_instance = instance.anyof_schema_2_validator
|
104
|
+
return instance
|
105
|
+
except (ValidationError, ValueError) as e:
|
106
|
+
error_messages.append(str(e))
|
107
|
+
|
108
|
+
if error_messages:
|
109
|
+
# no match
|
110
|
+
raise ValueError("No match found when deserializing the JSON string into ValidationErrorModelLocInner with anyOf schemas: int, str. Details: " + ", ".join(error_messages))
|
111
|
+
else:
|
112
|
+
return instance
|
113
|
+
|
114
|
+
def to_json(self) -> str:
|
115
|
+
"""Returns the JSON representation of the actual instance"""
|
116
|
+
if self.actual_instance is None:
|
117
|
+
return "null"
|
118
|
+
|
119
|
+
if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json):
|
120
|
+
return self.actual_instance.to_json()
|
121
|
+
else:
|
122
|
+
return json.dumps(self.actual_instance)
|
123
|
+
|
124
|
+
def to_dict(self) -> Optional[Union[Dict[str, Any], int, str]]:
|
125
|
+
"""Returns the dict representation of the actual instance"""
|
126
|
+
if self.actual_instance is None:
|
127
|
+
return None
|
128
|
+
|
129
|
+
if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict):
|
130
|
+
return self.actual_instance.to_dict()
|
131
|
+
else:
|
132
|
+
return self.actual_instance
|
133
|
+
|
134
|
+
def to_str(self) -> str:
|
135
|
+
"""Returns the string representation of the actual instance"""
|
136
|
+
return pprint.pformat(self.model_dump())
|
137
|
+
|
138
|
+
|
@@ -0,0 +1,257 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
"""
|
4
|
+
HiddenLayer ModelScan
|
5
|
+
|
6
|
+
HiddenLayer ModelScan API for scanning of models
|
7
|
+
|
8
|
+
The version of the OpenAPI document: 1
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
10
|
+
|
11
|
+
Do not edit the class manually.
|
12
|
+
""" # noqa: E501
|
13
|
+
|
14
|
+
|
15
|
+
import io
|
16
|
+
import json
|
17
|
+
import re
|
18
|
+
import ssl
|
19
|
+
|
20
|
+
import urllib3
|
21
|
+
|
22
|
+
from hiddenlayer.sdk.rest.exceptions import ApiException, ApiValueError
|
23
|
+
|
24
|
+
SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"}
|
25
|
+
RESTResponseType = urllib3.HTTPResponse
|
26
|
+
|
27
|
+
|
28
|
+
def is_socks_proxy_url(url):
|
29
|
+
if url is None:
|
30
|
+
return False
|
31
|
+
split_section = url.split("://")
|
32
|
+
if len(split_section) < 2:
|
33
|
+
return False
|
34
|
+
else:
|
35
|
+
return split_section[0].lower() in SUPPORTED_SOCKS_PROXIES
|
36
|
+
|
37
|
+
|
38
|
+
class RESTResponse(io.IOBase):
|
39
|
+
|
40
|
+
def __init__(self, resp) -> None:
|
41
|
+
self.response = resp
|
42
|
+
self.status = resp.status
|
43
|
+
self.reason = resp.reason
|
44
|
+
self.data = None
|
45
|
+
|
46
|
+
def read(self):
|
47
|
+
if self.data is None:
|
48
|
+
self.data = self.response.data
|
49
|
+
return self.data
|
50
|
+
|
51
|
+
def getheaders(self):
|
52
|
+
"""Returns a dictionary of the response headers."""
|
53
|
+
return self.response.headers
|
54
|
+
|
55
|
+
def getheader(self, name, default=None):
|
56
|
+
"""Returns a given response header."""
|
57
|
+
return self.response.headers.get(name, default)
|
58
|
+
|
59
|
+
|
60
|
+
class RESTClientObject:
|
61
|
+
|
62
|
+
def __init__(self, configuration) -> None:
|
63
|
+
# urllib3.PoolManager will pass all kw parameters to connectionpool
|
64
|
+
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
|
65
|
+
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
|
66
|
+
# Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501
|
67
|
+
|
68
|
+
# cert_reqs
|
69
|
+
if configuration.verify_ssl:
|
70
|
+
cert_reqs = ssl.CERT_REQUIRED
|
71
|
+
else:
|
72
|
+
cert_reqs = ssl.CERT_NONE
|
73
|
+
|
74
|
+
pool_args = {
|
75
|
+
"cert_reqs": cert_reqs,
|
76
|
+
"ca_certs": configuration.ssl_ca_cert,
|
77
|
+
"cert_file": configuration.cert_file,
|
78
|
+
"key_file": configuration.key_file,
|
79
|
+
}
|
80
|
+
if configuration.assert_hostname is not None:
|
81
|
+
pool_args['assert_hostname'] = (
|
82
|
+
configuration.assert_hostname
|
83
|
+
)
|
84
|
+
|
85
|
+
if configuration.retries is not None:
|
86
|
+
pool_args['retries'] = configuration.retries
|
87
|
+
|
88
|
+
if configuration.tls_server_name:
|
89
|
+
pool_args['server_hostname'] = configuration.tls_server_name
|
90
|
+
|
91
|
+
|
92
|
+
if configuration.socket_options is not None:
|
93
|
+
pool_args['socket_options'] = configuration.socket_options
|
94
|
+
|
95
|
+
if configuration.connection_pool_maxsize is not None:
|
96
|
+
pool_args['maxsize'] = configuration.connection_pool_maxsize
|
97
|
+
|
98
|
+
# https pool manager
|
99
|
+
self.pool_manager: urllib3.PoolManager
|
100
|
+
|
101
|
+
if configuration.proxy:
|
102
|
+
if is_socks_proxy_url(configuration.proxy):
|
103
|
+
from urllib3.contrib.socks import SOCKSProxyManager
|
104
|
+
pool_args["proxy_url"] = configuration.proxy
|
105
|
+
pool_args["headers"] = configuration.proxy_headers
|
106
|
+
self.pool_manager = SOCKSProxyManager(**pool_args)
|
107
|
+
else:
|
108
|
+
pool_args["proxy_url"] = configuration.proxy
|
109
|
+
pool_args["proxy_headers"] = configuration.proxy_headers
|
110
|
+
self.pool_manager = urllib3.ProxyManager(**pool_args)
|
111
|
+
else:
|
112
|
+
self.pool_manager = urllib3.PoolManager(**pool_args)
|
113
|
+
|
114
|
+
def request(
|
115
|
+
self,
|
116
|
+
method,
|
117
|
+
url,
|
118
|
+
headers=None,
|
119
|
+
body=None,
|
120
|
+
post_params=None,
|
121
|
+
_request_timeout=None
|
122
|
+
):
|
123
|
+
"""Perform requests.
|
124
|
+
|
125
|
+
:param method: http request method
|
126
|
+
:param url: http request url
|
127
|
+
:param headers: http request headers
|
128
|
+
:param body: request json body, for `application/json`
|
129
|
+
:param post_params: request post parameters,
|
130
|
+
`application/x-www-form-urlencoded`
|
131
|
+
and `multipart/form-data`
|
132
|
+
:param _request_timeout: timeout setting for this request. If one
|
133
|
+
number provided, it will be total request
|
134
|
+
timeout. It can also be a pair (tuple) of
|
135
|
+
(connection, read) timeouts.
|
136
|
+
"""
|
137
|
+
method = method.upper()
|
138
|
+
assert method in [
|
139
|
+
'GET',
|
140
|
+
'HEAD',
|
141
|
+
'DELETE',
|
142
|
+
'POST',
|
143
|
+
'PUT',
|
144
|
+
'PATCH',
|
145
|
+
'OPTIONS'
|
146
|
+
]
|
147
|
+
|
148
|
+
if post_params and body:
|
149
|
+
raise ApiValueError(
|
150
|
+
"body parameter cannot be used with post_params parameter."
|
151
|
+
)
|
152
|
+
|
153
|
+
post_params = post_params or {}
|
154
|
+
headers = headers or {}
|
155
|
+
|
156
|
+
timeout = None
|
157
|
+
if _request_timeout:
|
158
|
+
if isinstance(_request_timeout, (int, float)):
|
159
|
+
timeout = urllib3.Timeout(total=_request_timeout)
|
160
|
+
elif (
|
161
|
+
isinstance(_request_timeout, tuple)
|
162
|
+
and len(_request_timeout) == 2
|
163
|
+
):
|
164
|
+
timeout = urllib3.Timeout(
|
165
|
+
connect=_request_timeout[0],
|
166
|
+
read=_request_timeout[1]
|
167
|
+
)
|
168
|
+
|
169
|
+
try:
|
170
|
+
# For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE`
|
171
|
+
if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']:
|
172
|
+
|
173
|
+
# no content type provided or payload is json
|
174
|
+
content_type = headers.get('Content-Type')
|
175
|
+
if (
|
176
|
+
not content_type
|
177
|
+
or re.search('json', content_type, re.IGNORECASE)
|
178
|
+
):
|
179
|
+
request_body = None
|
180
|
+
if body is not None:
|
181
|
+
request_body = json.dumps(body)
|
182
|
+
r = self.pool_manager.request(
|
183
|
+
method,
|
184
|
+
url,
|
185
|
+
body=request_body,
|
186
|
+
timeout=timeout,
|
187
|
+
headers=headers,
|
188
|
+
preload_content=False
|
189
|
+
)
|
190
|
+
elif content_type == 'application/x-www-form-urlencoded':
|
191
|
+
r = self.pool_manager.request(
|
192
|
+
method,
|
193
|
+
url,
|
194
|
+
fields=post_params,
|
195
|
+
encode_multipart=False,
|
196
|
+
timeout=timeout,
|
197
|
+
headers=headers,
|
198
|
+
preload_content=False
|
199
|
+
)
|
200
|
+
elif content_type == 'multipart/form-data':
|
201
|
+
# must del headers['Content-Type'], or the correct
|
202
|
+
# Content-Type which generated by urllib3 will be
|
203
|
+
# overwritten.
|
204
|
+
del headers['Content-Type']
|
205
|
+
# Ensures that dict objects are serialized
|
206
|
+
post_params = [(a, json.dumps(b)) if isinstance(b, dict) else (a,b) for a, b in post_params]
|
207
|
+
r = self.pool_manager.request(
|
208
|
+
method,
|
209
|
+
url,
|
210
|
+
fields=post_params,
|
211
|
+
encode_multipart=True,
|
212
|
+
timeout=timeout,
|
213
|
+
headers=headers,
|
214
|
+
preload_content=False
|
215
|
+
)
|
216
|
+
# Pass a `string` parameter directly in the body to support
|
217
|
+
# other content types than JSON when `body` argument is
|
218
|
+
# provided in serialized form.
|
219
|
+
elif isinstance(body, str) or isinstance(body, bytes):
|
220
|
+
r = self.pool_manager.request(
|
221
|
+
method,
|
222
|
+
url,
|
223
|
+
body=body,
|
224
|
+
timeout=timeout,
|
225
|
+
headers=headers,
|
226
|
+
preload_content=False
|
227
|
+
)
|
228
|
+
elif headers['Content-Type'] == 'text/plain' and isinstance(body, bool):
|
229
|
+
request_body = "true" if body else "false"
|
230
|
+
r = self.pool_manager.request(
|
231
|
+
method,
|
232
|
+
url,
|
233
|
+
body=request_body,
|
234
|
+
preload_content=False,
|
235
|
+
timeout=timeout,
|
236
|
+
headers=headers)
|
237
|
+
else:
|
238
|
+
# Cannot generate the request from given parameters
|
239
|
+
msg = """Cannot prepare a request message for provided
|
240
|
+
arguments. Please check that your arguments match
|
241
|
+
declared content type."""
|
242
|
+
raise ApiException(status=0, reason=msg)
|
243
|
+
# For `GET`, `HEAD`
|
244
|
+
else:
|
245
|
+
r = self.pool_manager.request(
|
246
|
+
method,
|
247
|
+
url,
|
248
|
+
fields={},
|
249
|
+
timeout=timeout,
|
250
|
+
headers=headers,
|
251
|
+
preload_content=False
|
252
|
+
)
|
253
|
+
except urllib3.exceptions.SSLError as e:
|
254
|
+
msg = "\n".join([type(e).__name__, str(e)])
|
255
|
+
raise ApiException(status=0, reason=msg)
|
256
|
+
|
257
|
+
return RESTResponse(r)
|
File without changes
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import base64
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from hiddenlayer.sdk.rest.api import AidrPredictiveApi
|
8
|
+
from hiddenlayer.sdk.rest.api_client import ApiClient
|
9
|
+
from hiddenlayer.sdk.rest.models import (
|
10
|
+
SubmissionResponse,
|
11
|
+
SubmissionV2,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class AIDRPredictive:
|
16
|
+
def __init__(self, api_client: ApiClient) -> None:
|
17
|
+
self._aidr_predictive = AidrPredictiveApi(api_client=api_client)
|
18
|
+
|
19
|
+
def submit_vectors(
|
20
|
+
self,
|
21
|
+
*,
|
22
|
+
model_id: str,
|
23
|
+
requester_id: str,
|
24
|
+
input_vectors: Union[List[float], np.ndarray],
|
25
|
+
output: Union[List[float], np.ndarray],
|
26
|
+
predictions: Optional[List[float]] = None,
|
27
|
+
tags: Optional[List[str]] = None,
|
28
|
+
metadata: Optional[Dict[str, Any]] = None,
|
29
|
+
event_time: Optional[str] = None,
|
30
|
+
) -> SubmissionResponse:
|
31
|
+
"""
|
32
|
+
Submit feature vectors and model outputs via the HiddenLayer API.
|
33
|
+
|
34
|
+
:param model_id: Model id.
|
35
|
+
:param requester_id: Custom identifier for the inbound request. This should be a value that can be used to identify individual users interacting with the model.
|
36
|
+
:param input_vectors: Feature vectors for your model.
|
37
|
+
:param output: Output vectors directly from your model.
|
38
|
+
:param predictions: If you ran `np.argmax` or `np.argmin` or provided custom logic onto the model output.
|
39
|
+
:param tags: Custom tags attached to the request.
|
40
|
+
:param metadata: Custom metadata attached to the request.
|
41
|
+
:param event_time: Time when the features and outputs were created, defaults to now.
|
42
|
+
|
43
|
+
:returns: Submission Response
|
44
|
+
"""
|
45
|
+
|
46
|
+
input_vectors = (
|
47
|
+
np.array(input_vectors)
|
48
|
+
if isinstance(input_vectors, list)
|
49
|
+
else input_vectors
|
50
|
+
)
|
51
|
+
output = np.array(output) if isinstance(output, list) else output
|
52
|
+
|
53
|
+
# Output vectors need to be at least 2d or AIDR will fail silently
|
54
|
+
output = output.reshape(-1, 1) if len(output.shape) == 1 else output
|
55
|
+
|
56
|
+
input_layer = base64.b64encode(input_vectors.tobytes()).decode()
|
57
|
+
output_layer = base64.b64encode(output.tobytes()).decode()
|
58
|
+
|
59
|
+
return self._aidr_predictive.submit_vectors(
|
60
|
+
SubmissionV2(
|
61
|
+
metadata=metadata if metadata else {},
|
62
|
+
tags=tags if tags else [],
|
63
|
+
sensor_id=model_id,
|
64
|
+
requester_id=requester_id,
|
65
|
+
input_layer=input_layer,
|
66
|
+
input_layer_dtype=str(input_vectors.dtype),
|
67
|
+
input_layer_shape=list(input_vectors.shape),
|
68
|
+
output_layer=output_layer,
|
69
|
+
output_layer_dtype=str(output.dtype),
|
70
|
+
output_layer_shape=list(output.shape),
|
71
|
+
predictions=predictions,
|
72
|
+
event_time=event_time
|
73
|
+
if event_time
|
74
|
+
else str(datetime.now().isoformat()),
|
75
|
+
)
|
76
|
+
)
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from hiddenlayer.sdk.constants import ApiErrors
|
5
|
+
from hiddenlayer.sdk.exceptions import (
|
6
|
+
HiddenlayerConflictError,
|
7
|
+
IncompatibleModelError,
|
8
|
+
ModelDoesNotExistError,
|
9
|
+
)
|
10
|
+
from hiddenlayer.sdk.rest.api import SensorApi
|
11
|
+
from hiddenlayer.sdk.rest.api_client import ApiClient
|
12
|
+
from hiddenlayer.sdk.rest.exceptions import ApiException
|
13
|
+
from hiddenlayer.sdk.rest.models import (
|
14
|
+
CreateSensorRequest,
|
15
|
+
Model,
|
16
|
+
SensorSORQueryFilter,
|
17
|
+
SensorSORQueryRequest,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class ModelAPI:
|
22
|
+
def __init__(self, api_client: ApiClient) -> None:
|
23
|
+
self._sensor_api = SensorApi(api_client=api_client)
|
24
|
+
|
25
|
+
def create(self, *, model_name: str) -> Model:
|
26
|
+
"""
|
27
|
+
Creates a model in the HiddenLayer Platform.
|
28
|
+
|
29
|
+
:params model_name: Name of the model
|
30
|
+
|
31
|
+
:returns: HiddenLayer ModelID
|
32
|
+
"""
|
33
|
+
return self._sensor_api.create_sensor(
|
34
|
+
CreateSensorRequest(plaintext_name=model_name, adhoc=True)
|
35
|
+
)
|
36
|
+
|
37
|
+
def get(self, *, model_name: str, version: Optional[int] = None) -> Model:
|
38
|
+
"""
|
39
|
+
Gets a HiddenLayer model object. If not version is supplied, the latest model is returned.
|
40
|
+
|
41
|
+
:param model_name: Name of the model.
|
42
|
+
:param version: Version of the model to get.
|
43
|
+
|
44
|
+
:returns: HiddenLayer Model object
|
45
|
+
"""
|
46
|
+
|
47
|
+
return self._get_model_by_name(model_name=model_name, version=version)
|
48
|
+
|
49
|
+
def delete(self, *, model_name: str) -> None:
|
50
|
+
"""
|
51
|
+
Delete a model.
|
52
|
+
|
53
|
+
:param model_name: Name of the model.
|
54
|
+
"""
|
55
|
+
|
56
|
+
model = self._get_model_by_name(model_name=model_name, version=None)
|
57
|
+
|
58
|
+
try:
|
59
|
+
self._sensor_api.delete_model(sensor_id=model.sensor_id)
|
60
|
+
except ApiException as e:
|
61
|
+
reason = json.loads(str(e.body))["detail"]
|
62
|
+
|
63
|
+
if reason == ApiErrors.NON_ADHOC_SENSOR_DELETE:
|
64
|
+
raise IncompatibleModelError(
|
65
|
+
"This type of model is unable to be deleted."
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
raise HiddenlayerConflictError(reason)
|
69
|
+
except Exception as e:
|
70
|
+
raise e
|
71
|
+
|
72
|
+
def _get_model_by_name(
|
73
|
+
self, *, model_name: str, version: Optional[int] = None
|
74
|
+
) -> Model:
|
75
|
+
"""
|
76
|
+
Gets a model object by name. If version is not supplied, get the latest.
|
77
|
+
|
78
|
+
:param model_name: Name of the model.
|
79
|
+
:param version: Version of the model to get.
|
80
|
+
|
81
|
+
:returns: HiddenLayer Model object
|
82
|
+
"""
|
83
|
+
|
84
|
+
models = self._sensor_api.query_sensor(
|
85
|
+
sensor_sor_query_request=SensorSORQueryRequest(
|
86
|
+
filter=SensorSORQueryFilter(plaintext_name=model_name, version=version)
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
if not models.results:
|
91
|
+
msg = f"Model {model_name} does not exist"
|
92
|
+
|
93
|
+
if version:
|
94
|
+
msg = f"{msg} for version {version}."
|
95
|
+
|
96
|
+
raise ModelDoesNotExistError(msg)
|
97
|
+
|
98
|
+
if not version:
|
99
|
+
models.results.sort(key=lambda x: x.version, reverse=True)
|
100
|
+
|
101
|
+
return models.results[0]
|