pyattackforge 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyattackforge/client.py CHANGED
@@ -1,383 +1,1107 @@
1
- """
2
- PyAttackForge is free software: you can redistribute it and/or modify
3
- it under the terms of the GNU Affero General Public License as published by
4
- the Free Software Foundation, either version 3 of the License, or
5
- (at your option) any later version.
6
-
7
- PyAttackForge is distributed in the hope that it will be useful,
8
- but WITHOUT ANY WARRANTY; without even the implied warranty of
9
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
- GNU Affero General Public License for more details.
11
-
12
- You should have received a copy of the GNU Affero General Public License
13
- along with this program. If not, see <https://www.gnu.org/licenses/>.
14
- """
15
-
16
- import requests
17
- import logging
18
- from datetime import datetime, timezone, timedelta
19
- from typing import Any, Dict, Optional, Set, Tuple, List
20
-
21
-
22
- logger = logging.getLogger("pyattackforge")
23
-
24
- class PyAttackForgeClient:
25
- """
26
- Python client for interacting with the AttackForge API.
27
-
28
- Provides methods to manage assets, projects, and vulnerabilities.
29
- Supports dry-run mode for testing without making real API calls.
30
- """
31
-
32
- def __init__(self, api_key: str, base_url: str = "https://demo.attackforge.com", dry_run: bool = False):
33
- """
34
- Initialize the PyAttackForgeClient.
35
-
36
- Args:
37
- api_key (str): Your AttackForge API key.
38
- base_url (str, optional): The base URL for the AttackForge instance. Defaults to "https://demo.attackforge.com".
39
- dry_run (bool, optional): If True, no real API calls are made. Defaults to False.
40
- """
41
- self.base_url = base_url.rstrip("/")
42
- self.headers = {
43
- "X-SSAPI-KEY": api_key,
44
- "Content-Type": "application/json",
45
- "Connection": "close"
46
- }
47
- self.dry_run = dry_run
48
- self._asset_cache = None
49
- self._project_scope_cache = {} # {project_id: set(asset_names)}
50
-
51
- def _request(
52
- self,
53
- method: str,
54
- endpoint: str,
55
- json_data: Optional[Dict[str, Any]] = None,
56
- params: Optional[Dict[str, Any]] = None
57
- ) -> Any:
58
- """
59
- Internal method to send an HTTP request to the AttackForge API.
60
-
61
- Args:
62
- method (str): HTTP method (get, post, put, etc.).
63
- endpoint (str): API endpoint path.
64
- json_data (dict, optional): JSON payload for the request.
65
- params (dict, optional): Query parameters.
66
-
67
- Returns:
68
- Response: The HTTP response object.
69
- """
70
- url = f"{self.base_url}{endpoint}"
71
- if self.dry_run:
72
- logger.info("[DRY RUN] %s %s", method.upper(), url)
73
- if json_data:
74
- logger.info("Payload: %s", json_data)
75
- if params:
76
- logger.info("Params: %s", params)
77
- return DummyResponse()
78
- return requests.request(method, url, headers=self.headers, json=json_data, params=params)
79
-
80
- def get_assets(self) -> Dict[str, Dict[str, Any]]:
81
- """
82
- Retrieve all assets from AttackForge.
83
-
84
- Returns:
85
- dict: Mapping of asset names to asset details.
86
- """
87
- if self._asset_cache is None:
88
- self._asset_cache = {}
89
- skip, limit = 0, 500
90
- while True:
91
- resp = self._request("get", "/api/ss/assets", params={"skip": skip, "limit": limit})
92
- data = resp.json()
93
- for asset in data.get("assets", []):
94
- name = asset.get("asset")
95
- if name:
96
- self._asset_cache[name] = asset
97
- if skip + limit >= data.get("count", 0):
98
- break
99
- skip += limit
100
- return self._asset_cache
101
-
102
- def get_asset_by_name(self, name: str) -> Optional[Dict[str, Any]]:
103
- """
104
- Retrieve an asset by its name.
105
-
106
- Args:
107
- name (str): The asset name.
108
-
109
- Returns:
110
- dict or None: Asset details if found, else None.
111
- """
112
- return self.get_assets().get(name)
113
-
114
- def create_asset(self, asset_data: Dict[str, Any]) -> Dict[str, Any]:
115
- """
116
- Create a new asset in AttackForge.
117
-
118
- Args:
119
- asset_data (dict): Asset details.
120
-
121
- Returns:
122
- dict: Created asset details.
123
-
124
- Raises:
125
- RuntimeError: If asset creation fails.
126
- """
127
- resp = self._request("post", "/api/ss/library/asset", json_data=asset_data)
128
- if resp.status_code == 201:
129
- asset = resp.json()
130
- self._asset_cache = None # Invalidate cache
131
- return asset
132
- if "Asset Already Exists" in resp.text:
133
- return self.get_asset_by_name(asset_data["name"])
134
- raise RuntimeError(f"Asset creation failed: {resp.text}")
135
-
136
- def get_project_by_name(self, name: str) -> Optional[Dict[str, Any]]:
137
- """
138
- Retrieve a project by its name.
139
-
140
- Args:
141
- name (str): The project name.
142
-
143
- Returns:
144
- dict or None: Project details if found, else None.
145
- """
146
- params = {
147
- "startDate": "2000-01-01T00:00:00.000Z",
148
- "endDate": "2100-01-01T00:00:00.000Z",
149
- "status": "All"
150
- }
151
- resp = self._request("get", "/api/ss/projects", params=params)
152
- for proj in resp.json().get("projects", []):
153
- if proj.get("project_name") == name:
154
- return proj
155
- return None
156
-
157
- def get_project_scope(self, project_id: str) -> Set[str]:
158
- """
159
- Retrieve the scope (assets) of a project.
160
-
161
- Args:
162
- project_id (str): The project ID.
163
-
164
- Returns:
165
- set: Set of asset names in the project scope.
166
-
167
- Raises:
168
- RuntimeError: If project retrieval fails.
169
- """
170
- if project_id in self._project_scope_cache:
171
- return self._project_scope_cache[project_id]
172
-
173
- resp = self._request("get", f"/api/ss/project/{project_id}")
174
- if resp.status_code != 200:
175
- raise RuntimeError(f"Failed to retrieve project: {resp.text}")
176
-
177
- scope = set(resp.json().get("scope", []))
178
- self._project_scope_cache[project_id] = scope
179
- return scope
180
-
181
- def update_project_scope(self, project_id: str, new_assets: List[str]) -> Dict[str, Any]:
182
- """
183
- Update the scope (assets) of a project.
184
-
185
- Args:
186
- project_id (str): The project ID.
187
- new_assets (iterable): Asset names to add to the scope.
188
-
189
- Returns:
190
- dict: Updated project details.
191
-
192
- Raises:
193
- RuntimeError: If update fails.
194
- """
195
- current_scope = self.get_project_scope(project_id)
196
- updated_scope = list(current_scope.union(new_assets))
197
- resp = self._request("put", f"/api/ss/project/{project_id}", json_data={"scope": updated_scope})
198
- if resp.status_code not in (200, 201):
199
- raise RuntimeError(f"Failed to update project scope: {resp.text}")
200
- self._project_scope_cache[project_id] = set(updated_scope)
201
- return resp.json()
202
-
203
- def create_project(self, name: str, **kwargs) -> Dict[str, Any]:
204
- """
205
- Create a new project in AttackForge.
206
-
207
- Args:
208
- name (str): Project name.
209
- **kwargs: Additional project fields.
210
-
211
- Returns:
212
- dict: Created project details.
213
-
214
- Raises:
215
- RuntimeError: If project creation fails.
216
- """
217
- start, end = get_default_dates()
218
- payload = {
219
- "name": name,
220
- "code": kwargs.get("code", "DEFAULT"),
221
- "groups": kwargs.get("groups", []),
222
- "startDate": kwargs.get("startDate", start),
223
- "endDate": kwargs.get("endDate", end),
224
- "scope": kwargs.get("scope", []),
225
- "testsuites": kwargs.get("testsuites", []),
226
- "organization_code": kwargs.get("organization_code", "ORG_DEFAULT"),
227
- "vulnerability_code": kwargs.get("vulnerability_code", "VULN_"),
228
- "scoringSystem": kwargs.get("scoringSystem", "CVSSv3.1"),
229
- "team_notifications": kwargs.get("team_notifications", []),
230
- "admin_notifications": kwargs.get("admin_notifications", []),
231
- "custom_fields": kwargs.get("custom_fields", []),
232
- "asset_library_ids": kwargs.get("asset_library_ids", []),
233
- "sla_activation": kwargs.get("sla_activation", "automatic")
234
- }
235
- resp = self._request("post", "/api/ss/project", json_data=payload)
236
- if resp.status_code in (200, 201):
237
- return resp.json()
238
- raise RuntimeError(f"Project creation failed: {resp.text}")
239
-
240
- def update_project(self, project_id: str, update_fields: Dict[str, Any]) -> Dict[str, Any]:
241
- """
242
- Update an existing project.
243
-
244
- Args:
245
- project_id (str): The project ID.
246
- update_fields (dict): Fields to update.
247
-
248
- Returns:
249
- dict: Updated project details.
250
-
251
- Raises:
252
- RuntimeError: If update fails.
253
- """
254
- resp = self._request("put", f"/api/ss/project/{project_id}", json_data=update_fields)
255
- if resp.status_code in (200, 201):
256
- return resp.json()
257
- raise RuntimeError(f"Project update failed: {resp.text}")
258
-
259
- def create_vulnerability(
260
- self,
261
- project_id: str,
262
- title: str,
263
- affected_asset_name: str,
264
- priority: str,
265
- likelihood_of_exploitation: int,
266
- description: str,
267
- attack_scenario: str,
268
- remediation_recommendation: str,
269
- steps_to_reproduce: str,
270
- tags: Optional[list] = None,
271
- notes: Optional[list] = None,
272
- is_zeroday: bool = False,
273
- is_visible: bool = True,
274
- import_to_library: Optional[str] = None,
275
- import_source: Optional[str] = None,
276
- import_source_id: Optional[str] = None,
277
- custom_fields: Optional[list] = None,
278
- linked_testcases: Optional[list] = None,
279
- custom_tags: Optional[list] = None,
280
- ) -> Dict[str, Any]:
281
- """
282
- Create a new security finding (vulnerability) in AttackForge.
283
-
284
- Args:
285
- project_id (str): The project ID.
286
- title (str): The title of the finding.
287
- affected_asset_name (str): The name of the affected asset.
288
- priority (str): The priority (e.g., "Critical").
289
- likelihood_of_exploitation (int): Likelihood of exploitation (e.g., 10).
290
- description (str): Description of the finding.
291
- attack_scenario (str): Attack scenario details.
292
- remediation_recommendation (str): Remediation recommendation.
293
- steps_to_reproduce (str): Steps to reproduce the finding.
294
- tags (list, optional): List of tags.
295
- notes (list, optional): List of notes.
296
- is_zeroday (bool, optional): Whether this is a zero-day finding.
297
- is_visible (bool, optional): Whether the finding is visible.
298
- import_to_library (str, optional): Library to import to.
299
- import_source (str, optional): Source of import.
300
- import_source_id (str, optional): Source ID for import.
301
- custom_fields (list, optional): List of custom fields.
302
- linked_testcases (list, optional): List of linked testcases.
303
- custom_tags (list, optional): List of custom tags.
304
-
305
- Returns:
306
- dict: Created vulnerability details.
307
-
308
- Raises:
309
- ValueError: If any required field is missing.
310
- RuntimeError: If vulnerability creation fails.
311
- """
312
- # Validate required fields
313
- required_fields = [
314
- ("project_id", project_id),
315
- ("title", title),
316
- ("affected_asset_name", affected_asset_name),
317
- ("priority", priority),
318
- ("likelihood_of_exploitation", likelihood_of_exploitation),
319
- ("description", description),
320
- ("attack_scenario", attack_scenario),
321
- ("remediation_recommendation", remediation_recommendation),
322
- ("steps_to_reproduce", steps_to_reproduce),
323
- ]
324
- for field_name, value in required_fields:
325
- if value is None:
326
- raise ValueError(f"Missing required field: {field_name}")
327
-
328
- payload = {
329
- "projectId": project_id,
330
- "title": title,
331
- "affected_asset_name": affected_asset_name,
332
- "priority": priority,
333
- "likelihood_of_exploitation": likelihood_of_exploitation,
334
- "description": description,
335
- "attack_scenario": attack_scenario,
336
- "remediation_recommendation": remediation_recommendation,
337
- "steps_to_reproduce": steps_to_reproduce,
338
- "tags": tags or [],
339
- "is_zeroday": is_zeroday,
340
- "is_visible": is_visible,
341
- "import_to_library": import_to_library,
342
- "import_source": import_source,
343
- "import_source_id": import_source_id,
344
- "custom_fields": custom_fields or [],
345
- "linked_testcases": linked_testcases or [],
346
- "custom_tags": custom_tags or [],
347
- }
348
- # Only include notes if it is a non-empty list
349
- if notes:
350
- payload["notes"] = notes
351
-
352
- # Remove None values (for optional fields)
353
- payload = {k: v for k, v in payload.items() if v is not None}
354
-
355
- resp = self._request("post", "/api/ss/vulnerability", json_data=payload)
356
- if resp.status_code in (200, 201):
357
- return resp.json()
358
- raise RuntimeError(f"Vulnerability creation failed: {resp.text}")
359
-
360
-
361
- class DummyResponse:
362
- """
363
- Dummy response object for dry-run mode.
364
- """
365
- def __init__(self) -> None:
366
- self.status_code = 200
367
- self.text = "[DRY RUN] No real API call performed."
368
-
369
- def json(self) -> Dict[str, Any]:
370
- return {}
371
-
372
-
373
- def get_default_dates() -> Tuple[str, str]:
374
- """
375
- Get default start and end dates for a project (now and 30 days from now, in ISO format).
376
-
377
- Returns:
378
- tuple: (start_date, end_date) as ISO 8601 strings.
379
- """
380
- now = datetime.now(timezone.utc)
381
- start = now.isoformat(timespec="milliseconds").replace("+00:00", "Z")
382
- end = (now + timedelta(days=30)).isoformat(timespec="milliseconds").replace("+00:00", "Z")
383
- return start, end
1
+ """
2
+ PyAttackForge is free software: you can redistribute it and/or modify
3
+ it under the terms of the GNU Affero General Public License as published by
4
+ the Free Software Foundation, either version 3 of the License, or
5
+ (at your option) any later version.
6
+
7
+ PyAttackForge is distributed in the hope that it will be useful,
8
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
+ GNU Affero General Public License for more details.
11
+
12
+ You should have received a copy of the GNU Affero General Public License
13
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
14
+ """
15
+
16
+ import os
17
+ import requests
18
+ import logging
19
+ from datetime import datetime, timezone, timedelta
20
+ from typing import Any, Dict, Optional, Set, Tuple, List
21
+
22
+
23
+ logger = logging.getLogger("pyattackforge")
24
+
25
+
26
+ class PyAttackForgeClient:
27
+ """
28
+ Python client for interacting with the AttackForge API.
29
+
30
+ Provides methods to manage assets, projects, and vulnerabilities.
31
+ Supports dry-run mode for testing without making real API calls.
32
+ """
33
+
34
+ def upsert_finding_for_project( # noqa: C901
35
+ self,
36
+ project_id: str,
37
+ title: str,
38
+ affected_assets: list,
39
+ priority: str,
40
+ likelihood_of_exploitation: int,
41
+ description: str,
42
+ attack_scenario: str,
43
+ remediation_recommendation: str,
44
+ steps_to_reproduce: str,
45
+ tags: Optional[list] = None,
46
+ notes: Optional[list] = None,
47
+ is_zeroday: bool = False,
48
+ is_visible: bool = True,
49
+ import_to_library: Optional[str] = None,
50
+ import_source: Optional[str] = None,
51
+ import_source_id: Optional[str] = None,
52
+ custom_fields: Optional[list] = None,
53
+ linked_testcases: Optional[list] = None,
54
+ custom_tags: Optional[list] = None,
55
+ writeup_custom_fields: Optional[list] = None,
56
+ ) -> Dict[str, Any]:
57
+ """
58
+ Create or update a finding for a project. If a finding with the same title and writeup exists,
59
+ append the assets and notes; otherwise, create a new finding.
60
+
61
+ Args:
62
+ project_id (str): The project ID.
63
+ title (str): The title of the finding.
64
+ affected_assets (list): List of affected asset objects or names.
65
+ priority (str): The priority (e.g., "Critical").
66
+ likelihood_of_exploitation (int): Likelihood of exploitation (e.g., 10).
67
+ description (str): Description of the finding.
68
+ attack_scenario (str): Attack scenario details.
69
+ remediation_recommendation (str): Remediation recommendation.
70
+ steps_to_reproduce (str): Steps to reproduce the finding.
71
+ tags (list, optional): List of tags.
72
+ notes (list, optional): List of notes.
73
+ is_zeroday (bool, optional): Whether this is a zero-day finding.
74
+ is_visible (bool, optional): Whether the finding is visible.
75
+ import_to_library (str, optional): Library to import to.
76
+ import_source (str, optional): Source of import.
77
+ import_source_id (str, optional): Source ID for import.
78
+ custom_fields (list, optional): List of custom fields.
79
+ linked_testcases (list, optional): List of linked testcases.
80
+ custom_tags (list, optional): List of custom tags.
81
+ writeup_custom_fields (list, optional): List of custom fields for the writeup.
82
+
83
+ Returns:
84
+ dict: The created or updated finding.
85
+ """
86
+ # Ensure all assets exist before proceeding
87
+ asset_names = []
88
+ for asset in affected_assets:
89
+ name = asset["name"] if isinstance(asset, dict) and "name" in asset else asset
90
+ self.get_asset_by_name(name)
91
+ # if not asset_obj:
92
+ # try:
93
+ # asset_obj = self.create_asset({"name": name})
94
+ # except Exception as e:
95
+ # raise RuntimeError(f"Asset '{name}' does not exist and could not be created: {e}")
96
+ asset_names.append(name)
97
+
98
+ # Fetch all findings for the project
99
+ findings = self.get_findings_for_project(project_id)
100
+ logger.debug(
101
+ "Found %s findings for project %s",
102
+ len(findings),
103
+ project_id
104
+ )
105
+ for f in findings:
106
+ logger.debug(
107
+ "Finding id=%s title=%s steps=%s",
108
+ f.get("vulnerability_id"),
109
+ f.get("vulnerability_title"),
110
+ f.get("vulnerability_steps_to_reproduce"),
111
+ )
112
+ logger.debug("Finding payload: %s", f)
113
+ match = None
114
+ for f in findings:
115
+ if f.get("vulnerability_title") == title:
116
+ match = f
117
+ break
118
+
119
+ if match:
120
+ # Update the existing finding: append assets and notes if not present
121
+ updated_assets = set()
122
+ if "vulnerability_affected_assets" in match:
123
+ for asset in match["vulnerability_affected_assets"]:
124
+ # Handle nested asset structure from API
125
+ if isinstance(asset, dict):
126
+ if "asset" in asset and isinstance(asset["asset"], dict) and "name" in asset["asset"]:
127
+ updated_assets.add(asset["asset"]["name"])
128
+ elif "name" in asset:
129
+ updated_assets.add(asset["name"])
130
+ elif isinstance(asset, str):
131
+ updated_assets.add(asset)
132
+ elif "vulnerability_affected_asset_name" in match:
133
+ updated_assets.add(match["vulnerability_affected_asset_name"])
134
+ updated_assets.update(asset_names)
135
+ # Append notes
136
+ existing_notes = match.get("vulnerability_notes", [])
137
+ new_notes = notes or []
138
+ # Avoid duplicate notes
139
+ note_texts = {n["note"] for n in existing_notes if "note" in n}
140
+ for n in new_notes:
141
+ if isinstance(n, dict) and "note" in n:
142
+ if n["note"] not in note_texts:
143
+ existing_notes.append(n)
144
+ note_texts.add(n["note"])
145
+ elif isinstance(n, str):
146
+ if n not in note_texts:
147
+ existing_notes.append({"note": n, "type": "PLAINTEXT"})
148
+ note_texts.add(n)
149
+ # Prepare update payload
150
+ update_payload = {
151
+ "affected_assets": [{"assetName": n} for n in updated_assets],
152
+ "notes": existing_notes,
153
+ "project_id": project_id,
154
+ }
155
+ # Actually update the finding in the backend
156
+ resp = self._request("put", f"/api/ss/vulnerability/{match['vulnerability_id']}", json_data=update_payload)
157
+ if resp.status_code not in (200, 201):
158
+ raise RuntimeError(f"Failed to update finding: {resp.text}")
159
+ return {
160
+ "action": "update",
161
+ "existing_finding_id": match["vulnerability_id"],
162
+ "update_payload": update_payload,
163
+ "api_response": resp.json(),
164
+ }
165
+ else:
166
+ # No match, create a new finding
167
+ # Ensure all asset payloads use 'assetName'
168
+ assets_payload = []
169
+ for asset in affected_assets:
170
+ if isinstance(asset, dict) and "name" in asset:
171
+ assets_payload.append({"assetName": asset["name"]})
172
+ else:
173
+ assets_payload.append({"assetName": asset})
174
+ result = self.create_vulnerability(
175
+ project_id=project_id,
176
+ title=title,
177
+ affected_assets=assets_payload,
178
+ priority=priority,
179
+ likelihood_of_exploitation=likelihood_of_exploitation,
180
+ description=description,
181
+ attack_scenario=attack_scenario,
182
+ remediation_recommendation=remediation_recommendation,
183
+ steps_to_reproduce=steps_to_reproduce,
184
+ tags=tags,
185
+ notes=notes,
186
+ is_zeroday=is_zeroday,
187
+ is_visible=is_visible,
188
+ import_to_library=import_to_library,
189
+ import_source=import_source,
190
+ import_source_id=import_source_id,
191
+ custom_fields=custom_fields,
192
+ linked_testcases=linked_testcases,
193
+ custom_tags=custom_tags,
194
+ writeup_custom_fields=writeup_custom_fields,
195
+ )
196
+ return {
197
+ "action": "create",
198
+ "result": result,
199
+ }
200
+
201
+ def get_findings_for_project(self, project_id: str, priority: Optional[str] = None) -> list:
202
+ """
203
+ Fetch all findings/vulnerabilities for a given project.
204
+
205
+ Args:
206
+ project_id (str): The project ID.
207
+ priority (str, optional): Filter by priority (e.g., "Critical"). Defaults to None.
208
+
209
+ Returns:
210
+ list: List of finding/vulnerability dicts.
211
+ """
212
+ params = {}
213
+ if priority:
214
+ params["priority"] = priority
215
+ resp = self._request("get", f"/api/ss/project/{project_id}/vulnerabilities", params=params)
216
+ if resp.status_code != 200:
217
+ raise RuntimeError(f"Failed to fetch findings: {resp.text}")
218
+ # The response may have a "vulnerabilities" key or be a list directly
219
+ data = resp.json()
220
+ if isinstance(data, dict) and "vulnerabilities" in data:
221
+ return data["vulnerabilities"]
222
+ elif isinstance(data, list):
223
+ return data
224
+ else:
225
+ return []
226
+
227
+ def get_vulnerability(self, vulnerability_id: str) -> Dict[str, Any]:
228
+ """
229
+ Retrieve a single vulnerability by ID.
230
+
231
+ Args:
232
+ vulnerability_id (str): The vulnerability ID.
233
+
234
+ Returns:
235
+ dict: Vulnerability details.
236
+ """
237
+ if not vulnerability_id:
238
+ raise ValueError("Missing required field: vulnerability_id")
239
+ resp = self._request("get", f"/api/ss/vulnerability/{vulnerability_id}")
240
+ if resp.status_code != 200:
241
+ raise RuntimeError(f"Failed to fetch vulnerability: {resp.text}")
242
+ data = resp.json()
243
+ if isinstance(data, dict) and "vulnerability" in data:
244
+ return data["vulnerability"]
245
+ return data
246
+
247
+ def add_note_to_finding(
248
+ self,
249
+ vulnerability_id: str,
250
+ note: Any,
251
+ note_type: str = "PLAINTEXT"
252
+ ) -> Dict[str, Any]:
253
+ """
254
+ Append a note to an existing finding.
255
+
256
+ Args:
257
+ vulnerability_id (str): The vulnerability ID.
258
+ note (str or dict): Note text or note object with a 'note' key.
259
+ note_type (str): Note type when passing a plain string (default: "PLAINTEXT").
260
+
261
+ Returns:
262
+ dict: API response.
263
+ """
264
+ if not vulnerability_id:
265
+ raise ValueError("Missing required field: vulnerability_id")
266
+ if note is None or note == "":
267
+ raise ValueError("Missing required field: note")
268
+ if isinstance(note, dict):
269
+ note_text = note.get("note")
270
+ note_entry = note
271
+ else:
272
+ note_text = str(note)
273
+ note_entry = {"note": note_text, "type": note_type}
274
+ if not note_text:
275
+ raise ValueError("Note text cannot be empty")
276
+ try:
277
+ vuln = self.get_vulnerability(vulnerability_id)
278
+ existing_notes = (
279
+ vuln.get("vulnerability_notes")
280
+ or vuln.get("notes")
281
+ or []
282
+ ) if isinstance(vuln, dict) else []
283
+ except Exception as exc: # pragma: no cover - best-effort fetch
284
+ logger.warning(
285
+ "Unable to fetch existing vulnerability notes; proceeding with provided note only: %s",
286
+ exc
287
+ )
288
+ existing_notes = []
289
+ collected_notes = []
290
+ note_texts = set()
291
+ for n in existing_notes:
292
+ if isinstance(n, dict) and "note" in n:
293
+ if n["note"] in note_texts:
294
+ continue
295
+ collected_notes.append(n)
296
+ note_texts.add(n["note"])
297
+ if note_entry.get("note") not in note_texts:
298
+ collected_notes.append(note_entry)
299
+ payload = {"notes": collected_notes}
300
+ resp = self._request("put", f"/api/ss/vulnerability/{vulnerability_id}", json_data=payload)
301
+ if resp.status_code not in (200, 201):
302
+ raise RuntimeError(f"Failed to add note: {resp.text}")
303
+ return resp.json()
304
+
305
+ def link_vulnerability_to_testcases(
306
+ self,
307
+ vulnerability_id: str,
308
+ testcase_ids: List[str],
309
+ project_id: Optional[str] = None,
310
+ ) -> Dict[str, Any]:
311
+ """
312
+ Link a vulnerability to one or more testcases.
313
+
314
+ Args:
315
+ vulnerability_id (str): The vulnerability ID.
316
+ testcase_ids (list): List of testcase IDs to link.
317
+ project_id (str, optional): Project ID if required by the API.
318
+
319
+ Returns:
320
+ dict: API response.
321
+ """
322
+ if not vulnerability_id:
323
+ raise ValueError("Missing required field: vulnerability_id")
324
+ if not testcase_ids:
325
+ raise ValueError("testcase_ids must contain at least one ID")
326
+ payload: Dict[str, Any] = {
327
+ "linked_testcases": testcase_ids,
328
+ }
329
+ if project_id:
330
+ payload["project_id"] = project_id
331
+ resp = self._request(
332
+ "put",
333
+ f"/api/ss/vulnerability/{vulnerability_id}",
334
+ json_data=payload,
335
+ )
336
+ if resp.status_code not in (200, 201):
337
+ raise RuntimeError(f"Failed to link vulnerability to testcases: {resp.text}")
338
+ return resp.json()
339
+
340
+ def get_testcases(self, project_id: str) -> List[Dict[str, Any]]:
341
+ """
342
+ Retrieve testcases for a project.
343
+
344
+ Args:
345
+ project_id (str): Project ID.
346
+
347
+ Returns:
348
+ list: List of testcase dicts.
349
+ """
350
+ if not project_id:
351
+ raise ValueError("Missing required field: project_id")
352
+ resp = self._request("get", f"/api/ss/project/{project_id}/testcases")
353
+ if resp.status_code not in (200, 201):
354
+ raise RuntimeError(f"Failed to fetch testcases: {resp.text}")
355
+ data = resp.json()
356
+ if isinstance(data, dict) and "testcases" in data:
357
+ return data.get("testcases", [])
358
+ if isinstance(data, list):
359
+ return data
360
+ return []
361
+
362
+ def get_testcase(self, project_id: str, testcase_id: str) -> Optional[Dict[str, Any]]:
363
+ """
364
+ Retrieve a single testcase by ID.
365
+
366
+ Args:
367
+ project_id (str): Project ID.
368
+ testcase_id (str): Testcase ID.
369
+
370
+ Returns:
371
+ dict or None: Testcase details if found, else None.
372
+ """
373
+ if not project_id:
374
+ raise ValueError("Missing required field: project_id")
375
+ if not testcase_id:
376
+ raise ValueError("Missing required field: testcase_id")
377
+ resp = self._request("get", f"/api/ss/project/{project_id}/testcase/{testcase_id}")
378
+ if resp.status_code == 404:
379
+ return None
380
+ if resp.status_code not in (200, 201):
381
+ raise RuntimeError(f"Failed to fetch testcase: {resp.text}")
382
+ data = resp.json()
383
+ if isinstance(data, dict) and "testcase" in data:
384
+ return data["testcase"]
385
+ return data if isinstance(data, dict) else None
386
+
387
+ def upload_finding_evidence(self, vulnerability_id: str, file_path: str) -> Dict[str, Any]:
388
+ """
389
+ Upload evidence to a finding/vulnerability.
390
+
391
+ Args:
392
+ vulnerability_id (str): The vulnerability ID.
393
+ file_path (str): Path to the evidence file.
394
+
395
+ Returns:
396
+ dict: API response.
397
+ """
398
+ if not vulnerability_id:
399
+ raise ValueError("Missing required field: vulnerability_id")
400
+ if not file_path:
401
+ raise ValueError("Missing required field: file_path")
402
+ if not os.path.isfile(file_path):
403
+ raise FileNotFoundError(f"Evidence file not found: {file_path}")
404
+ endpoint = f"/api/ss/vulnerability/{vulnerability_id}/evidence"
405
+ if self.dry_run:
406
+ resp = self._request("post", endpoint)
407
+ return resp.json()
408
+ with open(file_path, "rb") as evidence:
409
+ resp = self._request(
410
+ "post",
411
+ endpoint,
412
+ files={"file": (os.path.basename(file_path), evidence)}
413
+ )
414
+ if resp.status_code not in (200, 201):
415
+ raise RuntimeError(f"Evidence upload failed: {resp.text}")
416
+ return resp.json()
417
+
418
+ def upload_testcase_evidence(
419
+ self,
420
+ project_id: str,
421
+ testcase_id: str,
422
+ file_path: str
423
+ ) -> Dict[str, Any]:
424
+ """
425
+ Upload evidence to a testcase.
426
+
427
+ Args:
428
+ project_id (str): The project ID.
429
+ testcase_id (str): The testcase ID.
430
+ file_path (str): Path to the evidence file.
431
+
432
+ Returns:
433
+ dict: API response.
434
+ """
435
+ if not project_id:
436
+ raise ValueError("Missing required field: project_id")
437
+ if not testcase_id:
438
+ raise ValueError("Missing required field: testcase_id")
439
+ if not file_path:
440
+ raise ValueError("Missing required field: file_path")
441
+ if not os.path.isfile(file_path):
442
+ raise FileNotFoundError(f"Evidence file not found: {file_path}")
443
+ endpoint = f"/api/ss/project/{project_id}/testcase/{testcase_id}/file"
444
+ if self.dry_run:
445
+ resp = self._request("post", endpoint)
446
+ return resp.json()
447
+ with open(file_path, "rb") as evidence:
448
+ resp = self._request(
449
+ "post",
450
+ endpoint,
451
+ files={"file": (os.path.basename(file_path), evidence)}
452
+ )
453
+ if resp.status_code not in (200, 201):
454
+ raise RuntimeError(f"Testcase evidence upload failed: {resp.text}")
455
+ return resp.json()
456
+
457
+ def add_note_to_testcase(
458
+ self,
459
+ project_id: str,
460
+ testcase_id: str,
461
+ note: str,
462
+ status: Optional[str] = None
463
+ ) -> Dict[str, Any]:
464
+ """
465
+ Create a testcase note via the dedicated note endpoint, optionally updating status via update_testcase.
466
+
467
+ Args:
468
+ project_id (str): Project ID.
469
+ testcase_id (str): Testcase ID.
470
+ note (str): Note text to set in the details field.
471
+ status (str, optional): Status to set (e.g., "Tested").
472
+
473
+ Returns:
474
+ dict: API response.
475
+ """
476
+ if not project_id:
477
+ raise ValueError("Missing required field: project_id")
478
+ if not testcase_id:
479
+ raise ValueError("Missing required field: testcase_id")
480
+ if not note:
481
+ raise ValueError("Missing required field: note")
482
+ endpoint = f"/api/ss/project/{project_id}/testcase/{testcase_id}/note"
483
+ payload: Dict[str, Any] = {"note": note, "note_type": "PLAINTEXT"}
484
+ resp = self._request("post", endpoint, json_data=payload)
485
+ if resp.status_code not in (200, 201):
486
+ raise RuntimeError(f"Failed to add testcase note: {resp.text}")
487
+ result = resp.json()
488
+
489
+ # Optionally update status using update_testcase
490
+ if status:
491
+ try:
492
+ self.update_testcase(project_id, testcase_id, {"status": status})
493
+ except Exception:
494
+ # If status update fails, still return note creation response
495
+ pass
496
+ return result
497
+
498
+ def assign_findings_to_testcase(
499
+ self,
500
+ project_id: str,
501
+ testcase_id: str,
502
+ vulnerability_ids: List[str],
503
+ existing_linked_vulnerabilities: Optional[List[str]] = None,
504
+ additional_fields: Optional[Dict[str, Any]] = None
505
+ ) -> Dict[str, Any]:
506
+ """
507
+ Assign one or more findings to a testcase.
508
+
509
+ Args:
510
+ project_id (str): The project ID.
511
+ testcase_id (str): The testcase ID.
512
+ vulnerability_ids (list): List of vulnerability IDs to assign.
513
+ existing_linked_vulnerabilities (list, optional): Existing linked vulnerability IDs to merge with.
514
+ additional_fields (dict, optional): Additional testcase fields to include (e.g., status, tags).
515
+
516
+ Returns:
517
+ dict: API response.
518
+ """
519
+ if not project_id:
520
+ raise ValueError("Missing required field: project_id")
521
+ if not testcase_id:
522
+ raise ValueError("Missing required field: testcase_id")
523
+ if not vulnerability_ids:
524
+ raise ValueError("vulnerability_ids must contain at least one ID")
525
+ payload = additional_fields.copy() if additional_fields else {}
526
+ merged_ids = []
527
+ seen = set()
528
+ for vid in (existing_linked_vulnerabilities or []) + vulnerability_ids:
529
+ if vid and vid not in seen:
530
+ merged_ids.append(vid)
531
+ seen.add(vid)
532
+ payload["linked_vulnerabilities"] = merged_ids
533
+ return self.update_testcase(project_id, testcase_id, payload)
534
+
535
+ def update_testcase(
536
+ self,
537
+ project_id: str,
538
+ testcase_id: str,
539
+ update_fields: Dict[str, Any]
540
+ ) -> Dict[str, Any]:
541
+ """
542
+ Update a testcase with the provided fields.
543
+
544
+ Args:
545
+ project_id (str): The project ID.
546
+ testcase_id (str): The testcase ID.
547
+ update_fields (dict): Fields to update (e.g., linked_vulnerabilities, details).
548
+
549
+ Returns:
550
+ dict: API response.
551
+ """
552
+ if not project_id:
553
+ raise ValueError("Missing required field: project_id")
554
+ if not testcase_id:
555
+ raise ValueError("Missing required field: testcase_id")
556
+ if not update_fields:
557
+ raise ValueError("update_fields cannot be empty")
558
+ endpoint = f"/api/ss/project/{project_id}/testcase/{testcase_id}"
559
+ resp = self._request("put", endpoint, json_data=update_fields)
560
+ if resp.status_code not in (200, 201):
561
+ raise RuntimeError(f"Failed to update testcase: {resp.text}")
562
+ return resp.json()
563
+
564
+ def add_findings_to_testcase(
565
+ self,
566
+ project_id: str,
567
+ testcase_id: str,
568
+ vulnerability_ids: List[str],
569
+ additional_fields: Optional[Dict[str, Any]] = None
570
+ ) -> Dict[str, Any]:
571
+ """
572
+ Fetch a testcase, merge existing linked vulnerabilities with the provided list, and update it.
573
+
574
+ Args:
575
+ project_id (str): The project ID.
576
+ testcase_id (str): The testcase ID.
577
+ vulnerability_ids (list): List of vulnerability IDs to add.
578
+ additional_fields (dict, optional): Extra fields to include (e.g., status).
579
+
580
+ Returns:
581
+ dict: API response from the update.
582
+ """
583
+ if not project_id:
584
+ raise ValueError("Missing required field: project_id")
585
+ if not testcase_id:
586
+ raise ValueError("Missing required field: testcase_id")
587
+ if not vulnerability_ids:
588
+ raise ValueError("vulnerability_ids must contain at least one ID")
589
+
590
+ testcases = self.get_testcases(project_id)
591
+ testcase = next((t for t in testcases if t.get("id") == testcase_id), None)
592
+ if not testcase:
593
+ raise RuntimeError(f"Testcase '{testcase_id}' not found in project '{project_id}'")
594
+
595
+ existing_raw = testcase.get("linked_vulnerabilities", []) or []
596
+ existing_ids: List[str] = []
597
+ for item in existing_raw:
598
+ if isinstance(item, dict) and item.get("id"):
599
+ existing_ids.append(item["id"])
600
+ elif isinstance(item, str):
601
+ existing_ids.append(item)
602
+
603
+ return self.assign_findings_to_testcase(
604
+ project_id=project_id,
605
+ testcase_id=testcase_id,
606
+ vulnerability_ids=vulnerability_ids,
607
+ existing_linked_vulnerabilities=existing_ids,
608
+ additional_fields=additional_fields,
609
+ )
610
+
611
+ def __init__(self, api_key: str, base_url: str = "https://demo.attackforge.com", dry_run: bool = False):
612
+ """
613
+ Initialize the PyAttackForgeClient.
614
+
615
+ Args:
616
+ api_key (str): Your AttackForge API key.
617
+ base_url (str, optional): The base URL for the AttackForge instance. Defaults to "https://demo.attackforge.com".
618
+ dry_run (bool, optional): If True, no real API calls are made. Defaults to False.
619
+ """
620
+ self.base_url = base_url.rstrip("/")
621
+ self.headers = {
622
+ "X-SSAPI-KEY": api_key,
623
+ "Content-Type": "application/json",
624
+ "Connection": "close"
625
+ }
626
+ self.dry_run = dry_run
627
+ self._asset_cache = None
628
+ self._project_scope_cache = {} # {project_id: set(asset_names)}
629
+ self._writeup_cache = None # Cache for all writeups
630
+
631
+ def get_all_writeups(self, force_refresh: bool = False) -> list:
632
+ """
633
+ Fetches and caches all writeups from the /api/ss/library endpoint.
634
+
635
+ Args:
636
+ force_refresh (bool): If True, refresh the cache even if it exists.
637
+
638
+ Returns:
639
+ list: List of writeup dicts.
640
+ """
641
+ if self._writeup_cache is not None and not force_refresh:
642
+ return self._writeup_cache
643
+ resp = self._request("get", "/api/ss/library")
644
+ if resp.status_code != 200:
645
+ raise RuntimeError(f"Failed to fetch writeups: {resp.text}")
646
+ data = resp.json()
647
+ # The endpoint may return a list or a dict with a key like "vulnerabilities"
648
+ if isinstance(data, dict) and "vulnerabilities" in data:
649
+ self._writeup_cache = data["vulnerabilities"]
650
+ elif isinstance(data, list):
651
+ self._writeup_cache = data
652
+ else:
653
+ # fallback: try to treat as a list of writeups
654
+ self._writeup_cache = data if isinstance(data, list) else []
655
+ return self._writeup_cache
656
+
657
+ def find_writeup_in_cache(self, title: str, library: str = "Main Library") -> str:
658
+ """
659
+ Searches the cached writeups for a writeup with the given title and library.
660
+
661
+ Args:
662
+ title (str): The title of the writeup to find.
663
+ library (str): The library name (default: "Main Library").
664
+
665
+ Returns:
666
+ str: The writeup's reference_id if found, else None.
667
+ """
668
+ writeups = self.get_all_writeups()
669
+ for w in writeups:
670
+ if w.get("title") == title and w.get("belongs_to_library", w.get("library", "")) == library:
671
+ return w.get("reference_id")
672
+ return None
673
+
674
+ def _request(
675
+ self,
676
+ method: str,
677
+ endpoint: str,
678
+ json_data: Optional[Dict[str, Any]] = None,
679
+ params: Optional[Dict[str, Any]] = None,
680
+ files: Optional[Dict[str, Any]] = None,
681
+ data: Optional[Dict[str, Any]] = None,
682
+ headers_override: Optional[Dict[str, str]] = None
683
+ ) -> Any:
684
+ url = f"{self.base_url}{endpoint}"
685
+ if self.dry_run:
686
+ logger.info("[DRY RUN] %s %s", method.upper(), url)
687
+ if json_data:
688
+ logger.info("Payload: %s", json_data)
689
+ if params:
690
+ logger.info("Params: %s", params)
691
+ if files:
692
+ logger.info("Files: %s", list(files.keys()))
693
+ if data:
694
+ logger.info("Data: %s", data)
695
+ return DummyResponse()
696
+ headers = self.headers.copy()
697
+ if files:
698
+ headers.pop("Content-Type", None)
699
+ if headers_override:
700
+ headers.update(headers_override)
701
+ return requests.request(
702
+ method,
703
+ url,
704
+ headers=headers,
705
+ json=json_data,
706
+ params=params,
707
+ files=files,
708
+ data=data
709
+ )
710
+
711
+ def get_assets(self) -> Dict[str, Dict[str, Any]]:
712
+ if self._asset_cache is None:
713
+ self._asset_cache = {}
714
+ skip, limit = 0, 500
715
+ while True:
716
+ resp = self._request("get", "/api/ss/assets", params={"skip": skip, "limit": limit})
717
+ data = resp.json()
718
+ for asset in data.get("assets", []):
719
+ name = asset.get("asset")
720
+ if name:
721
+ self._asset_cache[name] = asset
722
+ if skip + limit >= data.get("count", 0):
723
+ break
724
+ skip += limit
725
+ return self._asset_cache
726
+
727
+ def get_asset_by_name(self, name: str) -> Optional[Dict[str, Any]]:
728
+ return self.get_assets().get(name)
729
+
730
+ def create_asset(self, asset_data: Dict[str, Any]) -> Dict[str, Any]:
731
+ pass
732
+ # resp = self._request("post", "/api/ss/library/asset", json_data=asset_data)
733
+ # if resp.status_code in (200, 201):
734
+ # asset = resp.json()
735
+ # self._asset_cache = None # Invalidate cache
736
+ # return asset
737
+ # if "Asset Already Exists" in resp.text:
738
+ # return self.get_asset_by_name(asset_data["name"])
739
+ # raise RuntimeError(f"Asset creation failed: {resp.text}")
740
+
741
+ def get_project_by_name(self, name: str) -> Optional[Dict[str, Any]]:
742
+ params = {
743
+ "startDate": "2000-01-01T00:00:00.000Z",
744
+ "endDate": "2100-01-01T00:00:00.000Z",
745
+ "status": "All"
746
+ }
747
+ resp = self._request("get", "/api/ss/projects", params=params)
748
+ for proj in resp.json().get("projects", []):
749
+ if proj.get("project_name") == name:
750
+ return proj
751
+ return None
752
+
753
+ def get_project_scope(self, project_id: str) -> Set[str]:
754
+ if project_id in self._project_scope_cache:
755
+ return self._project_scope_cache[project_id]
756
+
757
+ resp = self._request("get", f"/api/ss/project/{project_id}")
758
+ if resp.status_code != 200:
759
+ raise RuntimeError(f"Failed to retrieve project: {resp.text}")
760
+
761
+ scope = set(resp.json().get("scope", []))
762
+ self._project_scope_cache[project_id] = scope
763
+ return scope
764
+
765
+ def update_project_scope(self, project_id: str, new_assets: List[str]) -> Dict[str, Any]:
766
+ current_scope = self.get_project_scope(project_id)
767
+ updated_scope = list(current_scope.union(new_assets))
768
+ resp = self._request("put", f"/api/ss/project/{project_id}", json_data={"scope": updated_scope})
769
+ if resp.status_code not in (200, 201):
770
+ raise RuntimeError(f"Failed to update project scope: {resp.text}")
771
+ self._project_scope_cache[project_id] = set(updated_scope)
772
+ return resp.json()
773
+
774
+ def create_project(self, name: str, **kwargs) -> Dict[str, Any]:
775
+ start, end = get_default_dates()
776
+ payload = {
777
+ "name": name,
778
+ "code": kwargs.get("code", "DEFAULT"),
779
+ "groups": kwargs.get("groups", []),
780
+ "startDate": kwargs.get("startDate", start),
781
+ "endDate": kwargs.get("endDate", end),
782
+ "scope": kwargs.get("scope", []),
783
+ "testsuites": kwargs.get("testsuites", []),
784
+ "organization_code": kwargs.get("organization_code", "ORG_DEFAULT"),
785
+ "vulnerability_code": kwargs.get("vulnerability_code", "VULN_"),
786
+ "scoringSystem": kwargs.get("scoringSystem", "CVSSv3.1"),
787
+ "team_notifications": kwargs.get("team_notifications", []),
788
+ "admin_notifications": kwargs.get("admin_notifications", []),
789
+ "custom_fields": kwargs.get("custom_fields", []),
790
+ "asset_library_ids": kwargs.get("asset_library_ids", []),
791
+ "sla_activation": kwargs.get("sla_activation", "automatic")
792
+ }
793
+ resp = self._request("post", "/api/ss/project", json_data=payload)
794
+ if resp.status_code in (200, 201):
795
+ return resp.json()
796
+ raise RuntimeError(f"Project creation failed: {resp.text}")
797
+
798
+ def update_project(self, project_id: str, update_fields: Dict[str, Any]) -> Dict[str, Any]:
799
+ resp = self._request("put", f"/api/ss/project/{project_id}", json_data=update_fields)
800
+ if resp.status_code in (200, 201):
801
+ return resp.json()
802
+ raise RuntimeError(f"Project update failed: {resp.text}")
803
+
804
+ def create_writeup(
805
+ self,
806
+ title: str,
807
+ description: str,
808
+ remediation_recommendation: str,
809
+ custom_fields: Optional[list] = None,
810
+ **kwargs
811
+ ) -> Dict[str, Any]:
812
+ if not title or not description or not remediation_recommendation:
813
+ raise ValueError("Missing required field: title, description, or remediation_recommendation")
814
+
815
+ payload = {
816
+ "title": title,
817
+ "description": description,
818
+ "remediation_recommendation": remediation_recommendation,
819
+ "custom_fields": custom_fields or []
820
+ }
821
+ payload.update(kwargs)
822
+ resp = self._request("post", "/api/ss/library/vulnerability", json_data=payload)
823
+ if resp.status_code in (200, 201):
824
+ result = resp.json()
825
+ print("DEBUG: create_writeup API response:", result)
826
+ return result
827
+ raise RuntimeError(f"Writeup creation failed: {resp.text}")
828
+
829
+ def create_finding_from_writeup(
830
+ self,
831
+ project_id: str,
832
+ writeup_id: str,
833
+ priority: str,
834
+ affected_assets: Optional[list] = None,
835
+ linked_testcases: Optional[list] = None,
836
+ **kwargs
837
+ ) -> Dict[str, Any]:
838
+ """
839
+ Create a finding from a writeup, supporting multiple affected assets.
840
+
841
+ Args:
842
+ project_id (str): The project ID.
843
+ writeup_id (str): The writeup/library ID.
844
+ priority (str): The priority.
845
+ affected_assets (list, optional): List of affected asset objects or names.
846
+ linked_testcases (list, optional): List of testcase IDs to link.
847
+ **kwargs: Additional fields.
848
+
849
+ Returns:
850
+ dict: Created finding details.
851
+ """
852
+ if not project_id or not writeup_id or not priority:
853
+ raise ValueError("Missing required field: project_id, writeup_id, or priority")
854
+
855
+ payload = {
856
+ "projectId": project_id,
857
+ "vulnerabilityLibraryId": writeup_id,
858
+ "priority": priority
859
+ }
860
+ if affected_assets is not None:
861
+ asset_names = [
862
+ asset["assetName"] if isinstance(asset, dict) and "assetName" in asset
863
+ else asset["name"] if isinstance(asset, dict) and "name" in asset
864
+ else asset
865
+ for asset in affected_assets
866
+ ]
867
+ payload["affected_assets"] = [{"assetName": n} for n in asset_names]
868
+ if linked_testcases:
869
+ payload["linked_testcases"] = linked_testcases
870
+ payload.update(kwargs)
871
+ resp = self._request("post", "/api/ss/vulnerability-with-library", json_data=payload)
872
+ if resp.status_code in (200, 201):
873
+ return resp.json()
874
+ raise RuntimeError(f"Finding creation from writeup failed: {resp.text}")
875
+
876
+ def create_vulnerability(
877
+ self,
878
+ project_id: str,
879
+ title: str,
880
+ affected_assets: list,
881
+ priority: str,
882
+ likelihood_of_exploitation: int,
883
+ description: str,
884
+ attack_scenario: str,
885
+ remediation_recommendation: str,
886
+ steps_to_reproduce: str,
887
+ writeup_id: Optional[str] = None,
888
+ tags: Optional[list] = None,
889
+ notes: Optional[list] = None,
890
+ is_zeroday: bool = False,
891
+ is_visible: bool = True,
892
+ import_to_library: Optional[str] = None,
893
+ import_source: Optional[str] = None,
894
+ import_source_id: Optional[str] = None,
895
+ custom_fields: Optional[list] = None,
896
+ linked_testcases: Optional[list] = None,
897
+ custom_tags: Optional[list] = None,
898
+ writeup_custom_fields: Optional[list] = None,
899
+ ) -> Dict[str, Any]:
900
+ """
901
+ Create a new security finding (vulnerability) in AttackForge with support for multiple assets.
902
+
903
+ Args:
904
+ project_id (str): The project ID.
905
+ title (str): The title of the finding.
906
+ affected_assets (list): List of affected asset objects or names.
907
+ priority (str): The priority (e.g., "Critical").
908
+ likelihood_of_exploitation (int): Likelihood of exploitation (e.g., 10).
909
+ description (str): Description of the finding.
910
+ attack_scenario (str): Attack scenario details.
911
+ remediation_recommendation (str): Remediation recommendation.
912
+ steps_to_reproduce (str): Steps to reproduce the finding.
913
+ writeup_id (str, optional): Existing writeup/library reference ID to use directly.
914
+ tags (list, optional): List of tags.
915
+ notes (list, optional): List of notes.
916
+ is_zeroday (bool, optional): Whether this is a zero-day finding.
917
+ is_visible (bool, optional): Whether the finding is visible.
918
+ import_to_library (str, optional): Library to import to.
919
+ import_source (str, optional): Source of import.
920
+ import_source_id (str, optional): Source ID for import.
921
+ custom_fields (list, optional): List of custom fields.
922
+ linked_testcases (list, optional): List of linked testcases.
923
+ custom_tags (list, optional): List of custom tags.
924
+ writeup_custom_fields (list, optional): List of custom fields for the writeup.
925
+
926
+ Returns:
927
+ dict: Created vulnerability details.
928
+ """
929
+ # Ensure all assets exist and are in project scope
930
+ asset_names = []
931
+ for asset in affected_assets:
932
+ name = asset["assetName"] if isinstance(asset, dict) and "assetName" in asset \
933
+ else asset["name"] if isinstance(asset, dict) and "name" in asset \
934
+ else asset
935
+ self.get_asset_by_name(name)
936
+ # if not asset_obj:
937
+ # asset_obj = self.create_asset({"name": name})
938
+ asset_names.append(name)
939
+ # Ensure all assets are in project scope
940
+ scope = self.get_project_scope(project_id)
941
+ missing_in_scope = [n for n in asset_names if n not in scope]
942
+ if missing_in_scope:
943
+ self.update_project_scope(project_id, missing_in_scope)
944
+
945
+ finding_payload = {
946
+ "affected_assets": [{"assetName": n} for n in asset_names],
947
+ "likelihood_of_exploitation": likelihood_of_exploitation,
948
+ "steps_to_reproduce": steps_to_reproduce,
949
+ "tags": tags or [],
950
+ "is_zeroday": is_zeroday,
951
+ "is_visible": is_visible,
952
+ "import_to_library": import_to_library,
953
+ "import_source": import_source,
954
+ "import_source_id": import_source_id,
955
+ "custom_fields": custom_fields or [],
956
+ "linked_testcases": linked_testcases or [],
957
+ "custom_tags": custom_tags or [],
958
+ }
959
+ if notes:
960
+ finding_payload["notes"] = notes
961
+ finding_payload = {k: v for k, v in finding_payload.items() if v is not None}
962
+ resolved_writeup_id = writeup_id
963
+ if not resolved_writeup_id:
964
+ self.get_all_writeups()
965
+ resolved_writeup_id = self.find_writeup_in_cache(title, "Main Vulnerabilities")
966
+ if not resolved_writeup_id:
967
+ writeup_fields = writeup_custom_fields[:] if writeup_custom_fields else []
968
+ if import_source:
969
+ writeup_fields.append({"key": "import_source", "value": import_source})
970
+ self.create_writeup(
971
+ title=title,
972
+ description=description,
973
+ remediation_recommendation=remediation_recommendation,
974
+ attack_scenario=attack_scenario,
975
+ custom_fields=writeup_fields
976
+ )
977
+ # Refresh the cache and search again
978
+ self.get_all_writeups(force_refresh=True)
979
+ resolved_writeup_id = self.find_writeup_in_cache(
980
+ title, "Main Vulnerabilities"
981
+ )
982
+ if not resolved_writeup_id:
983
+ raise RuntimeError(
984
+ "Writeup creation failed: missing reference_id"
985
+ )
986
+ result = self.create_finding_from_writeup(
987
+ project_id=project_id,
988
+ writeup_id=resolved_writeup_id,
989
+ priority=priority,
990
+ **finding_payload
991
+ )
992
+ return result
993
+
994
+ def create_vulnerability_old(
995
+ self,
996
+ project_id: str,
997
+ title: str,
998
+ affected_asset_name: str,
999
+ priority: str,
1000
+ likelihood_of_exploitation: int,
1001
+ description: str,
1002
+ attack_scenario: str,
1003
+ remediation_recommendation: str,
1004
+ steps_to_reproduce: str,
1005
+ tags: Optional[list] = None,
1006
+ notes: Optional[list] = None,
1007
+ is_zeroday: bool = False,
1008
+ is_visible: bool = True,
1009
+ import_to_library: Optional[str] = None,
1010
+ import_source: Optional[str] = None,
1011
+ import_source_id: Optional[str] = None,
1012
+ custom_fields: Optional[list] = None,
1013
+ linked_testcases: Optional[list] = None,
1014
+ custom_tags: Optional[list] = None,
1015
+ ) -> Dict[str, Any]:
1016
+ """
1017
+ [DEPRECATED] Create a new security finding (vulnerability) in AttackForge.
1018
+
1019
+ Args:
1020
+ project_id (str): The project ID.
1021
+ title (str): The title of the finding.
1022
+ affected_asset_name (str): The name of the affected asset.
1023
+ priority (str): The priority (e.g., "Critical").
1024
+ likelihood_of_exploitation (int): Likelihood of exploitation (e.g., 10).
1025
+ description (str): Description of the finding.
1026
+ attack_scenario (str): Attack scenario details.
1027
+ remediation_recommendation (str): Remediation recommendation.
1028
+ steps_to_reproduce (str): Steps to reproduce the finding.
1029
+ tags (list, optional): List of tags.
1030
+ notes (list, optional): List of notes.
1031
+ is_zeroday (bool, optional): Whether this is a zero-day finding.
1032
+ is_visible (bool, optional): Whether the finding is visible.
1033
+ import_to_library (str, optional): Library to import to.
1034
+ import_source (str, optional): Source of import.
1035
+ import_source_id (str, optional): Source ID for import.
1036
+ custom_fields (list, optional): List of custom fields.
1037
+ linked_testcases (list, optional): List of linked testcases.
1038
+ custom_tags (list, optional): List of custom tags.
1039
+
1040
+ Returns:
1041
+ dict: Created vulnerability details.
1042
+
1043
+ Raises:
1044
+ ValueError: If any required field is missing.
1045
+ RuntimeError: If vulnerability creation fails.
1046
+ """
1047
+ # Validate required fields
1048
+ required_fields = [
1049
+ ("project_id", project_id),
1050
+ ("title", title),
1051
+ ("affected_asset_name", affected_asset_name),
1052
+ ("priority", priority),
1053
+ ("likelihood_of_exploitation", likelihood_of_exploitation),
1054
+ ("description", description),
1055
+ ("attack_scenario", attack_scenario),
1056
+ ("remediation_recommendation", remediation_recommendation),
1057
+ ("steps_to_reproduce", steps_to_reproduce),
1058
+ ]
1059
+ for field_name, value in required_fields:
1060
+ if value is None:
1061
+ raise ValueError(f"Missing required field: {field_name}")
1062
+
1063
+ payload = {
1064
+ "projectId": project_id,
1065
+ "title": title,
1066
+ "affected_asset_name": affected_asset_name,
1067
+ "priority": priority,
1068
+ "likelihood_of_exploitation": likelihood_of_exploitation,
1069
+ "description": description,
1070
+ "attack_scenario": attack_scenario,
1071
+ "remediation_recommendation": remediation_recommendation,
1072
+ "steps_to_reproduce": steps_to_reproduce,
1073
+ "tags": tags or [],
1074
+ "is_zeroday": is_zeroday,
1075
+ "is_visible": is_visible,
1076
+ "import_to_library": import_to_library,
1077
+ "import_source": import_source,
1078
+ "import_source_id": import_source_id,
1079
+ "custom_fields": custom_fields or [],
1080
+ "linked_testcases": linked_testcases or [],
1081
+ "custom_tags": custom_tags or [],
1082
+ }
1083
+ if notes:
1084
+ payload["notes"] = notes
1085
+ payload = {k: v for k, v in payload.items() if v is not None}
1086
+ resp = self._request("post", "/api/ss/vulnerability", json_data=payload)
1087
+ if resp.status_code in (200, 201):
1088
+ return resp.json()
1089
+ raise RuntimeError(f"Vulnerability creation failed: {resp.text}")
1090
+
1091
+
1092
+ class DummyResponse:
1093
+ def __init__(self) -> None:
1094
+ self.status_code = 200
1095
+ self.text = "[DRY RUN] No real API call performed."
1096
+
1097
+ def json(self) -> Dict[str, Any]:
1098
+ return {}
1099
+
1100
+
1101
+ def get_default_dates() -> Tuple[str, str]:
1102
+ now = datetime.now(timezone.utc)
1103
+ start = now.isoformat(timespec="milliseconds").replace("+00:00", "Z")
1104
+ end = (
1105
+ now + timedelta(days=30)
1106
+ ).isoformat(timespec="milliseconds").replace("+00:00", "Z")
1107
+ return start, end