hiddenlayer-sdk 1.2.1__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. hiddenlayer/__init__.py +0 -10
  2. hiddenlayer/sdk/exceptions.py +1 -1
  3. hiddenlayer/sdk/models.py +2 -3
  4. hiddenlayer/sdk/rest/__init__.py +16 -11
  5. hiddenlayer/sdk/rest/api/__init__.py +0 -1
  6. hiddenlayer/sdk/rest/api/model_supply_chain_api.py +1706 -571
  7. hiddenlayer/sdk/rest/api/sensor_api.py +214 -1320
  8. hiddenlayer/sdk/rest/models/__init__.py +16 -10
  9. hiddenlayer/sdk/rest/models/{scan_model_request.py → begin_multi_file_upload200_response.py} +9 -9
  10. hiddenlayer/sdk/rest/models/{get_multipart_upload_response.py → begin_multipart_file_upload200_response.py} +9 -9
  11. hiddenlayer/sdk/rest/models/{multipart_upload_part.py → begin_multipart_file_upload200_response_parts_inner.py} +11 -10
  12. hiddenlayer/sdk/rest/models/errors_inner.py +91 -0
  13. hiddenlayer/sdk/rest/models/file_details_v3.py +8 -2
  14. hiddenlayer/sdk/rest/models/{scan_results_v2.py → file_result_v3.py} +21 -32
  15. hiddenlayer/sdk/rest/models/{model_scan_api_v3_scan_query200_response.py → get_condensed_model_scan_reports200_response.py} +4 -4
  16. hiddenlayer/sdk/rest/models/inventory_v3.py +97 -0
  17. hiddenlayer/sdk/rest/models/model_inventory_info.py +1 -1
  18. hiddenlayer/sdk/rest/models/{detections.py → multi_file_upload_request_v3.py} +14 -22
  19. hiddenlayer/sdk/rest/models/{model_scan_api_v3_scan_model_version_id_patch200_response.py → notify_model_scan_completed200_response.py} +4 -4
  20. hiddenlayer/sdk/rest/models/pagination_v3.py +95 -0
  21. hiddenlayer/sdk/rest/models/problem_details.py +103 -0
  22. hiddenlayer/sdk/rest/models/scan_detection_v31.py +155 -0
  23. hiddenlayer/sdk/rest/models/scan_model_details_v3.py +1 -1
  24. hiddenlayer/sdk/rest/models/scan_results_map_v3.py +105 -0
  25. hiddenlayer/sdk/rest/models/scan_results_v3.py +120 -0
  26. hiddenlayer/sdk/rest/models/{model.py → sensor.py} +4 -4
  27. hiddenlayer/sdk/rest/models/{model_query_response.py → sensor_query_response.py} +7 -7
  28. hiddenlayer/sdk/services/aidr_predictive.py +57 -3
  29. hiddenlayer/sdk/services/model_scan.py +98 -135
  30. hiddenlayer/sdk/version.py +1 -1
  31. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/METADATA +12 -2
  32. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/RECORD +35 -31
  33. hiddenlayer/sdk/rest/api/model_scan_api.py +0 -591
  34. hiddenlayer/sdk/rest/models/scan_results.py +0 -118
  35. hiddenlayer/sdk/services/model.py +0 -149
  36. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/LICENSE +0 -0
  37. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/WHEEL +0 -0
  38. {hiddenlayer_sdk-1.2.1.dist-info → hiddenlayer_sdk-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,120 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ HiddenLayer ModelScan V2
5
+
6
+ HiddenLayer ModelScan API for scanning of models
7
+
8
+ The version of the OpenAPI document: 1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, StrictInt, StrictStr, field_validator
21
+ from typing import Any, ClassVar, Dict, List, Optional
22
+ from hiddenlayer.sdk.rest.models.file_result_v3 import FileResultV3
23
+ from typing import Optional, Set
24
+ from typing_extensions import Self
25
+
26
+ class ScanResultsV3(BaseModel):
27
+ """
28
+ ScanResultsV3
29
+ """ # noqa: E501
30
+ scan_id: Optional[StrictStr] = None
31
+ start_time: Optional[StrictInt] = None
32
+ end_time: Optional[StrictInt] = None
33
+ status: Optional[StrictStr] = None
34
+ version: Optional[StrictStr] = None
35
+ inventory: Optional[Dict[str, Any]] = None
36
+ file_results: List[FileResultV3]
37
+ __properties: ClassVar[List[str]] = ["scan_id", "start_time", "end_time", "status", "version", "inventory", "file_results"]
38
+
39
+ @field_validator('status')
40
+ def status_validate_enum(cls, value):
41
+ """Validates the enum"""
42
+ if value is None:
43
+ return value
44
+
45
+ if value not in set(['done', 'running', 'failed', 'pending', 'canceled']):
46
+ raise ValueError("must be one of enum values ('done', 'running', 'failed', 'pending', 'canceled')")
47
+ return value
48
+
49
+ model_config = ConfigDict(
50
+ populate_by_name=True,
51
+ validate_assignment=True,
52
+ protected_namespaces=(),
53
+ )
54
+
55
+
56
+ def to_str(self) -> str:
57
+ """Returns the string representation of the model using alias"""
58
+ return pprint.pformat(self.model_dump(by_alias=True))
59
+
60
+ def to_json(self) -> str:
61
+ """Returns the JSON representation of the model using alias"""
62
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
63
+ return json.dumps(self.to_dict())
64
+
65
+ @classmethod
66
+ def from_json(cls, json_str: str) -> Optional[Self]:
67
+ """Create an instance of ScanResultsV3 from a JSON string"""
68
+ return cls.from_dict(json.loads(json_str))
69
+
70
+ def to_dict(self) -> Dict[str, Any]:
71
+ """Return the dictionary representation of the model using alias.
72
+
73
+ This has the following differences from calling pydantic's
74
+ `self.model_dump(by_alias=True)`:
75
+
76
+ * `None` is only added to the output dict for nullable fields that
77
+ were set at model initialization. Other fields with value `None`
78
+ are ignored.
79
+ """
80
+ excluded_fields: Set[str] = set([
81
+ ])
82
+
83
+ _dict = self.model_dump(
84
+ by_alias=True,
85
+ exclude=excluded_fields,
86
+ exclude_none=True,
87
+ )
88
+ # override the default output from pydantic by calling `to_dict()` of inventory
89
+ if self.inventory:
90
+ _dict['inventory'] = self.inventory.to_dict()
91
+ # override the default output from pydantic by calling `to_dict()` of each item in file_results (list)
92
+ _items = []
93
+ if self.file_results:
94
+ for _item in self.file_results:
95
+ if _item:
96
+ _items.append(_item.to_dict())
97
+ _dict['file_results'] = _items
98
+ return _dict
99
+
100
+ @classmethod
101
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
102
+ """Create an instance of ScanResultsV3 from a dict"""
103
+ if obj is None:
104
+ return None
105
+
106
+ if not isinstance(obj, dict):
107
+ return cls.model_validate(obj)
108
+
109
+ _obj = cls.model_validate({
110
+ "scan_id": obj.get("scan_id"),
111
+ "start_time": obj.get("start_time"),
112
+ "end_time": obj.get("end_time"),
113
+ "status": obj.get("status"),
114
+ "version": obj.get("version"),
115
+ "inventory": InventoryV3.from_dict(obj["inventory"]) if obj.get("inventory") is not None else None,
116
+ "file_results": [FileResultV3.from_dict(_item) for _item in obj["file_results"]] if obj.get("file_results") is not None else None
117
+ })
118
+ return _obj
119
+
120
+
@@ -23,9 +23,9 @@ from typing import Any, ClassVar, Dict, List, Optional
23
23
  from typing import Optional, Set
24
24
  from typing_extensions import Self
25
25
 
26
- class Model(BaseModel):
26
+ class Sensor(BaseModel):
27
27
  """
28
- Model
28
+ Sensor
29
29
  """ # noqa: E501
30
30
  sensor_id: StrictStr
31
31
  created_at: datetime
@@ -54,7 +54,7 @@ class Model(BaseModel):
54
54
 
55
55
  @classmethod
56
56
  def from_json(cls, json_str: str) -> Optional[Self]:
57
- """Create an instance of Model from a JSON string"""
57
+ """Create an instance of Sensor from a JSON string"""
58
58
  return cls.from_dict(json.loads(json_str))
59
59
 
60
60
  def to_dict(self) -> Dict[str, Any]:
@@ -79,7 +79,7 @@ class Model(BaseModel):
79
79
 
80
80
  @classmethod
81
81
  def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
82
- """Create an instance of Model from a dict"""
82
+ """Create an instance of Sensor from a dict"""
83
83
  if obj is None:
84
84
  return None
85
85
 
@@ -19,18 +19,18 @@ import json
19
19
 
20
20
  from pydantic import BaseModel, ConfigDict, StrictInt
21
21
  from typing import Any, ClassVar, Dict, List
22
- from hiddenlayer.sdk.rest.models.model import Model
22
+ from hiddenlayer.sdk.rest.models.sensor import Sensor
23
23
  from typing import Optional, Set
24
24
  from typing_extensions import Self
25
25
 
26
- class ModelQueryResponse(BaseModel):
26
+ class SensorQueryResponse(BaseModel):
27
27
  """
28
- ModelQueryResponse
28
+ SensorQueryResponse
29
29
  """ # noqa: E501
30
30
  total_count: StrictInt
31
31
  page_size: StrictInt
32
32
  page_number: StrictInt
33
- results: List[Model]
33
+ results: List[Sensor]
34
34
  __properties: ClassVar[List[str]] = ["total_count", "page_size", "page_number", "results"]
35
35
 
36
36
  model_config = ConfigDict(
@@ -51,7 +51,7 @@ class ModelQueryResponse(BaseModel):
51
51
 
52
52
  @classmethod
53
53
  def from_json(cls, json_str: str) -> Optional[Self]:
54
- """Create an instance of ModelQueryResponse from a JSON string"""
54
+ """Create an instance of SensorQueryResponse from a JSON string"""
55
55
  return cls.from_dict(json.loads(json_str))
56
56
 
57
57
  def to_dict(self) -> Dict[str, Any]:
@@ -83,7 +83,7 @@ class ModelQueryResponse(BaseModel):
83
83
 
84
84
  @classmethod
85
85
  def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
86
- """Create an instance of ModelQueryResponse from a dict"""
86
+ """Create an instance of SensorQueryResponse from a dict"""
87
87
  if obj is None:
88
88
  return None
89
89
 
@@ -94,7 +94,7 @@ class ModelQueryResponse(BaseModel):
94
94
  "total_count": obj.get("total_count"),
95
95
  "page_size": obj.get("page_size"),
96
96
  "page_number": obj.get("page_number"),
97
- "results": [Model.from_dict(_item) for _item in obj["results"]] if obj.get("results") is not None else None
97
+ "results": [Sensor.from_dict(_item) for _item in obj["results"]] if obj.get("results") is not None else None
98
98
  })
99
99
  return _obj
100
100
 
@@ -4,22 +4,29 @@ from typing import Any, Dict, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
 
7
+ from hiddenlayer.sdk.exceptions import SensorDoesNotExistError
7
8
  from hiddenlayer.sdk.rest.api import AidrPredictiveApi
9
+ from hiddenlayer.sdk.rest.api.sensor_api import SensorApi
8
10
  from hiddenlayer.sdk.rest.api_client import ApiClient
9
11
  from hiddenlayer.sdk.rest.models import (
10
12
  SubmissionResponse,
11
13
  SubmissionV2,
12
14
  )
15
+ from hiddenlayer.sdk.rest.models.create_sensor_request import CreateSensorRequest
16
+ from hiddenlayer.sdk.rest.models.sensor import Sensor
17
+ from hiddenlayer.sdk.rest.models.sensor_sor_query_filter import SensorSORQueryFilter
18
+ from hiddenlayer.sdk.rest.models.sensor_sor_query_request import SensorSORQueryRequest
13
19
 
14
20
 
15
21
  class AIDRPredictive:
16
22
  def __init__(self, api_client: ApiClient) -> None:
23
+ self._sensor_api = SensorApi(api_client=api_client)
17
24
  self._aidr_predictive = AidrPredictiveApi(api_client=api_client)
18
25
 
19
26
  def submit_vectors(
20
27
  self,
21
28
  *,
22
- model_id: str,
29
+ sensor_id: str,
23
30
  requester_id: str,
24
31
  input_vectors: Union[List[float], np.ndarray],
25
32
  output: Union[List[float], np.ndarray],
@@ -31,7 +38,7 @@ class AIDRPredictive:
31
38
  """
32
39
  Submit feature vectors and model outputs via the HiddenLayer API.
33
40
 
34
- :param model_id: Model id.
41
+ :param sensor_id: Sensor id.
35
42
  :param requester_id: Custom identifier for the inbound request. This should be a value that can be used to identify individual users interacting with the model.
36
43
  :param input_vectors: Feature vectors for your model.
37
44
  :param output: Output vectors directly from your model.
@@ -60,7 +67,7 @@ class AIDRPredictive:
60
67
  SubmissionV2(
61
68
  metadata=metadata if metadata else {},
62
69
  tags=tags if tags else [],
63
- sensor_id=model_id,
70
+ sensor_id=sensor_id,
64
71
  requester_id=requester_id,
65
72
  input_layer=input_layer,
66
73
  input_layer_dtype=str(input_vectors.dtype),
@@ -74,3 +81,50 @@ class AIDRPredictive:
74
81
  else str(datetime.now().isoformat()),
75
82
  )
76
83
  )
84
+
85
+ def create_sensor(self, *, sensor_name: str) -> Sensor:
86
+ """
87
+ Creates a sensor in the HiddenLayer Platform.
88
+
89
+ :params sensor_name: Name of the sensor
90
+
91
+ :returns: HiddenLayer Sensor
92
+ """
93
+ return self._sensor_api.create_sensor(
94
+ CreateSensorRequest(plaintext_name=sensor_name)
95
+ )
96
+
97
+ def get_sensor(self, *, sensor_name: str) -> Sensor:
98
+ """
99
+ Gets a HiddenLayer sensor object.
100
+
101
+ :params sensor_name: Name of the sensor
102
+
103
+ :returns: HiddenLayer Sensor
104
+ """
105
+
106
+ return self._get_sensor_by_name(sensor_name=sensor_name)
107
+
108
+ def _get_sensor_by_name(self, *, sensor_name: str) -> Sensor:
109
+ """
110
+ Gets a model sensor by name.
111
+
112
+ :param sensor_name: Name of the model.
113
+
114
+ :returns: HiddenLayer Model object
115
+ """
116
+
117
+ sensors = self._sensor_api.query_sensor(
118
+ sensor_sor_query_request=SensorSORQueryRequest(
119
+ filter=SensorSORQueryFilter(plaintext_name=sensor_name)
120
+ )
121
+ )
122
+
123
+ if not sensors.results or len(sensors.results) == 0:
124
+ msg = f"ModSensorel {sensor_name} does not exist"
125
+
126
+ raise SensorDoesNotExistError(msg)
127
+
128
+ sensors.results.sort(key=lambda x: x.version, reverse=True)
129
+
130
+ return sensors.results[0]
@@ -1,26 +1,18 @@
1
- import json
2
1
  import os
3
2
  import random
4
3
  import tempfile
5
4
  import time
6
- import warnings
7
5
  import zipfile
8
- from datetime import datetime
9
6
  from pathlib import Path
10
7
  from typing import List, Optional, Union
11
- from uuid import uuid4
12
-
13
- from pydantic_core import ValidationError
14
8
 
15
9
  from hiddenlayer.sdk.constants import ScanStatus
16
- from hiddenlayer.sdk.models import EmptyScanResults, Sarif, ScanResults
17
- from hiddenlayer.sdk.rest.api import ModelScanApi, ModelSupplyChainApi, SensorApi
10
+ from hiddenlayer.sdk.models import EmptyScanResults, ScanResults
11
+ from hiddenlayer.sdk.rest.api import ModelSupplyChainApi
18
12
  from hiddenlayer.sdk.rest.api_client import ApiClient
19
- from hiddenlayer.sdk.rest.models import MultipartUploadPart
20
- from hiddenlayer.sdk.rest.models.model import Model
21
- from hiddenlayer.sdk.rest.models.sarif210 import Sarif210
22
- from hiddenlayer.sdk.services.model import ModelAPI
23
- from hiddenlayer.sdk.utils import filter_path_objects, is_saas
13
+ from hiddenlayer.sdk.rest.exceptions import NotFoundException
14
+ from hiddenlayer.sdk.rest.models import MultiFileUploadRequestV3
15
+ from hiddenlayer.sdk.utils import filter_path_objects
24
16
 
25
17
  EXCLUDE_FILE_TYPES = [
26
18
  "*.txt",
@@ -37,22 +29,14 @@ EXCLUDE_FILE_TYPES = [
37
29
  class ModelScanAPI:
38
30
  def __init__(self, api_client: ApiClient) -> None:
39
31
  self._api_client = api_client
40
-
41
32
  self._model_supply_chain_api = ModelSupplyChainApi(api_client=api_client)
42
- self._model_api = ModelAPI(api_client=api_client)
43
- self._sensor_api = SensorApi(
44
- api_client=api_client
45
- ) # lower level api of ModelAPI
46
-
47
- self._model_scan_api = ModelScanApi(api_client=api_client)
48
33
 
49
34
  def scan_file(
50
35
  self,
51
36
  *,
52
37
  model_name: str,
53
38
  model_path: Union[str, os.PathLike],
54
- model_version: Optional[int] = None,
55
- chunk_size: int = 16,
39
+ model_version: str = "1",
56
40
  wait_for_results: bool = True,
57
41
  ) -> ScanResults:
58
42
  """
@@ -69,52 +53,22 @@ class ModelScanAPI:
69
53
 
70
54
  file_path = Path(model_path)
71
55
 
72
- filesize = file_path.stat().st_size
73
- sensor = self._model_api.create_or_get(
74
- model_name=model_name, model_version=model_version
56
+ request = MultiFileUploadRequestV3(
57
+ model_name=model_name,
58
+ model_version=model_version,
59
+ requesting_entity="hiddenlayer-python-sdk",
75
60
  )
76
- upload = self._sensor_api.begin_multipart_upload(sensor.sensor_id, filesize)
77
-
78
- with open(file_path, "rb") as f:
79
- for i in range(0, len(upload.parts), chunk_size):
80
- group: List[MultipartUploadPart] = upload.parts[i : i + chunk_size]
81
- for part in group:
82
- read_amount = part.end_offset - part.start_offset
83
- f.seek(int(part.start_offset))
84
- part_data = f.read(int(read_amount))
85
-
86
- # The SaaS multipart upload returns a upload url for each part
87
- # So there is no specified route
88
- self._api_client.call_api(
89
- "PUT",
90
- part.upload_url,
91
- body=part_data,
92
- header_params={"Content-Type": "application/octet-binary"},
93
- )
94
-
95
- self._sensor_api.complete_multipart_upload(sensor.sensor_id, upload.upload_id)
96
-
97
- self._model_scan_api.scan_model(sensor.sensor_id)
98
-
99
- scan_results = self.get_scan_results(
100
- model_name=model_name, model_version=model_version
61
+ response = self._model_supply_chain_api.begin_multi_file_upload(
62
+ multi_file_upload_request_v3=request
101
63
  )
64
+ scan_id = response.scan_id
65
+ if scan_id is None:
66
+ raise Exception("scan_id must have a value")
102
67
 
103
- base_delay = 0.1 # seconds
104
- retries = 0
105
- if wait_for_results:
106
- print(f"{file_path.name} scan status: {scan_results.status}")
107
- while scan_results.status not in [ScanStatus.DONE, ScanStatus.FAILED]:
108
- retries += 1
109
- delay = base_delay * 2**retries + random.uniform(
110
- 0, 1
111
- ) # exponential back off retry
112
- time.sleep(delay)
113
- scan_results = self.get_scan_results(
114
- model_name=model_name, model_version=model_version
115
- )
116
- print(f"{file_path.name} scan status: {scan_results.status}")
68
+ self._scan_file(scan_id=scan_id, file_path=file_path)
117
69
 
70
+ self._model_supply_chain_api.complete_multi_file_upload(scan_id=scan_id)
71
+ scan_results = self._wait_for_scan_results(scan_id=scan_id)
118
72
  scan_results.file_name = file_path.name
119
73
  scan_results.file_path = str(file_path)
120
74
 
@@ -126,9 +80,8 @@ class ModelScanAPI:
126
80
  model_name: str,
127
81
  bucket: str,
128
82
  key: str,
129
- model_version: Optional[int] = None,
83
+ model_version: str = "1",
130
84
  s3_client: Optional[object] = None,
131
- chunk_size: int = 4,
132
85
  wait_for_results: bool = True,
133
86
  ) -> ScanResults:
134
87
  """
@@ -173,7 +126,6 @@ class ModelScanAPI:
173
126
  model_path=f"/tmp/{file_name}",
174
127
  model_name=model_name,
175
128
  model_version=model_version,
176
- chunk_size=chunk_size,
177
129
  wait_for_results=wait_for_results,
178
130
  )
179
131
 
@@ -184,10 +136,9 @@ class ModelScanAPI:
184
136
  account_url: str,
185
137
  container: str,
186
138
  blob: str,
187
- model_version: Optional[int] = None,
139
+ model_version: str = "1",
188
140
  blob_service_client: Optional[object] = None,
189
141
  credential: Optional[object] = None,
190
- chunk_size: int = 4,
191
142
  wait_for_results: bool = True,
192
143
  ) -> ScanResults:
193
144
  """
@@ -253,7 +204,6 @@ class ModelScanAPI:
253
204
  model_path=f"/tmp/{file_name}",
254
205
  model_name=model_name,
255
206
  model_version=model_version,
256
- chunk_size=chunk_size,
257
207
  wait_for_results=wait_for_results,
258
208
  )
259
209
 
@@ -270,8 +220,6 @@ class ModelScanAPI:
270
220
  ignore_file_patterns: Optional[List[str]] = None,
271
221
  force_download: bool = False,
272
222
  hf_token: Optional[Union[str, bool]] = None,
273
- # HL parameters
274
- chunk_size: int = 4,
275
223
  wait_for_results: bool = True,
276
224
  ) -> ScanResults:
277
225
  """
@@ -317,21 +265,19 @@ class ModelScanAPI:
317
265
  token=hf_token,
318
266
  )
319
267
 
268
+ if revision is None:
269
+ revision = "1"
270
+
320
271
  return self.scan_folder(
321
272
  model_name=model_name or repo_id,
273
+ model_version=revision,
322
274
  path=local_dir,
323
275
  allow_file_patterns=allow_file_patterns,
324
276
  ignore_file_patterns=ignore_file_patterns,
325
- chunk_size=chunk_size,
326
277
  wait_for_results=wait_for_results,
327
278
  )
328
279
 
329
- def get_scan_results(
330
- self,
331
- *,
332
- model_name: str,
333
- model_version: Optional[int] = None,
334
- ) -> ScanResults:
280
+ def get_scan_results(self, *, scan_id: str) -> ScanResults:
335
281
  """
336
282
  Get results from a model scan.
337
283
 
@@ -341,48 +287,17 @@ class ModelScanAPI:
341
287
  :returns: Scan results.
342
288
  """
343
289
 
344
- response = self._sensor_api.sensor_sor_api_v3_model_cards_query_get(
345
- model_name_eq=model_name, limit=1
346
- )
347
- model_id = response.results[0].model_id
348
-
349
- scans = self._model_supply_chain_api.model_scan_api_v3_scan_query(
350
- model_ids=[model_id], latest_per_model_version_only=True
351
- )
352
- if scans.total == 0:
353
- return EmptyScanResults()
354
-
355
- if scans.items is None:
356
- return EmptyScanResults()
357
-
358
- scan = scans.items[0]
359
- if model_version:
360
- scan = next(
361
- (
362
- s
363
- for s in scans.items
364
- if s.inventory.model_version == str(model_version)
365
- ),
366
- None,
367
- )
368
- if not scan:
290
+ try:
291
+ scan_report = self._model_supply_chain_api.get_scan_results(scan_id)
292
+ except NotFoundException:
369
293
  return EmptyScanResults()
370
294
 
371
- scan_report = (
372
- self._model_supply_chain_api.model_scan_api_v3_scan_model_version_id_get(
373
- scan.scan_id
374
- )
375
- )
376
-
377
- return ScanResults.from_scanreportv3(
378
- scan_report_v3=scan_report, model_id=model_id
379
- )
295
+ return ScanResults.from_scanreportv3(scan_report_v3=scan_report)
380
296
 
381
297
  def get_sarif_results(
382
298
  self,
383
299
  *,
384
- model_name: str,
385
- model_version: Optional[int] = None,
300
+ scan_id: str,
386
301
  ) -> Optional[str]:
387
302
  """
388
303
  Get sarif results from a model scan.
@@ -392,34 +307,31 @@ class ModelScanAPI:
392
307
 
393
308
  :returns: Scan results.
394
309
  """
395
- scan = self.get_scan_results(model_name=model_name, model_version=model_version)
396
- if scan.scan_id == "":
397
- return None
398
310
 
399
311
  # Unfortunately, the generated code for the API doesn't directly support modifying the Accept header
400
312
  # in order to enable us to get the Sarif results
401
313
  # Here we will reach in to the request serialization process. The 2nd element in the tuple is the headers
402
314
  # where we will modify the Accept header to application/sarif+json
403
- request = self._model_supply_chain_api._model_scan_api_v3_scan_model_version_id_get_serialize(
404
- scan.scan_id, None, None, None, None, 0
315
+ request = self._model_supply_chain_api._get_scan_results_serialize(
316
+ scan_id, None, None, None, None, 0
405
317
  )
406
318
  request[2]["Accept"] = "application/sarif+json"
407
319
  response = self._api_client.call_api(*request)
408
320
  response.read()
409
321
 
410
- return self._api_client.response_deserialize(
411
- response_data=response, response_types_map={"200": str}
412
- ).data # type: ignore
322
+ if response.data is None:
323
+ return None
324
+
325
+ return response.data.decode()
413
326
 
414
327
  def scan_folder(
415
328
  self,
416
329
  *,
417
330
  model_name: str,
418
331
  path: Union[str, os.PathLike],
419
- model_version: Optional[int] = None,
332
+ model_version: str = "1",
420
333
  allow_file_patterns: Optional[List[str]] = None,
421
334
  ignore_file_patterns: Optional[List[str]] = None,
422
- chunk_size: int = 4,
423
335
  wait_for_results: bool = True,
424
336
  ) -> ScanResults:
425
337
  """
@@ -437,7 +349,18 @@ class ModelScanAPI:
437
349
  """
438
350
 
439
351
  model_path = Path(path)
440
- filename = tempfile.NamedTemporaryFile().name + ".zip"
352
+
353
+ request = MultiFileUploadRequestV3(
354
+ model_name=model_name,
355
+ model_version=model_version,
356
+ requesting_entity="hiddenlayer-python-sdk",
357
+ )
358
+ response = self._model_supply_chain_api.begin_multi_file_upload(
359
+ multi_file_upload_request_v3=request
360
+ )
361
+ scan_id = response.scan_id
362
+ if scan_id is None:
363
+ raise Exception("scan_id must have a value")
441
364
 
442
365
  ignore_file_patterns = (
443
366
  EXCLUDE_FILE_TYPES + ignore_file_patterns
@@ -451,14 +374,54 @@ class ModelScanAPI:
451
374
  ignore_patterns=ignore_file_patterns,
452
375
  )
453
376
 
454
- with zipfile.ZipFile(filename, "a") as zipf:
455
- for file in files:
456
- zipf.write(file, os.path.relpath(file, model_path))
377
+ for file in files:
378
+ self._scan_file(scan_id=scan_id, file_path=Path(file))
457
379
 
458
- return self.scan_file(
459
- model_name=model_name,
460
- model_version=model_version,
461
- model_path=filename,
462
- chunk_size=chunk_size,
463
- wait_for_results=wait_for_results,
380
+ self._model_supply_chain_api.complete_multi_file_upload(scan_id=scan_id)
381
+ scan_results = self._wait_for_scan_results(scan_id=scan_id)
382
+
383
+ return scan_results
384
+
385
+ def _scan_file(self, *, scan_id: str, file_path: Path):
386
+ filesize = file_path.stat().st_size
387
+ upload = self._model_supply_chain_api.begin_multipart_file_upload(
388
+ scan_id=str(scan_id), file_name=str(file_path), file_content_length=filesize
464
389
  )
390
+
391
+ with open(file_path, "rb") as f:
392
+ for part in upload.parts:
393
+ if part.start_offset is None:
394
+ raise Exception("part must have a start_offset")
395
+ if part.stop_offset is not None:
396
+ read_amount = part.stop_offset - part.start_offset
397
+ else:
398
+ read_amount = None
399
+ f.seek(part.start_offset)
400
+ part_data = f.read(read_amount)
401
+ self._api_client.call_api(
402
+ "PUT",
403
+ part.upload_url,
404
+ body=part_data,
405
+ header_params={"Content-Type": "application/octet-binary"},
406
+ )
407
+
408
+ self._model_supply_chain_api.complete_multipart_file_upload(
409
+ scan_id=scan_id, file_id=upload.upload_id
410
+ )
411
+
412
+ def _wait_for_scan_results(self, *, scan_id: str):
413
+ scan_results = self.get_scan_results(scan_id=scan_id)
414
+
415
+ base_delay = 0.1 # seconds
416
+ retries = 0
417
+ print(f"scan status: {scan_results.status}")
418
+ while scan_results.status not in [ScanStatus.DONE, ScanStatus.FAILED]:
419
+ retries += 1
420
+ delay = base_delay * 2**retries + random.uniform(
421
+ 0, 1
422
+ ) # exponential back off retry
423
+ time.sleep(delay)
424
+ scan_results = self.get_scan_results(scan_id=scan_id)
425
+ print(f"scan status: {scan_results.status}")
426
+
427
+ return scan_results
@@ -1 +1 @@
1
- VERSION = "1.2.1"
1
+ VERSION = "2.0.1"