cortexnet 3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortexnet/__init__.py +197 -0
- cortexnet/adapter/__init__.py +26 -0
- cortexnet/adapter/arch_adapter.py +209 -0
- cortexnet/adapter/calibrator.py +244 -0
- cortexnet/adapter/inference_adapter.py +272 -0
- cortexnet/adapter/model_registry.py +378 -0
- cortexnet/adapter/weight_adapter.py +415 -0
- cortexnet/adversarial.py +195 -0
- cortexnet/attention.py +520 -0
- cortexnet/blocks.py +682 -0
- cortexnet/cache.py +83 -0
- cortexnet/causal_reasoning.py +232 -0
- cortexnet/compat.py +245 -0
- cortexnet/config.py +234 -0
- cortexnet/continual_learning.py +256 -0
- cortexnet/cortex_block_lite.py +221 -0
- cortexnet/distributed.py +213 -0
- cortexnet/graph_reasoning.py +207 -0
- cortexnet/hierarchical_memory.py +360 -0
- cortexnet/interpretability.py +196 -0
- cortexnet/memory.py +179 -0
- cortexnet/meta_learning.py +187 -0
- cortexnet/model.py +1360 -0
- cortexnet/multi_agent.py +241 -0
- cortexnet/multimodal.py +278 -0
- cortexnet/ops/__init__.py +28 -0
- cortexnet/ops/device_manager.py +449 -0
- cortexnet/ops/npu_ops.py +243 -0
- cortexnet/quantization.py +496 -0
- cortexnet/routing.py +335 -0
- cortexnet/self_evolution.py +174 -0
- cortexnet/ssm.py +340 -0
- cortexnet/training_utils.py +204 -0
- cortexnet/transformer_baseline.py +157 -0
- cortexnet-3.2.1.dist-info/METADATA +114 -0
- cortexnet-3.2.1.dist-info/RECORD +39 -0
- cortexnet-3.2.1.dist-info/WHEEL +5 -0
- cortexnet-3.2.1.dist-info/licenses/LICENSE +201 -0
- cortexnet-3.2.1.dist-info/top_level.txt +1 -0
cortexnet/multi_agent.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""
|
|
2
|
+
多智能体协作系统 (Multi-Agent Collaboration System)
|
|
3
|
+
|
|
4
|
+
核心创新:
|
|
5
|
+
将单一模型扩展为多个专业化智能体的协作系统。
|
|
6
|
+
每个智能体拥有独特的 "认知风格",通过协调器融合集体智慧。
|
|
7
|
+
|
|
8
|
+
┌─────────────────────────────────────────────────────────────┐
|
|
9
|
+
│ 多智能体协作架构 │
|
|
10
|
+
├─────────────────────────────────────────────────────────────┤
|
|
11
|
+
│ │
|
|
12
|
+
│ 输入 ─────┬──► Agent 1 (逻辑推理型) ──┐ │
|
|
13
|
+
│ ├──► Agent 2 (模式识别型) ──┤ │
|
|
14
|
+
│ ├──► Agent 3 (创意联想型) ──┤──► 协调器 ──► 输出│
|
|
15
|
+
│ └──► Agent 4 (记忆检索型) ──┘ │
|
|
16
|
+
│ │
|
|
17
|
+
│ ● 每个 Agent 有独特的 "认知风格" (specialty embedding) │
|
|
18
|
+
│ ● 协调器根据输入内容动态分配各 Agent 的权重 │
|
|
19
|
+
│ ● Agent 间通过共享消息板交换关键信息 │
|
|
20
|
+
└─────────────────────────────────────────────────────────────┘
|
|
21
|
+
|
|
22
|
+
类比: 如同一个专家团队——不同专家从不同角度分析问题,
|
|
23
|
+
最终由一个主持人综合各方意见做出决策。
|
|
24
|
+
|
|
25
|
+
优化 (v3.2):
|
|
26
|
+
- SharedMessageBoard.write() 使用残差连接保持梯度流,
|
|
27
|
+
替代了原先 .data 原地赋值绕过 autograd 的问题
|
|
28
|
+
- BatchedAgents 将多个 Agent 的参数合并为批量矩阵乘法,
|
|
29
|
+
替代了原先的列表推导式串行执行
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import logging
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BatchedAgents(nn.Module):
|
|
41
|
+
"""批量化智能体:将多个 Agent 的 SwiGLU FFN 合并为单次矩阵乘法。
|
|
42
|
+
|
|
43
|
+
相比 ModuleList + for 循环,通过 (num_agents, d_model, d_ff) 的权重张量
|
|
44
|
+
一次 batched matmul 完成所有 Agent 的计算,无 Python 循环开销。
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
d_model: 输入/输出维度
|
|
48
|
+
num_agents: 智能体数量
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, d_model: int, num_agents: int = 4):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.num_agents = num_agents
|
|
54
|
+
d_ff = d_model * 2
|
|
55
|
+
|
|
56
|
+
# 每个 Agent 的独特认知风格
|
|
57
|
+
self.specialty = nn.Parameter(torch.randn(num_agents, 1, d_model) * 0.02)
|
|
58
|
+
|
|
59
|
+
# 合并所有 Agent 的 SwiGLU 权重: (num_agents, in, out)
|
|
60
|
+
self.gate_proj = nn.Parameter(torch.randn(num_agents, d_model, d_ff) * (d_model ** -0.5))
|
|
61
|
+
self.up_proj = nn.Parameter(torch.randn(num_agents, d_model, d_ff) * (d_model ** -0.5))
|
|
62
|
+
self.down_proj = nn.Parameter(torch.randn(num_agents, d_ff, d_model) * (d_ff ** -0.5))
|
|
63
|
+
|
|
64
|
+
self.norm = nn.LayerNorm(d_model)
|
|
65
|
+
|
|
66
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Args:
|
|
69
|
+
x: (B, L, D)
|
|
70
|
+
Returns:
|
|
71
|
+
outputs: (B, L, num_agents, D) — 所有 Agent 的输出
|
|
72
|
+
"""
|
|
73
|
+
B, L, D = x.shape
|
|
74
|
+
|
|
75
|
+
# 加入 specialty 并归一化: (K, B, L, D)
|
|
76
|
+
x_agents = self.norm(x.unsqueeze(0) + self.specialty.unsqueeze(1))
|
|
77
|
+
|
|
78
|
+
# 批量 SwiGLU: 所有 Agent 同时计算
|
|
79
|
+
# x_agents: (K, B, L, D), gate_proj: (K, D, FF) → (K, B, L, FF)
|
|
80
|
+
gate = torch.einsum('kbld,kdf->kblf', x_agents, self.gate_proj)
|
|
81
|
+
up = torch.einsum('kbld,kdf->kblf', x_agents, self.up_proj)
|
|
82
|
+
hidden = F.silu(gate) * up
|
|
83
|
+
out = torch.einsum('kblf,kfd->kbld', hidden, self.down_proj)
|
|
84
|
+
|
|
85
|
+
# 转换为 (B, L, K, D)
|
|
86
|
+
return out.permute(1, 2, 0, 3)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SharedMessageBoard(nn.Module):
|
|
90
|
+
"""共享消息板:智能体间的信息交换通道。
|
|
91
|
+
|
|
92
|
+
每个 Agent 向消息板写入关键发现,并读取其他 Agent 的消息。
|
|
93
|
+
通过交叉注意力实现选择性的信息交换。
|
|
94
|
+
|
|
95
|
+
优化: write() 方法通过返回更新后的 slots 视图(而非 .data 原地赋值)
|
|
96
|
+
保持梯度流通,支持端到端训练。
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
d_model: 维度
|
|
100
|
+
num_slots: 消息槽位数
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, d_model: int, num_slots: int = 8):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.slots = nn.Parameter(torch.randn(1, num_slots, d_model) * 0.02)
|
|
106
|
+
self.write_proj = nn.Linear(d_model, d_model, bias=False)
|
|
107
|
+
self.write_gate = nn.Linear(d_model, num_slots, bias=False)
|
|
108
|
+
self.read_proj = nn.Linear(d_model, d_model, bias=False)
|
|
109
|
+
self.ema_decay = nn.Parameter(torch.tensor(0.9))
|
|
110
|
+
|
|
111
|
+
# 用于存储上一步的 slots 更新(保持梯度流)
|
|
112
|
+
self._current_slots: torch.Tensor | None = None
|
|
113
|
+
|
|
114
|
+
def _get_slots(self) -> torch.Tensor:
|
|
115
|
+
"""获取当前的 slots 值(如果有残差更新则使用更新后的值)。"""
|
|
116
|
+
if self._current_slots is not None:
|
|
117
|
+
return self._current_slots
|
|
118
|
+
return self.slots
|
|
119
|
+
|
|
120
|
+
def read(self, query: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""从消息板读取相关信息。"""
|
|
122
|
+
B, L, D = query.shape
|
|
123
|
+
q = self.read_proj(query)
|
|
124
|
+
slots = self._get_slots().expand(B, -1, -1)
|
|
125
|
+
attn = F.softmax(q @ slots.transpose(-1, -2) / D**0.5, dim=-1)
|
|
126
|
+
return attn @ slots
|
|
127
|
+
|
|
128
|
+
def write(self, content: torch.Tensor) -> None:
|
|
129
|
+
"""向消息板写入信息(EMA 累积更新,保持梯度流)。
|
|
130
|
+
|
|
131
|
+
通过残差连接更新 slots,不再使用 .data 原地赋值。
|
|
132
|
+
梯度可以通过 write_proj / write_gate 流回写入内容。
|
|
133
|
+
"""
|
|
134
|
+
# content: (B, L, D) — Agent 的聚合输出
|
|
135
|
+
projected = self.write_proj(content) # (B, L, D)
|
|
136
|
+
# 注意力权重决定写入哪些槽位
|
|
137
|
+
gate = F.softmax(self.write_gate(content.mean(dim=1)), dim=-1) # (B, S)
|
|
138
|
+
# 新消息 = content 的加权平均
|
|
139
|
+
new_msg = projected.mean(dim=1, keepdim=True) # (B, 1, D)
|
|
140
|
+
# EMA 更新(保留计算图,梯度可回传到 content)
|
|
141
|
+
decay = torch.sigmoid(self.ema_decay)
|
|
142
|
+
update = gate.unsqueeze(-1) * new_msg # (B, S, D)
|
|
143
|
+
# batch 平均作为全局更新方向
|
|
144
|
+
update_mean = update.mean(dim=0, keepdim=True) # (1, S, D)
|
|
145
|
+
# 残差式更新: 保持梯度流通
|
|
146
|
+
self._current_slots = decay * self.slots + (1 - decay) * update_mean
|
|
147
|
+
|
|
148
|
+
def reset_state(self):
|
|
149
|
+
"""重置消息板状态(每个训练 step 开始时调用)。"""
|
|
150
|
+
self._current_slots = None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class AgentCoordinator(nn.Module):
|
|
154
|
+
"""智能体协调器:综合各 Agent 的输出做最终决策。
|
|
155
|
+
|
|
156
|
+
根据输入内容动态评估每个 Agent 的可信度,
|
|
157
|
+
加权融合产生最终输出。
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
d_model: 维度
|
|
161
|
+
num_agents: 智能体数量
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(self, d_model: int, num_agents: int):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.trust_scorer = nn.Sequential(
|
|
167
|
+
nn.Linear(d_model * 2, d_model // 2),
|
|
168
|
+
nn.GELU(),
|
|
169
|
+
nn.Linear(d_model // 2, num_agents),
|
|
170
|
+
)
|
|
171
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
|
172
|
+
|
|
173
|
+
def forward(
|
|
174
|
+
self, x: torch.Tensor, agent_outputs: torch.Tensor
|
|
175
|
+
) -> torch.Tensor:
|
|
176
|
+
"""
|
|
177
|
+
Args:
|
|
178
|
+
x: (B, L, D) 原始输入
|
|
179
|
+
agent_outputs: (B, L, K, D) 所有 Agent 的堆叠输出
|
|
180
|
+
Returns:
|
|
181
|
+
merged: (B, L, D)
|
|
182
|
+
"""
|
|
183
|
+
B, L, D = x.shape
|
|
184
|
+
|
|
185
|
+
context = x.mean(dim=1, keepdim=True).expand(-1, L, -1)
|
|
186
|
+
coord_input = torch.cat([x, context], dim=-1)
|
|
187
|
+
trust = F.softmax(self.trust_scorer(coord_input), dim=-1) # (B, L, K)
|
|
188
|
+
|
|
189
|
+
merged = (agent_outputs * trust.unsqueeze(-1)).sum(dim=2)
|
|
190
|
+
return self.out_proj(merged)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class MultiAgentSystem(nn.Module):
|
|
194
|
+
"""多智能体协作系统:统一接口。
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
d_model: 模型维度
|
|
198
|
+
num_agents: 智能体数量
|
|
199
|
+
message_slots: 共享消息板槽位数
|
|
200
|
+
dropout: Dropout 比率
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
d_model: int,
|
|
206
|
+
num_agents: int = 4,
|
|
207
|
+
message_slots: int = 8,
|
|
208
|
+
dropout: float = 0.0,
|
|
209
|
+
):
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.batched_agents = BatchedAgents(d_model, num_agents)
|
|
212
|
+
self.message_board = SharedMessageBoard(d_model, message_slots)
|
|
213
|
+
self.coordinator = AgentCoordinator(d_model, num_agents)
|
|
214
|
+
self.residual_gate = nn.Linear(d_model, 1)
|
|
215
|
+
# 可学习的上下文融合强度(替代硬编码 0.1)
|
|
216
|
+
self.context_scale = nn.Parameter(torch.tensor(0.1))
|
|
217
|
+
self.dropout = nn.Dropout(dropout)
|
|
218
|
+
|
|
219
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
# 从消息板获取共享上下文
|
|
221
|
+
shared_context = self.message_board.read(x)
|
|
222
|
+
|
|
223
|
+
# 所有 Agent 批量并行处理
|
|
224
|
+
enriched = x + shared_context * self.context_scale
|
|
225
|
+
agent_outputs = self.batched_agents(enriched) # (B, L, K, D)
|
|
226
|
+
|
|
227
|
+
# 协调器融合
|
|
228
|
+
merged = self.coordinator(x, agent_outputs)
|
|
229
|
+
merged = self.dropout(merged)
|
|
230
|
+
|
|
231
|
+
# 向消息板写入本轮的聚合输出(EMA 累积,保持梯度流)
|
|
232
|
+
self.message_board.write(merged)
|
|
233
|
+
|
|
234
|
+
# 残差门控
|
|
235
|
+
gate = torch.sigmoid(self.residual_gate(x))
|
|
236
|
+
return x + gate * merged
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# ═══ 向后兼容别名 ═══
|
|
240
|
+
# 保持旧的 import 路径可用
|
|
241
|
+
SpecialistAgent = BatchedAgents # 类接口变更,但名称保留
|
cortexnet/multimodal.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
多模态编码器框架 (Multi-Modal Encoder Framework)
|
|
3
|
+
|
|
4
|
+
核心创新:
|
|
5
|
+
统一的多模态输入处理框架,支持文本、图像、音频等多种模态。
|
|
6
|
+
通过模态特定的编码器将不同类型的输入映射到统一的表示空间,
|
|
7
|
+
然后通过跨模态融合机制建立模态间的语义关联。
|
|
8
|
+
|
|
9
|
+
┌────────────────────────────────────────────────────────┐
|
|
10
|
+
│ 多模态处理流水线 │
|
|
11
|
+
├────────────────────────────────────────────────────────┤
|
|
12
|
+
│ │
|
|
13
|
+
│ 文本 ──► TokenEmbed ──┐ │
|
|
14
|
+
│ │ │
|
|
15
|
+
│ 图像 ──► PatchEmbed ──┼──► 模态标记 ──► 拼接 ──► 输出│
|
|
16
|
+
│ │ │ │
|
|
17
|
+
│ 音频 ──► FrameEmbed ──┘ │ │
|
|
18
|
+
│ │ │
|
|
19
|
+
│ 跨模态融合层 (可选) │
|
|
20
|
+
│ │
|
|
21
|
+
└────────────────────────────────────────────────────────┘
|
|
22
|
+
|
|
23
|
+
每种模态的编码器将输入转换为 (batch, num_tokens, d_model) 的格式,
|
|
24
|
+
然后通过模态类型嵌入标记来源,最后拼接成统一序列。
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
from typing import Optional
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PatchEmbedding(nn.Module):
|
|
33
|
+
"""ViT 风格的图像块嵌入。
|
|
34
|
+
|
|
35
|
+
将图像分割为不重叠的块,每个块通过线性投影映射到 d_model 维。
|
|
36
|
+
加入可学习的位置嵌入。
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
d_model: 输出维度
|
|
40
|
+
image_size: 输入图像尺寸(正方形)
|
|
41
|
+
patch_size: 块大小
|
|
42
|
+
in_channels: 输入通道数
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
d_model: int,
|
|
48
|
+
image_size: int = 224,
|
|
49
|
+
patch_size: int = 16,
|
|
50
|
+
in_channels: int = 3,
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.patch_size = patch_size
|
|
54
|
+
self.num_patches = (image_size // patch_size) ** 2
|
|
55
|
+
|
|
56
|
+
# 块投影(卷积实现)
|
|
57
|
+
self.proj = nn.Conv2d(
|
|
58
|
+
in_channels, d_model, kernel_size=patch_size, stride=patch_size
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# 可学习的位置嵌入
|
|
62
|
+
self.pos_embed = nn.Parameter(
|
|
63
|
+
torch.randn(1, self.num_patches, d_model) * 0.02
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.norm = nn.LayerNorm(d_model)
|
|
67
|
+
|
|
68
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
69
|
+
"""
|
|
70
|
+
Args:
|
|
71
|
+
x: (batch, channels, height, width)
|
|
72
|
+
Returns:
|
|
73
|
+
patches: (batch, num_patches, d_model)
|
|
74
|
+
"""
|
|
75
|
+
x = self.proj(x) # (B, D, H/P, W/P)
|
|
76
|
+
x = x.flatten(2).transpose(1, 2) # (B, num_patches, D)
|
|
77
|
+
x = x + self.pos_embed[:, : x.shape[1]]
|
|
78
|
+
return self.norm(x)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AudioEmbedding(nn.Module):
|
|
82
|
+
"""音频帧嵌入。
|
|
83
|
+
|
|
84
|
+
将 Mel 频谱图通过 1D 卷积编码为帧级表示。
|
|
85
|
+
使用步进卷积进行时间降采样。
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
d_model: 输出维度
|
|
89
|
+
n_mels: Mel 频率通道数
|
|
90
|
+
frame_stride: 时间降采样步长
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self, d_model: int, n_mels: int = 80, frame_stride: int = 4
|
|
95
|
+
):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.encoder = nn.Sequential(
|
|
98
|
+
nn.Conv1d(n_mels, d_model, kernel_size=7, padding=3),
|
|
99
|
+
nn.GELU(),
|
|
100
|
+
nn.Conv1d(
|
|
101
|
+
d_model,
|
|
102
|
+
d_model,
|
|
103
|
+
kernel_size=frame_stride,
|
|
104
|
+
stride=frame_stride,
|
|
105
|
+
),
|
|
106
|
+
nn.GELU(),
|
|
107
|
+
)
|
|
108
|
+
self.norm = nn.LayerNorm(d_model)
|
|
109
|
+
|
|
110
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
"""
|
|
112
|
+
Args:
|
|
113
|
+
x: (batch, n_mels, time_frames) Mel 频谱图
|
|
114
|
+
Returns:
|
|
115
|
+
frames: (batch, time_frames/stride, d_model)
|
|
116
|
+
"""
|
|
117
|
+
x = self.encoder(x).transpose(1, 2) # (B, T/stride, D)
|
|
118
|
+
return self.norm(x)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CrossModalFusion(nn.Module):
|
|
122
|
+
"""跨模态融合层。
|
|
123
|
+
|
|
124
|
+
通过交叉注意力让不同模态的 token 相互交流,
|
|
125
|
+
建立跨模态的语义关联。
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
d_model: 特征维度
|
|
129
|
+
num_heads: 注意力头数
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(self, d_model: int, num_heads: int = 8):
|
|
133
|
+
super().__init__()
|
|
134
|
+
self.cross_attn = nn.MultiheadAttention(
|
|
135
|
+
d_model, num_heads, batch_first=True
|
|
136
|
+
)
|
|
137
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
138
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
139
|
+
self.ffn = nn.Sequential(
|
|
140
|
+
nn.Linear(d_model, d_model * 2),
|
|
141
|
+
nn.GELU(),
|
|
142
|
+
nn.Linear(d_model * 2, d_model),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self, x: torch.Tensor, context: torch.Tensor
|
|
147
|
+
) -> torch.Tensor:
|
|
148
|
+
"""
|
|
149
|
+
Args:
|
|
150
|
+
x: (batch, L1, d_model) 查询模态
|
|
151
|
+
context: (batch, L2, d_model) 上下文模态
|
|
152
|
+
Returns:
|
|
153
|
+
output: (batch, L1, d_model)
|
|
154
|
+
"""
|
|
155
|
+
# 交叉注意力
|
|
156
|
+
residual = x
|
|
157
|
+
x = self.norm1(x)
|
|
158
|
+
x, _ = self.cross_attn(x, context, context)
|
|
159
|
+
x = residual + x
|
|
160
|
+
|
|
161
|
+
# FFN
|
|
162
|
+
residual = x
|
|
163
|
+
x = residual + self.ffn(self.norm2(x))
|
|
164
|
+
|
|
165
|
+
return x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class MultiModalEncoder(nn.Module):
|
|
169
|
+
"""统一多模态编码器。
|
|
170
|
+
|
|
171
|
+
支持三种模态的输入编码,并提供跨模态融合能力。
|
|
172
|
+
所有模态的输出都映射到相同的 d_model 维空间。
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
d_model: 统一的表示维度
|
|
176
|
+
vocab_size: 文本词表大小
|
|
177
|
+
image_size: 图像尺寸
|
|
178
|
+
patch_size: 图像块大小
|
|
179
|
+
n_mels: 音频 Mel 通道数
|
|
180
|
+
num_fusion_heads: 跨模态融合注意力头数
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
d_model: int,
|
|
186
|
+
vocab_size: int = 32000,
|
|
187
|
+
image_size: int = 224,
|
|
188
|
+
patch_size: int = 16,
|
|
189
|
+
n_mels: int = 80,
|
|
190
|
+
num_fusion_heads: int = 8,
|
|
191
|
+
):
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.d_model = d_model
|
|
194
|
+
|
|
195
|
+
# 模态特定编码器
|
|
196
|
+
self.text_embed = nn.Embedding(vocab_size, d_model)
|
|
197
|
+
self.image_embed = PatchEmbedding(d_model, image_size, patch_size)
|
|
198
|
+
self.audio_embed = AudioEmbedding(d_model, n_mels)
|
|
199
|
+
|
|
200
|
+
# 模态类型嵌入 (text=0, image=1, audio=2)
|
|
201
|
+
self.modality_embed = nn.Embedding(3, d_model)
|
|
202
|
+
|
|
203
|
+
# 跨模态融合(可选)
|
|
204
|
+
self.cross_modal_fusion = CrossModalFusion(d_model, num_fusion_heads)
|
|
205
|
+
|
|
206
|
+
# 模态投影归一化
|
|
207
|
+
self.text_norm = nn.LayerNorm(d_model)
|
|
208
|
+
|
|
209
|
+
def encode_text(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
210
|
+
"""编码文本 token。"""
|
|
211
|
+
x = self.text_norm(self.text_embed(tokens))
|
|
212
|
+
mod = self.modality_embed(
|
|
213
|
+
torch.zeros(
|
|
214
|
+
tokens.shape[0],
|
|
215
|
+
tokens.shape[1],
|
|
216
|
+
dtype=torch.long,
|
|
217
|
+
device=tokens.device,
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
return x + mod
|
|
221
|
+
|
|
222
|
+
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
|
223
|
+
"""编码图像。"""
|
|
224
|
+
patches = self.image_embed(images)
|
|
225
|
+
B, L = patches.shape[:2]
|
|
226
|
+
mod = self.modality_embed(
|
|
227
|
+
torch.ones(B, L, dtype=torch.long, device=images.device)
|
|
228
|
+
)
|
|
229
|
+
return patches + mod
|
|
230
|
+
|
|
231
|
+
def encode_audio(self, mel_specs: torch.Tensor) -> torch.Tensor:
|
|
232
|
+
"""编码音频 Mel 频谱图。"""
|
|
233
|
+
frames = self.audio_embed(mel_specs)
|
|
234
|
+
B, L = frames.shape[:2]
|
|
235
|
+
mod = self.modality_embed(
|
|
236
|
+
2 * torch.ones(B, L, dtype=torch.long, device=mel_specs.device)
|
|
237
|
+
)
|
|
238
|
+
return frames + mod
|
|
239
|
+
|
|
240
|
+
def forward(
|
|
241
|
+
self,
|
|
242
|
+
text: Optional[torch.Tensor] = None,
|
|
243
|
+
images: Optional[torch.Tensor] = None,
|
|
244
|
+
audio: Optional[torch.Tensor] = None,
|
|
245
|
+
fuse_modalities: bool = False,
|
|
246
|
+
) -> torch.Tensor:
|
|
247
|
+
"""统一编码多模态输入。
|
|
248
|
+
|
|
249
|
+
将所有模态的编码结果拼接为一个序列。
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
text: (B, L_text) 文本 token
|
|
253
|
+
images: (B, C, H, W) 图像
|
|
254
|
+
audio: (B, n_mels, T) 音频 Mel 频谱
|
|
255
|
+
fuse_modalities: 是否进行跨模态融合
|
|
256
|
+
Returns:
|
|
257
|
+
embeddings: (B, L_total, d_model)
|
|
258
|
+
"""
|
|
259
|
+
parts = []
|
|
260
|
+
|
|
261
|
+
if text is not None:
|
|
262
|
+
parts.append(self.encode_text(text))
|
|
263
|
+
if images is not None:
|
|
264
|
+
parts.append(self.encode_image(images))
|
|
265
|
+
if audio is not None:
|
|
266
|
+
parts.append(self.encode_audio(audio))
|
|
267
|
+
|
|
268
|
+
if not parts:
|
|
269
|
+
raise ValueError("至少需要提供一种模态的输入")
|
|
270
|
+
|
|
271
|
+
# 拼接所有模态
|
|
272
|
+
combined = torch.cat(parts, dim=1) # (B, L_total, D)
|
|
273
|
+
|
|
274
|
+
# 可选的跨模态融合
|
|
275
|
+
if fuse_modalities and len(parts) > 1:
|
|
276
|
+
combined = self.cross_modal_fusion(combined, combined)
|
|
277
|
+
|
|
278
|
+
return combined
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CortexNet 算子子包 (Ops Sub-package)
|
|
3
|
+
|
|
4
|
+
提供硬件抽象层,支持 NVIDIA GPU、国产 NPU 和 CPU 后端。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .device_manager import (
|
|
8
|
+
DeviceManager,
|
|
9
|
+
get_best_device_info,
|
|
10
|
+
is_npu_available,
|
|
11
|
+
is_mlu_available,
|
|
12
|
+
get_device_type,
|
|
13
|
+
resolve_device_string,
|
|
14
|
+
resolve_dtype_for_device,
|
|
15
|
+
)
|
|
16
|
+
from .npu_ops import NPUOperators, get_operators
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DeviceManager",
|
|
20
|
+
"get_best_device_info",
|
|
21
|
+
"is_npu_available",
|
|
22
|
+
"is_mlu_available",
|
|
23
|
+
"get_device_type",
|
|
24
|
+
"resolve_device_string",
|
|
25
|
+
"resolve_dtype_for_device",
|
|
26
|
+
"NPUOperators",
|
|
27
|
+
"get_operators",
|
|
28
|
+
]
|