hiddenlayer-sdk 1.2.1__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. hiddenlayer/__init__.py +0 -10
  2. hiddenlayer/sdk/exceptions.py +1 -1
  3. hiddenlayer/sdk/models.py +2 -3
  4. hiddenlayer/sdk/rest/__init__.py +16 -11
  5. hiddenlayer/sdk/rest/api/__init__.py +0 -1
  6. hiddenlayer/sdk/rest/api/model_supply_chain_api.py +1706 -571
  7. hiddenlayer/sdk/rest/api/sensor_api.py +214 -1320
  8. hiddenlayer/sdk/rest/models/__init__.py +16 -10
  9. hiddenlayer/sdk/rest/models/{scan_model_request.py → begin_multi_file_upload200_response.py} +9 -9
  10. hiddenlayer/sdk/rest/models/{get_multipart_upload_response.py → begin_multipart_file_upload200_response.py} +9 -9
  11. hiddenlayer/sdk/rest/models/{multipart_upload_part.py → begin_multipart_file_upload200_response_parts_inner.py} +11 -10
  12. hiddenlayer/sdk/rest/models/errors_inner.py +91 -0
  13. hiddenlayer/sdk/rest/models/file_details_v3.py +8 -2
  14. hiddenlayer/sdk/rest/models/{scan_results_v2.py → file_result_v3.py} +21 -32
  15. hiddenlayer/sdk/rest/models/{model_scan_api_v3_scan_query200_response.py → get_condensed_model_scan_reports200_response.py} +4 -4
  16. hiddenlayer/sdk/rest/models/inventory_v3.py +97 -0
  17. hiddenlayer/sdk/rest/models/model_inventory_info.py +1 -1
  18. hiddenlayer/sdk/rest/models/{detections.py → multi_file_upload_request_v3.py} +14 -22
  19. hiddenlayer/sdk/rest/models/{model_scan_api_v3_scan_model_version_id_patch200_response.py → notify_model_scan_completed200_response.py} +4 -4
  20. hiddenlayer/sdk/rest/models/pagination_v3.py +95 -0
  21. hiddenlayer/sdk/rest/models/problem_details.py +103 -0
  22. hiddenlayer/sdk/rest/models/scan_detection_v31.py +155 -0
  23. hiddenlayer/sdk/rest/models/scan_model_details_v3.py +1 -1
  24. hiddenlayer/sdk/rest/models/scan_results_map_v3.py +105 -0
  25. hiddenlayer/sdk/rest/models/scan_results_v3.py +120 -0
  26. hiddenlayer/sdk/rest/models/{model.py → sensor.py} +4 -4
  27. hiddenlayer/sdk/rest/models/{model_query_response.py → sensor_query_response.py} +7 -7
  28. hiddenlayer/sdk/services/aidr_predictive.py +57 -3
  29. hiddenlayer/sdk/services/model_scan.py +98 -135
  30. hiddenlayer/sdk/version.py +1 -1
  31. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/METADATA +12 -2
  32. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/RECORD +35 -31
  33. hiddenlayer/sdk/rest/api/model_scan_api.py +0 -591
  34. hiddenlayer/sdk/rest/models/scan_results.py +0 -118
  35. hiddenlayer/sdk/services/model.py +0 -149
  36. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/LICENSE +0 -0
  37. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/WHEEL +0 -0
  38. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,120 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ HiddenLayer ModelScan V2
5
+
6
+ HiddenLayer ModelScan API for scanning of models
7
+
8
+ The version of the OpenAPI document: 1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, StrictInt, StrictStr, field_validator
21
+ from typing import Any, ClassVar, Dict, List, Optional
22
+ from hiddenlayer.sdk.rest.models.file_result_v3 import FileResultV3
23
+ from typing import Optional, Set
24
+ from typing_extensions import Self
25
+
26
+ class ScanResultsV3(BaseModel):
27
+ """
28
+ ScanResultsV3
29
+ """ # noqa: E501
30
+ scan_id: Optional[StrictStr] = None
31
+ start_time: Optional[StrictInt] = None
32
+ end_time: Optional[StrictInt] = None
33
+ status: Optional[StrictStr] = None
34
+ version: Optional[StrictStr] = None
35
+ inventory: Optional[Dict[str, Any]] = None
36
+ file_results: List[FileResultV3]
37
+ __properties: ClassVar[List[str]] = ["scan_id", "start_time", "end_time", "status", "version", "inventory", "file_results"]
38
+
39
+ @field_validator('status')
40
+ def status_validate_enum(cls, value):
41
+ """Validates the enum"""
42
+ if value is None:
43
+ return value
44
+
45
+ if value not in set(['done', 'running', 'failed', 'pending', 'canceled']):
46
+ raise ValueError("must be one of enum values ('done', 'running', 'failed', 'pending', 'canceled')")
47
+ return value
48
+
49
+ model_config = ConfigDict(
50
+ populate_by_name=True,
51
+ validate_assignment=True,
52
+ protected_namespaces=(),
53
+ )
54
+
55
+
56
+ def to_str(self) -> str:
57
+ """Returns the string representation of the model using alias"""
58
+ return pprint.pformat(self.model_dump(by_alias=True))
59
+
60
+ def to_json(self) -> str:
61
+ """Returns the JSON representation of the model using alias"""
62
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
63
+ return json.dumps(self.to_dict())
64
+
65
+ @classmethod
66
+ def from_json(cls, json_str: str) -> Optional[Self]:
67
+ """Create an instance of ScanResultsV3 from a JSON string"""
68
+ return cls.from_dict(json.loads(json_str))
69
+
70
+ def to_dict(self) -> Dict[str, Any]:
71
+ """Return the dictionary representation of the model using alias.
72
+
73
+ This has the following differences from calling pydantic's
74
+ `self.model_dump(by_alias=True)`:
75
+
76
+ * `None` is only added to the output dict for nullable fields that
77
+ were set at model initialization. Other fields with value `None`
78
+ are ignored.
79
+ """
80
+ excluded_fields: Set[str] = set([
81
+ ])
82
+
83
+ _dict = self.model_dump(
84
+ by_alias=True,
85
+ exclude=excluded_fields,
86
+ exclude_none=True,
87
+ )
88
+ # override the default output from pydantic by calling `to_dict()` of inventory
89
+ if self.inventory:
90
+ _dict['inventory'] = self.inventory.to_dict()
91
+ # override the default output from pydantic by calling `to_dict()` of each item in file_results (list)
92
+ _items = []
93
+ if self.file_results:
94
+ for _item in self.file_results:
95
+ if _item:
96
+ _items.append(_item.to_dict())
97
+ _dict['file_results'] = _items
98
+ return _dict
99
+
100
+ @classmethod
101
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
102
+ """Create an instance of ScanResultsV3 from a dict"""
103
+ if obj is None:
104
+ return None
105
+
106
+ if not isinstance(obj, dict):
107
+ return cls.model_validate(obj)
108
+
109
+ _obj = cls.model_validate({
110
+ "scan_id": obj.get("scan_id"),
111
+ "start_time": obj.get("start_time"),
112
+ "end_time": obj.get("end_time"),
113
+ "status": obj.get("status"),
114
+ "version": obj.get("version"),
115
+ "inventory": InventoryV3.from_dict(obj["inventory"]) if obj.get("inventory") is not None else None,
116
+ "file_results": [FileResultV3.from_dict(_item) for _item in obj["file_results"]] if obj.get("file_results") is not None else None
117
+ })
118
+ return _obj
119
+
120
+
@@ -23,9 +23,9 @@ from typing import Any, ClassVar, Dict, List, Optional
23
23
  from typing import Optional, Set
24
24
  from typing_extensions import Self
25
25
 
26
- class Model(BaseModel):
26
+ class Sensor(BaseModel):
27
27
  """
28
- Model
28
+ Sensor
29
29
  """ # noqa: E501
30
30
  sensor_id: StrictStr
31
31
  created_at: datetime
@@ -54,7 +54,7 @@ class Model(BaseModel):
54
54
 
55
55
  @classmethod
56
56
  def from_json(cls, json_str: str) -> Optional[Self]:
57
- """Create an instance of Model from a JSON string"""
57
+ """Create an instance of Sensor from a JSON string"""
58
58
  return cls.from_dict(json.loads(json_str))
59
59
 
60
60
  def to_dict(self) -> Dict[str, Any]:
@@ -79,7 +79,7 @@ class Model(BaseModel):
79
79
 
80
80
  @classmethod
81
81
  def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
82
- """Create an instance of Model from a dict"""
82
+ """Create an instance of Sensor from a dict"""
83
83
  if obj is None:
84
84
  return None
85
85
 
@@ -19,18 +19,18 @@ import json
19
19
 
20
20
  from pydantic import BaseModel, ConfigDict, StrictInt
21
21
  from typing import Any, ClassVar, Dict, List
22
- from hiddenlayer.sdk.rest.models.model import Model
22
+ from hiddenlayer.sdk.rest.models.sensor import Sensor
23
23
  from typing import Optional, Set
24
24
  from typing_extensions import Self
25
25
 
26
- class ModelQueryResponse(BaseModel):
26
+ class SensorQueryResponse(BaseModel):
27
27
  """
28
- ModelQueryResponse
28
+ SensorQueryResponse
29
29
  """ # noqa: E501
30
30
  total_count: StrictInt
31
31
  page_size: StrictInt
32
32
  page_number: StrictInt
33
- results: List[Model]
33
+ results: List[Sensor]
34
34
  __properties: ClassVar[List[str]] = ["total_count", "page_size", "page_number", "results"]
35
35
 
36
36
  model_config = ConfigDict(
@@ -51,7 +51,7 @@ class ModelQueryResponse(BaseModel):
51
51
 
52
52
  @classmethod
53
53
  def from_json(cls, json_str: str) -> Optional[Self]:
54
- """Create an instance of ModelQueryResponse from a JSON string"""
54
+ """Create an instance of SensorQueryResponse from a JSON string"""
55
55
  return cls.from_dict(json.loads(json_str))
56
56
 
57
57
  def to_dict(self) -> Dict[str, Any]:
@@ -83,7 +83,7 @@ class ModelQueryResponse(BaseModel):
83
83
 
84
84
  @classmethod
85
85
  def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
86
- """Create an instance of ModelQueryResponse from a dict"""
86
+ """Create an instance of SensorQueryResponse from a dict"""
87
87
  if obj is None:
88
88
  return None
89
89
 
@@ -94,7 +94,7 @@ class ModelQueryResponse(BaseModel):
94
94
  "total_count": obj.get("total_count"),
95
95
  "page_size": obj.get("page_size"),
96
96
  "page_number": obj.get("page_number"),
97
- "results": [Model.from_dict(_item) for _item in obj["results"]] if obj.get("results") is not None else None
97
+ "results": [Sensor.from_dict(_item) for _item in obj["results"]] if obj.get("results") is not None else None
98
98
  })
99
99
  return _obj
100
100
 
@@ -4,22 +4,29 @@ from typing import Any, Dict, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
 
7
+ from hiddenlayer.sdk.exceptions import SensorDoesNotExistError
7
8
  from hiddenlayer.sdk.rest.api import AidrPredictiveApi
9
+ from hiddenlayer.sdk.rest.api.sensor_api import SensorApi
8
10
  from hiddenlayer.sdk.rest.api_client import ApiClient
9
11
  from hiddenlayer.sdk.rest.models import (
10
12
  SubmissionResponse,
11
13
  SubmissionV2,
12
14
  )
15
+ from hiddenlayer.sdk.rest.models.create_sensor_request import CreateSensorRequest
16
+ from hiddenlayer.sdk.rest.models.sensor import Sensor
17
+ from hiddenlayer.sdk.rest.models.sensor_sor_query_filter import SensorSORQueryFilter
18
+ from hiddenlayer.sdk.rest.models.sensor_sor_query_request import SensorSORQueryRequest
13
19
 
14
20
 
15
21
  class AIDRPredictive:
16
22
  def __init__(self, api_client: ApiClient) -> None:
23
+ self._sensor_api = SensorApi(api_client=api_client)
17
24
  self._aidr_predictive = AidrPredictiveApi(api_client=api_client)
18
25
 
19
26
  def submit_vectors(
20
27
  self,
21
28
  *,
22
- model_id: str,
29
+ sensor_id: str,
23
30
  requester_id: str,
24
31
  input_vectors: Union[List[float], np.ndarray],
25
32
  output: Union[List[float], np.ndarray],
@@ -31,7 +38,7 @@ class AIDRPredictive:
31
38
  """
32
39
  Submit feature vectors and model outputs via the HiddenLayer API.
33
40
 
34
- :param model_id: Model id.
41
+ :param sensor_id: Sensor id.
35
42
  :param requester_id: Custom identifier for the inbound request. This should be a value that can be used to identify individual users interacting with the model.
36
43
  :param input_vectors: Feature vectors for your model.
37
44
  :param output: Output vectors directly from your model.
@@ -60,7 +67,7 @@ class AIDRPredictive:
60
67
  SubmissionV2(
61
68
  metadata=metadata if metadata else {},
62
69
  tags=tags if tags else [],
63
- sensor_id=model_id,
70
+ sensor_id=sensor_id,
64
71
  requester_id=requester_id,
65
72
  input_layer=input_layer,
66
73
  input_layer_dtype=str(input_vectors.dtype),
@@ -74,3 +81,50 @@ class AIDRPredictive:
74
81
  else str(datetime.now().isoformat()),
75
82
  )
76
83
  )
84
+
85
+ def create_sensor(self, *, sensor_name: str) -> Sensor:
86
+ """
87
+ Creates a sensor in the HiddenLayer Platform.
88
+
89
+ :params sensor_name: Name of the sensor
90
+
91
+ :returns: HiddenLayer Sensor
92
+ """
93
+ return self._sensor_api.create_sensor(
94
+ CreateSensorRequest(plaintext_name=sensor_name)
95
+ )
96
+
97
+ def get_sensor(self, *, sensor_name: str) -> Sensor:
98
+ """
99
+ Gets a HiddenLayer sensor object.
100
+
101
+ :params sensor_name: Name of the sensor
102
+
103
+ :returns: HiddenLayer Sensor
104
+ """
105
+
106
+ return self._get_sensor_by_name(sensor_name=sensor_name)
107
+
108
+ def _get_sensor_by_name(self, *, sensor_name: str) -> Sensor:
109
+ """
110
+ Gets a model sensor by name.
111
+
112
+ :param sensor_name: Name of the model.
113
+
114
+ :returns: HiddenLayer Model object
115
+ """
116
+
117
+ sensors = self._sensor_api.query_sensor(
118
+ sensor_sor_query_request=SensorSORQueryRequest(
119
+ filter=SensorSORQueryFilter(plaintext_name=sensor_name)
120
+ )
121
+ )
122
+
123
+ if not sensors.results or len(sensors.results) == 0:
124
+ msg = f"ModSensorel {sensor_name} does not exist"
125
+
126
+ raise SensorDoesNotExistError(msg)
127
+
128
+ sensors.results.sort(key=lambda x: x.version, reverse=True)
129
+
130
+ return sensors.results[0]
@@ -1,26 +1,18 @@
1
- import json
2
1
  import os
3
2
  import random
4
3
  import tempfile
5
4
  import time
6
- import warnings
7
5
  import zipfile
8
- from datetime import datetime
9
6
  from pathlib import Path
10
7
  from typing import List, Optional, Union
11
- from uuid import uuid4
12
-
13
- from pydantic_core import ValidationError
14
8
 
15
9
  from hiddenlayer.sdk.constants import ScanStatus
16
- from hiddenlayer.sdk.models import EmptyScanResults, Sarif, ScanResults
17
- from hiddenlayer.sdk.rest.api import ModelScanApi, ModelSupplyChainApi, SensorApi
10
+ from hiddenlayer.sdk.models import EmptyScanResults, ScanResults
11
+ from hiddenlayer.sdk.rest.api import ModelSupplyChainApi
18
12
  from hiddenlayer.sdk.rest.api_client import ApiClient
19
- from hiddenlayer.sdk.rest.models import MultipartUploadPart
20
- from hiddenlayer.sdk.rest.models.model import Model
21
- from hiddenlayer.sdk.rest.models.sarif210 import Sarif210
22
- from hiddenlayer.sdk.services.model import ModelAPI
23
- from hiddenlayer.sdk.utils import filter_path_objects, is_saas
13
+ from hiddenlayer.sdk.rest.exceptions import NotFoundException
14
+ from hiddenlayer.sdk.rest.models import MultiFileUploadRequestV3
15
+ from hiddenlayer.sdk.utils import filter_path_objects
24
16
 
25
17
  EXCLUDE_FILE_TYPES = [
26
18
  "*.txt",
@@ -37,22 +29,14 @@ EXCLUDE_FILE_TYPES = [
37
29
  class ModelScanAPI:
38
30
  def __init__(self, api_client: ApiClient) -> None:
39
31
  self._api_client = api_client
40
-
41
32
  self._model_supply_chain_api = ModelSupplyChainApi(api_client=api_client)
42
- self._model_api = ModelAPI(api_client=api_client)
43
- self._sensor_api = SensorApi(
44
- api_client=api_client
45
- ) # lower level api of ModelAPI
46
-
47
- self._model_scan_api = ModelScanApi(api_client=api_client)
48
33
 
49
34
  def scan_file(
50
35
  self,
51
36
  *,
52
37
  model_name: str,
53
38
  model_path: Union[str, os.PathLike],
54
- model_version: Optional[int] = None,
55
- chunk_size: int = 16,
39
+ model_version: str = "1",
56
40
  wait_for_results: bool = True,
57
41
  ) -> ScanResults:
58
42
  """
@@ -69,52 +53,22 @@ class ModelScanAPI:
69
53
 
70
54
  file_path = Path(model_path)
71
55
 
72
- filesize = file_path.stat().st_size
73
- sensor = self._model_api.create_or_get(
74
- model_name=model_name, model_version=model_version
56
+ request = MultiFileUploadRequestV3(
57
+ model_name=model_name,
58
+ model_version=model_version,
59
+ requesting_entity="hiddenlayer-python-sdk",
75
60
  )
76
- upload = self._sensor_api.begin_multipart_upload(sensor.sensor_id, filesize)
77
-
78
- with open(file_path, "rb") as f:
79
- for i in range(0, len(upload.parts), chunk_size):
80
- group: List[MultipartUploadPart] = upload.parts[i : i + chunk_size]
81
- for part in group:
82
- read_amount = part.end_offset - part.start_offset
83
- f.seek(int(part.start_offset))
84
- part_data = f.read(int(read_amount))
85
-
86
- # The SaaS multipart upload returns a upload url for each part
87
- # So there is no specified route
88
- self._api_client.call_api(
89
- "PUT",
90
- part.upload_url,
91
- body=part_data,
92
- header_params={"Content-Type": "application/octet-binary"},
93
- )
94
-
95
- self._sensor_api.complete_multipart_upload(sensor.sensor_id, upload.upload_id)
96
-
97
- self._model_scan_api.scan_model(sensor.sensor_id)
98
-
99
- scan_results = self.get_scan_results(
100
- model_name=model_name, model_version=model_version
61
+ response = self._model_supply_chain_api.begin_multi_file_upload(
62
+ multi_file_upload_request_v3=request
101
63
  )
64
+ scan_id = response.scan_id
65
+ if scan_id is None:
66
+ raise Exception("scan_id must have a value")
102
67
 
103
- base_delay = 0.1 # seconds
104
- retries = 0
105
- if wait_for_results:
106
- print(f"{file_path.name} scan status: {scan_results.status}")
107
- while scan_results.status not in [ScanStatus.DONE, ScanStatus.FAILED]:
108
- retries += 1
109
- delay = base_delay * 2**retries + random.uniform(
110
- 0, 1
111
- ) # exponential back off retry
112
- time.sleep(delay)
113
- scan_results = self.get_scan_results(
114
- model_name=model_name, model_version=model_version
115
- )
116
- print(f"{file_path.name} scan status: {scan_results.status}")
68
+ self._scan_file(scan_id=scan_id, file_path=file_path)
117
69
 
70
+ self._model_supply_chain_api.complete_multi_file_upload(scan_id=scan_id)
71
+ scan_results = self._wait_for_scan_results(scan_id=scan_id)
118
72
  scan_results.file_name = file_path.name
119
73
  scan_results.file_path = str(file_path)
120
74
 
@@ -126,9 +80,8 @@ class ModelScanAPI:
126
80
  model_name: str,
127
81
  bucket: str,
128
82
  key: str,
129
- model_version: Optional[int] = None,
83
+ model_version: str = "1",
130
84
  s3_client: Optional[object] = None,
131
- chunk_size: int = 4,
132
85
  wait_for_results: bool = True,
133
86
  ) -> ScanResults:
134
87
  """
@@ -173,7 +126,6 @@ class ModelScanAPI:
173
126
  model_path=f"/tmp/{file_name}",
174
127
  model_name=model_name,
175
128
  model_version=model_version,
176
- chunk_size=chunk_size,
177
129
  wait_for_results=wait_for_results,
178
130
  )
179
131
 
@@ -184,10 +136,9 @@ class ModelScanAPI:
184
136
  account_url: str,
185
137
  container: str,
186
138
  blob: str,
187
- model_version: Optional[int] = None,
139
+ model_version: str = "1",
188
140
  blob_service_client: Optional[object] = None,
189
141
  credential: Optional[object] = None,
190
- chunk_size: int = 4,
191
142
  wait_for_results: bool = True,
192
143
  ) -> ScanResults:
193
144
  """
@@ -253,7 +204,6 @@ class ModelScanAPI:
253
204
  model_path=f"/tmp/{file_name}",
254
205
  model_name=model_name,
255
206
  model_version=model_version,
256
- chunk_size=chunk_size,
257
207
  wait_for_results=wait_for_results,
258
208
  )
259
209
 
@@ -270,8 +220,6 @@ class ModelScanAPI:
270
220
  ignore_file_patterns: Optional[List[str]] = None,
271
221
  force_download: bool = False,
272
222
  hf_token: Optional[Union[str, bool]] = None,
273
- # HL parameters
274
- chunk_size: int = 4,
275
223
  wait_for_results: bool = True,
276
224
  ) -> ScanResults:
277
225
  """
@@ -317,21 +265,19 @@ class ModelScanAPI:
317
265
  token=hf_token,
318
266
  )
319
267
 
268
+ if revision is None:
269
+ revision = "1"
270
+
320
271
  return self.scan_folder(
321
272
  model_name=model_name or repo_id,
273
+ model_version=revision,
322
274
  path=local_dir,
323
275
  allow_file_patterns=allow_file_patterns,
324
276
  ignore_file_patterns=ignore_file_patterns,
325
- chunk_size=chunk_size,
326
277
  wait_for_results=wait_for_results,
327
278
  )
328
279
 
329
- def get_scan_results(
330
- self,
331
- *,
332
- model_name: str,
333
- model_version: Optional[int] = None,
334
- ) -> ScanResults:
280
+ def get_scan_results(self, *, scan_id: str) -> ScanResults:
335
281
  """
336
282
  Get results from a model scan.
337
283
 
@@ -341,48 +287,17 @@ class ModelScanAPI:
341
287
  :returns: Scan results.
342
288
  """
343
289
 
344
- response = self._sensor_api.sensor_sor_api_v3_model_cards_query_get(
345
- model_name_eq=model_name, limit=1
346
- )
347
- model_id = response.results[0].model_id
348
-
349
- scans = self._model_supply_chain_api.model_scan_api_v3_scan_query(
350
- model_ids=[model_id], latest_per_model_version_only=True
351
- )
352
- if scans.total == 0:
353
- return EmptyScanResults()
354
-
355
- if scans.items is None:
356
- return EmptyScanResults()
357
-
358
- scan = scans.items[0]
359
- if model_version:
360
- scan = next(
361
- (
362
- s
363
- for s in scans.items
364
- if s.inventory.model_version == str(model_version)
365
- ),
366
- None,
367
- )
368
- if not scan:
290
+ try:
291
+ scan_report = self._model_supply_chain_api.get_scan_results(scan_id)
292
+ except NotFoundException:
369
293
  return EmptyScanResults()
370
294
 
371
- scan_report = (
372
- self._model_supply_chain_api.model_scan_api_v3_scan_model_version_id_get(
373
- scan.scan_id
374
- )
375
- )
376
-
377
- return ScanResults.from_scanreportv3(
378
- scan_report_v3=scan_report, model_id=model_id
379
- )
295
+ return ScanResults.from_scanreportv3(scan_report_v3=scan_report)
380
296
 
381
297
  def get_sarif_results(
382
298
  self,
383
299
  *,
384
- model_name: str,
385
- model_version: Optional[int] = None,
300
+ scan_id: str,
386
301
  ) -> Optional[str]:
387
302
  """
388
303
  Get sarif results from a model scan.
@@ -392,34 +307,31 @@ class ModelScanAPI:
392
307
 
393
308
  :returns: Scan results.
394
309
  """
395
- scan = self.get_scan_results(model_name=model_name, model_version=model_version)
396
- if scan.scan_id == "":
397
- return None
398
310
 
399
311
  # Unfortunately, the generated code for the API doesn't directly support modifying the Accept header
400
312
  # in order to enable us to get the Sarif results
401
313
  # Here we will reach in to the request serialization process. The 2nd element in the tuple is the headers
402
314
  # where we will modify the Accept header to application/sarif+json
403
- request = self._model_supply_chain_api._model_scan_api_v3_scan_model_version_id_get_serialize(
404
- scan.scan_id, None, None, None, None, 0
315
+ request = self._model_supply_chain_api._get_scan_results_serialize(
316
+ scan_id, None, None, None, None, 0
405
317
  )
406
318
  request[2]["Accept"] = "application/sarif+json"
407
319
  response = self._api_client.call_api(*request)
408
320
  response.read()
409
321
 
410
- return self._api_client.response_deserialize(
411
- response_data=response, response_types_map={"200": str}
412
- ).data # type: ignore
322
+ if response.data is None:
323
+ return None
324
+
325
+ return response.data.decode()
413
326
 
414
327
  def scan_folder(
415
328
  self,
416
329
  *,
417
330
  model_name: str,
418
331
  path: Union[str, os.PathLike],
419
- model_version: Optional[int] = None,
332
+ model_version: str = "1",
420
333
  allow_file_patterns: Optional[List[str]] = None,
421
334
  ignore_file_patterns: Optional[List[str]] = None,
422
- chunk_size: int = 4,
423
335
  wait_for_results: bool = True,
424
336
  ) -> ScanResults:
425
337
  """
@@ -437,7 +349,18 @@ class ModelScanAPI:
437
349
  """
438
350
 
439
351
  model_path = Path(path)
440
- filename = tempfile.NamedTemporaryFile().name + ".zip"
352
+
353
+ request = MultiFileUploadRequestV3(
354
+ model_name=model_name,
355
+ model_version=model_version,
356
+ requesting_entity="hiddenlayer-python-sdk",
357
+ )
358
+ response = self._model_supply_chain_api.begin_multi_file_upload(
359
+ multi_file_upload_request_v3=request
360
+ )
361
+ scan_id = response.scan_id
362
+ if scan_id is None:
363
+ raise Exception("scan_id must have a value")
441
364
 
442
365
  ignore_file_patterns = (
443
366
  EXCLUDE_FILE_TYPES + ignore_file_patterns
@@ -451,14 +374,54 @@ class ModelScanAPI:
451
374
  ignore_patterns=ignore_file_patterns,
452
375
  )
453
376
 
454
- with zipfile.ZipFile(filename, "a") as zipf:
455
- for file in files:
456
- zipf.write(file, os.path.relpath(file, model_path))
377
+ for file in files:
378
+ self._scan_file(scan_id=scan_id, file_path=Path(file))
457
379
 
458
- return self.scan_file(
459
- model_name=model_name,
460
- model_version=model_version,
461
- model_path=filename,
462
- chunk_size=chunk_size,
463
- wait_for_results=wait_for_results,
380
+ self._model_supply_chain_api.complete_multi_file_upload(scan_id=scan_id)
381
+ scan_results = self._wait_for_scan_results(scan_id=scan_id)
382
+
383
+ return scan_results
384
+
385
+ def _scan_file(self, *, scan_id: str, file_path: Path):
386
+ filesize = file_path.stat().st_size
387
+ upload = self._model_supply_chain_api.begin_multipart_file_upload(
388
+ scan_id=str(scan_id), file_name=str(file_path), file_content_length=filesize
464
389
  )
390
+
391
+ with open(file_path, "rb") as f:
392
+ for part in upload.parts:
393
+ if part.start_offset is None:
394
+ raise Exception("part must have a start_offset")
395
+ if part.stop_offset is not None:
396
+ read_amount = part.stop_offset - part.start_offset
397
+ else:
398
+ read_amount = None
399
+ f.seek(part.start_offset)
400
+ part_data = f.read(read_amount)
401
+ self._api_client.call_api(
402
+ "PUT",
403
+ part.upload_url,
404
+ body=part_data,
405
+ header_params={"Content-Type": "application/octet-binary"},
406
+ )
407
+
408
+ self._model_supply_chain_api.complete_multipart_file_upload(
409
+ scan_id=scan_id, file_id=upload.upload_id
410
+ )
411
+
412
+ def _wait_for_scan_results(self, *, scan_id: str):
413
+ scan_results = self.get_scan_results(scan_id=scan_id)
414
+
415
+ base_delay = 0.1 # seconds
416
+ retries = 0
417
+ print(f"scan status: {scan_results.status}")
418
+ while scan_results.status not in [ScanStatus.DONE, ScanStatus.FAILED]:
419
+ retries += 1
420
+ delay = base_delay * 2**retries + random.uniform(
421
+ 0, 1
422
+ ) # exponential back off retry
423
+ time.sleep(delay)
424
+ scan_results = self.get_scan_results(scan_id=scan_id)
425
+ print(f"scan status: {scan_results.status}")
426
+
427
+ return scan_results
@@ -1 +1 @@
1
- VERSION = "1.2.1"
1
+ VERSION = "2.0.1"