hiddenlayer-sdk 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- hiddenlayer/__init__.py +109 -0
- hiddenlayer/sdk/__init__.py +0 -0
- hiddenlayer/sdk/constants.py +14 -0
- hiddenlayer/sdk/enterprise/__init__.py +0 -0
- hiddenlayer/sdk/enterprise/enterprise_model_scan_api.py +55 -0
- hiddenlayer/sdk/exceptions.py +12 -0
- hiddenlayer/sdk/models.py +22 -0
- hiddenlayer/sdk/rest/__init__.py +49 -0
- hiddenlayer/sdk/rest/api/__init__.py +7 -0
- hiddenlayer/sdk/rest/api/aidr_predictive_api.py +308 -0
- hiddenlayer/sdk/rest/api/model_scan_api.py +591 -0
- hiddenlayer/sdk/rest/api/sensor_api.py +1966 -0
- hiddenlayer/sdk/rest/api_client.py +770 -0
- hiddenlayer/sdk/rest/api_response.py +21 -0
- hiddenlayer/sdk/rest/configuration.py +445 -0
- hiddenlayer/sdk/rest/exceptions.py +199 -0
- hiddenlayer/sdk/rest/models/__init__.py +30 -0
- hiddenlayer/sdk/rest/models/create_sensor_request.py +95 -0
- hiddenlayer/sdk/rest/models/file_info.py +110 -0
- hiddenlayer/sdk/rest/models/get_multipart_upload_response.py +97 -0
- hiddenlayer/sdk/rest/models/model.py +100 -0
- hiddenlayer/sdk/rest/models/model_query_response.py +101 -0
- hiddenlayer/sdk/rest/models/multipart_upload_part.py +93 -0
- hiddenlayer/sdk/rest/models/scan_model_request.py +87 -0
- hiddenlayer/sdk/rest/models/scan_results_v2.py +108 -0
- hiddenlayer/sdk/rest/models/sensor_sor_query_filter.py +108 -0
- hiddenlayer/sdk/rest/models/sensor_sor_query_request.py +109 -0
- hiddenlayer/sdk/rest/models/submission_response.py +95 -0
- hiddenlayer/sdk/rest/models/submission_v2.py +109 -0
- hiddenlayer/sdk/rest/models/validation_error_model.py +99 -0
- hiddenlayer/sdk/rest/models/validation_error_model_loc_inner.py +138 -0
- hiddenlayer/sdk/rest/rest.py +257 -0
- hiddenlayer/sdk/services/__init__.py +0 -0
- hiddenlayer/sdk/services/aidr_predictive.py +76 -0
- hiddenlayer/sdk/services/model.py +101 -0
- hiddenlayer/sdk/services/model_scan.py +414 -0
- hiddenlayer/sdk/utils.py +92 -0
- hiddenlayer/sdk/version.py +1 -0
- hiddenlayer_sdk-0.1.0.dist-info/LICENSE +201 -0
- hiddenlayer_sdk-0.1.0.dist-info/METADATA +320 -0
- hiddenlayer_sdk-0.1.0.dist-info/RECORD +43 -0
- hiddenlayer_sdk-0.1.0.dist-info/WHEEL +5 -0
- hiddenlayer_sdk-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
"""
|
4
|
+
HiddenLayer ModelScan
|
5
|
+
|
6
|
+
HiddenLayer ModelScan API for scanning of models
|
7
|
+
|
8
|
+
The version of the OpenAPI document: 1
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
10
|
+
|
11
|
+
Do not edit the class manually.
|
12
|
+
""" # noqa: E501
|
13
|
+
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
from inspect import getfullargspec
|
17
|
+
import json
|
18
|
+
import pprint
|
19
|
+
import re # noqa: F401
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator
|
21
|
+
from typing import Optional
|
22
|
+
from typing import Union, Any, List, Set, TYPE_CHECKING, Optional, Dict
|
23
|
+
from typing_extensions import Literal, Self
|
24
|
+
from pydantic import Field
|
25
|
+
|
26
|
+
VALIDATIONERRORMODELLOCINNER_ANY_OF_SCHEMAS = ["int", "str"]
|
27
|
+
|
28
|
+
class ValidationErrorModelLocInner(BaseModel):
|
29
|
+
"""
|
30
|
+
ValidationErrorModelLocInner
|
31
|
+
"""
|
32
|
+
|
33
|
+
# data type: str
|
34
|
+
anyof_schema_1_validator: Optional[StrictStr] = None
|
35
|
+
# data type: int
|
36
|
+
anyof_schema_2_validator: Optional[StrictInt] = None
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
actual_instance: Optional[Union[int, str]] = None
|
39
|
+
else:
|
40
|
+
actual_instance: Any = None
|
41
|
+
any_of_schemas: Set[str] = { "int", "str" }
|
42
|
+
|
43
|
+
model_config = {
|
44
|
+
"validate_assignment": True,
|
45
|
+
"protected_namespaces": (),
|
46
|
+
}
|
47
|
+
|
48
|
+
def __init__(self, *args, **kwargs) -> None:
|
49
|
+
if args:
|
50
|
+
if len(args) > 1:
|
51
|
+
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
|
52
|
+
if kwargs:
|
53
|
+
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
|
54
|
+
super().__init__(actual_instance=args[0])
|
55
|
+
else:
|
56
|
+
super().__init__(**kwargs)
|
57
|
+
|
58
|
+
@field_validator('actual_instance')
|
59
|
+
def actual_instance_must_validate_anyof(cls, v):
|
60
|
+
instance = ValidationErrorModelLocInner.model_construct()
|
61
|
+
error_messages = []
|
62
|
+
# validate data type: str
|
63
|
+
try:
|
64
|
+
instance.anyof_schema_1_validator = v
|
65
|
+
return v
|
66
|
+
except (ValidationError, ValueError) as e:
|
67
|
+
error_messages.append(str(e))
|
68
|
+
# validate data type: int
|
69
|
+
try:
|
70
|
+
instance.anyof_schema_2_validator = v
|
71
|
+
return v
|
72
|
+
except (ValidationError, ValueError) as e:
|
73
|
+
error_messages.append(str(e))
|
74
|
+
if error_messages:
|
75
|
+
# no match
|
76
|
+
raise ValueError("No match found when setting the actual_instance in ValidationErrorModelLocInner with anyOf schemas: int, str. Details: " + ", ".join(error_messages))
|
77
|
+
else:
|
78
|
+
return v
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def from_dict(cls, obj: Dict[str, Any]) -> Self:
|
82
|
+
return cls.from_json(json.dumps(obj))
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def from_json(cls, json_str: str) -> Self:
|
86
|
+
"""Returns the object represented by the json string"""
|
87
|
+
instance = cls.model_construct()
|
88
|
+
error_messages = []
|
89
|
+
# deserialize data into str
|
90
|
+
try:
|
91
|
+
# validation
|
92
|
+
instance.anyof_schema_1_validator = json.loads(json_str)
|
93
|
+
# assign value to actual_instance
|
94
|
+
instance.actual_instance = instance.anyof_schema_1_validator
|
95
|
+
return instance
|
96
|
+
except (ValidationError, ValueError) as e:
|
97
|
+
error_messages.append(str(e))
|
98
|
+
# deserialize data into int
|
99
|
+
try:
|
100
|
+
# validation
|
101
|
+
instance.anyof_schema_2_validator = json.loads(json_str)
|
102
|
+
# assign value to actual_instance
|
103
|
+
instance.actual_instance = instance.anyof_schema_2_validator
|
104
|
+
return instance
|
105
|
+
except (ValidationError, ValueError) as e:
|
106
|
+
error_messages.append(str(e))
|
107
|
+
|
108
|
+
if error_messages:
|
109
|
+
# no match
|
110
|
+
raise ValueError("No match found when deserializing the JSON string into ValidationErrorModelLocInner with anyOf schemas: int, str. Details: " + ", ".join(error_messages))
|
111
|
+
else:
|
112
|
+
return instance
|
113
|
+
|
114
|
+
def to_json(self) -> str:
|
115
|
+
"""Returns the JSON representation of the actual instance"""
|
116
|
+
if self.actual_instance is None:
|
117
|
+
return "null"
|
118
|
+
|
119
|
+
if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json):
|
120
|
+
return self.actual_instance.to_json()
|
121
|
+
else:
|
122
|
+
return json.dumps(self.actual_instance)
|
123
|
+
|
124
|
+
def to_dict(self) -> Optional[Union[Dict[str, Any], int, str]]:
|
125
|
+
"""Returns the dict representation of the actual instance"""
|
126
|
+
if self.actual_instance is None:
|
127
|
+
return None
|
128
|
+
|
129
|
+
if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict):
|
130
|
+
return self.actual_instance.to_dict()
|
131
|
+
else:
|
132
|
+
return self.actual_instance
|
133
|
+
|
134
|
+
def to_str(self) -> str:
|
135
|
+
"""Returns the string representation of the actual instance"""
|
136
|
+
return pprint.pformat(self.model_dump())
|
137
|
+
|
138
|
+
|
@@ -0,0 +1,257 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
|
3
|
+
"""
|
4
|
+
HiddenLayer ModelScan
|
5
|
+
|
6
|
+
HiddenLayer ModelScan API for scanning of models
|
7
|
+
|
8
|
+
The version of the OpenAPI document: 1
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
10
|
+
|
11
|
+
Do not edit the class manually.
|
12
|
+
""" # noqa: E501
|
13
|
+
|
14
|
+
|
15
|
+
import io
|
16
|
+
import json
|
17
|
+
import re
|
18
|
+
import ssl
|
19
|
+
|
20
|
+
import urllib3
|
21
|
+
|
22
|
+
from hiddenlayer.sdk.rest.exceptions import ApiException, ApiValueError
|
23
|
+
|
24
|
+
SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"}
|
25
|
+
RESTResponseType = urllib3.HTTPResponse
|
26
|
+
|
27
|
+
|
28
|
+
def is_socks_proxy_url(url):
|
29
|
+
if url is None:
|
30
|
+
return False
|
31
|
+
split_section = url.split("://")
|
32
|
+
if len(split_section) < 2:
|
33
|
+
return False
|
34
|
+
else:
|
35
|
+
return split_section[0].lower() in SUPPORTED_SOCKS_PROXIES
|
36
|
+
|
37
|
+
|
38
|
+
class RESTResponse(io.IOBase):
|
39
|
+
|
40
|
+
def __init__(self, resp) -> None:
|
41
|
+
self.response = resp
|
42
|
+
self.status = resp.status
|
43
|
+
self.reason = resp.reason
|
44
|
+
self.data = None
|
45
|
+
|
46
|
+
def read(self):
|
47
|
+
if self.data is None:
|
48
|
+
self.data = self.response.data
|
49
|
+
return self.data
|
50
|
+
|
51
|
+
def getheaders(self):
|
52
|
+
"""Returns a dictionary of the response headers."""
|
53
|
+
return self.response.headers
|
54
|
+
|
55
|
+
def getheader(self, name, default=None):
|
56
|
+
"""Returns a given response header."""
|
57
|
+
return self.response.headers.get(name, default)
|
58
|
+
|
59
|
+
|
60
|
+
class RESTClientObject:
|
61
|
+
|
62
|
+
def __init__(self, configuration) -> None:
|
63
|
+
# urllib3.PoolManager will pass all kw parameters to connectionpool
|
64
|
+
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
|
65
|
+
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
|
66
|
+
# Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501
|
67
|
+
|
68
|
+
# cert_reqs
|
69
|
+
if configuration.verify_ssl:
|
70
|
+
cert_reqs = ssl.CERT_REQUIRED
|
71
|
+
else:
|
72
|
+
cert_reqs = ssl.CERT_NONE
|
73
|
+
|
74
|
+
pool_args = {
|
75
|
+
"cert_reqs": cert_reqs,
|
76
|
+
"ca_certs": configuration.ssl_ca_cert,
|
77
|
+
"cert_file": configuration.cert_file,
|
78
|
+
"key_file": configuration.key_file,
|
79
|
+
}
|
80
|
+
if configuration.assert_hostname is not None:
|
81
|
+
pool_args['assert_hostname'] = (
|
82
|
+
configuration.assert_hostname
|
83
|
+
)
|
84
|
+
|
85
|
+
if configuration.retries is not None:
|
86
|
+
pool_args['retries'] = configuration.retries
|
87
|
+
|
88
|
+
if configuration.tls_server_name:
|
89
|
+
pool_args['server_hostname'] = configuration.tls_server_name
|
90
|
+
|
91
|
+
|
92
|
+
if configuration.socket_options is not None:
|
93
|
+
pool_args['socket_options'] = configuration.socket_options
|
94
|
+
|
95
|
+
if configuration.connection_pool_maxsize is not None:
|
96
|
+
pool_args['maxsize'] = configuration.connection_pool_maxsize
|
97
|
+
|
98
|
+
# https pool manager
|
99
|
+
self.pool_manager: urllib3.PoolManager
|
100
|
+
|
101
|
+
if configuration.proxy:
|
102
|
+
if is_socks_proxy_url(configuration.proxy):
|
103
|
+
from urllib3.contrib.socks import SOCKSProxyManager
|
104
|
+
pool_args["proxy_url"] = configuration.proxy
|
105
|
+
pool_args["headers"] = configuration.proxy_headers
|
106
|
+
self.pool_manager = SOCKSProxyManager(**pool_args)
|
107
|
+
else:
|
108
|
+
pool_args["proxy_url"] = configuration.proxy
|
109
|
+
pool_args["proxy_headers"] = configuration.proxy_headers
|
110
|
+
self.pool_manager = urllib3.ProxyManager(**pool_args)
|
111
|
+
else:
|
112
|
+
self.pool_manager = urllib3.PoolManager(**pool_args)
|
113
|
+
|
114
|
+
def request(
|
115
|
+
self,
|
116
|
+
method,
|
117
|
+
url,
|
118
|
+
headers=None,
|
119
|
+
body=None,
|
120
|
+
post_params=None,
|
121
|
+
_request_timeout=None
|
122
|
+
):
|
123
|
+
"""Perform requests.
|
124
|
+
|
125
|
+
:param method: http request method
|
126
|
+
:param url: http request url
|
127
|
+
:param headers: http request headers
|
128
|
+
:param body: request json body, for `application/json`
|
129
|
+
:param post_params: request post parameters,
|
130
|
+
`application/x-www-form-urlencoded`
|
131
|
+
and `multipart/form-data`
|
132
|
+
:param _request_timeout: timeout setting for this request. If one
|
133
|
+
number provided, it will be total request
|
134
|
+
timeout. It can also be a pair (tuple) of
|
135
|
+
(connection, read) timeouts.
|
136
|
+
"""
|
137
|
+
method = method.upper()
|
138
|
+
assert method in [
|
139
|
+
'GET',
|
140
|
+
'HEAD',
|
141
|
+
'DELETE',
|
142
|
+
'POST',
|
143
|
+
'PUT',
|
144
|
+
'PATCH',
|
145
|
+
'OPTIONS'
|
146
|
+
]
|
147
|
+
|
148
|
+
if post_params and body:
|
149
|
+
raise ApiValueError(
|
150
|
+
"body parameter cannot be used with post_params parameter."
|
151
|
+
)
|
152
|
+
|
153
|
+
post_params = post_params or {}
|
154
|
+
headers = headers or {}
|
155
|
+
|
156
|
+
timeout = None
|
157
|
+
if _request_timeout:
|
158
|
+
if isinstance(_request_timeout, (int, float)):
|
159
|
+
timeout = urllib3.Timeout(total=_request_timeout)
|
160
|
+
elif (
|
161
|
+
isinstance(_request_timeout, tuple)
|
162
|
+
and len(_request_timeout) == 2
|
163
|
+
):
|
164
|
+
timeout = urllib3.Timeout(
|
165
|
+
connect=_request_timeout[0],
|
166
|
+
read=_request_timeout[1]
|
167
|
+
)
|
168
|
+
|
169
|
+
try:
|
170
|
+
# For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE`
|
171
|
+
if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']:
|
172
|
+
|
173
|
+
# no content type provided or payload is json
|
174
|
+
content_type = headers.get('Content-Type')
|
175
|
+
if (
|
176
|
+
not content_type
|
177
|
+
or re.search('json', content_type, re.IGNORECASE)
|
178
|
+
):
|
179
|
+
request_body = None
|
180
|
+
if body is not None:
|
181
|
+
request_body = json.dumps(body)
|
182
|
+
r = self.pool_manager.request(
|
183
|
+
method,
|
184
|
+
url,
|
185
|
+
body=request_body,
|
186
|
+
timeout=timeout,
|
187
|
+
headers=headers,
|
188
|
+
preload_content=False
|
189
|
+
)
|
190
|
+
elif content_type == 'application/x-www-form-urlencoded':
|
191
|
+
r = self.pool_manager.request(
|
192
|
+
method,
|
193
|
+
url,
|
194
|
+
fields=post_params,
|
195
|
+
encode_multipart=False,
|
196
|
+
timeout=timeout,
|
197
|
+
headers=headers,
|
198
|
+
preload_content=False
|
199
|
+
)
|
200
|
+
elif content_type == 'multipart/form-data':
|
201
|
+
# must del headers['Content-Type'], or the correct
|
202
|
+
# Content-Type which generated by urllib3 will be
|
203
|
+
# overwritten.
|
204
|
+
del headers['Content-Type']
|
205
|
+
# Ensures that dict objects are serialized
|
206
|
+
post_params = [(a, json.dumps(b)) if isinstance(b, dict) else (a,b) for a, b in post_params]
|
207
|
+
r = self.pool_manager.request(
|
208
|
+
method,
|
209
|
+
url,
|
210
|
+
fields=post_params,
|
211
|
+
encode_multipart=True,
|
212
|
+
timeout=timeout,
|
213
|
+
headers=headers,
|
214
|
+
preload_content=False
|
215
|
+
)
|
216
|
+
# Pass a `string` parameter directly in the body to support
|
217
|
+
# other content types than JSON when `body` argument is
|
218
|
+
# provided in serialized form.
|
219
|
+
elif isinstance(body, str) or isinstance(body, bytes):
|
220
|
+
r = self.pool_manager.request(
|
221
|
+
method,
|
222
|
+
url,
|
223
|
+
body=body,
|
224
|
+
timeout=timeout,
|
225
|
+
headers=headers,
|
226
|
+
preload_content=False
|
227
|
+
)
|
228
|
+
elif headers['Content-Type'] == 'text/plain' and isinstance(body, bool):
|
229
|
+
request_body = "true" if body else "false"
|
230
|
+
r = self.pool_manager.request(
|
231
|
+
method,
|
232
|
+
url,
|
233
|
+
body=request_body,
|
234
|
+
preload_content=False,
|
235
|
+
timeout=timeout,
|
236
|
+
headers=headers)
|
237
|
+
else:
|
238
|
+
# Cannot generate the request from given parameters
|
239
|
+
msg = """Cannot prepare a request message for provided
|
240
|
+
arguments. Please check that your arguments match
|
241
|
+
declared content type."""
|
242
|
+
raise ApiException(status=0, reason=msg)
|
243
|
+
# For `GET`, `HEAD`
|
244
|
+
else:
|
245
|
+
r = self.pool_manager.request(
|
246
|
+
method,
|
247
|
+
url,
|
248
|
+
fields={},
|
249
|
+
timeout=timeout,
|
250
|
+
headers=headers,
|
251
|
+
preload_content=False
|
252
|
+
)
|
253
|
+
except urllib3.exceptions.SSLError as e:
|
254
|
+
msg = "\n".join([type(e).__name__, str(e)])
|
255
|
+
raise ApiException(status=0, reason=msg)
|
256
|
+
|
257
|
+
return RESTResponse(r)
|
File without changes
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import base64
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from hiddenlayer.sdk.rest.api import AidrPredictiveApi
|
8
|
+
from hiddenlayer.sdk.rest.api_client import ApiClient
|
9
|
+
from hiddenlayer.sdk.rest.models import (
|
10
|
+
SubmissionResponse,
|
11
|
+
SubmissionV2,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class AIDRPredictive:
|
16
|
+
def __init__(self, api_client: ApiClient) -> None:
|
17
|
+
self._aidr_predictive = AidrPredictiveApi(api_client=api_client)
|
18
|
+
|
19
|
+
def submit_vectors(
|
20
|
+
self,
|
21
|
+
*,
|
22
|
+
model_id: str,
|
23
|
+
requester_id: str,
|
24
|
+
input_vectors: Union[List[float], np.ndarray],
|
25
|
+
output: Union[List[float], np.ndarray],
|
26
|
+
predictions: Optional[List[float]] = None,
|
27
|
+
tags: Optional[List[str]] = None,
|
28
|
+
metadata: Optional[Dict[str, Any]] = None,
|
29
|
+
event_time: Optional[str] = None,
|
30
|
+
) -> SubmissionResponse:
|
31
|
+
"""
|
32
|
+
Submit feature vectors and model outputs via the HiddenLayer API.
|
33
|
+
|
34
|
+
:param model_id: Model id.
|
35
|
+
:param requester_id: Custom identifier for the inbound request. This should be a value that can be used to identify individual users interacting with the model.
|
36
|
+
:param input_vectors: Feature vectors for your model.
|
37
|
+
:param output: Output vectors directly from your model.
|
38
|
+
:param predictions: If you ran `np.argmax` or `np.argmin` or provided custom logic onto the model output.
|
39
|
+
:param tags: Custom tags attached to the request.
|
40
|
+
:param metadata: Custom metadata attached to the request.
|
41
|
+
:param event_time: Time when the features and outputs were created, defaults to now.
|
42
|
+
|
43
|
+
:returns: Submission Response
|
44
|
+
"""
|
45
|
+
|
46
|
+
input_vectors = (
|
47
|
+
np.array(input_vectors)
|
48
|
+
if isinstance(input_vectors, list)
|
49
|
+
else input_vectors
|
50
|
+
)
|
51
|
+
output = np.array(output) if isinstance(output, list) else output
|
52
|
+
|
53
|
+
# Output vectors need to be at least 2d or AIDR will fail silently
|
54
|
+
output = output.reshape(-1, 1) if len(output.shape) == 1 else output
|
55
|
+
|
56
|
+
input_layer = base64.b64encode(input_vectors.tobytes()).decode()
|
57
|
+
output_layer = base64.b64encode(output.tobytes()).decode()
|
58
|
+
|
59
|
+
return self._aidr_predictive.submit_vectors(
|
60
|
+
SubmissionV2(
|
61
|
+
metadata=metadata if metadata else {},
|
62
|
+
tags=tags if tags else [],
|
63
|
+
sensor_id=model_id,
|
64
|
+
requester_id=requester_id,
|
65
|
+
input_layer=input_layer,
|
66
|
+
input_layer_dtype=str(input_vectors.dtype),
|
67
|
+
input_layer_shape=list(input_vectors.shape),
|
68
|
+
output_layer=output_layer,
|
69
|
+
output_layer_dtype=str(output.dtype),
|
70
|
+
output_layer_shape=list(output.shape),
|
71
|
+
predictions=predictions,
|
72
|
+
event_time=event_time
|
73
|
+
if event_time
|
74
|
+
else str(datetime.now().isoformat()),
|
75
|
+
)
|
76
|
+
)
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from hiddenlayer.sdk.constants import ApiErrors
|
5
|
+
from hiddenlayer.sdk.exceptions import (
|
6
|
+
HiddenlayerConflictError,
|
7
|
+
IncompatibleModelError,
|
8
|
+
ModelDoesNotExistError,
|
9
|
+
)
|
10
|
+
from hiddenlayer.sdk.rest.api import SensorApi
|
11
|
+
from hiddenlayer.sdk.rest.api_client import ApiClient
|
12
|
+
from hiddenlayer.sdk.rest.exceptions import ApiException
|
13
|
+
from hiddenlayer.sdk.rest.models import (
|
14
|
+
CreateSensorRequest,
|
15
|
+
Model,
|
16
|
+
SensorSORQueryFilter,
|
17
|
+
SensorSORQueryRequest,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
class ModelAPI:
|
22
|
+
def __init__(self, api_client: ApiClient) -> None:
|
23
|
+
self._sensor_api = SensorApi(api_client=api_client)
|
24
|
+
|
25
|
+
def create(self, *, model_name: str) -> Model:
|
26
|
+
"""
|
27
|
+
Creates a model in the HiddenLayer Platform.
|
28
|
+
|
29
|
+
:params model_name: Name of the model
|
30
|
+
|
31
|
+
:returns: HiddenLayer ModelID
|
32
|
+
"""
|
33
|
+
return self._sensor_api.create_sensor(
|
34
|
+
CreateSensorRequest(plaintext_name=model_name, adhoc=True)
|
35
|
+
)
|
36
|
+
|
37
|
+
def get(self, *, model_name: str, version: Optional[int] = None) -> Model:
|
38
|
+
"""
|
39
|
+
Gets a HiddenLayer model object. If not version is supplied, the latest model is returned.
|
40
|
+
|
41
|
+
:param model_name: Name of the model.
|
42
|
+
:param version: Version of the model to get.
|
43
|
+
|
44
|
+
:returns: HiddenLayer Model object
|
45
|
+
"""
|
46
|
+
|
47
|
+
return self._get_model_by_name(model_name=model_name, version=version)
|
48
|
+
|
49
|
+
def delete(self, *, model_name: str) -> None:
|
50
|
+
"""
|
51
|
+
Delete a model.
|
52
|
+
|
53
|
+
:param model_name: Name of the model.
|
54
|
+
"""
|
55
|
+
|
56
|
+
model = self._get_model_by_name(model_name=model_name, version=None)
|
57
|
+
|
58
|
+
try:
|
59
|
+
self._sensor_api.delete_model(sensor_id=model.sensor_id)
|
60
|
+
except ApiException as e:
|
61
|
+
reason = json.loads(str(e.body))["detail"]
|
62
|
+
|
63
|
+
if reason == ApiErrors.NON_ADHOC_SENSOR_DELETE:
|
64
|
+
raise IncompatibleModelError(
|
65
|
+
"This type of model is unable to be deleted."
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
raise HiddenlayerConflictError(reason)
|
69
|
+
except Exception as e:
|
70
|
+
raise e
|
71
|
+
|
72
|
+
def _get_model_by_name(
|
73
|
+
self, *, model_name: str, version: Optional[int] = None
|
74
|
+
) -> Model:
|
75
|
+
"""
|
76
|
+
Gets a model object by name. If version is not supplied, get the latest.
|
77
|
+
|
78
|
+
:param model_name: Name of the model.
|
79
|
+
:param version: Version of the model to get.
|
80
|
+
|
81
|
+
:returns: HiddenLayer Model object
|
82
|
+
"""
|
83
|
+
|
84
|
+
models = self._sensor_api.query_sensor(
|
85
|
+
sensor_sor_query_request=SensorSORQueryRequest(
|
86
|
+
filter=SensorSORQueryFilter(plaintext_name=model_name, version=version)
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
if not models.results:
|
91
|
+
msg = f"Model {model_name} does not exist"
|
92
|
+
|
93
|
+
if version:
|
94
|
+
msg = f"{msg} for version {version}."
|
95
|
+
|
96
|
+
raise ModelDoesNotExistError(msg)
|
97
|
+
|
98
|
+
if not version:
|
99
|
+
models.results.sort(key=lambda x: x.version, reverse=True)
|
100
|
+
|
101
|
+
return models.results[0]
|