cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexBlockLite — 精简版 CortexNet 块
|
|
3
|
+
|
|
4
|
+
仅保留三条核心路径 + SwiGLU FFN,参数量与 Transformer 相当:
|
|
5
|
+
1. Multi-Scale SSM — O(1) 解码,长序列优势
|
|
6
|
+
2. Selective Attention — 兼容预训练权重
|
|
7
|
+
3. Synaptic Memory — 快速权重系统
|
|
8
|
+
|
|
9
|
+
相比 CortexBlockV3 移除了:
|
|
10
|
+
- MultiAgentSystem (23.7%)
|
|
11
|
+
- CollaborativeMoE → 直接 SwiGLU FFN
|
|
12
|
+
- CausalReasoning (8.0%)
|
|
13
|
+
- TaskAdaptiveController (6.3%)
|
|
14
|
+
- GraphReasoning (5.2%)
|
|
15
|
+
- MetaLearningAdapter (4.2%)
|
|
16
|
+
- AdversarialShield (2.1%)
|
|
17
|
+
|
|
18
|
+
典型参数节省:60-70%
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Optional, Tuple, Any
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
import torch.utils.checkpoint as checkpoint
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from .blocks import RMSNorm
|
|
32
|
+
from .config import CortexNetConfig
|
|
33
|
+
from .ssm import MultiScaleSSM
|
|
34
|
+
from .attention import SelectiveSparseAttention
|
|
35
|
+
from .memory import SynapticMemory
|
|
36
|
+
except ImportError:
|
|
37
|
+
from cortexnet.blocks import RMSNorm
|
|
38
|
+
from cortexnet.config import CortexNetConfig
|
|
39
|
+
from cortexnet.ssm import MultiScaleSSM
|
|
40
|
+
from cortexnet.attention import SelectiveSparseAttention
|
|
41
|
+
from cortexnet.memory import SynapticMemory
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SwiGLU_FFN(nn.Module):
|
|
45
|
+
"""SwiGLU 前馈网络(与 LLaMA/Qwen 兼容)。
|
|
46
|
+
|
|
47
|
+
参数量 = 3 × d_model × intermediate_size(无 bias)。
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, d_model: int, intermediate_size: int):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
|
|
53
|
+
self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
|
|
54
|
+
self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class CortexBlockLite(nn.Module):
|
|
61
|
+
"""精简 CortexNet 块:SSM + Attention + Memory + SwiGLU FFN。
|
|
62
|
+
|
|
63
|
+
核心设计:
|
|
64
|
+
- 三条并行路径(SSM, Attention, Memory)通过自适应融合门控合并
|
|
65
|
+
- SwiGLU FFN 替代重量级 MoE(参数减少 80%+)
|
|
66
|
+
- 支持 KV cache 推理和梯度检查点
|
|
67
|
+
- SSM 优先模式:长序列生成时可跳过 Attention
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
config: CortexNetConfig 配置对象
|
|
71
|
+
layer_idx: 当前层索引(用于 RoPE 偏移计算)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, config: CortexNetConfig, layer_idx: int = 0):
|
|
75
|
+
super().__init__()
|
|
76
|
+
D = config.hidden_size
|
|
77
|
+
self.layer_idx = layer_idx
|
|
78
|
+
self.use_checkpoint = config.use_gradient_checkpointing
|
|
79
|
+
|
|
80
|
+
# ═══ 归一化 ═══
|
|
81
|
+
self.norm1 = RMSNorm(D, eps=config.norm_eps)
|
|
82
|
+
self.norm2 = RMSNorm(D, eps=config.norm_eps)
|
|
83
|
+
|
|
84
|
+
# ═══ 路径 1: 低秩 SSM(O(1) 解码,O(d×rank) 参数) ═══
|
|
85
|
+
# Lite 模式:SSM 在降维空间运行,大幅减少参数量
|
|
86
|
+
ssm_rank = getattr(config, 'compat_ssm_rank', max(64, D // 4))
|
|
87
|
+
ssm_rank = min(ssm_rank, D) # 不超过 D
|
|
88
|
+
self.ssm_rank = ssm_rank
|
|
89
|
+
self.ssm_down = nn.Linear(D, ssm_rank, bias=False) if ssm_rank < D else nn.Identity()
|
|
90
|
+
self.ssm = MultiScaleSSM(
|
|
91
|
+
d_model=ssm_rank,
|
|
92
|
+
num_scales=1,
|
|
93
|
+
state_size=config.ssm_state_size,
|
|
94
|
+
expand_factor=1,
|
|
95
|
+
)
|
|
96
|
+
self.ssm_up = nn.Linear(ssm_rank, D, bias=False) if ssm_rank < D else nn.Identity()
|
|
97
|
+
|
|
98
|
+
# ═══ 路径 2: Selective Sparse Attention ═══
|
|
99
|
+
self.attention = SelectiveSparseAttention(
|
|
100
|
+
d_model=D,
|
|
101
|
+
num_heads=config.num_heads,
|
|
102
|
+
num_kv_heads=config.num_kv_heads,
|
|
103
|
+
top_k_ratio=config.top_k_ratio,
|
|
104
|
+
k_mode=config.attention_k_mode,
|
|
105
|
+
max_seq_len=config.max_seq_len,
|
|
106
|
+
rope_theta=config.rope_theta,
|
|
107
|
+
dropout=config.dropout,
|
|
108
|
+
sliding_window_size=config.sliding_window_size,
|
|
109
|
+
use_qk_norm=getattr(config, "use_qk_norm", False),
|
|
110
|
+
norm_eps=config.norm_eps,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# ═══ 路径 3: Synaptic Memory ═══
|
|
114
|
+
# 大模型用较小 memory_dim 避免参数膨胀
|
|
115
|
+
mem_dim = min(config.memory_dim, max(32, D // 16))
|
|
116
|
+
self.memory = SynapticMemory(
|
|
117
|
+
d_model=D,
|
|
118
|
+
memory_dim=mem_dim,
|
|
119
|
+
decay_init=config.memory_decay_init,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# ═══ 轻量级三路径融合(3 个标量权重) ═══
|
|
123
|
+
# 替代 AdaptiveFusionGate 的全连接网络,参数从 O(d²) 降至 3
|
|
124
|
+
self.fusion_weights = nn.Parameter(torch.ones(3) / 3)
|
|
125
|
+
|
|
126
|
+
# ═══ SwiGLU FFN(替代 MoE) ═══
|
|
127
|
+
ff_dim = config.intermediate_size if config.intermediate_size > 0 else config.expert_ff_dim
|
|
128
|
+
self.ffn = SwiGLU_FFN(D, ff_dim)
|
|
129
|
+
|
|
130
|
+
# ═══ Dropout ═══
|
|
131
|
+
self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()
|
|
132
|
+
|
|
133
|
+
def forward(
|
|
134
|
+
self,
|
|
135
|
+
x: torch.Tensor,
|
|
136
|
+
past_cache: Optional[Tuple[Any, Any, Any]] = None,
|
|
137
|
+
use_cache: bool = False,
|
|
138
|
+
ssm_only: bool = False,
|
|
139
|
+
) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Args:
|
|
142
|
+
x: (B, L, D)
|
|
143
|
+
past_cache: (past_attn_kv, past_ssm_state, past_mem_state)
|
|
144
|
+
use_cache: 是否返回缓存
|
|
145
|
+
ssm_only: True → 纯 SSM 模式(跳过 Attention,极速解码)
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
use_cache=False: output (B, L, D)
|
|
149
|
+
use_cache=True: (output, (attn_cache, ssm_cache, mem_cache))
|
|
150
|
+
"""
|
|
151
|
+
residual = x
|
|
152
|
+
x_norm = self.norm1(x)
|
|
153
|
+
|
|
154
|
+
# 解包缓存
|
|
155
|
+
past_attn = past_ssm = past_mem = None
|
|
156
|
+
if past_cache is not None:
|
|
157
|
+
past_attn, past_ssm, past_mem = past_cache
|
|
158
|
+
|
|
159
|
+
# ═══ SSM 路径(低秩空间) ═══
|
|
160
|
+
ssm_input = self.ssm_down(x_norm)
|
|
161
|
+
if use_cache:
|
|
162
|
+
ssm_out_lr, new_ssm_cache = self.ssm(
|
|
163
|
+
ssm_input, past_state=past_ssm, use_cache=True,
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
ssm_out_lr = self._maybe_checkpoint(self.ssm, ssm_input)
|
|
167
|
+
new_ssm_cache = None
|
|
168
|
+
ssm_out = self.ssm_up(ssm_out_lr)
|
|
169
|
+
|
|
170
|
+
if ssm_only:
|
|
171
|
+
# 纯 SSM 模式:跳过 Attention 和 Memory
|
|
172
|
+
x = residual + self.dropout(ssm_out)
|
|
173
|
+
else:
|
|
174
|
+
# ═══ Attention 路径 ═══
|
|
175
|
+
if use_cache:
|
|
176
|
+
attn_out, new_attn_cache = self.attention(
|
|
177
|
+
x_norm, past_key_value=past_attn, use_cache=True,
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
attn_out = self._maybe_checkpoint(self.attention, x_norm)
|
|
181
|
+
new_attn_cache = None
|
|
182
|
+
|
|
183
|
+
# ═══ Memory 路径 ═══
|
|
184
|
+
if use_cache:
|
|
185
|
+
past_memory = past_mem[0] if isinstance(past_mem, tuple) else past_mem
|
|
186
|
+
past_z = past_mem[1] if isinstance(past_mem, tuple) else None
|
|
187
|
+
mem_out, new_mem, new_z = self.memory(
|
|
188
|
+
x_norm, past_memory=past_memory, past_z=past_z, use_cache=True,
|
|
189
|
+
)
|
|
190
|
+
new_mem_cache = (new_mem, new_z)
|
|
191
|
+
else:
|
|
192
|
+
mem_out = self._maybe_checkpoint(self.memory, x_norm)
|
|
193
|
+
new_mem_cache = None
|
|
194
|
+
|
|
195
|
+
# ═══ 轻量融合 ═══
|
|
196
|
+
w = torch.softmax(self.fusion_weights, dim=0)
|
|
197
|
+
fused = w[0] * ssm_out + w[1] * attn_out + w[2] * mem_out
|
|
198
|
+
x = residual + self.dropout(fused)
|
|
199
|
+
|
|
200
|
+
# ═══ FFN ═══
|
|
201
|
+
residual = x
|
|
202
|
+
x = residual + self.dropout(self.ffn(self.norm2(x)))
|
|
203
|
+
|
|
204
|
+
if use_cache:
|
|
205
|
+
cache = (
|
|
206
|
+
new_attn_cache if not ssm_only else past_attn,
|
|
207
|
+
new_ssm_cache,
|
|
208
|
+
new_mem_cache if not ssm_only else past_mem,
|
|
209
|
+
)
|
|
210
|
+
return x, cache
|
|
211
|
+
return x
|
|
212
|
+
|
|
213
|
+
def _maybe_checkpoint(self, module, *args):
|
|
214
|
+
"""条件梯度检查点。"""
|
|
215
|
+
if self.use_checkpoint and self.training:
|
|
216
|
+
return checkpoint.checkpoint(module, *args, use_reentrant=False)
|
|
217
|
+
return module(*args)
|
|
218
|
+
|
|
219
|
+
def get_aux_loss(self) -> torch.Tensor:
|
|
220
|
+
"""Lite 模式无 MoE,返回零损失。"""
|
|
221
|
+
return self.fusion_weights.new_zeros(())
|
cortexnet/distributed.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 分布式训练支持 (Distributed Training)
|
|
3
|
+
|
|
4
|
+
提供 FSDP (Fully Sharded Data Parallel) 策略和基础分布式工具,
|
|
5
|
+
为大规模训练做准备。
|
|
6
|
+
|
|
7
|
+
用法:
|
|
8
|
+
from cortexnet.distributed import setup_distributed, wrap_fsdp, cleanup_distributed
|
|
9
|
+
|
|
10
|
+
setup_distributed()
|
|
11
|
+
model = CortexNet(config)
|
|
12
|
+
model = wrap_fsdp(model, config)
|
|
13
|
+
# ... training loop ...
|
|
14
|
+
cleanup_distributed()
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
from typing import Optional, Any
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def setup_distributed(backend: str = "nccl") -> bool:
|
|
30
|
+
"""初始化分布式训练环境。
|
|
31
|
+
|
|
32
|
+
自动检测是否在分布式环境中运行(通过环境变量),
|
|
33
|
+
如果是则初始化 process group。
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
backend: 通信后端 ("nccl" for GPU, "gloo" for CPU)
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
True 如果成功初始化分布式环境
|
|
40
|
+
"""
|
|
41
|
+
if torch.distributed.is_initialized():
|
|
42
|
+
logger.info("Distributed already initialized")
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
# 检查环境变量
|
|
46
|
+
rank = os.environ.get("RANK")
|
|
47
|
+
world_size = os.environ.get("WORLD_SIZE")
|
|
48
|
+
local_rank = os.environ.get("LOCAL_RANK")
|
|
49
|
+
|
|
50
|
+
if rank is None or world_size is None:
|
|
51
|
+
logger.info("Not in distributed environment, skipping initialization")
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
rank = int(rank)
|
|
55
|
+
world_size = int(world_size)
|
|
56
|
+
local_rank = int(local_rank) if local_rank is not None else rank
|
|
57
|
+
|
|
58
|
+
# 自动选择后端
|
|
59
|
+
if backend == "nccl" and not torch.cuda.is_available():
|
|
60
|
+
backend = "gloo"
|
|
61
|
+
logger.info("CUDA not available, falling back to gloo backend")
|
|
62
|
+
|
|
63
|
+
# 设置 GPU
|
|
64
|
+
if torch.cuda.is_available():
|
|
65
|
+
torch.cuda.set_device(local_rank)
|
|
66
|
+
|
|
67
|
+
torch.distributed.init_process_group(
|
|
68
|
+
backend=backend,
|
|
69
|
+
rank=rank,
|
|
70
|
+
world_size=world_size,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger.info(
|
|
74
|
+
f"Distributed initialized: rank={rank}/{world_size}, "
|
|
75
|
+
f"local_rank={local_rank}, backend={backend}"
|
|
76
|
+
)
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def cleanup_distributed():
|
|
81
|
+
"""清理分布式训练环境。"""
|
|
82
|
+
if torch.distributed.is_initialized():
|
|
83
|
+
torch.distributed.destroy_process_group()
|
|
84
|
+
logger.info("Distributed process group destroyed")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_rank() -> int:
|
|
88
|
+
"""获取当前进程的 rank(非分布式返回 0)。"""
|
|
89
|
+
if torch.distributed.is_initialized():
|
|
90
|
+
return torch.distributed.get_rank()
|
|
91
|
+
return 0
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_world_size() -> int:
|
|
95
|
+
"""获取总进程数(非分布式返回 1)。"""
|
|
96
|
+
if torch.distributed.is_initialized():
|
|
97
|
+
return torch.distributed.get_world_size()
|
|
98
|
+
return 1
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def is_main_process() -> bool:
|
|
102
|
+
"""是否为主进程。"""
|
|
103
|
+
return get_rank() == 0
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def wrap_fsdp(
|
|
107
|
+
model: nn.Module,
|
|
108
|
+
config: Optional[Any] = None,
|
|
109
|
+
mixed_precision: Optional[str] = None,
|
|
110
|
+
activation_checkpointing: bool = False,
|
|
111
|
+
) -> nn.Module:
|
|
112
|
+
"""使用 FSDP 包装模型。
|
|
113
|
+
|
|
114
|
+
自动为 CortexNet 的不同组件选择合适的分片策略。
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model: CortexNet 模型实例
|
|
118
|
+
config: CortexNetConfig(可选,用于读取配置)
|
|
119
|
+
mixed_precision: "fp16", "bf16", 或 None
|
|
120
|
+
activation_checkpointing: 是否启用激活检查点
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
FSDP 包装后的模型
|
|
124
|
+
"""
|
|
125
|
+
if not torch.distributed.is_initialized():
|
|
126
|
+
logger.warning("Distributed not initialized, returning model as-is")
|
|
127
|
+
return model
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
from torch.distributed.fsdp import (
|
|
131
|
+
FullyShardedDataParallel as FSDP,
|
|
132
|
+
MixedPrecision,
|
|
133
|
+
ShardingStrategy,
|
|
134
|
+
)
|
|
135
|
+
from torch.distributed.fsdp.wrap import (
|
|
136
|
+
size_based_auto_wrap_policy,
|
|
137
|
+
)
|
|
138
|
+
except ImportError:
|
|
139
|
+
logger.error("FSDP not available in this PyTorch version")
|
|
140
|
+
return model
|
|
141
|
+
|
|
142
|
+
# 混合精度策略
|
|
143
|
+
mp_policy = None
|
|
144
|
+
if mixed_precision == "fp16":
|
|
145
|
+
mp_policy = MixedPrecision(
|
|
146
|
+
param_dtype=torch.float16,
|
|
147
|
+
reduce_dtype=torch.float16,
|
|
148
|
+
buffer_dtype=torch.float16,
|
|
149
|
+
)
|
|
150
|
+
elif mixed_precision == "bf16":
|
|
151
|
+
mp_policy = MixedPrecision(
|
|
152
|
+
param_dtype=torch.bfloat16,
|
|
153
|
+
reduce_dtype=torch.bfloat16,
|
|
154
|
+
buffer_dtype=torch.bfloat16,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# 自动分片策略:按参数数量决定分片粒度
|
|
158
|
+
min_params = 1_000_000 # 超过 1M 参数的子模块独立分片
|
|
159
|
+
auto_wrap_policy = size_based_auto_wrap_policy(min_num_params=min_params)
|
|
160
|
+
|
|
161
|
+
model = FSDP(
|
|
162
|
+
model,
|
|
163
|
+
auto_wrap_policy=auto_wrap_policy,
|
|
164
|
+
mixed_precision=mp_policy,
|
|
165
|
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
166
|
+
device_id=torch.cuda.current_device() if torch.cuda.is_available() else None,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# 激活检查点
|
|
170
|
+
if activation_checkpointing:
|
|
171
|
+
try:
|
|
172
|
+
from torch.distributed.fsdp import apply_activation_checkpointing
|
|
173
|
+
# 对所有 CortexBlock 启用激活检查点
|
|
174
|
+
apply_activation_checkpointing(model)
|
|
175
|
+
logger.info("Activation checkpointing enabled")
|
|
176
|
+
except ImportError:
|
|
177
|
+
logger.warning("Activation checkpointing not available")
|
|
178
|
+
|
|
179
|
+
logger.info(f"Model wrapped with FSDP (mixed_precision={mixed_precision})")
|
|
180
|
+
return model
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def wrap_ddp(
|
|
184
|
+
model: nn.Module,
|
|
185
|
+
find_unused_parameters: bool = False,
|
|
186
|
+
) -> nn.Module:
|
|
187
|
+
"""使用 DDP (DistributedDataParallel) 包装模型。
|
|
188
|
+
|
|
189
|
+
对于不需要 FSDP 级别分片的场景,使用更简单的 DDP。
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
model: 模型
|
|
193
|
+
find_unused_parameters: 是否查找未使用参数
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
DDP 包装后的模型
|
|
197
|
+
"""
|
|
198
|
+
if not torch.distributed.is_initialized():
|
|
199
|
+
logger.warning("Distributed not initialized, returning model as-is")
|
|
200
|
+
return model
|
|
201
|
+
|
|
202
|
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
203
|
+
if torch.cuda.is_available():
|
|
204
|
+
model = model.to(local_rank)
|
|
205
|
+
|
|
206
|
+
model = nn.parallel.DistributedDataParallel(
|
|
207
|
+
model,
|
|
208
|
+
device_ids=[local_rank] if torch.cuda.is_available() else None,
|
|
209
|
+
find_unused_parameters=find_unused_parameters,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
logger.info(f"Model wrapped with DDP (device={local_rank})")
|
|
213
|
+
return model
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""
|
|
2
|
+
图推理模块 (Graph Reasoning Module)
|
|
3
|
+
|
|
4
|
+
核心创新:
|
|
5
|
+
将 token 序列视为图结构,通过消息传递神经网络 (MPNN)
|
|
6
|
+
进行多步关系推理。与注意力机制的关键区别:
|
|
7
|
+
|
|
8
|
+
┌─────────────────────────────────────────────────────┐
|
|
9
|
+
│ 注意力 vs 图推理 │
|
|
10
|
+
├─────────────────────────────────────────────────────┤
|
|
11
|
+
│ 注意力: │
|
|
12
|
+
│ • 单步信息聚合 │
|
|
13
|
+
│ • 全局/稀疏的软关联 │
|
|
14
|
+
│ • 无显式关系建模 │
|
|
15
|
+
│ │
|
|
16
|
+
│ 图推理: │
|
|
17
|
+
│ • 多步迭代推理(信息逐步传播) │
|
|
18
|
+
│ • 结构化的邻域关系 │
|
|
19
|
+
│ • 显式边特征建模 token 间关系 │
|
|
20
|
+
│ • 门控线性单元保持推理记忆 │
|
|
21
|
+
└─────────────────────────────────────────────────────┘
|
|
22
|
+
|
|
23
|
+
通过 K 步消息传递,信息可以在图中传播 K 跳距离,
|
|
24
|
+
实现类似 "思维链" 的逐步推理能力。
|
|
25
|
+
|
|
26
|
+
复杂度: O(L · k · D) per iteration,其中 k = 邻居数
|
|
27
|
+
|
|
28
|
+
优化 (v3.2):
|
|
29
|
+
- 用门控线性单元 (GLU) 替代 GRUCell,完全消除顺序依赖,
|
|
30
|
+
所有节点可并行更新。GRU 的 (B*L, D) 独立 cell 调用本身
|
|
31
|
+
无序列依赖,但 GLU 更轻量且编译器友好
|
|
32
|
+
- 添加了 dropout 支持
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
import logging
|
|
36
|
+
import torch
|
|
37
|
+
import torch.nn as nn
|
|
38
|
+
import torch.nn.functional as F
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GraphReasoningModule(nn.Module):
|
|
44
|
+
"""图推理模块:基于消息传递的关系推理。
|
|
45
|
+
|
|
46
|
+
架构流程(每次迭代):
|
|
47
|
+
1. 构建稀疏图(局部 + 跨步邻居)
|
|
48
|
+
2. 计算边特征(节点对的关系表示)
|
|
49
|
+
3. 生成消息(边特征 × 邻居特征)
|
|
50
|
+
4. 聚合消息(加权求和)
|
|
51
|
+
5. 门控线性单元更新节点状态(替代 GRU,更易并行化)
|
|
52
|
+
|
|
53
|
+
多次迭代使信息在图中逐步传播,实现多跳推理。
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
d_model: 节点特征维度
|
|
57
|
+
num_neighbors: 每个节点的邻居数
|
|
58
|
+
num_iterations: 消息传递迭代次数
|
|
59
|
+
edge_dim: 边特征维度
|
|
60
|
+
dropout: Dropout 比率
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
d_model: int,
|
|
66
|
+
num_neighbors: int = 16,
|
|
67
|
+
num_iterations: int = 2,
|
|
68
|
+
edge_dim: int = 64,
|
|
69
|
+
dropout: float = 0.0,
|
|
70
|
+
):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.d_model = d_model
|
|
73
|
+
self.num_neighbors = num_neighbors
|
|
74
|
+
self.num_iterations = num_iterations
|
|
75
|
+
|
|
76
|
+
# 边特征计算:从节点对中提取关系
|
|
77
|
+
self.edge_encoder = nn.Sequential(
|
|
78
|
+
nn.Linear(d_model * 2, edge_dim),
|
|
79
|
+
nn.GELU(),
|
|
80
|
+
nn.Linear(edge_dim, edge_dim),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# 消息生成:边特征调制邻居信息
|
|
84
|
+
self.message_fn = nn.Sequential(
|
|
85
|
+
nn.Linear(d_model + edge_dim, d_model),
|
|
86
|
+
nn.GELU(),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# 消息注意力权重
|
|
90
|
+
self.attn_proj = nn.Linear(edge_dim, 1)
|
|
91
|
+
|
|
92
|
+
# 节点更新: 门控线性单元 (GLU)
|
|
93
|
+
# 替代 GRU,完全可并行,不需要顺序执行
|
|
94
|
+
self.node_update = nn.Linear(d_model * 2, d_model * 2) # 输出一半是gate,一半是value
|
|
95
|
+
self.update_norm = nn.LayerNorm(d_model)
|
|
96
|
+
|
|
97
|
+
# 输出投影
|
|
98
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
|
99
|
+
self.norm = nn.LayerNorm(d_model)
|
|
100
|
+
self.dropout = nn.Dropout(dropout)
|
|
101
|
+
|
|
102
|
+
def _build_neighbor_indices(
|
|
103
|
+
self, seq_len: int, device: torch.device
|
|
104
|
+
) -> torch.Tensor:
|
|
105
|
+
"""构建混合邻居索引:局部 + 跨步。
|
|
106
|
+
|
|
107
|
+
不使用 O(L²) 的全局相似度计算,而是使用:
|
|
108
|
+
- 局部邻居:捕获相邻 token 的关系
|
|
109
|
+
- 跨步邻居:捕获长距离 token 的关系
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
neighbors: (1, L, k) 邻居索引
|
|
113
|
+
"""
|
|
114
|
+
k = min(self.num_neighbors, seq_len - 1)
|
|
115
|
+
if k <= 0:
|
|
116
|
+
return torch.zeros(1, seq_len, 1, dtype=torch.long, device=device)
|
|
117
|
+
|
|
118
|
+
k_local = max(k // 2, 1)
|
|
119
|
+
k_stride = k - k_local
|
|
120
|
+
|
|
121
|
+
indices = torch.arange(seq_len, device=device)
|
|
122
|
+
|
|
123
|
+
# 局部邻居
|
|
124
|
+
local_offsets = torch.arange(
|
|
125
|
+
-(k_local // 2), k_local // 2 + 1, device=device
|
|
126
|
+
)
|
|
127
|
+
local_offsets = local_offsets[local_offsets != 0][:k_local]
|
|
128
|
+
if len(local_offsets) < k_local:
|
|
129
|
+
extra = torch.arange(1, k_local - len(local_offsets) + 1, device=device)
|
|
130
|
+
local_offsets = torch.cat([local_offsets, extra])
|
|
131
|
+
local_nb = (indices.unsqueeze(1) + local_offsets.unsqueeze(0)).clamp(
|
|
132
|
+
0, seq_len - 1
|
|
133
|
+
) # (L, k_local)
|
|
134
|
+
|
|
135
|
+
# 跨步邻居
|
|
136
|
+
if k_stride > 0 and seq_len > k_local + 1:
|
|
137
|
+
stride = max(1, seq_len // (k_stride + 1))
|
|
138
|
+
stride_positions = torch.arange(0, seq_len, stride, device=device)
|
|
139
|
+
if len(stride_positions) > k_stride:
|
|
140
|
+
stride_positions = stride_positions[:k_stride]
|
|
141
|
+
else:
|
|
142
|
+
pad = torch.zeros(
|
|
143
|
+
k_stride - len(stride_positions),
|
|
144
|
+
dtype=torch.long,
|
|
145
|
+
device=device,
|
|
146
|
+
)
|
|
147
|
+
stride_positions = torch.cat([stride_positions, pad])
|
|
148
|
+
stride_nb = stride_positions.unsqueeze(0).expand(
|
|
149
|
+
seq_len, -1
|
|
150
|
+
) # (L, k_stride)
|
|
151
|
+
neighbors = torch.cat([local_nb, stride_nb], dim=-1) # (L, k)
|
|
152
|
+
else:
|
|
153
|
+
neighbors = local_nb # (L, k_local)
|
|
154
|
+
|
|
155
|
+
return neighbors.unsqueeze(0) # (1, L, k)
|
|
156
|
+
|
|
157
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
"""
|
|
159
|
+
Args:
|
|
160
|
+
x: (batch, seq_len, d_model) 节点特征
|
|
161
|
+
Returns:
|
|
162
|
+
output: (batch, seq_len, d_model) 更新后的节点特征
|
|
163
|
+
"""
|
|
164
|
+
B, L, D = x.shape
|
|
165
|
+
|
|
166
|
+
if L <= 1:
|
|
167
|
+
# 单 token 时无法构建图,但仍通过投影和归一化保持表达力
|
|
168
|
+
return self.out_proj(self.norm(x))
|
|
169
|
+
|
|
170
|
+
# 构建邻居索引
|
|
171
|
+
neighbors = self._build_neighbor_indices(L, x.device) # (1, L, k)
|
|
172
|
+
k = neighbors.shape[-1]
|
|
173
|
+
neighbors = neighbors.expand(B, -1, -1) # (B, L, k)
|
|
174
|
+
|
|
175
|
+
h = x # 初始节点状态
|
|
176
|
+
|
|
177
|
+
for iteration in range(self.num_iterations):
|
|
178
|
+
# 收集邻居特征
|
|
179
|
+
nb_flat = neighbors.reshape(B, -1) # (B, L*k)
|
|
180
|
+
nb_expanded = nb_flat.unsqueeze(-1).expand(-1, -1, D) # (B, L*k, D)
|
|
181
|
+
h_neighbors = h.gather(1, nb_expanded).view(B, L, k, D) # (B, L, k, D)
|
|
182
|
+
|
|
183
|
+
# 计算边特征
|
|
184
|
+
h_expanded = h.unsqueeze(2).expand(-1, -1, k, -1) # (B, L, k, D)
|
|
185
|
+
edge_input = torch.cat([h_expanded, h_neighbors], dim=-1) # (B,L,k,2D)
|
|
186
|
+
edge_features = self.edge_encoder(edge_input) # (B, L, k, edge_dim)
|
|
187
|
+
|
|
188
|
+
# 生成消息
|
|
189
|
+
msg_input = torch.cat([h_neighbors, edge_features], dim=-1)
|
|
190
|
+
messages = self.message_fn(msg_input) # (B, L, k, D)
|
|
191
|
+
|
|
192
|
+
# 注意力加权聚合
|
|
193
|
+
attn_logits = self.attn_proj(edge_features).squeeze(-1) # (B, L, k)
|
|
194
|
+
attn_weights = F.softmax(attn_logits, dim=-1)
|
|
195
|
+
agg_messages = (messages * attn_weights.unsqueeze(-1)).sum(
|
|
196
|
+
dim=2
|
|
197
|
+
) # (B, L, D)
|
|
198
|
+
|
|
199
|
+
# 门控线性单元节点更新(全并行,替代 GRU)
|
|
200
|
+
update_input = torch.cat([h, agg_messages], dim=-1) # (B, L, 2D)
|
|
201
|
+
update_raw = self.node_update(update_input) # (B, L, 2D)
|
|
202
|
+
gate_val, value = update_raw.chunk(2, dim=-1) # 各 (B, L, D)
|
|
203
|
+
gate_val = torch.sigmoid(gate_val)
|
|
204
|
+
h = gate_val * self.update_norm(value) + (1 - gate_val) * h
|
|
205
|
+
|
|
206
|
+
output = self.out_proj(self.norm(h))
|
|
207
|
+
return self.dropout(output)
|