cve-sentinel 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,496 @@
1
+ """Vulnerability matcher for correlating packages with CVEs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, List, Optional, Set, Tuple
9
+
10
+ from packaging.version import InvalidVersion, Version
11
+
12
+ from cve_sentinel.analyzers.base import Package
13
+ from cve_sentinel.fetchers.nvd import CVEData, NVDClient
14
+ from cve_sentinel.fetchers.osv import OSVClient, OSVVulnerability
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class VulnerabilityMatch:
21
+ """Represents a matched vulnerability for a package."""
22
+
23
+ cve_id: str
24
+ package: Package
25
+ severity: str
26
+ cvss_score: Optional[float]
27
+ description: str
28
+ fix_version: Optional[str]
29
+ fix_command: Optional[str]
30
+ affected_files: List[Dict] = field(default_factory=list)
31
+ references: List[str] = field(default_factory=list)
32
+ osv_id: Optional[str] = None
33
+
34
+ def __hash__(self) -> int:
35
+ return hash((self.cve_id, self.package.name, self.package.ecosystem))
36
+
37
+ def __eq__(self, other: object) -> bool:
38
+ if not isinstance(other, VulnerabilityMatch):
39
+ return False
40
+ return (
41
+ self.cve_id == other.cve_id
42
+ and self.package.name == other.package.name
43
+ and self.package.ecosystem == other.package.ecosystem
44
+ )
45
+
46
+ def to_dict(self) -> Dict:
47
+ """Convert to dictionary for serialization."""
48
+ return {
49
+ "cve_id": self.cve_id,
50
+ "osv_id": self.osv_id,
51
+ "package": {
52
+ "name": self.package.name,
53
+ "version": self.package.version,
54
+ "ecosystem": self.package.ecosystem,
55
+ },
56
+ "severity": self.severity,
57
+ "cvss_score": self.cvss_score,
58
+ "description": self.description,
59
+ "fix_version": self.fix_version,
60
+ "fix_command": self.fix_command,
61
+ "affected_files": self.affected_files,
62
+ "references": self.references,
63
+ }
64
+
65
+
66
+ class VersionMatcher:
67
+ """Utility class for version comparison and matching."""
68
+
69
+ @staticmethod
70
+ def parse_version(version_str: str) -> Optional[Version]:
71
+ """Parse a version string into a Version object.
72
+
73
+ Args:
74
+ version_str: Version string to parse.
75
+
76
+ Returns:
77
+ Version object if parseable, None otherwise.
78
+ """
79
+ if not version_str or version_str == "*":
80
+ return None
81
+
82
+ # Clean up version string
83
+ version_str = version_str.strip()
84
+ # Remove leading 'v' if present
85
+ if version_str.startswith("v"):
86
+ version_str = version_str[1:]
87
+
88
+ try:
89
+ return Version(version_str)
90
+ except InvalidVersion:
91
+ # Try to extract a valid version
92
+ match = re.match(r"^(\d+(?:\.\d+)*)", version_str)
93
+ if match:
94
+ try:
95
+ return Version(match.group(1))
96
+ except InvalidVersion:
97
+ pass
98
+ logger.debug(f"Could not parse version: {version_str}")
99
+ return None
100
+
101
+ @staticmethod
102
+ def is_version_affected(
103
+ version: str,
104
+ affected_ranges: List[Dict],
105
+ ) -> Tuple[bool, Optional[str]]:
106
+ """Check if a version is affected by the vulnerability.
107
+
108
+ Args:
109
+ version: The version to check.
110
+ affected_ranges: List of affected version ranges from OSV.
111
+
112
+ Returns:
113
+ Tuple of (is_affected, fix_version).
114
+ """
115
+ parsed_version = VersionMatcher.parse_version(version)
116
+ if parsed_version is None:
117
+ # Can't determine, assume affected
118
+ return True, None
119
+
120
+ fix_version: Optional[str] = None
121
+
122
+ for affected in affected_ranges:
123
+ ranges = affected.get("ranges", [])
124
+ versions = affected.get("versions", [])
125
+
126
+ # Check explicit version list
127
+ if versions:
128
+ if version in versions:
129
+ # Find fix version from ranges
130
+ for r in ranges:
131
+ events = r.get("events", [])
132
+ for event in events:
133
+ if "fixed" in event:
134
+ fix_version = event["fixed"]
135
+ return True, fix_version
136
+
137
+ # Check version ranges
138
+ for r in ranges:
139
+ range_type = r.get("type", "")
140
+ events = r.get("events", [])
141
+
142
+ introduced: Optional[str] = None
143
+ fixed: Optional[str] = None
144
+
145
+ for event in events:
146
+ if "introduced" in event:
147
+ introduced = event["introduced"]
148
+ if "fixed" in event:
149
+ fixed = event["fixed"]
150
+
151
+ # Check if version falls within range
152
+ if VersionMatcher._is_in_range(parsed_version, introduced, fixed, range_type):
153
+ return True, fixed
154
+
155
+ return False, None
156
+
157
+ @staticmethod
158
+ def _is_in_range(
159
+ version: Version,
160
+ introduced: Optional[str],
161
+ fixed: Optional[str],
162
+ range_type: str,
163
+ ) -> bool:
164
+ """Check if version is within the affected range.
165
+
166
+ Args:
167
+ version: Parsed version to check.
168
+ introduced: Version where vulnerability was introduced.
169
+ fixed: Version where vulnerability was fixed.
170
+ range_type: Type of range (SEMVER, ECOSYSTEM, GIT).
171
+
172
+ Returns:
173
+ True if version is in the affected range.
174
+ """
175
+ # Parse introduced version
176
+ if introduced and introduced != "0":
177
+ introduced_ver = VersionMatcher.parse_version(introduced)
178
+ if introduced_ver and version < introduced_ver:
179
+ return False
180
+
181
+ # Parse fixed version
182
+ if fixed:
183
+ fixed_ver = VersionMatcher.parse_version(fixed)
184
+ if fixed_ver and version >= fixed_ver:
185
+ return False
186
+
187
+ return True
188
+
189
+ @staticmethod
190
+ def compare_versions(v1: str, v2: str) -> int:
191
+ """Compare two version strings.
192
+
193
+ Args:
194
+ v1: First version.
195
+ v2: Second version.
196
+
197
+ Returns:
198
+ -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2.
199
+ """
200
+ parsed_v1 = VersionMatcher.parse_version(v1)
201
+ parsed_v2 = VersionMatcher.parse_version(v2)
202
+
203
+ if parsed_v1 is None or parsed_v2 is None:
204
+ return 0
205
+
206
+ if parsed_v1 < parsed_v2:
207
+ return -1
208
+ elif parsed_v1 > parsed_v2:
209
+ return 1
210
+ return 0
211
+
212
+
213
+ class VulnerabilityMatcher:
214
+ """Matches detected packages against known CVEs."""
215
+
216
+ # Fix command templates by ecosystem
217
+ FIX_COMMANDS = {
218
+ "npm": "npm install {package}@{version}",
219
+ "pypi": "pip install {package}=={version}",
220
+ "go": "go get {package}@v{version}",
221
+ "maven": "Update version in pom.xml to {version}",
222
+ "rubygems": "bundle update {package}",
223
+ "crates.io": "cargo update -p {package}",
224
+ "packagist": "composer require {package}:{version}",
225
+ }
226
+
227
+ def __init__(
228
+ self,
229
+ nvd_client: Optional[NVDClient] = None,
230
+ osv_client: Optional[OSVClient] = None,
231
+ fetch_nvd_details: bool = True,
232
+ ) -> None:
233
+ """Initialize matcher with API clients.
234
+
235
+ Args:
236
+ nvd_client: NVD API client for CVE details.
237
+ osv_client: OSV API client for vulnerability queries.
238
+ fetch_nvd_details: Whether to fetch additional details from NVD.
239
+ """
240
+ self.nvd_client = nvd_client
241
+ self.osv_client = osv_client
242
+ self.fetch_nvd_details = fetch_nvd_details
243
+ self._seen_vulns: Set[Tuple[str, str, str]] = set()
244
+
245
+ def match(self, packages: List[Package]) -> List[VulnerabilityMatch]:
246
+ """Match packages against known vulnerabilities.
247
+
248
+ Args:
249
+ packages: List of packages to check.
250
+
251
+ Returns:
252
+ List of vulnerability matches.
253
+ """
254
+ if not self.osv_client:
255
+ logger.warning("No OSV client configured, skipping vulnerability matching")
256
+ return []
257
+
258
+ matches: List[VulnerabilityMatch] = []
259
+ self._seen_vulns.clear()
260
+
261
+ # Group packages by ecosystem for batch querying
262
+ packages_by_ecosystem: Dict[str, List[Package]] = {}
263
+ for pkg in packages:
264
+ if pkg.ecosystem not in packages_by_ecosystem:
265
+ packages_by_ecosystem[pkg.ecosystem] = []
266
+ packages_by_ecosystem[pkg.ecosystem].append(pkg)
267
+
268
+ # Process each package
269
+ for ecosystem, eco_packages in packages_by_ecosystem.items():
270
+ # Build batch query
271
+ batch_packages = [
272
+ {
273
+ "name": pkg.name,
274
+ "ecosystem": ecosystem,
275
+ "version": pkg.version if pkg.version != "*" else None,
276
+ }
277
+ for pkg in eco_packages
278
+ ]
279
+
280
+ try:
281
+ # Query OSV for vulnerabilities
282
+ results = self.osv_client.query_batch(batch_packages)
283
+
284
+ # Process results
285
+ for pkg in eco_packages:
286
+ pkg_key = f"{ecosystem}:{pkg.name}"
287
+ vulns = results.get(pkg_key, [])
288
+
289
+ for osv_vuln in vulns:
290
+ match = self._process_osv_vulnerability(pkg, osv_vuln)
291
+ if match:
292
+ matches.append(match)
293
+
294
+ except Exception as e:
295
+ logger.error(f"Error querying OSV for {ecosystem}: {e}")
296
+ # Fall back to individual queries
297
+ for pkg in eco_packages:
298
+ try:
299
+ vulns = self.osv_client.query(
300
+ pkg.name,
301
+ ecosystem,
302
+ pkg.version if pkg.version != "*" else None,
303
+ )
304
+ for osv_vuln in vulns:
305
+ match = self._process_osv_vulnerability(pkg, osv_vuln)
306
+ if match:
307
+ matches.append(match)
308
+ except Exception as e2:
309
+ logger.error(f"Error querying OSV for {pkg.name}: {e2}")
310
+
311
+ return matches
312
+
313
+ def _process_osv_vulnerability(
314
+ self,
315
+ package: Package,
316
+ osv_vuln: OSVVulnerability,
317
+ ) -> Optional[VulnerabilityMatch]:
318
+ """Process an OSV vulnerability and create a match.
319
+
320
+ Args:
321
+ package: The affected package.
322
+ osv_vuln: The OSV vulnerability data.
323
+
324
+ Returns:
325
+ VulnerabilityMatch if affected, None otherwise.
326
+ """
327
+ # Check version affectedness
328
+ is_affected, fix_version = VersionMatcher.is_version_affected(
329
+ package.version,
330
+ osv_vuln.affected,
331
+ )
332
+
333
+ if not is_affected:
334
+ return None
335
+
336
+ # Get CVE ID (prefer CVE over GHSA/OSV ID)
337
+ cve_ids = osv_vuln.get_cve_ids()
338
+ primary_id = cve_ids[0] if cve_ids else osv_vuln.id
339
+
340
+ # Check for duplicates
341
+ vuln_key = (primary_id, package.name, package.ecosystem)
342
+ if vuln_key in self._seen_vulns:
343
+ return None
344
+ self._seen_vulns.add(vuln_key)
345
+
346
+ # Get CVSS score and severity from OSV
347
+ cvss_score = osv_vuln.get_cvss_score()
348
+ severity = osv_vuln.get_cvss_severity()
349
+ description = osv_vuln.summary or osv_vuln.details
350
+ references = osv_vuln.references
351
+
352
+ # Try to get better data from NVD if we have a CVE ID
353
+ if cve_ids and self.nvd_client and self.fetch_nvd_details:
354
+ nvd_data = self._fetch_nvd_details(cve_ids[0])
355
+ if nvd_data:
356
+ if nvd_data.cvss_score is not None:
357
+ cvss_score = nvd_data.cvss_score
358
+ if nvd_data.cvss_severity:
359
+ severity = nvd_data.cvss_severity
360
+ if nvd_data.description:
361
+ description = nvd_data.description
362
+ if nvd_data.references:
363
+ references = list(set(references + nvd_data.references))
364
+
365
+ # Use fix version from OSV if not found in range check
366
+ if not fix_version and osv_vuln.fixed_versions:
367
+ fix_version = osv_vuln.fixed_versions[0]
368
+
369
+ # Generate fix command
370
+ fix_command = None
371
+ if fix_version:
372
+ fix_command = self.generate_fix_command(package, fix_version)
373
+
374
+ # Build affected files info
375
+ affected_files = []
376
+ if package.source_file:
377
+ affected_files.append(
378
+ {
379
+ "file": str(package.source_file),
380
+ "line": package.source_line,
381
+ }
382
+ )
383
+
384
+ return VulnerabilityMatch(
385
+ cve_id=primary_id,
386
+ osv_id=osv_vuln.id if osv_vuln.id != primary_id else None,
387
+ package=package,
388
+ severity=severity or "UNKNOWN",
389
+ cvss_score=cvss_score,
390
+ description=description[:500] if description else "",
391
+ fix_version=fix_version,
392
+ fix_command=fix_command,
393
+ affected_files=affected_files,
394
+ references=references[:5], # Limit references
395
+ )
396
+
397
+ def _fetch_nvd_details(self, cve_id: str) -> Optional[CVEData]:
398
+ """Fetch CVE details from NVD.
399
+
400
+ Args:
401
+ cve_id: CVE identifier.
402
+
403
+ Returns:
404
+ CVEData if found, None otherwise.
405
+ """
406
+ if not self.nvd_client:
407
+ return None
408
+
409
+ try:
410
+ return self.nvd_client.get_cve(cve_id)
411
+ except Exception as e:
412
+ logger.debug(f"Could not fetch NVD data for {cve_id}: {e}")
413
+ return None
414
+
415
+ def generate_fix_command(
416
+ self,
417
+ package: Package,
418
+ fix_version: str,
419
+ ) -> Optional[str]:
420
+ """Generate fix command for a package.
421
+
422
+ Args:
423
+ package: The package to fix.
424
+ fix_version: The version to upgrade to.
425
+
426
+ Returns:
427
+ Fix command string or None if not supported.
428
+ """
429
+ template = self.FIX_COMMANDS.get(package.ecosystem)
430
+ if template:
431
+ return template.format(package=package.name, version=fix_version)
432
+ return None
433
+
434
+ def match_single(self, package: Package) -> List[VulnerabilityMatch]:
435
+ """Match a single package against known vulnerabilities.
436
+
437
+ Args:
438
+ package: Package to check.
439
+
440
+ Returns:
441
+ List of vulnerability matches.
442
+ """
443
+ return self.match([package])
444
+
445
+ def get_severity_counts(
446
+ self,
447
+ matches: List[VulnerabilityMatch],
448
+ ) -> Dict[str, int]:
449
+ """Get count of vulnerabilities by severity.
450
+
451
+ Args:
452
+ matches: List of vulnerability matches.
453
+
454
+ Returns:
455
+ Dictionary mapping severity to count.
456
+ """
457
+ counts: Dict[str, int] = {
458
+ "CRITICAL": 0,
459
+ "HIGH": 0,
460
+ "MEDIUM": 0,
461
+ "LOW": 0,
462
+ "UNKNOWN": 0,
463
+ }
464
+ for match in matches:
465
+ severity = match.severity.upper() if match.severity else "UNKNOWN"
466
+ if severity in counts:
467
+ counts[severity] += 1
468
+ else:
469
+ counts["UNKNOWN"] += 1
470
+ return counts
471
+
472
+ def filter_by_severity(
473
+ self,
474
+ matches: List[VulnerabilityMatch],
475
+ min_severity: str = "LOW",
476
+ ) -> List[VulnerabilityMatch]:
477
+ """Filter matches by minimum severity.
478
+
479
+ Args:
480
+ matches: List of vulnerability matches.
481
+ min_severity: Minimum severity to include.
482
+
483
+ Returns:
484
+ Filtered list of matches.
485
+ """
486
+ severity_order = ["CRITICAL", "HIGH", "MEDIUM", "LOW", "UNKNOWN"]
487
+ try:
488
+ min_index = severity_order.index(min_severity.upper())
489
+ except ValueError:
490
+ min_index = len(severity_order) - 1
491
+
492
+ return [
493
+ m
494
+ for m in matches
495
+ if severity_order.index(m.severity.upper() if m.severity else "UNKNOWN") <= min_index
496
+ ]