pyattackforge 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyattackforge/client.py CHANGED
@@ -11,6 +11,7 @@ from .transport import AttackForgeTransport
11
11
  from .resources import (
12
12
  AssetsResource,
13
13
  ProjectsResource,
14
+ GroupsResource,
14
15
  FindingsResource,
15
16
  WriteupsResource,
16
17
  TestcasesResource,
@@ -51,6 +52,7 @@ class AttackForgeClient:
51
52
 
52
53
  self.assets = AssetsResource(self._transport)
53
54
  self.projects = ProjectsResource(self._transport)
55
+ self.groups = GroupsResource(self._transport)
54
56
  self.findings = FindingsResource(self._transport)
55
57
  self.writeups = WriteupsResource(self._transport)
56
58
  self.testcases = TestcasesResource(self._transport)
pyattackforge/config.py CHANGED
@@ -16,7 +16,7 @@ class ClientConfig:
16
16
  timeout: float = 30.0
17
17
  max_retries: int = 3
18
18
  backoff_factor: float = 0.5
19
- user_agent: str = "pyattackforge/0.0.1"
19
+ user_agent: str = "pyattackforge/0.2.2"
20
20
  http2: bool = True
21
21
  # Default visibility for newly created findings. False = pending/hidden.
22
22
  default_findings_visible: bool = False
@@ -2,6 +2,7 @@
2
2
 
3
3
  from .assets import AssetsResource
4
4
  from .projects import ProjectsResource
5
+ from .groups import GroupsResource
5
6
  from .findings import FindingsResource
6
7
  from .writeups import WriteupsResource
7
8
  from .testcases import TestcasesResource
@@ -13,6 +14,7 @@ from .reports import ReportsResource
13
14
  __all__ = [
14
15
  "AssetsResource",
15
16
  "ProjectsResource",
17
+ "GroupsResource",
16
18
  "FindingsResource",
17
19
  "WriteupsResource",
18
20
  "TestcasesResource",
@@ -5,6 +5,7 @@ from __future__ import annotations
5
5
  from typing import Any, Dict, Iterable, List, Optional, Sequence
6
6
  import os
7
7
  from datetime import datetime
8
+ import time
8
9
 
9
10
  from ..cache import TTLCache
10
11
  from .base import BaseResource
@@ -247,18 +248,29 @@ class FindingsResource(BaseResource):
247
248
  return {"action": "noop", "existing": match}
248
249
 
249
250
  merged = sorted(existing_assets.union(new_assets))
250
- asset_ids = self._resolve_project_asset_ids(project_id, merged)
251
- payload_assets = []
252
- for name in merged:
253
- asset_id = asset_ids.get(name)
254
- if asset_id:
255
- payload_assets.append({"assetId": asset_id})
256
- else:
257
- payload_assets.append({"assetName": name})
258
- payload = {"affected_assets": payload_assets}
251
+ asset_ids = {}
252
+ asset_ids.update(self._extract_asset_id_map_from_finding(match))
253
+ asset_ids.update(self._map_names_to_ids_from_payload(affected_assets, create_payload))
259
254
  if update_payload:
260
- payload.update(update_payload)
261
- result = self.update_vulnerability(match.get("vulnerability_id") or match.get("id"), payload)
255
+ asset_ids.update(self._map_names_to_ids_from_payload(affected_assets, update_payload))
256
+ resolved = self._resolve_project_asset_ids_with_retry(project_id, merged)
257
+ asset_ids.update(resolved)
258
+ payload_assets = self._build_affected_assets_payload(merged, asset_ids)
259
+ payload = dict(update_payload or {})
260
+ payload.pop("affected_assets", None)
261
+ payload.pop("vulnerability_affected_assets", None)
262
+ payload.pop("projectId", None)
263
+ payload.pop("project_id", None)
264
+ payload["affected_assets"] = payload_assets
265
+ vuln_id = match.get("vulnerability_id") or match.get("id")
266
+ result = self.update_vulnerability(vuln_id, payload)
267
+ if missing:
268
+ result = self._ensure_assets_on_vulnerability(
269
+ vulnerability_id=vuln_id,
270
+ project_id=project_id,
271
+ desired_assets=merged,
272
+ asset_ids=asset_ids,
273
+ )
262
274
  return {"action": "update", "result": result, "added_assets": missing}
263
275
 
264
276
  def _normalize_title(self, title: str) -> str:
@@ -376,11 +388,143 @@ class FindingsResource(BaseResource):
376
388
  if not isinstance(entry, dict):
377
389
  continue
378
390
  name = entry.get("name")
379
- asset_id = entry.get("asset_id") or entry.get("assetId")
391
+ asset_id = entry.get("id") or entry.get("asset_id") or entry.get("assetId")
380
392
  if isinstance(name, str) and isinstance(asset_id, str):
393
+ stripped = name.strip()
394
+ if stripped:
395
+ mapping[stripped] = asset_id
396
+ return mapping
397
+
398
+ def _resolve_project_asset_ids_with_retry(
399
+ self,
400
+ project_id: str,
401
+ asset_names: Sequence[str],
402
+ *,
403
+ attempts: int = 3,
404
+ delay: float = 0.5,
405
+ ) -> Dict[str, str]:
406
+ seen: set = set()
407
+ unique_names: List[str] = []
408
+ for value in asset_names:
409
+ if not isinstance(value, str):
410
+ continue
411
+ name = value.strip()
412
+ if not name or name in seen:
413
+ continue
414
+ seen.add(name)
415
+ unique_names.append(name)
416
+ if not unique_names:
417
+ return {}
418
+ mapping = self._resolve_project_asset_ids(project_id, unique_names)
419
+ missing = [name for name in unique_names if name not in mapping]
420
+ if not missing or attempts <= 1:
421
+ return mapping
422
+ for _ in range(attempts - 1):
423
+ time.sleep(delay)
424
+ refreshed = self._resolve_project_asset_ids(project_id, missing)
425
+ mapping.update(refreshed)
426
+ missing = [name for name in missing if name not in mapping]
427
+ if not missing:
428
+ break
429
+ return mapping
430
+
431
+ def _build_affected_assets_payload(self, asset_names: Sequence[str], asset_ids: Dict[str, str]) -> List[Dict[str, Any]]:
432
+ payload_assets: List[Dict[str, Any]] = []
433
+ for name in asset_names:
434
+ if not isinstance(name, str):
435
+ continue
436
+ stripped = name.strip()
437
+ if not stripped:
438
+ continue
439
+ asset_id = asset_ids.get(stripped)
440
+ if asset_id:
441
+ payload_assets.append({"assetId": asset_id})
442
+ else:
443
+ payload_assets.append({"assetName": stripped})
444
+ return payload_assets
445
+
446
+ def _ensure_assets_on_vulnerability(
447
+ self,
448
+ *,
449
+ vulnerability_id: str,
450
+ project_id: str,
451
+ desired_assets: Sequence[str],
452
+ asset_ids: Dict[str, str],
453
+ attempts: int = 3,
454
+ delay: float = 1.0,
455
+ ) -> Any:
456
+ last_result: Any = None
457
+ for _ in range(max(attempts, 1)):
458
+ try:
459
+ vuln = self.get_vulnerability(vulnerability_id)
460
+ except Exception:
461
+ break
462
+ current_assets = self._extract_asset_names_from_finding(vuln)
463
+ missing = [name for name in desired_assets if name not in current_assets]
464
+ if not missing:
465
+ return last_result if last_result is not None else vuln
466
+ refreshed = self._resolve_project_asset_ids_with_retry(
467
+ project_id, missing, attempts=2, delay=delay
468
+ )
469
+ asset_ids.update(refreshed)
470
+ payload_assets = self._build_affected_assets_payload(desired_assets, asset_ids)
471
+ payload = {"affected_assets": payload_assets}
472
+ last_result = self.update_vulnerability(vulnerability_id, payload)
473
+ time.sleep(delay)
474
+ return last_result
475
+
476
+ def _extract_asset_id_map_from_finding(self, finding: Dict[str, Any]) -> Dict[str, str]:
477
+ raw = finding.get("vulnerability_affected_assets") or finding.get("affected_assets") or []
478
+ return self._extract_asset_id_map(raw if isinstance(raw, list) else [])
479
+
480
+ def _extract_asset_id_map(self, assets: Sequence[Any]) -> Dict[str, str]:
481
+ mapping: Dict[str, str] = {}
482
+ for item in assets:
483
+ if not isinstance(item, dict):
484
+ continue
485
+ name = None
486
+ asset_id = None
487
+ for key in ("assetName", "name"):
488
+ value = item.get(key)
489
+ if isinstance(value, str) and value.strip():
490
+ name = value.strip()
491
+ break
492
+ asset_obj = item.get("asset")
493
+ if isinstance(asset_obj, dict):
494
+ obj_name = asset_obj.get("name")
495
+ if isinstance(obj_name, str) and obj_name.strip():
496
+ name = obj_name.strip()
497
+ obj_id = asset_obj.get("id") or asset_obj.get("asset_id") or asset_obj.get("assetId")
498
+ if isinstance(obj_id, str) and obj_id.strip():
499
+ asset_id = obj_id.strip()
500
+ for key in ("assetId", "asset_id", "id"):
501
+ value = item.get(key)
502
+ if isinstance(value, str) and value.strip():
503
+ asset_id = value.strip()
504
+ break
505
+ if name and asset_id:
381
506
  mapping[name] = asset_id
382
507
  return mapping
383
508
 
509
+ def _map_names_to_ids_from_payload(
510
+ self, asset_names: Sequence[Any], payload: Optional[Dict[str, Any]]
511
+ ) -> Dict[str, str]:
512
+ if not isinstance(payload, dict):
513
+ return {}
514
+ payload_assets = payload.get("affected_assets") or payload.get("vulnerability_affected_assets") or []
515
+ if not isinstance(payload_assets, list):
516
+ return {}
517
+ mapping = self._extract_asset_id_map(payload_assets)
518
+ names_list = [name for name in asset_names if isinstance(name, str) and name.strip()]
519
+ if names_list and len(names_list) == len(payload_assets):
520
+ for name, entry in zip(names_list, payload_assets):
521
+ if name in mapping or not isinstance(entry, dict):
522
+ continue
523
+ asset_id = entry.get("assetId") or entry.get("asset_id") or entry.get("id")
524
+ if isinstance(asset_id, str) and asset_id.strip():
525
+ mapping[name] = asset_id.strip()
526
+ return mapping
527
+
384
528
  def _enforce_vulnerability_evidence_fifo(
385
529
  self, vulnerability_id: str, keep_last: int, *, project_id: Optional[str] = None
386
530
  ) -> List[str]:
@@ -0,0 +1,20 @@
1
+ """Resource: groups."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional
6
+
7
+ from .base import BaseResource
8
+
9
+
10
+ class GroupsResource(BaseResource):
11
+ """Groups API resource wrapper."""
12
+
13
+ def create_group(self, payload: Dict[str, Any]) -> Any:
14
+ return self._post("/api/ss/group", json=payload)
15
+
16
+ def get_groups(self, params: Optional[Dict[str, Any]] = None) -> Any:
17
+ return self._get("/api/ss/groups", params=params)
18
+
19
+ def get_group(self, group_id: str, params: Optional[Dict[str, Any]] = None) -> Any:
20
+ return self._get(f"/api/ss/group/{group_id}", params=params)
@@ -130,6 +130,46 @@ class ProjectsResource(BaseResource):
130
130
  def update_user_access_on_project(self, project_id: str, user_id: str, payload: Dict[str, Any]) -> Any:
131
131
  return self._put(f"/api/ss/project/{project_id}/access/{user_id}", json=payload)
132
132
 
133
+ def add_project_to_group(
134
+ self,
135
+ project_id: str,
136
+ group_id: str,
137
+ *,
138
+ project_data: Optional[Dict[str, Any]] = None,
139
+ ) -> Any:
140
+ """
141
+ Convenience helper to add a project to a group by updating the project's groups list.
142
+ """
143
+ if not group_id:
144
+ raise ValueError("group_id is required")
145
+ data = project_data if project_data is not None else self.get_project(project_id)
146
+ project = data
147
+ if isinstance(project, dict) and isinstance(project.get("data"), dict):
148
+ project = project["data"]
149
+ if isinstance(project, dict) and isinstance(project.get("project"), dict):
150
+ project = project["project"]
151
+ groups = project.get("groups") or project.get("project_groups") or []
152
+ group_ids: List[str] = []
153
+ if isinstance(groups, list):
154
+ for entry in groups:
155
+ if isinstance(entry, str) and entry:
156
+ group_ids.append(entry)
157
+ continue
158
+ if isinstance(entry, dict):
159
+ candidate = entry.get("id") or entry.get("group_id") or entry.get("_id")
160
+ if isinstance(candidate, str) and candidate:
161
+ group_ids.append(candidate)
162
+ seen: set = set()
163
+ deduped: List[str] = []
164
+ for value in group_ids:
165
+ if value in seen:
166
+ continue
167
+ seen.add(value)
168
+ deduped.append(value)
169
+ if group_id not in seen:
170
+ deduped.append(group_id)
171
+ return self.update_project(project_id, {"groups": deduped})
172
+
133
173
  def extract_projects_list(self, projects_data: Any) -> List[Dict[str, Any]]:
134
174
  if not isinstance(projects_data, dict):
135
175
  return []
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Any, Dict, Optional
6
+ from typing import Any, Dict, Optional, Iterator, Tuple
7
7
  import time
8
8
  import httpx
9
9
 
@@ -52,6 +52,42 @@ class AttackForgeTransport:
52
52
  )
53
53
  return self._http2_client
54
54
 
55
+ def _iter_file_objects(self, files: Dict[str, Any]) -> Iterator[Any]:
56
+ for value in files.values():
57
+ file_obj = None
58
+ if hasattr(value, "read"):
59
+ file_obj = value
60
+ elif isinstance(value, tuple) and len(value) >= 2 and hasattr(value[1], "read"):
61
+ file_obj = value[1]
62
+ if file_obj is not None:
63
+ yield file_obj
64
+
65
+ def _capture_file_positions(self, files: Dict[str, Any]) -> Tuple[Dict[int, Optional[int]], bool]:
66
+ positions: Dict[int, Optional[int]] = {}
67
+ reusable = True
68
+ for file_obj in self._iter_file_objects(files):
69
+ try:
70
+ if hasattr(file_obj, "seekable") and not file_obj.seekable():
71
+ reusable = False
72
+ positions[id(file_obj)] = file_obj.tell()
73
+ except Exception:
74
+ positions[id(file_obj)] = None
75
+ reusable = False
76
+ return positions, reusable
77
+
78
+ def _rewind_files(self, files: Dict[str, Any], positions: Dict[int, Optional[int]]) -> bool:
79
+ ok = True
80
+ for file_obj in self._iter_file_objects(files):
81
+ pos = positions.get(id(file_obj))
82
+ if pos is None:
83
+ ok = False
84
+ continue
85
+ try:
86
+ file_obj.seek(pos)
87
+ except Exception:
88
+ ok = False
89
+ return ok
90
+
55
91
  def request(
56
92
  self,
57
93
  method: str,
@@ -68,8 +104,16 @@ class AttackForgeTransport:
68
104
  attempt = 0
69
105
  max_retries = self._config.max_retries if retries is None else retries
70
106
  backoff = self._config.backoff_factor
107
+ file_positions: Optional[Dict[int, Optional[int]]] = None
108
+ files_reusable = True
109
+ if files:
110
+ file_positions, files_reusable = self._capture_file_positions(files)
111
+ if not files_reusable:
112
+ max_retries = 0
71
113
  while True:
72
114
  try:
115
+ if files and attempt > 0 and files_reusable:
116
+ self._rewind_files(files, file_positions or {})
73
117
  request_headers = None
74
118
  if headers:
75
119
  request_headers = headers
@@ -96,19 +140,21 @@ class AttackForgeTransport:
96
140
  continue
97
141
 
98
142
  if response.status_code >= 500 and files:
99
- try:
100
- response = self._get_http2_client().request(
101
- method=method,
102
- url=path,
103
- params=params,
104
- json=json,
105
- files=files,
106
- data=data,
107
- headers=request_headers,
108
- timeout=timeout,
109
- )
110
- except httpx.HTTPError:
111
- pass
143
+ if files_reusable:
144
+ self._rewind_files(files, file_positions or {})
145
+ try:
146
+ response = self._get_http2_client().request(
147
+ method=method,
148
+ url=path,
149
+ params=params,
150
+ json=json,
151
+ files=files,
152
+ data=data,
153
+ headers=request_headers,
154
+ timeout=timeout,
155
+ )
156
+ except httpx.HTTPError:
157
+ pass
112
158
 
113
159
  if response.status_code in (429, 500, 502, 503, 504) and attempt < max_retries:
114
160
  time.sleep(backoff * (2 ** attempt))