cve-sentinel 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,544 @@
1
+ """NVD (National Vulnerability Database) API client."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import requests
13
+
14
+ from cve_sentinel.utils.cache import Cache
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class NVDAPIError(Exception):
20
+ """Exception raised for NVD API errors."""
21
+
22
+ def __init__(self, message: str, status_code: Optional[int] = None) -> None:
23
+ super().__init__(message)
24
+ self.status_code = status_code
25
+
26
+
27
+ class NVDRateLimitError(NVDAPIError):
28
+ """Exception raised when rate limit is exceeded."""
29
+
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class CVEData:
35
+ """CVE data from NVD.
36
+
37
+ Attributes:
38
+ cve_id: The CVE identifier (e.g., CVE-2021-44228).
39
+ description: Description of the vulnerability.
40
+ cvss_score: CVSS v3.x base score (0.0-10.0).
41
+ cvss_severity: Severity level (CRITICAL, HIGH, MEDIUM, LOW).
42
+ affected_cpes: List of affected CPE URIs.
43
+ fixed_versions: List of fixed versions if known.
44
+ references: List of reference URLs.
45
+ published_date: Date the CVE was published.
46
+ last_modified: Date the CVE was last modified.
47
+ """
48
+
49
+ cve_id: str
50
+ description: str
51
+ cvss_score: Optional[float]
52
+ cvss_severity: Optional[str]
53
+ affected_cpes: List[str]
54
+ fixed_versions: Optional[List[str]]
55
+ references: List[str]
56
+ published_date: datetime
57
+ last_modified: datetime
58
+
59
+ def to_dict(self) -> Dict[str, Any]:
60
+ """Convert to dictionary for serialization."""
61
+ return {
62
+ "cve_id": self.cve_id,
63
+ "description": self.description,
64
+ "cvss_score": self.cvss_score,
65
+ "cvss_severity": self.cvss_severity,
66
+ "affected_cpes": self.affected_cpes,
67
+ "fixed_versions": self.fixed_versions,
68
+ "references": self.references,
69
+ "published_date": self.published_date.isoformat(),
70
+ "last_modified": self.last_modified.isoformat(),
71
+ }
72
+
73
+ @classmethod
74
+ def from_dict(cls, data: Dict[str, Any]) -> CVEData:
75
+ """Create CVEData from dictionary."""
76
+ return cls(
77
+ cve_id=data["cve_id"],
78
+ description=data["description"],
79
+ cvss_score=data["cvss_score"],
80
+ cvss_severity=data["cvss_severity"],
81
+ affected_cpes=data["affected_cpes"],
82
+ fixed_versions=data["fixed_versions"],
83
+ references=data["references"],
84
+ published_date=datetime.fromisoformat(data["published_date"]),
85
+ last_modified=datetime.fromisoformat(data["last_modified"]),
86
+ )
87
+
88
+
89
+ @dataclass
90
+ class RateLimiter:
91
+ """Rate limiter for API requests.
92
+
93
+ Implements a sliding window rate limiter for NVD API.
94
+ With API key: 50 requests per 30 seconds.
95
+ """
96
+
97
+ max_requests: int = 50
98
+ window_seconds: int = 30
99
+ request_times: List[float] = field(default_factory=list)
100
+
101
+ def wait_if_needed(self) -> None:
102
+ """Wait if rate limit would be exceeded."""
103
+ now = time.time()
104
+
105
+ # Remove old requests outside the window
106
+ self.request_times = [t for t in self.request_times if now - t < self.window_seconds]
107
+
108
+ if len(self.request_times) >= self.max_requests:
109
+ # Need to wait until oldest request exits the window
110
+ oldest = self.request_times[0]
111
+ wait_time = self.window_seconds - (now - oldest) + 0.1
112
+ if wait_time > 0:
113
+ logger.debug(f"Rate limit reached, waiting {wait_time:.2f}s")
114
+ time.sleep(wait_time)
115
+ # Recursively check again
116
+ self.wait_if_needed()
117
+ return
118
+
119
+ self.request_times.append(time.time())
120
+
121
+
122
+ class NVDClient:
123
+ """Client for NVD API 2.0.
124
+
125
+ This client provides methods to search and retrieve CVE data from the
126
+ National Vulnerability Database (NVD) API.
127
+
128
+ Attributes:
129
+ api_key: NVD API key for authentication.
130
+ cache: Optional cache for storing API responses.
131
+ rate_limiter: Rate limiter to respect API limits.
132
+ """
133
+
134
+ BASE_URL = "https://services.nvd.nist.gov/rest/json/cves/2.0"
135
+ DEFAULT_TIMEOUT = 30
136
+ MAX_RETRIES = 3
137
+ RETRY_DELAY = 2
138
+
139
+ def __init__(
140
+ self,
141
+ api_key: str,
142
+ cache_dir: Optional[Path] = None,
143
+ cache_ttl_hours: int = 24,
144
+ ) -> None:
145
+ """Initialize NVD client with API key.
146
+
147
+ Args:
148
+ api_key: NVD API key for authentication.
149
+ cache_dir: Directory for caching responses. If None, caching is disabled.
150
+ cache_ttl_hours: Cache time-to-live in hours.
151
+ """
152
+ self.api_key = api_key
153
+ self.cache: Optional[Cache] = None
154
+ if cache_dir:
155
+ self.cache = Cache(cache_dir, ttl_hours=cache_ttl_hours)
156
+ self.rate_limiter = RateLimiter()
157
+ self._session = requests.Session()
158
+ self._session.headers.update(
159
+ {
160
+ "apiKey": api_key,
161
+ "Accept": "application/json",
162
+ }
163
+ )
164
+
165
+ def _make_request(
166
+ self,
167
+ params: Dict[str, Any],
168
+ use_cache: bool = True,
169
+ ) -> Dict[str, Any]:
170
+ """Make a request to the NVD API with retry logic.
171
+
172
+ Args:
173
+ params: Query parameters for the request.
174
+ use_cache: Whether to use caching for this request.
175
+
176
+ Returns:
177
+ JSON response from the API.
178
+
179
+ Raises:
180
+ NVDAPIError: If the API request fails after retries.
181
+ NVDRateLimitError: If rate limit is exceeded and retries fail.
182
+ """
183
+ # Check cache first
184
+ cache_key = f"nvd_{hash(frozenset(params.items()))}"
185
+ if use_cache and self.cache:
186
+ cached = self.cache.get(cache_key)
187
+ if cached:
188
+ logger.debug(f"Cache hit for NVD query: {params}")
189
+ return cached
190
+
191
+ # Rate limiting
192
+ self.rate_limiter.wait_if_needed()
193
+
194
+ last_error: Optional[Exception] = None
195
+ for attempt in range(self.MAX_RETRIES):
196
+ try:
197
+ response = self._session.get(
198
+ self.BASE_URL,
199
+ params=params,
200
+ timeout=self.DEFAULT_TIMEOUT,
201
+ )
202
+
203
+ if response.status_code == 200:
204
+ data = response.json()
205
+ # Cache successful response
206
+ if use_cache and self.cache:
207
+ self.cache.set(cache_key, data)
208
+ return data
209
+
210
+ elif response.status_code == 403:
211
+ raise NVDRateLimitError(
212
+ "NVD API rate limit exceeded",
213
+ status_code=403,
214
+ )
215
+
216
+ elif response.status_code == 404:
217
+ # No results found - return empty response
218
+ return {"vulnerabilities": [], "totalResults": 0}
219
+
220
+ else:
221
+ raise NVDAPIError(
222
+ f"NVD API error: {response.status_code} - {response.text}",
223
+ status_code=response.status_code,
224
+ )
225
+
226
+ except requests.exceptions.Timeout as e:
227
+ last_error = NVDAPIError(f"Request timeout: {e}")
228
+ logger.warning(f"NVD API timeout (attempt {attempt + 1}/{self.MAX_RETRIES})")
229
+
230
+ except requests.exceptions.ConnectionError as e:
231
+ last_error = NVDAPIError(f"Connection error: {e}")
232
+ logger.warning(
233
+ f"NVD API connection error (attempt {attempt + 1}/{self.MAX_RETRIES})"
234
+ )
235
+
236
+ except NVDRateLimitError:
237
+ # Wait longer for rate limit errors
238
+ if attempt < self.MAX_RETRIES - 1:
239
+ wait_time = self.RETRY_DELAY * (attempt + 1) * 2
240
+ logger.warning(f"Rate limited, waiting {wait_time}s before retry")
241
+ time.sleep(wait_time)
242
+ else:
243
+ raise
244
+
245
+ except NVDAPIError:
246
+ raise
247
+
248
+ # Wait before retry
249
+ if attempt < self.MAX_RETRIES - 1:
250
+ time.sleep(self.RETRY_DELAY * (attempt + 1))
251
+
252
+ raise last_error or NVDAPIError("Unknown error after retries")
253
+
254
+ def _parse_cve_item(self, cve_item: Dict[str, Any]) -> CVEData:
255
+ """Parse a CVE item from NVD API response.
256
+
257
+ Args:
258
+ cve_item: Raw CVE item from API response.
259
+
260
+ Returns:
261
+ Parsed CVEData object.
262
+ """
263
+ cve = cve_item.get("cve", {})
264
+
265
+ # Get CVE ID
266
+ cve_id = cve.get("id", "")
267
+
268
+ # Get description (prefer English)
269
+ description = ""
270
+ descriptions = cve.get("descriptions", [])
271
+ for desc in descriptions:
272
+ if desc.get("lang") == "en":
273
+ description = desc.get("value", "")
274
+ break
275
+ if not description and descriptions:
276
+ description = descriptions[0].get("value", "")
277
+
278
+ # Get CVSS score and severity (prefer v3.1, fallback to v3.0)
279
+ cvss_score: Optional[float] = None
280
+ cvss_severity: Optional[str] = None
281
+ metrics = cve.get("metrics", {})
282
+
283
+ # Try CVSS v3.1 first
284
+ cvss_v31 = metrics.get("cvssMetricV31", [])
285
+ if cvss_v31:
286
+ cvss_data = cvss_v31[0].get("cvssData", {})
287
+ cvss_score = cvss_data.get("baseScore")
288
+ cvss_severity = cvss_data.get("baseSeverity")
289
+ else:
290
+ # Fallback to CVSS v3.0
291
+ cvss_v30 = metrics.get("cvssMetricV30", [])
292
+ if cvss_v30:
293
+ cvss_data = cvss_v30[0].get("cvssData", {})
294
+ cvss_score = cvss_data.get("baseScore")
295
+ cvss_severity = cvss_data.get("baseSeverity")
296
+
297
+ # Get affected CPEs
298
+ affected_cpes: List[str] = []
299
+ configurations = cve.get("configurations", [])
300
+ for config in configurations:
301
+ nodes = config.get("nodes", [])
302
+ for node in nodes:
303
+ cpe_matches = node.get("cpeMatch", [])
304
+ for cpe_match in cpe_matches:
305
+ if cpe_match.get("vulnerable", False):
306
+ cpe_uri = cpe_match.get("criteria", "")
307
+ if cpe_uri:
308
+ affected_cpes.append(cpe_uri)
309
+
310
+ # Get references
311
+ references: List[str] = []
312
+ refs = cve.get("references", [])
313
+ for ref in refs:
314
+ url = ref.get("url", "")
315
+ if url:
316
+ references.append(url)
317
+
318
+ # Get dates
319
+ published_str = cve.get("published", "")
320
+ modified_str = cve.get("lastModified", "")
321
+
322
+ # Parse dates (handle 'Z' suffix)
323
+ try:
324
+ if published_str:
325
+ published_str = published_str.replace("Z", "+00:00")
326
+ published_date = datetime.fromisoformat(published_str)
327
+ else:
328
+ published_date = datetime.now(timezone.utc)
329
+ except ValueError:
330
+ published_date = datetime.now(timezone.utc)
331
+
332
+ try:
333
+ if modified_str:
334
+ modified_str = modified_str.replace("Z", "+00:00")
335
+ last_modified = datetime.fromisoformat(modified_str)
336
+ else:
337
+ last_modified = published_date
338
+ except ValueError:
339
+ last_modified = published_date
340
+
341
+ return CVEData(
342
+ cve_id=cve_id,
343
+ description=description,
344
+ cvss_score=cvss_score,
345
+ cvss_severity=cvss_severity,
346
+ affected_cpes=affected_cpes,
347
+ fixed_versions=None, # NVD doesn't provide this directly
348
+ references=references,
349
+ published_date=published_date,
350
+ last_modified=last_modified,
351
+ )
352
+
353
+ def search_by_keyword(
354
+ self,
355
+ keyword: str,
356
+ start_index: int = 0,
357
+ results_per_page: int = 100,
358
+ ) -> List[CVEData]:
359
+ """Search CVEs by keyword.
360
+
361
+ Args:
362
+ keyword: Search keyword (package name, etc.).
363
+ start_index: Starting index for pagination.
364
+ results_per_page: Number of results per page (max 2000).
365
+
366
+ Returns:
367
+ List of matching CVEData objects.
368
+ """
369
+ params = {
370
+ "keywordSearch": keyword,
371
+ "startIndex": start_index,
372
+ "resultsPerPage": min(results_per_page, 2000),
373
+ }
374
+
375
+ response = self._make_request(params)
376
+ vulnerabilities = response.get("vulnerabilities", [])
377
+
378
+ results: List[CVEData] = []
379
+ for vuln in vulnerabilities:
380
+ try:
381
+ cve_data = self._parse_cve_item(vuln)
382
+ results.append(cve_data)
383
+ except (KeyError, ValueError) as e:
384
+ logger.warning(f"Failed to parse CVE item: {e}")
385
+ continue
386
+
387
+ return results
388
+
389
+ def search_by_cpe(
390
+ self,
391
+ cpe_name: str,
392
+ start_index: int = 0,
393
+ results_per_page: int = 100,
394
+ ) -> List[CVEData]:
395
+ """Search CVEs by CPE name.
396
+
397
+ Args:
398
+ cpe_name: CPE URI string (e.g., cpe:2.3:a:vendor:product:*:*:*:*:*:*:*:*).
399
+ start_index: Starting index for pagination.
400
+ results_per_page: Number of results per page (max 2000).
401
+
402
+ Returns:
403
+ List of matching CVEData objects.
404
+ """
405
+ params = {
406
+ "cpeName": cpe_name,
407
+ "startIndex": start_index,
408
+ "resultsPerPage": min(results_per_page, 2000),
409
+ }
410
+
411
+ response = self._make_request(params)
412
+ vulnerabilities = response.get("vulnerabilities", [])
413
+
414
+ results: List[CVEData] = []
415
+ for vuln in vulnerabilities:
416
+ try:
417
+ cve_data = self._parse_cve_item(vuln)
418
+ results.append(cve_data)
419
+ except (KeyError, ValueError) as e:
420
+ logger.warning(f"Failed to parse CVE item: {e}")
421
+ continue
422
+
423
+ return results
424
+
425
+ def search_by_date_range(
426
+ self,
427
+ start_date: datetime,
428
+ end_date: Optional[datetime] = None,
429
+ start_index: int = 0,
430
+ results_per_page: int = 100,
431
+ ) -> List[CVEData]:
432
+ """Search CVEs modified within a date range.
433
+
434
+ Args:
435
+ start_date: Start of the date range.
436
+ end_date: End of the date range (defaults to now).
437
+ start_index: Starting index for pagination.
438
+ results_per_page: Number of results per page (max 2000).
439
+
440
+ Returns:
441
+ List of matching CVEData objects.
442
+ """
443
+ if end_date is None:
444
+ end_date = datetime.now(timezone.utc)
445
+
446
+ params = {
447
+ "lastModStartDate": start_date.strftime("%Y-%m-%dT%H:%M:%S.000"),
448
+ "lastModEndDate": end_date.strftime("%Y-%m-%dT%H:%M:%S.000"),
449
+ "startIndex": start_index,
450
+ "resultsPerPage": min(results_per_page, 2000),
451
+ }
452
+
453
+ response = self._make_request(params)
454
+ vulnerabilities = response.get("vulnerabilities", [])
455
+
456
+ results: List[CVEData] = []
457
+ for vuln in vulnerabilities:
458
+ try:
459
+ cve_data = self._parse_cve_item(vuln)
460
+ results.append(cve_data)
461
+ except (KeyError, ValueError) as e:
462
+ logger.warning(f"Failed to parse CVE item: {e}")
463
+ continue
464
+
465
+ return results
466
+
467
+ def get_cve(self, cve_id: str) -> Optional[CVEData]:
468
+ """Get a specific CVE by ID.
469
+
470
+ Args:
471
+ cve_id: CVE identifier (e.g., CVE-2021-44228).
472
+
473
+ Returns:
474
+ CVEData object if found, None otherwise.
475
+ """
476
+ params = {"cveId": cve_id}
477
+
478
+ try:
479
+ response = self._make_request(params)
480
+ vulnerabilities = response.get("vulnerabilities", [])
481
+
482
+ if vulnerabilities:
483
+ return self._parse_cve_item(vulnerabilities[0])
484
+ return None
485
+
486
+ except NVDAPIError as e:
487
+ if e.status_code == 404:
488
+ return None
489
+ raise
490
+
491
+ def get_total_results(self, keyword: str) -> int:
492
+ """Get total number of results for a keyword search.
493
+
494
+ Args:
495
+ keyword: Search keyword.
496
+
497
+ Returns:
498
+ Total number of matching CVEs.
499
+ """
500
+ params = {
501
+ "keywordSearch": keyword,
502
+ "startIndex": 0,
503
+ "resultsPerPage": 1,
504
+ }
505
+
506
+ response = self._make_request(params, use_cache=False)
507
+ return response.get("totalResults", 0)
508
+
509
+ def search_all_by_keyword(
510
+ self,
511
+ keyword: str,
512
+ max_results: int = 1000,
513
+ ) -> List[CVEData]:
514
+ """Search all CVEs by keyword with pagination.
515
+
516
+ Args:
517
+ keyword: Search keyword.
518
+ max_results: Maximum number of results to return.
519
+
520
+ Returns:
521
+ List of all matching CVEData objects up to max_results.
522
+ """
523
+ all_results: List[CVEData] = []
524
+ start_index = 0
525
+ results_per_page = 100
526
+
527
+ while len(all_results) < max_results:
528
+ results = self.search_by_keyword(
529
+ keyword,
530
+ start_index=start_index,
531
+ results_per_page=results_per_page,
532
+ )
533
+
534
+ if not results:
535
+ break
536
+
537
+ all_results.extend(results)
538
+ start_index += len(results)
539
+
540
+ # Check if we've retrieved all results
541
+ if len(results) < results_per_page:
542
+ break
543
+
544
+ return all_results[:max_results]