cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,221 @@
1
+ """
2
+ CortexBlockLite — 精简版 CortexNet 块
3
+
4
+ 仅保留三条核心路径 + SwiGLU FFN,参数量与 Transformer 相当:
5
+ 1. Multi-Scale SSM — O(1) 解码,长序列优势
6
+ 2. Selective Attention — 兼容预训练权重
7
+ 3. Synaptic Memory — 快速权重系统
8
+
9
+ 相比 CortexBlockV3 移除了:
10
+ - MultiAgentSystem (23.7%)
11
+ - CollaborativeMoE → 直接 SwiGLU FFN
12
+ - CausalReasoning (8.0%)
13
+ - TaskAdaptiveController (6.3%)
14
+ - GraphReasoning (5.2%)
15
+ - MetaLearningAdapter (4.2%)
16
+ - AdversarialShield (2.1%)
17
+
18
+ 典型参数节省:60-70%
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import Optional, Tuple, Any
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint as checkpoint
29
+
30
+ try:
31
+ from .blocks import RMSNorm
32
+ from .config import CortexNetConfig
33
+ from .ssm import MultiScaleSSM
34
+ from .attention import SelectiveSparseAttention
35
+ from .memory import SynapticMemory
36
+ except ImportError:
37
+ from cortexnet.blocks import RMSNorm
38
+ from cortexnet.config import CortexNetConfig
39
+ from cortexnet.ssm import MultiScaleSSM
40
+ from cortexnet.attention import SelectiveSparseAttention
41
+ from cortexnet.memory import SynapticMemory
42
+
43
+
44
+ class SwiGLU_FFN(nn.Module):
45
+ """SwiGLU 前馈网络(与 LLaMA/Qwen 兼容)。
46
+
47
+ 参数量 = 3 × d_model × intermediate_size(无 bias)。
48
+ """
49
+
50
+ def __init__(self, d_model: int, intermediate_size: int):
51
+ super().__init__()
52
+ self.gate_proj = nn.Linear(d_model, intermediate_size, bias=False)
53
+ self.up_proj = nn.Linear(d_model, intermediate_size, bias=False)
54
+ self.down_proj = nn.Linear(intermediate_size, d_model, bias=False)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
58
+
59
+
60
+ class CortexBlockLite(nn.Module):
61
+ """精简 CortexNet 块:SSM + Attention + Memory + SwiGLU FFN。
62
+
63
+ 核心设计:
64
+ - 三条并行路径(SSM, Attention, Memory)通过自适应融合门控合并
65
+ - SwiGLU FFN 替代重量级 MoE(参数减少 80%+)
66
+ - 支持 KV cache 推理和梯度检查点
67
+ - SSM 优先模式:长序列生成时可跳过 Attention
68
+
69
+ Args:
70
+ config: CortexNetConfig 配置对象
71
+ layer_idx: 当前层索引(用于 RoPE 偏移计算)
72
+ """
73
+
74
+ def __init__(self, config: CortexNetConfig, layer_idx: int = 0):
75
+ super().__init__()
76
+ D = config.hidden_size
77
+ self.layer_idx = layer_idx
78
+ self.use_checkpoint = config.use_gradient_checkpointing
79
+
80
+ # ═══ 归一化 ═══
81
+ self.norm1 = RMSNorm(D, eps=config.norm_eps)
82
+ self.norm2 = RMSNorm(D, eps=config.norm_eps)
83
+
84
+ # ═══ 路径 1: 低秩 SSM(O(1) 解码,O(d×rank) 参数) ═══
85
+ # Lite 模式:SSM 在降维空间运行,大幅减少参数量
86
+ ssm_rank = getattr(config, 'compat_ssm_rank', max(64, D // 4))
87
+ ssm_rank = min(ssm_rank, D) # 不超过 D
88
+ self.ssm_rank = ssm_rank
89
+ self.ssm_down = nn.Linear(D, ssm_rank, bias=False) if ssm_rank < D else nn.Identity()
90
+ self.ssm = MultiScaleSSM(
91
+ d_model=ssm_rank,
92
+ num_scales=1,
93
+ state_size=config.ssm_state_size,
94
+ expand_factor=1,
95
+ )
96
+ self.ssm_up = nn.Linear(ssm_rank, D, bias=False) if ssm_rank < D else nn.Identity()
97
+
98
+ # ═══ 路径 2: Selective Sparse Attention ═══
99
+ self.attention = SelectiveSparseAttention(
100
+ d_model=D,
101
+ num_heads=config.num_heads,
102
+ num_kv_heads=config.num_kv_heads,
103
+ top_k_ratio=config.top_k_ratio,
104
+ k_mode=config.attention_k_mode,
105
+ max_seq_len=config.max_seq_len,
106
+ rope_theta=config.rope_theta,
107
+ dropout=config.dropout,
108
+ sliding_window_size=config.sliding_window_size,
109
+ use_qk_norm=getattr(config, "use_qk_norm", False),
110
+ norm_eps=config.norm_eps,
111
+ )
112
+
113
+ # ═══ 路径 3: Synaptic Memory ═══
114
+ # 大模型用较小 memory_dim 避免参数膨胀
115
+ mem_dim = min(config.memory_dim, max(32, D // 16))
116
+ self.memory = SynapticMemory(
117
+ d_model=D,
118
+ memory_dim=mem_dim,
119
+ decay_init=config.memory_decay_init,
120
+ )
121
+
122
+ # ═══ 轻量级三路径融合(3 个标量权重) ═══
123
+ # 替代 AdaptiveFusionGate 的全连接网络,参数从 O(d²) 降至 3
124
+ self.fusion_weights = nn.Parameter(torch.ones(3) / 3)
125
+
126
+ # ═══ SwiGLU FFN(替代 MoE) ═══
127
+ ff_dim = config.intermediate_size if config.intermediate_size > 0 else config.expert_ff_dim
128
+ self.ffn = SwiGLU_FFN(D, ff_dim)
129
+
130
+ # ═══ Dropout ═══
131
+ self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else nn.Identity()
132
+
133
+ def forward(
134
+ self,
135
+ x: torch.Tensor,
136
+ past_cache: Optional[Tuple[Any, Any, Any]] = None,
137
+ use_cache: bool = False,
138
+ ssm_only: bool = False,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ x: (B, L, D)
143
+ past_cache: (past_attn_kv, past_ssm_state, past_mem_state)
144
+ use_cache: 是否返回缓存
145
+ ssm_only: True → 纯 SSM 模式(跳过 Attention,极速解码)
146
+
147
+ Returns:
148
+ use_cache=False: output (B, L, D)
149
+ use_cache=True: (output, (attn_cache, ssm_cache, mem_cache))
150
+ """
151
+ residual = x
152
+ x_norm = self.norm1(x)
153
+
154
+ # 解包缓存
155
+ past_attn = past_ssm = past_mem = None
156
+ if past_cache is not None:
157
+ past_attn, past_ssm, past_mem = past_cache
158
+
159
+ # ═══ SSM 路径(低秩空间) ═══
160
+ ssm_input = self.ssm_down(x_norm)
161
+ if use_cache:
162
+ ssm_out_lr, new_ssm_cache = self.ssm(
163
+ ssm_input, past_state=past_ssm, use_cache=True,
164
+ )
165
+ else:
166
+ ssm_out_lr = self._maybe_checkpoint(self.ssm, ssm_input)
167
+ new_ssm_cache = None
168
+ ssm_out = self.ssm_up(ssm_out_lr)
169
+
170
+ if ssm_only:
171
+ # 纯 SSM 模式:跳过 Attention 和 Memory
172
+ x = residual + self.dropout(ssm_out)
173
+ else:
174
+ # ═══ Attention 路径 ═══
175
+ if use_cache:
176
+ attn_out, new_attn_cache = self.attention(
177
+ x_norm, past_key_value=past_attn, use_cache=True,
178
+ )
179
+ else:
180
+ attn_out = self._maybe_checkpoint(self.attention, x_norm)
181
+ new_attn_cache = None
182
+
183
+ # ═══ Memory 路径 ═══
184
+ if use_cache:
185
+ past_memory = past_mem[0] if isinstance(past_mem, tuple) else past_mem
186
+ past_z = past_mem[1] if isinstance(past_mem, tuple) else None
187
+ mem_out, new_mem, new_z = self.memory(
188
+ x_norm, past_memory=past_memory, past_z=past_z, use_cache=True,
189
+ )
190
+ new_mem_cache = (new_mem, new_z)
191
+ else:
192
+ mem_out = self._maybe_checkpoint(self.memory, x_norm)
193
+ new_mem_cache = None
194
+
195
+ # ═══ 轻量融合 ═══
196
+ w = torch.softmax(self.fusion_weights, dim=0)
197
+ fused = w[0] * ssm_out + w[1] * attn_out + w[2] * mem_out
198
+ x = residual + self.dropout(fused)
199
+
200
+ # ═══ FFN ═══
201
+ residual = x
202
+ x = residual + self.dropout(self.ffn(self.norm2(x)))
203
+
204
+ if use_cache:
205
+ cache = (
206
+ new_attn_cache if not ssm_only else past_attn,
207
+ new_ssm_cache,
208
+ new_mem_cache if not ssm_only else past_mem,
209
+ )
210
+ return x, cache
211
+ return x
212
+
213
+ def _maybe_checkpoint(self, module, *args):
214
+ """条件梯度检查点。"""
215
+ if self.use_checkpoint and self.training:
216
+ return checkpoint.checkpoint(module, *args, use_reentrant=False)
217
+ return module(*args)
218
+
219
+ def get_aux_loss(self) -> torch.Tensor:
220
+ """Lite 模式无 MoE,返回零损失。"""
221
+ return self.fusion_weights.new_zeros(())
@@ -0,0 +1,213 @@
1
+ """
2
+ CortexNet 分布式训练支持 (Distributed Training)
3
+
4
+ 提供 FSDP (Fully Sharded Data Parallel) 策略和基础分布式工具,
5
+ 为大规模训练做准备。
6
+
7
+ 用法:
8
+ from cortexnet.distributed import setup_distributed, wrap_fsdp, cleanup_distributed
9
+
10
+ setup_distributed()
11
+ model = CortexNet(config)
12
+ model = wrap_fsdp(model, config)
13
+ # ... training loop ...
14
+ cleanup_distributed()
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ import os
21
+ from typing import Optional, Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def setup_distributed(backend: str = "nccl") -> bool:
30
+ """初始化分布式训练环境。
31
+
32
+ 自动检测是否在分布式环境中运行(通过环境变量),
33
+ 如果是则初始化 process group。
34
+
35
+ Args:
36
+ backend: 通信后端 ("nccl" for GPU, "gloo" for CPU)
37
+
38
+ Returns:
39
+ True 如果成功初始化分布式环境
40
+ """
41
+ if torch.distributed.is_initialized():
42
+ logger.info("Distributed already initialized")
43
+ return True
44
+
45
+ # 检查环境变量
46
+ rank = os.environ.get("RANK")
47
+ world_size = os.environ.get("WORLD_SIZE")
48
+ local_rank = os.environ.get("LOCAL_RANK")
49
+
50
+ if rank is None or world_size is None:
51
+ logger.info("Not in distributed environment, skipping initialization")
52
+ return False
53
+
54
+ rank = int(rank)
55
+ world_size = int(world_size)
56
+ local_rank = int(local_rank) if local_rank is not None else rank
57
+
58
+ # 自动选择后端
59
+ if backend == "nccl" and not torch.cuda.is_available():
60
+ backend = "gloo"
61
+ logger.info("CUDA not available, falling back to gloo backend")
62
+
63
+ # 设置 GPU
64
+ if torch.cuda.is_available():
65
+ torch.cuda.set_device(local_rank)
66
+
67
+ torch.distributed.init_process_group(
68
+ backend=backend,
69
+ rank=rank,
70
+ world_size=world_size,
71
+ )
72
+
73
+ logger.info(
74
+ f"Distributed initialized: rank={rank}/{world_size}, "
75
+ f"local_rank={local_rank}, backend={backend}"
76
+ )
77
+ return True
78
+
79
+
80
+ def cleanup_distributed():
81
+ """清理分布式训练环境。"""
82
+ if torch.distributed.is_initialized():
83
+ torch.distributed.destroy_process_group()
84
+ logger.info("Distributed process group destroyed")
85
+
86
+
87
+ def get_rank() -> int:
88
+ """获取当前进程的 rank(非分布式返回 0)。"""
89
+ if torch.distributed.is_initialized():
90
+ return torch.distributed.get_rank()
91
+ return 0
92
+
93
+
94
+ def get_world_size() -> int:
95
+ """获取总进程数(非分布式返回 1)。"""
96
+ if torch.distributed.is_initialized():
97
+ return torch.distributed.get_world_size()
98
+ return 1
99
+
100
+
101
+ def is_main_process() -> bool:
102
+ """是否为主进程。"""
103
+ return get_rank() == 0
104
+
105
+
106
+ def wrap_fsdp(
107
+ model: nn.Module,
108
+ config: Optional[Any] = None,
109
+ mixed_precision: Optional[str] = None,
110
+ activation_checkpointing: bool = False,
111
+ ) -> nn.Module:
112
+ """使用 FSDP 包装模型。
113
+
114
+ 自动为 CortexNet 的不同组件选择合适的分片策略。
115
+
116
+ Args:
117
+ model: CortexNet 模型实例
118
+ config: CortexNetConfig(可选,用于读取配置)
119
+ mixed_precision: "fp16", "bf16", 或 None
120
+ activation_checkpointing: 是否启用激活检查点
121
+
122
+ Returns:
123
+ FSDP 包装后的模型
124
+ """
125
+ if not torch.distributed.is_initialized():
126
+ logger.warning("Distributed not initialized, returning model as-is")
127
+ return model
128
+
129
+ try:
130
+ from torch.distributed.fsdp import (
131
+ FullyShardedDataParallel as FSDP,
132
+ MixedPrecision,
133
+ ShardingStrategy,
134
+ )
135
+ from torch.distributed.fsdp.wrap import (
136
+ size_based_auto_wrap_policy,
137
+ )
138
+ except ImportError:
139
+ logger.error("FSDP not available in this PyTorch version")
140
+ return model
141
+
142
+ # 混合精度策略
143
+ mp_policy = None
144
+ if mixed_precision == "fp16":
145
+ mp_policy = MixedPrecision(
146
+ param_dtype=torch.float16,
147
+ reduce_dtype=torch.float16,
148
+ buffer_dtype=torch.float16,
149
+ )
150
+ elif mixed_precision == "bf16":
151
+ mp_policy = MixedPrecision(
152
+ param_dtype=torch.bfloat16,
153
+ reduce_dtype=torch.bfloat16,
154
+ buffer_dtype=torch.bfloat16,
155
+ )
156
+
157
+ # 自动分片策略:按参数数量决定分片粒度
158
+ min_params = 1_000_000 # 超过 1M 参数的子模块独立分片
159
+ auto_wrap_policy = size_based_auto_wrap_policy(min_num_params=min_params)
160
+
161
+ model = FSDP(
162
+ model,
163
+ auto_wrap_policy=auto_wrap_policy,
164
+ mixed_precision=mp_policy,
165
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
166
+ device_id=torch.cuda.current_device() if torch.cuda.is_available() else None,
167
+ )
168
+
169
+ # 激活检查点
170
+ if activation_checkpointing:
171
+ try:
172
+ from torch.distributed.fsdp import apply_activation_checkpointing
173
+ # 对所有 CortexBlock 启用激活检查点
174
+ apply_activation_checkpointing(model)
175
+ logger.info("Activation checkpointing enabled")
176
+ except ImportError:
177
+ logger.warning("Activation checkpointing not available")
178
+
179
+ logger.info(f"Model wrapped with FSDP (mixed_precision={mixed_precision})")
180
+ return model
181
+
182
+
183
+ def wrap_ddp(
184
+ model: nn.Module,
185
+ find_unused_parameters: bool = False,
186
+ ) -> nn.Module:
187
+ """使用 DDP (DistributedDataParallel) 包装模型。
188
+
189
+ 对于不需要 FSDP 级别分片的场景,使用更简单的 DDP。
190
+
191
+ Args:
192
+ model: 模型
193
+ find_unused_parameters: 是否查找未使用参数
194
+
195
+ Returns:
196
+ DDP 包装后的模型
197
+ """
198
+ if not torch.distributed.is_initialized():
199
+ logger.warning("Distributed not initialized, returning model as-is")
200
+ return model
201
+
202
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
203
+ if torch.cuda.is_available():
204
+ model = model.to(local_rank)
205
+
206
+ model = nn.parallel.DistributedDataParallel(
207
+ model,
208
+ device_ids=[local_rank] if torch.cuda.is_available() else None,
209
+ find_unused_parameters=find_unused_parameters,
210
+ )
211
+
212
+ logger.info(f"Model wrapped with DDP (device={local_rank})")
213
+ return model
@@ -0,0 +1,207 @@
1
+ """
2
+ 图推理模块 (Graph Reasoning Module)
3
+
4
+ 核心创新:
5
+ 将 token 序列视为图结构,通过消息传递神经网络 (MPNN)
6
+ 进行多步关系推理。与注意力机制的关键区别:
7
+
8
+ ┌─────────────────────────────────────────────────────┐
9
+ │ 注意力 vs 图推理 │
10
+ ├─────────────────────────────────────────────────────┤
11
+ │ 注意力: │
12
+ │ • 单步信息聚合 │
13
+ │ • 全局/稀疏的软关联 │
14
+ │ • 无显式关系建模 │
15
+ │ │
16
+ │ 图推理: │
17
+ │ • 多步迭代推理(信息逐步传播) │
18
+ │ • 结构化的邻域关系 │
19
+ │ • 显式边特征建模 token 间关系 │
20
+ │ • 门控线性单元保持推理记忆 │
21
+ └─────────────────────────────────────────────────────┘
22
+
23
+ 通过 K 步消息传递,信息可以在图中传播 K 跳距离,
24
+ 实现类似 "思维链" 的逐步推理能力。
25
+
26
+ 复杂度: O(L · k · D) per iteration,其中 k = 邻居数
27
+
28
+ 优化 (v3.2):
29
+ - 用门控线性单元 (GLU) 替代 GRUCell,完全消除顺序依赖,
30
+ 所有节点可并行更新。GRU 的 (B*L, D) 独立 cell 调用本身
31
+ 无序列依赖,但 GLU 更轻量且编译器友好
32
+ - 添加了 dropout 支持
33
+ """
34
+
35
+ import logging
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class GraphReasoningModule(nn.Module):
44
+ """图推理模块:基于消息传递的关系推理。
45
+
46
+ 架构流程(每次迭代):
47
+ 1. 构建稀疏图(局部 + 跨步邻居)
48
+ 2. 计算边特征(节点对的关系表示)
49
+ 3. 生成消息(边特征 × 邻居特征)
50
+ 4. 聚合消息(加权求和)
51
+ 5. 门控线性单元更新节点状态(替代 GRU,更易并行化)
52
+
53
+ 多次迭代使信息在图中逐步传播,实现多跳推理。
54
+
55
+ Args:
56
+ d_model: 节点特征维度
57
+ num_neighbors: 每个节点的邻居数
58
+ num_iterations: 消息传递迭代次数
59
+ edge_dim: 边特征维度
60
+ dropout: Dropout 比率
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ d_model: int,
66
+ num_neighbors: int = 16,
67
+ num_iterations: int = 2,
68
+ edge_dim: int = 64,
69
+ dropout: float = 0.0,
70
+ ):
71
+ super().__init__()
72
+ self.d_model = d_model
73
+ self.num_neighbors = num_neighbors
74
+ self.num_iterations = num_iterations
75
+
76
+ # 边特征计算:从节点对中提取关系
77
+ self.edge_encoder = nn.Sequential(
78
+ nn.Linear(d_model * 2, edge_dim),
79
+ nn.GELU(),
80
+ nn.Linear(edge_dim, edge_dim),
81
+ )
82
+
83
+ # 消息生成:边特征调制邻居信息
84
+ self.message_fn = nn.Sequential(
85
+ nn.Linear(d_model + edge_dim, d_model),
86
+ nn.GELU(),
87
+ )
88
+
89
+ # 消息注意力权重
90
+ self.attn_proj = nn.Linear(edge_dim, 1)
91
+
92
+ # 节点更新: 门控线性单元 (GLU)
93
+ # 替代 GRU,完全可并行,不需要顺序执行
94
+ self.node_update = nn.Linear(d_model * 2, d_model * 2) # 输出一半是gate,一半是value
95
+ self.update_norm = nn.LayerNorm(d_model)
96
+
97
+ # 输出投影
98
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
99
+ self.norm = nn.LayerNorm(d_model)
100
+ self.dropout = nn.Dropout(dropout)
101
+
102
+ def _build_neighbor_indices(
103
+ self, seq_len: int, device: torch.device
104
+ ) -> torch.Tensor:
105
+ """构建混合邻居索引:局部 + 跨步。
106
+
107
+ 不使用 O(L²) 的全局相似度计算,而是使用:
108
+ - 局部邻居:捕获相邻 token 的关系
109
+ - 跨步邻居:捕获长距离 token 的关系
110
+
111
+ Returns:
112
+ neighbors: (1, L, k) 邻居索引
113
+ """
114
+ k = min(self.num_neighbors, seq_len - 1)
115
+ if k <= 0:
116
+ return torch.zeros(1, seq_len, 1, dtype=torch.long, device=device)
117
+
118
+ k_local = max(k // 2, 1)
119
+ k_stride = k - k_local
120
+
121
+ indices = torch.arange(seq_len, device=device)
122
+
123
+ # 局部邻居
124
+ local_offsets = torch.arange(
125
+ -(k_local // 2), k_local // 2 + 1, device=device
126
+ )
127
+ local_offsets = local_offsets[local_offsets != 0][:k_local]
128
+ if len(local_offsets) < k_local:
129
+ extra = torch.arange(1, k_local - len(local_offsets) + 1, device=device)
130
+ local_offsets = torch.cat([local_offsets, extra])
131
+ local_nb = (indices.unsqueeze(1) + local_offsets.unsqueeze(0)).clamp(
132
+ 0, seq_len - 1
133
+ ) # (L, k_local)
134
+
135
+ # 跨步邻居
136
+ if k_stride > 0 and seq_len > k_local + 1:
137
+ stride = max(1, seq_len // (k_stride + 1))
138
+ stride_positions = torch.arange(0, seq_len, stride, device=device)
139
+ if len(stride_positions) > k_stride:
140
+ stride_positions = stride_positions[:k_stride]
141
+ else:
142
+ pad = torch.zeros(
143
+ k_stride - len(stride_positions),
144
+ dtype=torch.long,
145
+ device=device,
146
+ )
147
+ stride_positions = torch.cat([stride_positions, pad])
148
+ stride_nb = stride_positions.unsqueeze(0).expand(
149
+ seq_len, -1
150
+ ) # (L, k_stride)
151
+ neighbors = torch.cat([local_nb, stride_nb], dim=-1) # (L, k)
152
+ else:
153
+ neighbors = local_nb # (L, k_local)
154
+
155
+ return neighbors.unsqueeze(0) # (1, L, k)
156
+
157
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Args:
160
+ x: (batch, seq_len, d_model) 节点特征
161
+ Returns:
162
+ output: (batch, seq_len, d_model) 更新后的节点特征
163
+ """
164
+ B, L, D = x.shape
165
+
166
+ if L <= 1:
167
+ # 单 token 时无法构建图,但仍通过投影和归一化保持表达力
168
+ return self.out_proj(self.norm(x))
169
+
170
+ # 构建邻居索引
171
+ neighbors = self._build_neighbor_indices(L, x.device) # (1, L, k)
172
+ k = neighbors.shape[-1]
173
+ neighbors = neighbors.expand(B, -1, -1) # (B, L, k)
174
+
175
+ h = x # 初始节点状态
176
+
177
+ for iteration in range(self.num_iterations):
178
+ # 收集邻居特征
179
+ nb_flat = neighbors.reshape(B, -1) # (B, L*k)
180
+ nb_expanded = nb_flat.unsqueeze(-1).expand(-1, -1, D) # (B, L*k, D)
181
+ h_neighbors = h.gather(1, nb_expanded).view(B, L, k, D) # (B, L, k, D)
182
+
183
+ # 计算边特征
184
+ h_expanded = h.unsqueeze(2).expand(-1, -1, k, -1) # (B, L, k, D)
185
+ edge_input = torch.cat([h_expanded, h_neighbors], dim=-1) # (B,L,k,2D)
186
+ edge_features = self.edge_encoder(edge_input) # (B, L, k, edge_dim)
187
+
188
+ # 生成消息
189
+ msg_input = torch.cat([h_neighbors, edge_features], dim=-1)
190
+ messages = self.message_fn(msg_input) # (B, L, k, D)
191
+
192
+ # 注意力加权聚合
193
+ attn_logits = self.attn_proj(edge_features).squeeze(-1) # (B, L, k)
194
+ attn_weights = F.softmax(attn_logits, dim=-1)
195
+ agg_messages = (messages * attn_weights.unsqueeze(-1)).sum(
196
+ dim=2
197
+ ) # (B, L, D)
198
+
199
+ # 门控线性单元节点更新(全并行,替代 GRU)
200
+ update_input = torch.cat([h, agg_messages], dim=-1) # (B, L, 2D)
201
+ update_raw = self.node_update(update_input) # (B, L, 2D)
202
+ gate_val, value = update_raw.chunk(2, dim=-1) # 各 (B, L, D)
203
+ gate_val = torch.sigmoid(gate_val)
204
+ h = gate_val * self.update_norm(value) + (1 - gate_val) * h
205
+
206
+ output = self.out_proj(self.norm(h))
207
+ return self.dropout(output)