cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,415 @@
1
+ """
2
+ 权重适配层 (Weight Adapter)
3
+
4
+ 核心功能:
5
+ 将开源大模型的原生权重无损映射到 CortexNet 的模块结构。
6
+
7
+ 技术要点:
8
+ 1. 模式化权重名称匹配(支持模糊匹配)
9
+ 2. 维度自适应投影(当源模型与 CortexNet 维度不匹配时)
10
+ 3. 归一化参数自动转换(LayerNorm ↔ RMSNorm)
11
+ 4. GQA 权重扩展(num_kv_heads < num_heads 时复制 KV 投影)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import re
17
+ import logging
18
+ from typing import Dict, List, Optional, Tuple, Any
19
+ from collections import OrderedDict
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ # ═══════════════════════════════════════════════════════════════
28
+ # 权重名称映射规则库
29
+ # ═══════════════════════════════════════════════════════════════
30
+
31
+ # 每个模型族定义从 HuggingFace 权重名称到 CortexNet 模块的映射规则。
32
+ # 格式: (source_pattern, cortex_pattern)
33
+ # source_pattern: HuggingFace 权重名中的子串
34
+ # cortex_pattern: CortexNet 权重名模板({layer} 会被替换为层索引)
35
+
36
+ WEIGHT_MAPPING_RULES: Dict[str, List[Tuple[str, str]]] = {
37
+ "llama": [
38
+ # Embedding
39
+ ("model.embed_tokens.weight", "embed.weight"),
40
+ ("lm_head.weight", "lm_head.weight"),
41
+ # Per-layer mappings
42
+ ("self_attn.q_proj.weight", "blocks.{layer}.attention.q_proj.weight"),
43
+ ("self_attn.k_proj.weight", "blocks.{layer}.attention.k_proj.weight"),
44
+ ("self_attn.v_proj.weight", "blocks.{layer}.attention.v_proj.weight"),
45
+ ("self_attn.o_proj.weight", "blocks.{layer}.attention.o_proj.weight"),
46
+ ("mlp.gate_proj.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
47
+ ("mlp.up_proj.weight", "blocks.{layer}.moe.experts.0.up_proj.weight"),
48
+ ("mlp.down_proj.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
49
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
50
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
51
+ # 归一化层 bias(如有)
52
+ ("input_layernorm.bias", "blocks.{layer}.norm1.bias"),
53
+ ("post_attention_layernorm.bias", "blocks.{layer}.norm2.bias"),
54
+ # Final norm
55
+ ("model.norm.weight", "final_norm.weight"),
56
+ ],
57
+ "qwen": [
58
+ ("transformer.wte.weight", "embed.weight"),
59
+ ("lm_head.weight", "lm_head.weight"),
60
+ # Qwen2 style
61
+ ("model.embed_tokens.weight", "embed.weight"),
62
+ ("self_attn.q_proj.weight", "blocks.{layer}.attention.q_proj.weight"),
63
+ ("self_attn.k_proj.weight", "blocks.{layer}.attention.k_proj.weight"),
64
+ ("self_attn.v_proj.weight", "blocks.{layer}.attention.v_proj.weight"),
65
+ ("self_attn.q_norm.weight", "blocks.{layer}.attention.q_norm.weight"),
66
+ ("self_attn.k_norm.weight", "blocks.{layer}.attention.k_norm.weight"),
67
+ ("self_attn.o_proj.weight", "blocks.{layer}.attention.o_proj.weight"),
68
+ ("mlp.gate_proj.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
69
+ ("mlp.up_proj.weight", "blocks.{layer}.moe.experts.0.up_proj.weight"),
70
+ ("mlp.down_proj.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
71
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
72
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
73
+ ("model.norm.weight", "final_norm.weight"),
74
+ ],
75
+ "glm": [
76
+ ("transformer.embedding.word_embeddings.weight", "embed.weight"),
77
+ ("transformer.output_layer.weight", "lm_head.weight"),
78
+ ("self_attention.query_key_value.weight", "blocks.{layer}.attention.qkv_proj.weight"),
79
+ ("self_attention.dense.weight", "blocks.{layer}.attention.o_proj.weight"),
80
+ ("mlp.dense_h_to_4h.weight", "blocks.{layer}.moe.experts.0.gate_up_proj.weight"),
81
+ ("mlp.dense_4h_to_h.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
82
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
83
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
84
+ ("transformer.encoder.final_layernorm.weight", "final_norm.weight"),
85
+ ],
86
+ "baichuan": [
87
+ ("model.embed_tokens.weight", "embed.weight"),
88
+ ("lm_head.weight", "lm_head.weight"),
89
+ ("self_attn.W_pack.weight", "blocks.{layer}.attention.qkv_proj.weight"),
90
+ ("self_attn.o_proj.weight", "blocks.{layer}.attention.o_proj.weight"),
91
+ ("mlp.gate_proj.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
92
+ ("mlp.up_proj.weight", "blocks.{layer}.moe.experts.0.up_proj.weight"),
93
+ ("mlp.down_proj.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
94
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
95
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
96
+ ("model.norm.weight", "final_norm.weight"),
97
+ ],
98
+ "phi": [
99
+ ("model.embed_tokens.weight", "embed.weight"),
100
+ ("lm_head.weight", "lm_head.weight"),
101
+ ("self_attn.q_proj.weight", "blocks.{layer}.attention.q_proj.weight"),
102
+ ("self_attn.k_proj.weight", "blocks.{layer}.attention.k_proj.weight"),
103
+ ("self_attn.v_proj.weight", "blocks.{layer}.attention.v_proj.weight"),
104
+ ("self_attn.dense.weight", "blocks.{layer}.attention.o_proj.weight"),
105
+ ("mlp.fc1.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
106
+ ("mlp.fc2.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
107
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
108
+ ("input_layernorm.bias", "blocks.{layer}.norm1.bias"),
109
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
110
+ ("post_attention_layernorm.bias", "blocks.{layer}.norm2.bias"),
111
+ ("model.final_layernorm.weight", "final_norm.weight"),
112
+ ],
113
+ "gemma": [
114
+ ("model.embed_tokens.weight", "embed.weight"),
115
+ # Gemma 没有单独的 lm_head,使用 embed_tokens
116
+ ("self_attn.q_proj.weight", "blocks.{layer}.attention.q_proj.weight"),
117
+ ("self_attn.k_proj.weight", "blocks.{layer}.attention.k_proj.weight"),
118
+ ("self_attn.v_proj.weight", "blocks.{layer}.attention.v_proj.weight"),
119
+ ("self_attn.o_proj.weight", "blocks.{layer}.attention.o_proj.weight"),
120
+ ("mlp.gate_proj.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
121
+ ("mlp.up_proj.weight", "blocks.{layer}.moe.experts.0.up_proj.weight"),
122
+ ("mlp.down_proj.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
123
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
124
+ ("post_attention_layernorm.weight", "blocks.{layer}.norm2.weight"),
125
+ ("model.norm.weight", "final_norm.weight"),
126
+ ],
127
+ "falcon": [
128
+ ("transformer.word_embeddings.weight", "embed.weight"),
129
+ ("lm_head.weight", "lm_head.weight"),
130
+ ("self_attention.query_key_value.weight", "blocks.{layer}.attention.qkv_proj.weight"),
131
+ ("self_attention.dense.weight", "blocks.{layer}.attention.o_proj.weight"),
132
+ ("mlp.dense_h_to_4h.weight", "blocks.{layer}.moe.experts.0.gate_proj.weight"),
133
+ ("mlp.dense_4h_to_h.weight", "blocks.{layer}.moe.experts.0.down_proj.weight"),
134
+ ("input_layernorm.weight", "blocks.{layer}.norm1.weight"),
135
+ ("input_layernorm.bias", "blocks.{layer}.norm1.bias"),
136
+ ("ln_attn.weight", "blocks.{layer}.norm2.weight"),
137
+ ("ln_attn.bias", "blocks.{layer}.norm2.bias"),
138
+ ("transformer.ln_f.weight", "final_norm.weight"),
139
+ ],
140
+ }
141
+
142
+ # 使用 llama 格式的模型族别名
143
+ for _alias in ["mistral", "yi", "deepseek", "internlm", "codelama"]:
144
+ WEIGHT_MAPPING_RULES[_alias] = WEIGHT_MAPPING_RULES["llama"]
145
+
146
+
147
+ def _extract_layer_idx(name: str) -> Optional[int]:
148
+ """从权重名称中提取层索引。"""
149
+ patterns = [
150
+ r"\.layers\.(\d+)\.", # LLaMA/Qwen/Mistral ...
151
+ r"\.h\.(\d+)\.", # GPT-style
152
+ r"\.encoder\.layers\.(\d+)\.", # ChatGLM
153
+ r"\.blocks\.(\d+)\.", # 某些自定义模型
154
+ r"\.transformer\.h\.(\d+)\.", # GPT-2 style
155
+ ]
156
+ for pat in patterns:
157
+ m = re.search(pat, name)
158
+ if m:
159
+ return int(m.group(1))
160
+ return None
161
+
162
+
163
+ class WeightAdapter:
164
+ """权重适配器:将开源模型权重映射到 CortexNet 模块。
165
+
166
+ Args:
167
+ model_type: 模型族标识 (如 "llama", "qwen")
168
+ config: CortexNetConfig
169
+ """
170
+
171
+ def __init__(self, model_type: str, config: Any):
172
+ self.model_type = model_type
173
+ self.config = config
174
+ self.mapping_rules = self._get_mapping_rules()
175
+ self._unmapped: List[str] = []
176
+ # Lite 注意力使用 kv_proj(合并 K/V),缓存分片流式加载下的半边投影。
177
+ self._pending_kv_proj: Dict[str, Dict[str, torch.Tensor]] = {}
178
+
179
+ def _get_mapping_rules(self) -> List[Tuple[str, str]]:
180
+ """获取当前模型族的映射规则。"""
181
+ # 首先查找精确匹配
182
+ from .model_registry import MODEL_REGISTRY
183
+ model_info = MODEL_REGISTRY.get(self.model_type)
184
+ weight_format = model_info.weight_format if model_info else self.model_type
185
+
186
+ rules = WEIGHT_MAPPING_RULES.get(weight_format)
187
+ if rules is None:
188
+ logger.warning(
189
+ f"No weight mapping rules for '{weight_format}', "
190
+ f"falling back to 'llama' format."
191
+ )
192
+ rules = WEIGHT_MAPPING_RULES["llama"]
193
+
194
+ # Lite 模式:将 MoE 路径重映射为 FFN 路径
195
+ if getattr(self.config, 'lite', False):
196
+ remapped = []
197
+ for src, dst in rules:
198
+ dst = dst.replace("moe.experts.0.", "ffn.")
199
+ remapped.append((src, dst))
200
+ rules = remapped
201
+
202
+ return rules
203
+
204
+ def map_weights(
205
+ self,
206
+ raw_weights: Dict[str, torch.Tensor],
207
+ ) -> Dict[str, torch.Tensor]:
208
+ """将 HuggingFace 权重映射到 CortexNet 权重名。
209
+
210
+ Args:
211
+ raw_weights: HuggingFace 模型 state_dict
212
+
213
+ Returns:
214
+ CortexNet 格式的 state_dict(部分映射)
215
+ """
216
+ cortex_weights: Dict[str, torch.Tensor] = OrderedDict()
217
+ self._unmapped = []
218
+ mapped_count = 0
219
+
220
+ for raw_name, tensor in raw_weights.items():
221
+ mapped = False
222
+
223
+ # 提取层索引
224
+ layer_idx = _extract_layer_idx(raw_name)
225
+
226
+ for src_pattern, dst_pattern in self.mapping_rules:
227
+ if src_pattern in raw_name:
228
+ # 替换层索引占位符
229
+ cortex_name = dst_pattern
230
+ if "{layer}" in cortex_name and layer_idx is not None:
231
+ cortex_name = cortex_name.replace("{layer}", str(layer_idx))
232
+ elif "{layer}" in cortex_name:
233
+ # 无法提取层索引,跳过
234
+ continue
235
+
236
+ # 处理特殊情况:合并的 QKV 投影
237
+ if "qkv_proj" in cortex_name:
238
+ self._split_qkv(cortex_name, tensor, cortex_weights)
239
+ elif "gate_up_proj" in cortex_name:
240
+ self._split_gate_up(cortex_name, tensor, cortex_weights)
241
+ elif getattr(self.config, "lite", False) and (
242
+ cortex_name.endswith(".attention.k_proj.weight")
243
+ or cortex_name.endswith(".attention.v_proj.weight")
244
+ ):
245
+ self._merge_lite_kv(cortex_name, tensor, cortex_weights)
246
+ else:
247
+ cortex_weights[cortex_name] = tensor
248
+
249
+ mapped = True
250
+ mapped_count += 1
251
+ break
252
+
253
+ if not mapped:
254
+ self._unmapped.append(raw_name)
255
+
256
+ # 归一化参数转换
257
+ cortex_weights = self._convert_norm_params(cortex_weights)
258
+
259
+ # GQA 权重扩展
260
+ cortex_weights = self._expand_gqa_weights(cortex_weights)
261
+
262
+ logger.info(
263
+ f"Weight mapping complete: {mapped_count} mapped, "
264
+ f"{len(self._unmapped)} unmapped out of {len(raw_weights)} total"
265
+ )
266
+
267
+ if self._unmapped:
268
+ logger.debug(f"Unmapped weights: {self._unmapped[:10]}...")
269
+
270
+ return cortex_weights
271
+
272
+ def _merge_lite_kv(
273
+ self,
274
+ cortex_name: str,
275
+ tensor: torch.Tensor,
276
+ weights: Dict[str, torch.Tensor],
277
+ ) -> None:
278
+ """Lite 注意力把分离的 K/V 权重合并为 kv_proj.weight。"""
279
+ if cortex_name.endswith(".attention.k_proj.weight"):
280
+ base = cortex_name[: -len("k_proj.weight")]
281
+ slot = "k"
282
+ else:
283
+ base = cortex_name[: -len("v_proj.weight")]
284
+ slot = "v"
285
+
286
+ bucket = self._pending_kv_proj.setdefault(base, {})
287
+ bucket[slot] = tensor
288
+ if "k" in bucket and "v" in bucket:
289
+ kv_name = f"{base}kv_proj.weight"
290
+ weights[kv_name] = torch.cat([bucket["k"], bucket["v"]], dim=0)
291
+ del self._pending_kv_proj[base]
292
+
293
+ def _split_qkv(
294
+ self,
295
+ cortex_name: str,
296
+ tensor: torch.Tensor,
297
+ weights: Dict[str, torch.Tensor],
298
+ ):
299
+ """将合并的 QKV 投影分割为独立的 Q/K/V 权重。"""
300
+ # 推断 Q/K/V 的维度
301
+ hidden_size = self.config.hidden_size
302
+ num_heads = self.config.num_heads
303
+ num_kv_heads = self.config.num_kv_heads
304
+ head_dim = hidden_size // num_heads
305
+
306
+ q_size = num_heads * head_dim
307
+ k_size = num_kv_heads * head_dim
308
+ v_size = num_kv_heads * head_dim
309
+
310
+ q_name = cortex_name.replace("qkv_proj", "q_proj")
311
+ k_name = cortex_name.replace("qkv_proj", "k_proj")
312
+ v_name = cortex_name.replace("qkv_proj", "v_proj")
313
+
314
+ if tensor.shape[0] == q_size + k_size + v_size:
315
+ q, k, v = tensor.split([q_size, k_size, v_size], dim=0)
316
+ weights[q_name] = q
317
+ weights[k_name] = k
318
+ weights[v_name] = v
319
+ else:
320
+ # 均分(如 Baichuan W_pack,三等份)
321
+ q, k, v = tensor.chunk(3, dim=0)
322
+ weights[q_name] = q
323
+ weights[k_name] = k
324
+ weights[v_name] = v
325
+
326
+ def _split_gate_up(
327
+ self,
328
+ cortex_name: str,
329
+ tensor: torch.Tensor,
330
+ weights: Dict[str, torch.Tensor],
331
+ ):
332
+ """将合并的 gate+up 投影分割。"""
333
+ gate, up = tensor.chunk(2, dim=0)
334
+ gate_name = cortex_name.replace("gate_up_proj", "gate_proj")
335
+ up_name = cortex_name.replace("gate_up_proj", "up_proj")
336
+ weights[gate_name] = gate
337
+ weights[up_name] = up
338
+
339
+ def _convert_norm_params(
340
+ self,
341
+ weights: Dict[str, torch.Tensor],
342
+ ) -> Dict[str, torch.Tensor]:
343
+ """转换归一化参数(LayerNorm bias 移除 → RMSNorm 仅 weight)。"""
344
+ from .model_registry import MODEL_REGISTRY
345
+
346
+ model_info = MODEL_REGISTRY.get(self.model_type)
347
+ if model_info and model_info.default_norm_type == "layernorm":
348
+ # CortexNet 使用 RMSNorm(仅 weight,无 bias)
349
+ # 保留 weight,丢弃 bias
350
+ keys_to_remove = [k for k in weights if k.endswith(".bias") and "norm" in k]
351
+ for k in keys_to_remove:
352
+ logger.debug(f"Removing LayerNorm bias (CortexNet uses RMSNorm): {k}")
353
+ del weights[k]
354
+
355
+ return weights
356
+
357
+ def _expand_gqa_weights(
358
+ self,
359
+ weights: Dict[str, torch.Tensor],
360
+ ) -> Dict[str, torch.Tensor]:
361
+ """扩展 GQA 的 K/V 权重到全头数。
362
+
363
+ 当 num_kv_heads < num_heads 时,需要复制 K/V 投影使其匹配。
364
+ """
365
+ if not getattr(self.config, "expand_gqa_weights", True):
366
+ return weights
367
+
368
+ if self.config.num_kv_heads >= self.config.num_heads:
369
+ return weights
370
+
371
+ repeat_factor = self.config.num_heads // self.config.num_kv_heads
372
+
373
+ for key in list(weights.keys()):
374
+ if ("k_proj.weight" in key or "v_proj.weight" in key):
375
+ tensor = weights[key]
376
+ head_dim = self.config.hidden_size // self.config.num_heads
377
+ expected_full_size = self.config.num_heads * head_dim
378
+
379
+ if tensor.shape[0] < expected_full_size:
380
+ # GQA: 每个 KV 头复制 repeat_factor 次
381
+ weights[key] = tensor.repeat(repeat_factor, 1)
382
+ logger.debug(
383
+ f"GQA weight expansion: {key} "
384
+ f"{tensor.shape} -> {weights[key].shape}"
385
+ )
386
+
387
+ return weights
388
+
389
+ def get_unmapped_weights(self) -> List[str]:
390
+ """返回未映射的权重名称列表。"""
391
+ return self._unmapped.copy()
392
+
393
+ def verify_mapping(
394
+ self,
395
+ cortex_state_dict: Dict[str, torch.Tensor],
396
+ model: nn.Module,
397
+ ) -> Dict[str, str]:
398
+ """验证映射后的权重是否可以正确加载到 CortexNet 模型。
399
+
400
+ Returns:
401
+ {"matched": [...], "missing": [...], "unexpected": [...]}
402
+ """
403
+ model_keys = set(model.state_dict().keys())
404
+ mapped_keys = set(cortex_state_dict.keys())
405
+
406
+ matched = model_keys & mapped_keys
407
+ missing = model_keys - mapped_keys
408
+ unexpected = mapped_keys - model_keys
409
+
410
+ return {
411
+ "matched": sorted(matched),
412
+ "missing": sorted(missing),
413
+ "unexpected": sorted(unexpected),
414
+ "match_ratio": len(matched) / max(len(model_keys), 1),
415
+ }
@@ -0,0 +1,195 @@
1
+ """
2
+ 对抗防御系统 (Adversarial Defense System)
3
+
4
+ 核心创新:
5
+ 多层防御机制保护模型免受对抗攻击,确保在恶意输入下仍能
6
+ 可靠运行。同时提供对抗训练工具增强模型鲁棒性。
7
+
8
+ ┌─────────────────────────────────────────────────────────────┐
9
+ │ 三层防御架构 │
10
+ ├─────────────────────────────────────────────────────────────┤
11
+ │ │
12
+ │ 第1层: 输入防御 ──── 异常检测 + 随机平滑 │
13
+ │ │ │
14
+ │ 第2层: 特征防御 ──── 特征去噪 + 鲁棒归一化 │
15
+ │ │ │
16
+ │ 第3层: 输出防御 ──── 置信度校准 + 一致性检查 │
17
+ │ │
18
+ │ 训练工具: │
19
+ │ ● FGSM 对抗训练 — 快速梯度符号攻击 │
20
+ │ ● PGD 对抗训练 — 投影梯度下降攻击 │
21
+ │ ● 随机平滑 — 概率性鲁棒性保证 │
22
+ └─────────────────────────────────────────────────────────────┘
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+
29
+ class InputShield(nn.Module):
30
+ """输入防护层:检测并中和异常输入。
31
+
32
+ 1. 异常检测:识别偏离正常分布的输入
33
+ 2. 随机平滑:通过添加校准噪声提供概率鲁棒性
34
+ 3. 自适应去噪:根据异常程度调整去噪强度
35
+ """
36
+
37
+ def __init__(self, d_model: int):
38
+ super().__init__()
39
+ self.anomaly_detector = nn.Sequential(
40
+ nn.Linear(d_model, d_model // 4),
41
+ nn.GELU(),
42
+ nn.Linear(d_model // 4, 1),
43
+ nn.Sigmoid(),
44
+ )
45
+ # 可学习的噪声尺度
46
+ self.noise_scale = nn.Parameter(torch.tensor(0.01))
47
+ # 去噪网络
48
+ self.denoiser = nn.Sequential(
49
+ nn.Linear(d_model, d_model),
50
+ nn.GELU(),
51
+ nn.Linear(d_model, d_model),
52
+ )
53
+ nn.init.zeros_(self.denoiser[-1].weight)
54
+ nn.init.zeros_(self.denoiser[-1].bias)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ anomaly_score = self.anomaly_detector(x) # (B, L, 1)
58
+
59
+ # 随机平滑 (推理时)
60
+ if not self.training:
61
+ noise = torch.randn_like(x) * self.noise_scale.abs()
62
+ x = x + noise
63
+
64
+ # 自适应去噪: 异常越大,去噪越强
65
+ correction = self.denoiser(x)
66
+ x = x + anomaly_score * correction
67
+
68
+ return x
69
+
70
+
71
+ class FeatureShield(nn.Module):
72
+ """特征防护层:对中间特征进行鲁棒性增强。
73
+
74
+ 1. 鲁棒归一化:比标准 LayerNorm 更抗扰动
75
+ 2. 特征裁剪:限制特征值范围防止极端值
76
+ 3. 频谱正则化:限制特征的频谱范数
77
+ """
78
+
79
+ def __init__(self, d_model: int, clip_value: float = 10.0):
80
+ super().__init__()
81
+ self.clip_value = clip_value
82
+ self.robust_norm = nn.LayerNorm(d_model)
83
+ self.scale = nn.Parameter(torch.ones(d_model))
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ # 特征裁剪
87
+ x = x.clamp(-self.clip_value, self.clip_value)
88
+ # 鲁棒归一化
89
+ x = self.robust_norm(x) * self.scale
90
+ return x
91
+
92
+
93
+ class OutputShield(nn.Module):
94
+ """输出防护层:校准输出置信度。"""
95
+
96
+ def __init__(self, d_model: int):
97
+ super().__init__()
98
+ self.confidence_calibrator = nn.Sequential(
99
+ nn.Linear(d_model, d_model // 4),
100
+ nn.GELU(),
101
+ nn.Linear(d_model // 4, 1),
102
+ nn.Sigmoid(),
103
+ )
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ confidence = self.confidence_calibrator(x)
107
+ return x * confidence
108
+
109
+
110
+ class AdversarialShield(nn.Module):
111
+ """对抗防御系统:三层防御的统一接口。
112
+
113
+ Args:
114
+ d_model: 模型维度
115
+ enable_input_shield: 是否启用输入防护
116
+ enable_feature_shield: 是否启用特征防护
117
+ enable_output_shield: 是否启用输出防护
118
+ """
119
+
120
+ def __init__(self, d_model: int, enable_input: bool = True,
121
+ enable_feature: bool = True, enable_output: bool = True):
122
+ super().__init__()
123
+ self.input_shield = InputShield(d_model) if enable_input else nn.Identity()
124
+ self.feature_shield = FeatureShield(d_model) if enable_feature else nn.Identity()
125
+ self.output_shield = OutputShield(d_model) if enable_output else nn.Identity()
126
+
127
+ def defend_input(self, x: torch.Tensor) -> torch.Tensor:
128
+ return self.input_shield(x)
129
+
130
+ def defend_features(self, x: torch.Tensor) -> torch.Tensor:
131
+ return self.feature_shield(x)
132
+
133
+ def defend_output(self, x: torch.Tensor) -> torch.Tensor:
134
+ return self.output_shield(x)
135
+
136
+
137
+ class AdversarialTrainer:
138
+ """对抗训练工具:通过模拟攻击增强模型鲁棒性。
139
+
140
+ 支持 FGSM 和 PGD 两种攻击方式。
141
+
142
+ 使用方式:
143
+ adv_trainer = AdversarialTrainer(model)
144
+ for batch in dataloader:
145
+ loss = adv_trainer.adversarial_step(batch, optimizer)
146
+ """
147
+
148
+ def __init__(self, model: nn.Module, epsilon: float = 0.01,
149
+ attack_type: str = "fgsm", pgd_steps: int = 3,
150
+ use_amp: bool = False):
151
+ self.model = model
152
+ self.epsilon = epsilon
153
+ self.attack_type = attack_type
154
+ self.pgd_steps = pgd_steps
155
+ self.use_amp = use_amp and torch.cuda.is_available()
156
+
157
+ @torch.enable_grad()
158
+ def generate_adversarial(
159
+ self, embeddings: torch.Tensor, labels: torch.Tensor
160
+ ) -> torch.Tensor:
161
+ """生成对抗样本。"""
162
+ emb = embeddings.clone().detach().requires_grad_(True)
163
+
164
+ if self.attack_type == "fgsm":
165
+ return self._fgsm(emb, labels)
166
+ else:
167
+ return self._pgd(emb, labels)
168
+
169
+ def _fgsm(self, emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
170
+ """FGSM: 单步快速攻击(支持 AMP)。"""
171
+ with torch.cuda.amp.autocast(enabled=self.use_amp):
172
+ output = self.model.forward_from_embeddings(emb, labels=labels)
173
+ loss = output["loss"]
174
+ loss.backward(retain_graph=True)
175
+ perturbation = self.epsilon * emb.grad.sign()
176
+ return (emb + perturbation).detach()
177
+
178
+ def _pgd(self, emb: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
179
+ """PGD: 多步迭代攻击(支持 AMP)。"""
180
+ perturbed = emb.clone()
181
+ step_size = self.epsilon / self.pgd_steps * 2
182
+
183
+ for _ in range(self.pgd_steps):
184
+ perturbed = perturbed.detach().requires_grad_(True)
185
+ with torch.cuda.amp.autocast(enabled=self.use_amp):
186
+ output = self.model.forward_from_embeddings(perturbed, labels=labels)
187
+ loss = output["loss"]
188
+ loss.backward(retain_graph=True)
189
+ perturbation = step_size * perturbed.grad.sign()
190
+ perturbed = perturbed + perturbation
191
+ # 投影到 epsilon 球内
192
+ delta = (perturbed - emb).clamp(-self.epsilon, self.epsilon)
193
+ perturbed = emb + delta
194
+
195
+ return perturbed.detach()