codon-model 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {codon_model-0.0.1/codon_model.egg-info → codon_model-0.0.2}/PKG-INFO +4 -1
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/__init__.py +1 -1
- codon_model-0.0.2/codon/block/__init__.py +83 -0
- codon_model-0.0.2/codon/block/attention.py +173 -0
- codon_model-0.0.2/codon/block/bio/hebian.py +223 -0
- codon_model-0.0.2/codon/block/bio/predictive.py +377 -0
- codon_model-0.0.2/codon/block/codebook.py +142 -0
- codon_model-0.0.2/codon/block/conv.py +571 -0
- codon_model-0.0.2/codon/block/embedding.py +350 -0
- codon_model-0.0.2/codon/block/film.py +184 -0
- codon_model-0.0.2/codon/block/fusion.py +348 -0
- codon_model-0.0.2/codon/block/lora.py +870 -0
- codon_model-0.0.2/codon/block/mlp.py +88 -0
- codon_model-0.0.2/codon/block/moe.py +239 -0
- codon_model-0.0.2/codon/block/pixelshuffle.py +333 -0
- codon_model-0.0.2/codon/block/transformer.py +346 -0
- codon_model-0.0.2/codon/kit/__init__.py +6 -0
- codon_model-0.0.2/codon/kit/auto_vision_train.py +171 -0
- codon_model-0.0.2/codon/model/motif/__init__.py +14 -0
- codon_model-0.0.2/codon/model/motif/motif_v1.py +447 -0
- codon_model-0.0.2/codon/model/resnet.py +273 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/tcn.py +1 -1
- codon_model-0.0.2/codon/ops/__init__.py +38 -0
- codon_model-0.0.2/codon/ops/bio.py +309 -0
- codon_model-0.0.2/codon/ops/pixelshuffle.py +92 -0
- codon_model-0.0.2/codon/utils/dataset/__init__.py +19 -0
- codon_model-0.0.2/codon/utils/dataset/corpus.py +630 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/flatdata.py +4 -4
- codon_model-0.0.2/codon/utils/dataset/image.py +449 -0
- codon_model-0.0.2/codon/utils/split.py +133 -0
- {codon_model-0.0.1 → codon_model-0.0.2/codon_model.egg-info}/PKG-INFO +4 -1
- {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/SOURCES.txt +24 -1
- {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/requires.txt +3 -0
- codon_model-0.0.2/test/test_motifv1_train.py +100 -0
- codon_model-0.0.1/codon/model/motif/__init__.py +0 -1
- codon_model-0.0.1/codon/ops/__init__.py +0 -3
- codon_model-0.0.1/codon/utils/dataset/__init__.py +0 -3
- codon_model-0.0.1/codon/utils/dataset/corpus.py +0 -478
- {codon_model-0.0.1 → codon_model-0.0.2}/LICENSE +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/base.py +0 -0
- {codon_model-0.0.1/codon/exp → codon_model-0.0.2/codon/block/bio}/__init__.py +0 -0
- {codon_model-0.0.1/codon/model → codon_model-0.0.2/codon/exp}/__init__.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/exp/moe.py +0 -0
- {codon_model-0.0.1/codon/utils → codon_model-0.0.2/codon/model}/__init__.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/motif/motif_a1.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/patch_disc.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/ops/attention.py +0 -0
- /codon_model-0.0.1/codon/ops/bio.py → /codon_model-0.0.2/codon/utils/__init__.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/base.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/dataviewer.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/mask.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/safecode.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/seed.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/theta.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/token.py +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/dependency_links.txt +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/top_level.txt +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/setup.cfg +0 -0
- {codon_model-0.0.1 → codon_model-0.0.2}/setup.py +0 -0
|
@@ -1,15 +1,18 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: codon-model
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: Codon model package
|
|
5
5
|
Author: CodonTeam
|
|
6
6
|
Requires-Python: >=3.8
|
|
7
7
|
License-File: LICENSE
|
|
8
8
|
Requires-Dist: torch
|
|
9
|
+
Requires-Dist: torchvision
|
|
9
10
|
Requires-Dist: transformers
|
|
10
11
|
Requires-Dist: tokenizers
|
|
11
12
|
Requires-Dist: pandas
|
|
13
|
+
Requires-Dist: numpy
|
|
12
14
|
Requires-Dist: pyarrow
|
|
15
|
+
Requires-Dist: pillow
|
|
13
16
|
Dynamic: author
|
|
14
17
|
Dynamic: license-file
|
|
15
18
|
Dynamic: requires-dist
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from .attention import AttentionOutput, MultiHeadAttention
|
|
2
|
+
from .codebook import LookupFreeQuantization, LookupFreeQuantizationOutput
|
|
3
|
+
from .conv import (
|
|
4
|
+
CausalConv1d,
|
|
5
|
+
ConvBlock,
|
|
6
|
+
DepthwiseSeparableConv,
|
|
7
|
+
ResBasicBlock,
|
|
8
|
+
calculate_causal_layer,
|
|
9
|
+
)
|
|
10
|
+
from .embedding import (
|
|
11
|
+
BasicEmbedding,
|
|
12
|
+
BasicRotaryEmbedding,
|
|
13
|
+
InterleavedRotaryEmbedding,
|
|
14
|
+
RotaryEmbedding,
|
|
15
|
+
SinusoidalEmbedding,
|
|
16
|
+
)
|
|
17
|
+
from .film import FiLM, FiLMOutput
|
|
18
|
+
from .fusion import (
|
|
19
|
+
CompactMultimodalPooling,
|
|
20
|
+
DiffusionMapsFusion,
|
|
21
|
+
GatedMultimodalUnit,
|
|
22
|
+
LowRankFusion,
|
|
23
|
+
)
|
|
24
|
+
from .lora import BasicLoRA, Conv1dLoRA, Conv2dLoRA, EmbeddingLoRA, LinearLoRA
|
|
25
|
+
from .mlp import MLP
|
|
26
|
+
from .moe import Expert, MoE, MoEInfo, MoEOutput
|
|
27
|
+
from .pixelshuffle import PixelShuffleUpSample, UnPixelShuffleDownSample
|
|
28
|
+
from .transformer import (
|
|
29
|
+
TransformerDecoderOutput,
|
|
30
|
+
TransformerDenseDecoder,
|
|
31
|
+
TransformerMoEDecoder,
|
|
32
|
+
_TransformerDecoder,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
# attention
|
|
37
|
+
'AttentionOutput',
|
|
38
|
+
'MultiHeadAttention',
|
|
39
|
+
# codebook
|
|
40
|
+
'LookupFreeQuantization',
|
|
41
|
+
'LookupFreeQuantizationOutput',
|
|
42
|
+
# conv
|
|
43
|
+
'CausalConv1d',
|
|
44
|
+
'ConvBlock',
|
|
45
|
+
'DepthwiseSeparableConv',
|
|
46
|
+
'ResBasicBlock',
|
|
47
|
+
'calculate_causal_layer',
|
|
48
|
+
# embedding
|
|
49
|
+
'BasicEmbedding',
|
|
50
|
+
'BasicRotaryEmbedding',
|
|
51
|
+
'InterleavedRotaryEmbedding',
|
|
52
|
+
'RotaryEmbedding',
|
|
53
|
+
'SinusoidalEmbedding',
|
|
54
|
+
# film
|
|
55
|
+
'FiLM',
|
|
56
|
+
'FiLMOutput',
|
|
57
|
+
# fusion
|
|
58
|
+
'CompactMultimodalPooling',
|
|
59
|
+
'DiffusionMapsFusion',
|
|
60
|
+
'GatedMultimodalUnit',
|
|
61
|
+
'LowRankFusion',
|
|
62
|
+
# lora
|
|
63
|
+
'BasicLoRA',
|
|
64
|
+
'Conv1dLoRA',
|
|
65
|
+
'Conv2dLoRA',
|
|
66
|
+
'EmbeddingLoRA',
|
|
67
|
+
'LinearLoRA',
|
|
68
|
+
# mlp
|
|
69
|
+
'MLP',
|
|
70
|
+
# moe
|
|
71
|
+
'Expert',
|
|
72
|
+
'MoE',
|
|
73
|
+
'MoEInfo',
|
|
74
|
+
'MoEOutput',
|
|
75
|
+
# pixelshuffle
|
|
76
|
+
'PixelShuffleUpSample',
|
|
77
|
+
'UnPixelShuffleDownSample',
|
|
78
|
+
# transformer
|
|
79
|
+
'_TransformerDecoder',
|
|
80
|
+
'TransformerDecoderOutput',
|
|
81
|
+
'TransformerDenseDecoder',
|
|
82
|
+
'TransformerMoEDecoder',
|
|
83
|
+
]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from codon.base import *
|
|
2
|
+
from codon.block.embedding import BasicEmbedding
|
|
3
|
+
from codon.ops.attention import AttentionOutput, apply_attention
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MultiHeadAttention(BasicModel):
|
|
7
|
+
'''
|
|
8
|
+
Multi-Head Attention module.
|
|
9
|
+
Supports Grouped Query Attention (GQA), QK Normalization, and Gating mechanisms.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
q_proj (nn.Linear): Linear layer for query projection.
|
|
13
|
+
k_proj (nn.Linear): Linear layer for key projection.
|
|
14
|
+
v_proj (nn.Linear): Linear layer for value projection.
|
|
15
|
+
o_proj (nn.Linear): Linear layer for output projection.
|
|
16
|
+
q_norm (nn.RMSNorm, optional): Normalization layer for queries.
|
|
17
|
+
k_norm (nn.RMSNorm, optional): Normalization layer for keys.
|
|
18
|
+
g_proj (nn.Linear, optional): Linear layer for gating mechanism.
|
|
19
|
+
'''
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
hidden_size,
|
|
23
|
+
num_heads,
|
|
24
|
+
num_kv_heads=None,
|
|
25
|
+
use_qk_norm=True,
|
|
26
|
+
use_gate=False,
|
|
27
|
+
dropout=0.1,
|
|
28
|
+
is_causal=True
|
|
29
|
+
):
|
|
30
|
+
'''
|
|
31
|
+
Initialize the Multi-Head Attention module.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
hidden_size (int): Size of the hidden layer.
|
|
35
|
+
num_heads (int): Number of attention heads.
|
|
36
|
+
num_kv_heads (int, optional): Number of key/value heads for GQA.
|
|
37
|
+
If None, defaults to num_heads.
|
|
38
|
+
use_qk_norm (bool, optional): Whether to apply RMSNorm to queries and keys.
|
|
39
|
+
Defaults to True.
|
|
40
|
+
use_gate (bool, optional): Whether to apply a gating mechanism. Defaults to False.
|
|
41
|
+
dropout (float, optional): Dropout probability. Defaults to 0.1.
|
|
42
|
+
is_causal (bool, optional): Whether to apply a causal mask.
|
|
43
|
+
Defaults to True (for Decoder architectures).
|
|
44
|
+
'''
|
|
45
|
+
super(MultiHeadAttention, self).__init__()
|
|
46
|
+
|
|
47
|
+
if num_kv_heads is None: num_kv_heads = num_heads
|
|
48
|
+
|
|
49
|
+
assert hidden_size % num_heads == 0
|
|
50
|
+
assert num_heads % num_kv_heads == 0
|
|
51
|
+
|
|
52
|
+
self.hidden_size = hidden_size
|
|
53
|
+
self.num_heads = num_heads
|
|
54
|
+
self.num_kv_heads = num_kv_heads
|
|
55
|
+
self.num_kv_queries = num_heads // num_kv_heads
|
|
56
|
+
self.head_dim = hidden_size // num_heads
|
|
57
|
+
self.kv_dim = self.num_kv_heads * self.head_dim
|
|
58
|
+
self.use_qk_norm = use_qk_norm
|
|
59
|
+
self.use_gate = use_gate
|
|
60
|
+
self.dropout = dropout
|
|
61
|
+
self.is_causal = is_causal
|
|
62
|
+
|
|
63
|
+
if use_qk_norm:
|
|
64
|
+
self.q_norm = nn.RMSNorm(self.head_dim)
|
|
65
|
+
self.k_norm = nn.RMSNorm(self.head_dim)
|
|
66
|
+
|
|
67
|
+
if use_gate:
|
|
68
|
+
self.g_proj = nn.Linear(hidden_size, hidden_size)
|
|
69
|
+
|
|
70
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
|
71
|
+
self.k_proj = nn.Linear(hidden_size, self.kv_dim)
|
|
72
|
+
self.v_proj = nn.Linear(hidden_size, self.kv_dim)
|
|
73
|
+
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
74
|
+
|
|
75
|
+
def forward(
|
|
76
|
+
self,
|
|
77
|
+
hidden_states: torch.Tensor,
|
|
78
|
+
kv_states: torch.Tensor = None,
|
|
79
|
+
attention_mask: torch.Tensor = None,
|
|
80
|
+
output_attentions: bool = False,
|
|
81
|
+
position_emb: BasicEmbedding = None,
|
|
82
|
+
embedding_start: int = 0,
|
|
83
|
+
embedding_pos: torch.Tensor = None,
|
|
84
|
+
past_key_value: tuple[torch.Tensor, torch.Tensor] = None,
|
|
85
|
+
use_cache: bool = False
|
|
86
|
+
) -> AttentionOutput:
|
|
87
|
+
'''
|
|
88
|
+
Perform forward pass of Multi-Head Attention.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
hidden_states (torch.Tensor): Input hidden states.
|
|
92
|
+
kv_states (torch.Tensor, optional): Hidden states for keys/values.
|
|
93
|
+
If None, uses hidden_states. Defaults to None.
|
|
94
|
+
attention_mask (torch.Tensor, optional): Attention mask.
|
|
95
|
+
Defaults to None.
|
|
96
|
+
output_attentions (bool, optional): Whether to output attention weights.
|
|
97
|
+
Defaults to False.
|
|
98
|
+
position_emb (BasicEmbedding, optional): Positional embedding module.
|
|
99
|
+
Defaults to None.
|
|
100
|
+
embedding_start (int, optional): Starting position for embedding. Defaults to 0.
|
|
101
|
+
embedding_pos (torch.Tensor, optional): Explicit position indices for positional embedding.
|
|
102
|
+
Defaults to None.
|
|
103
|
+
past_key_value (tuple[torch.Tensor, torch.Tensor], optional): Past key-value cache.
|
|
104
|
+
Defaults to None.
|
|
105
|
+
use_cache (bool, optional): Whether to use KV cache. Defaults to False.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
AttentionOutput: Object containing output, attention weights, and KV cache.
|
|
109
|
+
'''
|
|
110
|
+
|
|
111
|
+
if kv_states is None:
|
|
112
|
+
kv_states = hidden_states
|
|
113
|
+
|
|
114
|
+
batch_size, q_len, _ = hidden_states.shape
|
|
115
|
+
kv_len_input = kv_states.shape[1]
|
|
116
|
+
|
|
117
|
+
if self.use_gate:
|
|
118
|
+
G = torch.sigmoid(self.g_proj(hidden_states))
|
|
119
|
+
|
|
120
|
+
Q = self.q_proj(hidden_states)
|
|
121
|
+
K = self.k_proj(kv_states)
|
|
122
|
+
V = self.v_proj(kv_states)
|
|
123
|
+
|
|
124
|
+
Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
125
|
+
K = K.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
126
|
+
V = V.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
127
|
+
|
|
128
|
+
if self.use_qk_norm:
|
|
129
|
+
Q = self.q_norm(Q)
|
|
130
|
+
K = self.k_norm(K)
|
|
131
|
+
|
|
132
|
+
if position_emb is not None:
|
|
133
|
+
Q = position_emb(Q, start_pos=embedding_start, positions=embedding_pos)
|
|
134
|
+
K = position_emb(K, start_pos=embedding_start, positions=embedding_pos)
|
|
135
|
+
|
|
136
|
+
current_key_value = None
|
|
137
|
+
if use_cache:
|
|
138
|
+
if past_key_value is not None:
|
|
139
|
+
past_k, past_v = past_key_value
|
|
140
|
+
K = torch.cat((past_k, K), dim=2)
|
|
141
|
+
V = torch.cat((past_v, V), dim=2)
|
|
142
|
+
current_key_value = (K, V)
|
|
143
|
+
|
|
144
|
+
kv_seq_len_total = K.shape[2]
|
|
145
|
+
|
|
146
|
+
if self.num_kv_queries > 1:
|
|
147
|
+
# [B, H_kv, 1, L, D] -> [B, H_kv, G, L, D]
|
|
148
|
+
K = K[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
|
|
149
|
+
V = V[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
|
|
150
|
+
|
|
151
|
+
K = K.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
|
|
152
|
+
V = V.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
|
|
153
|
+
|
|
154
|
+
attn_output = apply_attention(
|
|
155
|
+
Q, K, V,
|
|
156
|
+
attention_mask=attention_mask,
|
|
157
|
+
output_attentions=output_attentions,
|
|
158
|
+
is_causal=self.is_causal,
|
|
159
|
+
dropout=self.dropout if self.training else 0.0
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
output = attn_output.output
|
|
163
|
+
attention_weights = attn_output.attention_weights
|
|
164
|
+
output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
|
|
165
|
+
output = self.o_proj(output)
|
|
166
|
+
|
|
167
|
+
if self.use_gate: output = output * G
|
|
168
|
+
|
|
169
|
+
return AttentionOutput(
|
|
170
|
+
output=output,
|
|
171
|
+
attention_weights=attention_weights,
|
|
172
|
+
past_key_value=current_key_value
|
|
173
|
+
)
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from typing import Optional, Dict
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
from codon.base import BasicModel
|
|
10
|
+
from codon.ops.bio import (
|
|
11
|
+
hebbian_update,
|
|
12
|
+
oja_update,
|
|
13
|
+
bcm_update,
|
|
14
|
+
covariance_update,
|
|
15
|
+
instar_update,
|
|
16
|
+
synaptic_scaling_update,
|
|
17
|
+
vogels_sprekeler_update,
|
|
18
|
+
reward_modulated_hebbian_update,
|
|
19
|
+
rate_based_stdp_update,
|
|
20
|
+
eligibility_trace_update
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class HebianOutput:
|
|
25
|
+
'''
|
|
26
|
+
Output of the Hebian layer.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
output_tensor (torch.Tensor): The output/activation of the layer.
|
|
30
|
+
weight_updates (Dict[str, torch.Tensor]): A dictionary containing weight updates
|
|
31
|
+
for synapses.
|
|
32
|
+
'''
|
|
33
|
+
output_tensor: torch.Tensor
|
|
34
|
+
weight_updates: Dict[str, torch.Tensor]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Hebian(BasicModel):
|
|
38
|
+
'''
|
|
39
|
+
A layer implementing various biologically plausible Hebbian learning rules.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
weight (nn.Parameter): Forward synaptic weights.
|
|
43
|
+
bias (nn.Parameter, optional): Bias term for the forward activation.
|
|
44
|
+
'''
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
in_features: int,
|
|
49
|
+
out_features: int,
|
|
50
|
+
learning_rate: float = 0.01,
|
|
51
|
+
rule: str = 'oja',
|
|
52
|
+
use_bias: bool = True,
|
|
53
|
+
auto_update: bool = False,
|
|
54
|
+
bcm_momentum: float = 0.1,
|
|
55
|
+
target_rate: float = 0.1,
|
|
56
|
+
trace_decay: float = 0.9,
|
|
57
|
+
activation: str = 'linear'
|
|
58
|
+
) -> None:
|
|
59
|
+
'''
|
|
60
|
+
Initializes the Hebian layer.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
in_features (int): Dimension of the input data.
|
|
64
|
+
out_features (int): Dimension of the output representation.
|
|
65
|
+
learning_rate (float, optional): Synaptic plasticity learning rate. Defaults to 0.01.
|
|
66
|
+
rule (str, optional): Learning rule ('hebbian', 'oja', 'bcm', 'covariance', 'instar',
|
|
67
|
+
'scaling', 'vogels', 'reward_hebb', 'stdp', 'eligibility'). Defaults to 'oja'.
|
|
68
|
+
use_bias (bool, optional): Whether to use a bias term. Defaults to True.
|
|
69
|
+
auto_update (bool, optional): Automatically apply calculated weight updates in forward. Defaults to False.
|
|
70
|
+
bcm_momentum (float, optional): Momentum for BCM sliding threshold. Defaults to 0.1.
|
|
71
|
+
target_rate (float, optional): Desired average firing rate for homeostasis. Defaults to 0.1.
|
|
72
|
+
trace_decay (float, optional): Decay factor for eligibility trace. Defaults to 0.9.
|
|
73
|
+
activation (str, optional): Activation function ('linear', 'relu', 'sigmoid', 'tanh'). Defaults to 'linear'.
|
|
74
|
+
'''
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.in_features = in_features
|
|
77
|
+
self.out_features = out_features
|
|
78
|
+
self.learning_rate = learning_rate
|
|
79
|
+
self.rule = rule.lower()
|
|
80
|
+
self.use_bias = use_bias
|
|
81
|
+
self.auto_update = auto_update
|
|
82
|
+
self.bcm_momentum = bcm_momentum
|
|
83
|
+
self.target_rate = target_rate
|
|
84
|
+
self.trace_decay = trace_decay
|
|
85
|
+
self.activation = activation.lower()
|
|
86
|
+
|
|
87
|
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
|
|
88
|
+
|
|
89
|
+
if self.use_bias:
|
|
90
|
+
self.bias = nn.Parameter(torch.Tensor(out_features), requires_grad=False)
|
|
91
|
+
else:
|
|
92
|
+
self.register_parameter('bias', None)
|
|
93
|
+
|
|
94
|
+
if self.rule == 'bcm':
|
|
95
|
+
self.register_buffer('bcm_threshold', torch.zeros(out_features))
|
|
96
|
+
else:
|
|
97
|
+
self.bcm_threshold = None
|
|
98
|
+
|
|
99
|
+
if self.rule == 'stdp':
|
|
100
|
+
self.register_buffer('prev_input', None)
|
|
101
|
+
self.register_buffer('prev_state', None)
|
|
102
|
+
|
|
103
|
+
if self.rule == 'eligibility':
|
|
104
|
+
self.register_buffer('eligibility_trace', None)
|
|
105
|
+
|
|
106
|
+
self.reset_parameters()
|
|
107
|
+
|
|
108
|
+
def reset_parameters(self) -> None:
|
|
109
|
+
'''
|
|
110
|
+
Resets all synaptic parameters using Kaiming/Uniform initializations.
|
|
111
|
+
'''
|
|
112
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
113
|
+
|
|
114
|
+
if self.bias is not None:
|
|
115
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
116
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
117
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
118
|
+
|
|
119
|
+
def _apply_activation(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
120
|
+
'''
|
|
121
|
+
Applies the selected activation function.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
input_tensor (torch.Tensor): Input tensor.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
torch.Tensor: Activated tensor.
|
|
128
|
+
'''
|
|
129
|
+
if self.activation == 'relu':
|
|
130
|
+
return F.relu(input_tensor)
|
|
131
|
+
elif self.activation == 'sigmoid':
|
|
132
|
+
return torch.sigmoid(input_tensor)
|
|
133
|
+
elif self.activation == 'tanh':
|
|
134
|
+
return torch.tanh(input_tensor)
|
|
135
|
+
return input_tensor
|
|
136
|
+
|
|
137
|
+
@torch.no_grad()
|
|
138
|
+
def forward(self, input_tensor: torch.Tensor, reward: Optional[torch.Tensor] = None) -> HebianOutput:
|
|
139
|
+
'''
|
|
140
|
+
Calculates the forward pass and biological synaptic updates.
|
|
141
|
+
|
|
142
|
+
Note:
|
|
143
|
+
Decorated with @torch.no_grad() to block global backpropagation,
|
|
144
|
+
as learning is done via local plasticity rules.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
input_tensor (torch.Tensor): The input data with shape (batch_size, in_features).
|
|
148
|
+
reward (Optional[torch.Tensor], optional): Global reward signal for 'reward_hebb' or 'eligibility' rules. Defaults to None.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
HebianOutput: Output containing the activation state and weight updates.
|
|
152
|
+
'''
|
|
153
|
+
r_state = F.linear(input_tensor, self.weight, self.bias)
|
|
154
|
+
r_state = self._apply_activation(r_state)
|
|
155
|
+
|
|
156
|
+
updates: Dict[str, torch.Tensor] = {}
|
|
157
|
+
|
|
158
|
+
if self.rule == 'hebbian':
|
|
159
|
+
updates['weight'] = hebbian_update(self.weight, input_tensor, r_state, self.learning_rate)
|
|
160
|
+
elif self.rule == 'oja':
|
|
161
|
+
updates['weight'] = oja_update(self.weight, input_tensor, r_state, self.learning_rate)
|
|
162
|
+
elif self.rule == 'bcm':
|
|
163
|
+
updates['weight'] = bcm_update(self.weight, input_tensor, r_state, self.bcm_threshold, self.learning_rate)
|
|
164
|
+
# Update sliding threshold: E[y^2]
|
|
165
|
+
current_y2 = torch.mean(r_state ** 2, dim=0)
|
|
166
|
+
self.bcm_threshold.mul_(1 - self.bcm_momentum).add_(current_y2, alpha=self.bcm_momentum)
|
|
167
|
+
elif self.rule == 'covariance':
|
|
168
|
+
updates['weight'] = covariance_update(self.weight, input_tensor, r_state, self.learning_rate)
|
|
169
|
+
elif self.rule == 'instar':
|
|
170
|
+
updates['weight'] = instar_update(self.weight, input_tensor, r_state, self.learning_rate)
|
|
171
|
+
elif self.rule == 'scaling':
|
|
172
|
+
updates['weight'] = synaptic_scaling_update(self.weight, r_state, target_rate=self.target_rate, learning_rate=self.learning_rate)
|
|
173
|
+
elif self.rule == 'vogels':
|
|
174
|
+
updates['weight'] = vogels_sprekeler_update(input_tensor, r_state, target_rate=self.target_rate, learning_rate=self.learning_rate)
|
|
175
|
+
elif self.rule == 'reward_hebb':
|
|
176
|
+
if reward is None:
|
|
177
|
+
raise ValueError("The 'reward_hebb' rule requires a reward signal to be passed to forward().")
|
|
178
|
+
updates['weight'] = reward_modulated_hebbian_update(input_tensor, r_state, reward, self.learning_rate)
|
|
179
|
+
elif self.rule == 'stdp':
|
|
180
|
+
if getattr(self, 'prev_input', None) is None or self.prev_input.shape != input_tensor.shape:
|
|
181
|
+
self.prev_input = input_tensor.clone().detach()
|
|
182
|
+
self.prev_state = r_state.clone().detach()
|
|
183
|
+
updates['weight'] = torch.zeros_like(self.weight)
|
|
184
|
+
else:
|
|
185
|
+
updates['weight'] = rate_based_stdp_update(input_tensor, self.prev_input, r_state, self.prev_state, self.learning_rate)
|
|
186
|
+
self.prev_input = input_tensor.clone().detach()
|
|
187
|
+
self.prev_state = r_state.clone().detach()
|
|
188
|
+
elif self.rule == 'eligibility':
|
|
189
|
+
current_hebbian = hebbian_update(self.weight, input_tensor, r_state, learning_rate=1.0)
|
|
190
|
+
|
|
191
|
+
if getattr(self, 'eligibility_trace', None) is None or self.eligibility_trace.shape != self.weight.shape:
|
|
192
|
+
self.eligibility_trace = torch.zeros_like(self.weight)
|
|
193
|
+
|
|
194
|
+
self.eligibility_trace = self.eligibility_trace * self.trace_decay + current_hebbian
|
|
195
|
+
|
|
196
|
+
if reward is not None:
|
|
197
|
+
updates['weight'] = eligibility_trace_update(self.eligibility_trace, reward, self.learning_rate)
|
|
198
|
+
else:
|
|
199
|
+
updates['weight'] = torch.zeros_like(self.weight)
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError(f'Unsupported learning rule: {self.rule}')
|
|
202
|
+
|
|
203
|
+
if self.auto_update:
|
|
204
|
+
self.apply_updates(updates)
|
|
205
|
+
|
|
206
|
+
return HebianOutput(
|
|
207
|
+
output_tensor=r_state,
|
|
208
|
+
weight_updates=updates
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
@torch.no_grad()
|
|
212
|
+
def apply_updates(self, updates: Dict[str, torch.Tensor]) -> None:
|
|
213
|
+
'''
|
|
214
|
+
Applies the calculated weight updates to the layer's parameters in-place.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
updates (Dict[str, torch.Tensor]): A dictionary of parameter names and their updates.
|
|
218
|
+
'''
|
|
219
|
+
if 'weight' in updates and self.weight is not None:
|
|
220
|
+
self.weight.add_(updates['weight'])
|
|
221
|
+
|
|
222
|
+
if 'bias' in updates and getattr(self, 'bias', None) is not None:
|
|
223
|
+
self.bias.add_(updates['bias'])
|