cortexnet 3.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cortexnet/memory.py ADDED
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ 突触可塑性记忆模块 (Synaptic Plasticity Memory)
5
+
6
+ 核心创新:
7
+ 受生物神经系统突触可塑性启发的快速权重记忆系统。
8
+ 在前向传播过程中逐步累积键值关联,使模型具备:
9
+
10
+ 1. 即时上下文学习:无需梯度更新即可在推理时"学习"新模式
11
+ 2. 工作记忆:维护当前上下文的动态表示
12
+ 3. 模式发现:自动发现并记忆输入序列中的规律
13
+
14
+ 与传统注意力的区别:
15
+ - 注意力:每次重新计算所有 token 间的关系 → O(n²)
16
+ - 突触记忆:递增地累积关联 → O(n),且具有压缩效果
17
+
18
+ 学术渊源:
19
+ - Fast Weight Programmers (Schmidhuber, 1992)
20
+ - Linear Attention (Katharopoulos et al., 2020)
21
+ - Delta Net (Schlag et al., 2021)
22
+
23
+ 本模块在此基础上创新性地加入:
24
+ - 可学习的衰减率(平衡新旧信息)
25
+ - ELU+1 正值化(类似核方法,保证记忆矩阵的正定性)
26
+ - 归一化读取(防止记忆值爆炸或消失)
27
+ """
28
+
29
+ from typing import Optional, Tuple
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class _RMSNorm(nn.Module):
37
+ """内联 RMSNorm(避免 blocks.py 循环导入,兼容 MPS/float16)。"""
38
+ def __init__(self, dim: int, eps: float = 1e-6):
39
+ super().__init__()
40
+ self.eps = eps
41
+ self.weight = nn.Parameter(torch.ones(dim))
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ dtype = x.dtype
44
+ x_f = x.float()
45
+ rms = torch.sqrt(x_f.pow(2).mean(-1, keepdim=True) + self.eps)
46
+ return (x_f / rms).to(dtype) * self.weight.to(dtype)
47
+
48
+
49
+ class SynapticMemory(nn.Module):
50
+ """突触可塑性记忆模块。
51
+
52
+ 维护一个在前向传播中动态更新的快速权重矩阵。
53
+ 每个 token 先从累积的记忆中读取,然后写入自己的关联,
54
+ 形成一种快速学习机制。
55
+
56
+ 架构流程(对每个时间步 t):
57
+ 1. 读取:o_t = φ(q_t) @ Memory_{t-1} / (φ(q_t) @ z_{t-1})
58
+ 2. 写入:Memory_t = λ·Memory_{t-1} + φ(k_t)^T @ v_t
59
+ 3. 更新归一化因子:z_t = λ·z_{t-1} + φ(k_t)
60
+ 4. 输出:RMSNorm(o_t) → Linear → output_t
61
+
62
+ 其中 φ(x) = ELU(x) + 1 是正值化的核函数。
63
+
64
+ Args:
65
+ d_model: 输入/输出维度
66
+ memory_dim: 记忆矩阵的行维度(类似于"记忆槽"数量)
67
+ decay_init: 衰减率初始值(控制新旧信息的平衡)
68
+ """
69
+
70
+ def __init__(
71
+ self, d_model: int, memory_dim: int = 64, decay_init: float = 0.95
72
+ ):
73
+ super().__init__()
74
+ self.d_model = d_model
75
+ self.memory_dim = memory_dim
76
+
77
+ # 投影层
78
+ self.query_proj = nn.Linear(d_model, memory_dim, bias=False)
79
+ self.key_proj = nn.Linear(d_model, memory_dim, bias=False)
80
+ self.value_proj = nn.Linear(d_model, d_model, bias=False)
81
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
82
+
83
+ # 可学习的衰减率
84
+ # logit(0.95) ≈ 2.944, sigmoid 后恢复为 ~0.95
85
+ self.decay_logit = nn.Parameter(
86
+ torch.tensor(torch.logit(torch.tensor(decay_init)).item())
87
+ )
88
+
89
+ # 缩放因子
90
+ self.scale = memory_dim ** -0.5
91
+
92
+ # 输出归一化(使用 RMSNorm 保持一致性)
93
+ self.norm = _RMSNorm(d_model)
94
+
95
+ def forward(
96
+ self,
97
+ x: torch.Tensor,
98
+ past_memory: Optional[torch.Tensor] = None,
99
+ past_z: Optional[torch.Tensor] = None,
100
+ use_cache: bool = False,
101
+ ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
102
+ B, L, D = x.shape
103
+ # 增量缓存模式或需要返回状态:用顺序实现
104
+ if use_cache or past_memory is not None:
105
+ return self._sequential_forward(x, past_memory, past_z, use_cache)
106
+ # 训练 / 非缓存推理:用并行实现(无 Python 循环,大幅提速)
107
+ return self._parallel_forward(x)
108
+
109
+ def _parallel_forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """并行线性注意力:用 log-space cumsum 替代 Python for-loop。
111
+
112
+ 数学等价于顺序版本,但全部操作是张量并行的。
113
+ 原理:M_t = Σ_{s≤t} λ^{t-s} · k_s^T · v_s
114
+ 用 exp(-log(λ)·s) 归一化后 cumsum,再乘回 exp(log(λ)·t)。
115
+ """
116
+ B, L, D = x.shape
117
+ queries = F.elu(self.query_proj(x), alpha=1.0) + 1 # (B, L, mem)
118
+ keys = F.elu(self.key_proj(x), alpha=1.0) + 1
119
+ values = self.value_proj(x) # (B, L, D)
120
+ queries = queries * self.scale
121
+ decay = torch.sigmoid(self.decay_logit)
122
+
123
+ # Log-space cumulative decay trick(clamp 防止溢出)
124
+ log_decay = torch.log(decay.clamp(min=1e-6))
125
+ positions = torch.arange(L, device=x.device, dtype=x.dtype)
126
+ log_cum = (log_decay * positions).clamp(-20, 20) # (L,)
127
+
128
+ exp_neg = torch.exp(-log_cum).view(1, L, 1) # (1, L, 1)
129
+ exp_pos = torch.exp(log_cum).view(1, L, 1) # (1, L, 1)
130
+
131
+ # 记忆矩阵并行累积: M_t = Σ_{s≤t} decay^{t-s} * k_s ⊗ v_s
132
+ kv = keys.unsqueeze(-1) * values.unsqueeze(-2) # (B, L, mem, D)
133
+ weighted_kv = exp_neg.unsqueeze(-1) * kv # broadcast (1,L,1,1)
134
+ cum_kv = torch.cumsum(weighted_kv, dim=1)
135
+ M = exp_pos.unsqueeze(-1) * cum_kv # (B, L, mem, D)
136
+
137
+ # 归一化因子: z_t = Σ_{s≤t} decay^{t-s} * k_s
138
+ weighted_k = exp_neg * keys # (B, L, mem)
139
+ cum_k = torch.cumsum(weighted_k, dim=1)
140
+ z = exp_pos * cum_k # (B, L, mem)
141
+
142
+ # 读取: o_t = q_t · M_t / (q_t · z_t)
143
+ numerator = torch.einsum('blm,blmd->bld', queries, M)
144
+ denominator = torch.einsum('blm,blm->bl', queries, z).unsqueeze(-1) + 1e-6
145
+
146
+ output = self.norm(numerator / denominator)
147
+ return self.out_proj(output)
148
+
149
+ def _sequential_forward(
150
+ self, x, past_memory, past_z, use_cache,
151
+ ):
152
+ """顺序实现(用于增量缓存生成,L 通常 = 1)。"""
153
+ B, L, D = x.shape
154
+ queries = F.elu(self.query_proj(x), alpha=1.0) + 1
155
+ keys = F.elu(self.key_proj(x), alpha=1.0) + 1
156
+ values = self.value_proj(x)
157
+ queries = queries * self.scale
158
+ decay = torch.sigmoid(self.decay_logit)
159
+ memory = (
160
+ past_memory if past_memory is not None
161
+ else torch.zeros(B, self.memory_dim, D, device=x.device, dtype=x.dtype)
162
+ )
163
+ z = (
164
+ past_z if past_z is not None
165
+ else torch.zeros(B, self.memory_dim, 1, device=x.device, dtype=x.dtype)
166
+ )
167
+ outputs = []
168
+ for t in range(L):
169
+ q_t, k_t, v_t = queries[:, t], keys[:, t], values[:, t]
170
+ num = torch.bmm(q_t.unsqueeze(1), memory).squeeze(1)
171
+ den = torch.bmm(q_t.unsqueeze(1), z).squeeze(1)
172
+ outputs.append(num / (den + 1e-6))
173
+ memory = decay * memory + torch.bmm(k_t.unsqueeze(2), v_t.unsqueeze(1))
174
+ z = decay * z + k_t.unsqueeze(2)
175
+ output = self.norm(torch.stack(outputs, dim=1))
176
+ out = self.out_proj(output)
177
+ if use_cache:
178
+ return out, memory, z
179
+ return out
@@ -0,0 +1,187 @@
1
+ """
2
+ 元学习自适应模块 (Meta-Learning Adaptation Module)
3
+
4
+ 核心创新:
5
+ 使网络能够根据输入上下文快速调整自身行为,
6
+ 无需梯度更新即可适应新模式和新任务。
7
+
8
+ 灵感来源:
9
+ - FiLM (Feature-wise Linear Modulation) — 特征级调制
10
+ - MAML (Model-Agnostic Meta-Learning) — 快速适应
11
+ - HyperNetworks — 用网络生成网络参数
12
+
13
+ ┌─────────────────────────────────────────────────────┐
14
+ │ 元学习自适应机制 │
15
+ ├─────────────────────────────────────────────────────┤
16
+ │ │
17
+ │ 输入序列 ──► 上下文编码器 ──► 全局上下文向量 │
18
+ │ │ │
19
+ │ ┌─────────┴────────┐ │
20
+ │ ▼ ▼ │
21
+ │ γ (缩放因子) β (偏移因子) │
22
+ │ │ │ │
23
+ │ └─────────┬────────┘ │
24
+ │ ▼ │
25
+ │ FiLM 调制: y = γ ⊙ x + β │
26
+ │ │
27
+ │ 效果:每个样本/任务获得定制化的特征变换 │
28
+ └─────────────────────────────────────────────────────┘
29
+
30
+ 额外创新:
31
+ - 层级自适应:不同层可以有不同的适应策略
32
+ - 注意力池化:比均值池化更精确的上下文提取
33
+ - 残差调制:保证训练稳定性
34
+ """
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+
40
+
41
+ class ContextEncoder(nn.Module):
42
+ """上下文编码器:从输入序列提取全局上下文向量。
43
+
44
+ 使用注意力池化代替简单的均值池化,
45
+ 使模型能够关注序列中最有信息量的部分。
46
+
47
+ Args:
48
+ d_model: 输入维度
49
+ """
50
+
51
+ def __init__(self, d_model: int):
52
+ super().__init__()
53
+ # 注意力池化
54
+ self.attn_query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
55
+ self.attn_proj = nn.Linear(d_model, d_model, bias=False)
56
+
57
+ # 上下文变换
58
+ self.context_transform = nn.Sequential(
59
+ nn.Linear(d_model, d_model),
60
+ nn.GELU(),
61
+ nn.Linear(d_model, d_model),
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Args:
67
+ x: (batch, seq_len, d_model)
68
+ Returns:
69
+ context: (batch, d_model)
70
+ """
71
+ B, L, D = x.shape
72
+
73
+ # 注意力池化
74
+ query = self.attn_query.expand(B, -1, -1) # (B, 1, D)
75
+ keys = self.attn_proj(x) # (B, L, D)
76
+
77
+ attn = F.softmax(
78
+ (query @ keys.transpose(-1, -2)) / D**0.5, dim=-1
79
+ ) # (B, 1, L)
80
+ pooled = (attn @ x).squeeze(1) # (B, D)
81
+
82
+ return self.context_transform(pooled)
83
+
84
+
85
+ class MetaLearningAdapter(nn.Module):
86
+ """元学习自适应层 (FiLM Conditioning)。
87
+
88
+ 根据全局上下文向量生成特征调制参数 (γ, β),
89
+ 对层输出进行 token 级的缩放和偏移。
90
+
91
+ 数学原理:
92
+ context = ContextEncoder(input)
93
+ γ = W_γ · context + 1 (初始化为恒等映射)
94
+ β = W_β · context
95
+ output = γ ⊙ x + β
96
+
97
+ Args:
98
+ d_model: 特征维度
99
+ """
100
+
101
+ def __init__(self, d_model: int):
102
+ super().__init__()
103
+ self.context_encoder = ContextEncoder(d_model)
104
+
105
+ # FiLM 参数生成器
106
+ self.gamma_gen = nn.Linear(d_model, d_model)
107
+ self.beta_gen = nn.Linear(d_model, d_model)
108
+
109
+ # 初始化为近似恒等映射
110
+ nn.init.zeros_(self.gamma_gen.weight)
111
+ nn.init.zeros_(self.gamma_gen.bias)
112
+ nn.init.zeros_(self.beta_gen.weight)
113
+ nn.init.zeros_(self.beta_gen.bias)
114
+
115
+ def forward(
116
+ self, x: torch.Tensor, context: torch.Tensor = None
117
+ ) -> torch.Tensor:
118
+ """
119
+ Args:
120
+ x: (batch, seq_len, d_model) 待调制的特征
121
+ context: (batch, d_model) 可选的外部上下文
122
+ Returns:
123
+ modulated: (batch, seq_len, d_model) 调制后的特征
124
+ """
125
+ if context is None:
126
+ context = self.context_encoder(x) # (B, D)
127
+
128
+ gamma = 1 + self.gamma_gen(context).unsqueeze(1) # (B, 1, D)
129
+ beta = self.beta_gen(context).unsqueeze(1) # (B, 1, D)
130
+
131
+ return x * gamma + beta
132
+
133
+
134
+ class TaskAdaptiveController(nn.Module):
135
+ """任务自适应控制器:根据输入推断任务类型并调整策略。
136
+
137
+ 通过一个小型分类器推断当前输入属于哪种"任务模式",
138
+ 然后为每种模式提供不同的处理策略。
139
+
140
+ 这实现了一种隐式的任务切换机制,使模型能够
141
+ 根据上下文自动在不同的处理模式间切换。
142
+
143
+ Args:
144
+ d_model: 特征维度
145
+ num_modes: 任务模式数量
146
+ """
147
+
148
+ def __init__(self, d_model: int, num_modes: int = 4):
149
+ super().__init__()
150
+ self.num_modes = num_modes
151
+ self.context_encoder = ContextEncoder(d_model)
152
+ # 可学习的调制强度(替代硬编码 0.1)
153
+ self.adapt_scale = nn.Parameter(torch.tensor(0.1))
154
+
155
+ # 模式分类器
156
+ self.mode_classifier = nn.Sequential(
157
+ nn.Linear(d_model, d_model // 2),
158
+ nn.GELU(),
159
+ nn.Linear(d_model // 2, num_modes),
160
+ )
161
+
162
+ # 合并所有模式的变换到单个 Linear(消除 Python for-loop)
163
+ self.mode_transform_merged = nn.Linear(d_model, d_model * num_modes, bias=False)
164
+ nn.init.zeros_(self.mode_transform_merged.weight)
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ x: (batch, seq_len, d_model)
170
+ Returns:
171
+ adapted: (batch, seq_len, d_model)
172
+ """
173
+ context = self.context_encoder(x) # (B, D)
174
+
175
+ # 推断任务模式
176
+ mode_logits = self.mode_classifier(context) # (B, num_modes)
177
+ mode_weights = F.softmax(mode_logits, dim=-1) # (B, num_modes)
178
+
179
+ # 批量变换:单次 Linear + reshape + tanh(替代 num_modes 次串行 forward)
180
+ all_transforms = self.mode_transform_merged(context) # (B, D*num_modes)
181
+ transforms = torch.tanh(all_transforms.view(-1, self.num_modes, context.shape[-1])) # (B, num_modes, D)
182
+
183
+ # 加权融合
184
+ adaptation = (transforms * mode_weights.unsqueeze(-1)).sum(dim=1) # (B, D)
185
+
186
+ # 应用调制(残差形式)
187
+ return x + x * adaptation.unsqueeze(1) * self.adapt_scale # 可学习的调制强度