codon-model 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {codon_model-0.0.1/codon_model.egg-info → codon_model-0.0.2}/PKG-INFO +4 -1
  2. {codon_model-0.0.1 → codon_model-0.0.2}/codon/__init__.py +1 -1
  3. codon_model-0.0.2/codon/block/__init__.py +83 -0
  4. codon_model-0.0.2/codon/block/attention.py +173 -0
  5. codon_model-0.0.2/codon/block/bio/hebian.py +223 -0
  6. codon_model-0.0.2/codon/block/bio/predictive.py +377 -0
  7. codon_model-0.0.2/codon/block/codebook.py +142 -0
  8. codon_model-0.0.2/codon/block/conv.py +571 -0
  9. codon_model-0.0.2/codon/block/embedding.py +350 -0
  10. codon_model-0.0.2/codon/block/film.py +184 -0
  11. codon_model-0.0.2/codon/block/fusion.py +348 -0
  12. codon_model-0.0.2/codon/block/lora.py +870 -0
  13. codon_model-0.0.2/codon/block/mlp.py +88 -0
  14. codon_model-0.0.2/codon/block/moe.py +239 -0
  15. codon_model-0.0.2/codon/block/pixelshuffle.py +333 -0
  16. codon_model-0.0.2/codon/block/transformer.py +346 -0
  17. codon_model-0.0.2/codon/kit/__init__.py +6 -0
  18. codon_model-0.0.2/codon/kit/auto_vision_train.py +171 -0
  19. codon_model-0.0.2/codon/model/motif/__init__.py +14 -0
  20. codon_model-0.0.2/codon/model/motif/motif_v1.py +447 -0
  21. codon_model-0.0.2/codon/model/resnet.py +273 -0
  22. {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/tcn.py +1 -1
  23. codon_model-0.0.2/codon/ops/__init__.py +38 -0
  24. codon_model-0.0.2/codon/ops/bio.py +309 -0
  25. codon_model-0.0.2/codon/ops/pixelshuffle.py +92 -0
  26. codon_model-0.0.2/codon/utils/dataset/__init__.py +19 -0
  27. codon_model-0.0.2/codon/utils/dataset/corpus.py +630 -0
  28. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/flatdata.py +4 -4
  29. codon_model-0.0.2/codon/utils/dataset/image.py +449 -0
  30. codon_model-0.0.2/codon/utils/split.py +133 -0
  31. {codon_model-0.0.1 → codon_model-0.0.2/codon_model.egg-info}/PKG-INFO +4 -1
  32. {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/SOURCES.txt +24 -1
  33. {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/requires.txt +3 -0
  34. codon_model-0.0.2/test/test_motifv1_train.py +100 -0
  35. codon_model-0.0.1/codon/model/motif/__init__.py +0 -1
  36. codon_model-0.0.1/codon/ops/__init__.py +0 -3
  37. codon_model-0.0.1/codon/utils/dataset/__init__.py +0 -3
  38. codon_model-0.0.1/codon/utils/dataset/corpus.py +0 -478
  39. {codon_model-0.0.1 → codon_model-0.0.2}/LICENSE +0 -0
  40. {codon_model-0.0.1 → codon_model-0.0.2}/codon/base.py +0 -0
  41. {codon_model-0.0.1/codon/exp → codon_model-0.0.2/codon/block/bio}/__init__.py +0 -0
  42. {codon_model-0.0.1/codon/model → codon_model-0.0.2/codon/exp}/__init__.py +0 -0
  43. {codon_model-0.0.1 → codon_model-0.0.2}/codon/exp/moe.py +0 -0
  44. {codon_model-0.0.1/codon/utils → codon_model-0.0.2/codon/model}/__init__.py +0 -0
  45. {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/motif/motif_a1.py +0 -0
  46. {codon_model-0.0.1 → codon_model-0.0.2}/codon/model/patch_disc.py +0 -0
  47. {codon_model-0.0.1 → codon_model-0.0.2}/codon/ops/attention.py +0 -0
  48. /codon_model-0.0.1/codon/ops/bio.py → /codon_model-0.0.2/codon/utils/__init__.py +0 -0
  49. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/base.py +0 -0
  50. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/dataset/dataviewer.py +0 -0
  51. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/mask.py +0 -0
  52. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/safecode.py +0 -0
  53. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/seed.py +0 -0
  54. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/theta.py +0 -0
  55. {codon_model-0.0.1 → codon_model-0.0.2}/codon/utils/token.py +0 -0
  56. {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/dependency_links.txt +0 -0
  57. {codon_model-0.0.1 → codon_model-0.0.2}/codon_model.egg-info/top_level.txt +0 -0
  58. {codon_model-0.0.1 → codon_model-0.0.2}/setup.cfg +0 -0
  59. {codon_model-0.0.1 → codon_model-0.0.2}/setup.py +0 -0
@@ -1,15 +1,18 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: codon-model
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Codon model package
5
5
  Author: CodonTeam
6
6
  Requires-Python: >=3.8
7
7
  License-File: LICENSE
8
8
  Requires-Dist: torch
9
+ Requires-Dist: torchvision
9
10
  Requires-Dist: transformers
10
11
  Requires-Dist: tokenizers
11
12
  Requires-Dist: pandas
13
+ Requires-Dist: numpy
12
14
  Requires-Dist: pyarrow
15
+ Requires-Dist: pillow
13
16
  Dynamic: author
14
17
  Dynamic: license-file
15
18
  Dynamic: requires-dist
@@ -1,5 +1,5 @@
1
1
  from typing import Optional
2
2
 
3
- __version__ = '0.0.1'
3
+ __version__ = '0.0.2'
4
4
 
5
5
  __seed__: Optional[int] = None
@@ -0,0 +1,83 @@
1
+ from .attention import AttentionOutput, MultiHeadAttention
2
+ from .codebook import LookupFreeQuantization, LookupFreeQuantizationOutput
3
+ from .conv import (
4
+ CausalConv1d,
5
+ ConvBlock,
6
+ DepthwiseSeparableConv,
7
+ ResBasicBlock,
8
+ calculate_causal_layer,
9
+ )
10
+ from .embedding import (
11
+ BasicEmbedding,
12
+ BasicRotaryEmbedding,
13
+ InterleavedRotaryEmbedding,
14
+ RotaryEmbedding,
15
+ SinusoidalEmbedding,
16
+ )
17
+ from .film import FiLM, FiLMOutput
18
+ from .fusion import (
19
+ CompactMultimodalPooling,
20
+ DiffusionMapsFusion,
21
+ GatedMultimodalUnit,
22
+ LowRankFusion,
23
+ )
24
+ from .lora import BasicLoRA, Conv1dLoRA, Conv2dLoRA, EmbeddingLoRA, LinearLoRA
25
+ from .mlp import MLP
26
+ from .moe import Expert, MoE, MoEInfo, MoEOutput
27
+ from .pixelshuffle import PixelShuffleUpSample, UnPixelShuffleDownSample
28
+ from .transformer import (
29
+ TransformerDecoderOutput,
30
+ TransformerDenseDecoder,
31
+ TransformerMoEDecoder,
32
+ _TransformerDecoder,
33
+ )
34
+
35
+ __all__ = [
36
+ # attention
37
+ 'AttentionOutput',
38
+ 'MultiHeadAttention',
39
+ # codebook
40
+ 'LookupFreeQuantization',
41
+ 'LookupFreeQuantizationOutput',
42
+ # conv
43
+ 'CausalConv1d',
44
+ 'ConvBlock',
45
+ 'DepthwiseSeparableConv',
46
+ 'ResBasicBlock',
47
+ 'calculate_causal_layer',
48
+ # embedding
49
+ 'BasicEmbedding',
50
+ 'BasicRotaryEmbedding',
51
+ 'InterleavedRotaryEmbedding',
52
+ 'RotaryEmbedding',
53
+ 'SinusoidalEmbedding',
54
+ # film
55
+ 'FiLM',
56
+ 'FiLMOutput',
57
+ # fusion
58
+ 'CompactMultimodalPooling',
59
+ 'DiffusionMapsFusion',
60
+ 'GatedMultimodalUnit',
61
+ 'LowRankFusion',
62
+ # lora
63
+ 'BasicLoRA',
64
+ 'Conv1dLoRA',
65
+ 'Conv2dLoRA',
66
+ 'EmbeddingLoRA',
67
+ 'LinearLoRA',
68
+ # mlp
69
+ 'MLP',
70
+ # moe
71
+ 'Expert',
72
+ 'MoE',
73
+ 'MoEInfo',
74
+ 'MoEOutput',
75
+ # pixelshuffle
76
+ 'PixelShuffleUpSample',
77
+ 'UnPixelShuffleDownSample',
78
+ # transformer
79
+ '_TransformerDecoder',
80
+ 'TransformerDecoderOutput',
81
+ 'TransformerDenseDecoder',
82
+ 'TransformerMoEDecoder',
83
+ ]
@@ -0,0 +1,173 @@
1
+ from codon.base import *
2
+ from codon.block.embedding import BasicEmbedding
3
+ from codon.ops.attention import AttentionOutput, apply_attention
4
+
5
+
6
+ class MultiHeadAttention(BasicModel):
7
+ '''
8
+ Multi-Head Attention module.
9
+ Supports Grouped Query Attention (GQA), QK Normalization, and Gating mechanisms.
10
+
11
+ Attributes:
12
+ q_proj (nn.Linear): Linear layer for query projection.
13
+ k_proj (nn.Linear): Linear layer for key projection.
14
+ v_proj (nn.Linear): Linear layer for value projection.
15
+ o_proj (nn.Linear): Linear layer for output projection.
16
+ q_norm (nn.RMSNorm, optional): Normalization layer for queries.
17
+ k_norm (nn.RMSNorm, optional): Normalization layer for keys.
18
+ g_proj (nn.Linear, optional): Linear layer for gating mechanism.
19
+ '''
20
+ def __init__(
21
+ self,
22
+ hidden_size,
23
+ num_heads,
24
+ num_kv_heads=None,
25
+ use_qk_norm=True,
26
+ use_gate=False,
27
+ dropout=0.1,
28
+ is_causal=True
29
+ ):
30
+ '''
31
+ Initialize the Multi-Head Attention module.
32
+
33
+ Args:
34
+ hidden_size (int): Size of the hidden layer.
35
+ num_heads (int): Number of attention heads.
36
+ num_kv_heads (int, optional): Number of key/value heads for GQA.
37
+ If None, defaults to num_heads.
38
+ use_qk_norm (bool, optional): Whether to apply RMSNorm to queries and keys.
39
+ Defaults to True.
40
+ use_gate (bool, optional): Whether to apply a gating mechanism. Defaults to False.
41
+ dropout (float, optional): Dropout probability. Defaults to 0.1.
42
+ is_causal (bool, optional): Whether to apply a causal mask.
43
+ Defaults to True (for Decoder architectures).
44
+ '''
45
+ super(MultiHeadAttention, self).__init__()
46
+
47
+ if num_kv_heads is None: num_kv_heads = num_heads
48
+
49
+ assert hidden_size % num_heads == 0
50
+ assert num_heads % num_kv_heads == 0
51
+
52
+ self.hidden_size = hidden_size
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_queries = num_heads // num_kv_heads
56
+ self.head_dim = hidden_size // num_heads
57
+ self.kv_dim = self.num_kv_heads * self.head_dim
58
+ self.use_qk_norm = use_qk_norm
59
+ self.use_gate = use_gate
60
+ self.dropout = dropout
61
+ self.is_causal = is_causal
62
+
63
+ if use_qk_norm:
64
+ self.q_norm = nn.RMSNorm(self.head_dim)
65
+ self.k_norm = nn.RMSNorm(self.head_dim)
66
+
67
+ if use_gate:
68
+ self.g_proj = nn.Linear(hidden_size, hidden_size)
69
+
70
+ self.q_proj = nn.Linear(hidden_size, hidden_size)
71
+ self.k_proj = nn.Linear(hidden_size, self.kv_dim)
72
+ self.v_proj = nn.Linear(hidden_size, self.kv_dim)
73
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ kv_states: torch.Tensor = None,
79
+ attention_mask: torch.Tensor = None,
80
+ output_attentions: bool = False,
81
+ position_emb: BasicEmbedding = None,
82
+ embedding_start: int = 0,
83
+ embedding_pos: torch.Tensor = None,
84
+ past_key_value: tuple[torch.Tensor, torch.Tensor] = None,
85
+ use_cache: bool = False
86
+ ) -> AttentionOutput:
87
+ '''
88
+ Perform forward pass of Multi-Head Attention.
89
+
90
+ Args:
91
+ hidden_states (torch.Tensor): Input hidden states.
92
+ kv_states (torch.Tensor, optional): Hidden states for keys/values.
93
+ If None, uses hidden_states. Defaults to None.
94
+ attention_mask (torch.Tensor, optional): Attention mask.
95
+ Defaults to None.
96
+ output_attentions (bool, optional): Whether to output attention weights.
97
+ Defaults to False.
98
+ position_emb (BasicEmbedding, optional): Positional embedding module.
99
+ Defaults to None.
100
+ embedding_start (int, optional): Starting position for embedding. Defaults to 0.
101
+ embedding_pos (torch.Tensor, optional): Explicit position indices for positional embedding.
102
+ Defaults to None.
103
+ past_key_value (tuple[torch.Tensor, torch.Tensor], optional): Past key-value cache.
104
+ Defaults to None.
105
+ use_cache (bool, optional): Whether to use KV cache. Defaults to False.
106
+
107
+ Returns:
108
+ AttentionOutput: Object containing output, attention weights, and KV cache.
109
+ '''
110
+
111
+ if kv_states is None:
112
+ kv_states = hidden_states
113
+
114
+ batch_size, q_len, _ = hidden_states.shape
115
+ kv_len_input = kv_states.shape[1]
116
+
117
+ if self.use_gate:
118
+ G = torch.sigmoid(self.g_proj(hidden_states))
119
+
120
+ Q = self.q_proj(hidden_states)
121
+ K = self.k_proj(kv_states)
122
+ V = self.v_proj(kv_states)
123
+
124
+ Q = Q.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
125
+ K = K.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
126
+ V = V.view(batch_size, kv_len_input, self.num_kv_heads, self.head_dim).transpose(1, 2)
127
+
128
+ if self.use_qk_norm:
129
+ Q = self.q_norm(Q)
130
+ K = self.k_norm(K)
131
+
132
+ if position_emb is not None:
133
+ Q = position_emb(Q, start_pos=embedding_start, positions=embedding_pos)
134
+ K = position_emb(K, start_pos=embedding_start, positions=embedding_pos)
135
+
136
+ current_key_value = None
137
+ if use_cache:
138
+ if past_key_value is not None:
139
+ past_k, past_v = past_key_value
140
+ K = torch.cat((past_k, K), dim=2)
141
+ V = torch.cat((past_v, V), dim=2)
142
+ current_key_value = (K, V)
143
+
144
+ kv_seq_len_total = K.shape[2]
145
+
146
+ if self.num_kv_queries > 1:
147
+ # [B, H_kv, 1, L, D] -> [B, H_kv, G, L, D]
148
+ K = K[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
149
+ V = V[:, :, None, :, :].expand(batch_size, self.num_kv_heads, self.num_kv_queries, kv_seq_len_total, self.head_dim)
150
+
151
+ K = K.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
152
+ V = V.reshape(batch_size, self.num_heads, kv_seq_len_total, self.head_dim)
153
+
154
+ attn_output = apply_attention(
155
+ Q, K, V,
156
+ attention_mask=attention_mask,
157
+ output_attentions=output_attentions,
158
+ is_causal=self.is_causal,
159
+ dropout=self.dropout if self.training else 0.0
160
+ )
161
+
162
+ output = attn_output.output
163
+ attention_weights = attn_output.attention_weights
164
+ output = output.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_size)
165
+ output = self.o_proj(output)
166
+
167
+ if self.use_gate: output = output * G
168
+
169
+ return AttentionOutput(
170
+ output=output,
171
+ attention_weights=attention_weights,
172
+ past_key_value=current_key_value
173
+ )
@@ -0,0 +1,223 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from typing import Optional, Dict
7
+ from dataclasses import dataclass
8
+
9
+ from codon.base import BasicModel
10
+ from codon.ops.bio import (
11
+ hebbian_update,
12
+ oja_update,
13
+ bcm_update,
14
+ covariance_update,
15
+ instar_update,
16
+ synaptic_scaling_update,
17
+ vogels_sprekeler_update,
18
+ reward_modulated_hebbian_update,
19
+ rate_based_stdp_update,
20
+ eligibility_trace_update
21
+ )
22
+
23
+ @dataclass
24
+ class HebianOutput:
25
+ '''
26
+ Output of the Hebian layer.
27
+
28
+ Attributes:
29
+ output_tensor (torch.Tensor): The output/activation of the layer.
30
+ weight_updates (Dict[str, torch.Tensor]): A dictionary containing weight updates
31
+ for synapses.
32
+ '''
33
+ output_tensor: torch.Tensor
34
+ weight_updates: Dict[str, torch.Tensor]
35
+
36
+
37
+ class Hebian(BasicModel):
38
+ '''
39
+ A layer implementing various biologically plausible Hebbian learning rules.
40
+
41
+ Attributes:
42
+ weight (nn.Parameter): Forward synaptic weights.
43
+ bias (nn.Parameter, optional): Bias term for the forward activation.
44
+ '''
45
+
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ out_features: int,
50
+ learning_rate: float = 0.01,
51
+ rule: str = 'oja',
52
+ use_bias: bool = True,
53
+ auto_update: bool = False,
54
+ bcm_momentum: float = 0.1,
55
+ target_rate: float = 0.1,
56
+ trace_decay: float = 0.9,
57
+ activation: str = 'linear'
58
+ ) -> None:
59
+ '''
60
+ Initializes the Hebian layer.
61
+
62
+ Args:
63
+ in_features (int): Dimension of the input data.
64
+ out_features (int): Dimension of the output representation.
65
+ learning_rate (float, optional): Synaptic plasticity learning rate. Defaults to 0.01.
66
+ rule (str, optional): Learning rule ('hebbian', 'oja', 'bcm', 'covariance', 'instar',
67
+ 'scaling', 'vogels', 'reward_hebb', 'stdp', 'eligibility'). Defaults to 'oja'.
68
+ use_bias (bool, optional): Whether to use a bias term. Defaults to True.
69
+ auto_update (bool, optional): Automatically apply calculated weight updates in forward. Defaults to False.
70
+ bcm_momentum (float, optional): Momentum for BCM sliding threshold. Defaults to 0.1.
71
+ target_rate (float, optional): Desired average firing rate for homeostasis. Defaults to 0.1.
72
+ trace_decay (float, optional): Decay factor for eligibility trace. Defaults to 0.9.
73
+ activation (str, optional): Activation function ('linear', 'relu', 'sigmoid', 'tanh'). Defaults to 'linear'.
74
+ '''
75
+ super().__init__()
76
+ self.in_features = in_features
77
+ self.out_features = out_features
78
+ self.learning_rate = learning_rate
79
+ self.rule = rule.lower()
80
+ self.use_bias = use_bias
81
+ self.auto_update = auto_update
82
+ self.bcm_momentum = bcm_momentum
83
+ self.target_rate = target_rate
84
+ self.trace_decay = trace_decay
85
+ self.activation = activation.lower()
86
+
87
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
88
+
89
+ if self.use_bias:
90
+ self.bias = nn.Parameter(torch.Tensor(out_features), requires_grad=False)
91
+ else:
92
+ self.register_parameter('bias', None)
93
+
94
+ if self.rule == 'bcm':
95
+ self.register_buffer('bcm_threshold', torch.zeros(out_features))
96
+ else:
97
+ self.bcm_threshold = None
98
+
99
+ if self.rule == 'stdp':
100
+ self.register_buffer('prev_input', None)
101
+ self.register_buffer('prev_state', None)
102
+
103
+ if self.rule == 'eligibility':
104
+ self.register_buffer('eligibility_trace', None)
105
+
106
+ self.reset_parameters()
107
+
108
+ def reset_parameters(self) -> None:
109
+ '''
110
+ Resets all synaptic parameters using Kaiming/Uniform initializations.
111
+ '''
112
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
113
+
114
+ if self.bias is not None:
115
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
116
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
117
+ nn.init.uniform_(self.bias, -bound, bound)
118
+
119
+ def _apply_activation(self, input_tensor: torch.Tensor) -> torch.Tensor:
120
+ '''
121
+ Applies the selected activation function.
122
+
123
+ Args:
124
+ input_tensor (torch.Tensor): Input tensor.
125
+
126
+ Returns:
127
+ torch.Tensor: Activated tensor.
128
+ '''
129
+ if self.activation == 'relu':
130
+ return F.relu(input_tensor)
131
+ elif self.activation == 'sigmoid':
132
+ return torch.sigmoid(input_tensor)
133
+ elif self.activation == 'tanh':
134
+ return torch.tanh(input_tensor)
135
+ return input_tensor
136
+
137
+ @torch.no_grad()
138
+ def forward(self, input_tensor: torch.Tensor, reward: Optional[torch.Tensor] = None) -> HebianOutput:
139
+ '''
140
+ Calculates the forward pass and biological synaptic updates.
141
+
142
+ Note:
143
+ Decorated with @torch.no_grad() to block global backpropagation,
144
+ as learning is done via local plasticity rules.
145
+
146
+ Args:
147
+ input_tensor (torch.Tensor): The input data with shape (batch_size, in_features).
148
+ reward (Optional[torch.Tensor], optional): Global reward signal for 'reward_hebb' or 'eligibility' rules. Defaults to None.
149
+
150
+ Returns:
151
+ HebianOutput: Output containing the activation state and weight updates.
152
+ '''
153
+ r_state = F.linear(input_tensor, self.weight, self.bias)
154
+ r_state = self._apply_activation(r_state)
155
+
156
+ updates: Dict[str, torch.Tensor] = {}
157
+
158
+ if self.rule == 'hebbian':
159
+ updates['weight'] = hebbian_update(self.weight, input_tensor, r_state, self.learning_rate)
160
+ elif self.rule == 'oja':
161
+ updates['weight'] = oja_update(self.weight, input_tensor, r_state, self.learning_rate)
162
+ elif self.rule == 'bcm':
163
+ updates['weight'] = bcm_update(self.weight, input_tensor, r_state, self.bcm_threshold, self.learning_rate)
164
+ # Update sliding threshold: E[y^2]
165
+ current_y2 = torch.mean(r_state ** 2, dim=0)
166
+ self.bcm_threshold.mul_(1 - self.bcm_momentum).add_(current_y2, alpha=self.bcm_momentum)
167
+ elif self.rule == 'covariance':
168
+ updates['weight'] = covariance_update(self.weight, input_tensor, r_state, self.learning_rate)
169
+ elif self.rule == 'instar':
170
+ updates['weight'] = instar_update(self.weight, input_tensor, r_state, self.learning_rate)
171
+ elif self.rule == 'scaling':
172
+ updates['weight'] = synaptic_scaling_update(self.weight, r_state, target_rate=self.target_rate, learning_rate=self.learning_rate)
173
+ elif self.rule == 'vogels':
174
+ updates['weight'] = vogels_sprekeler_update(input_tensor, r_state, target_rate=self.target_rate, learning_rate=self.learning_rate)
175
+ elif self.rule == 'reward_hebb':
176
+ if reward is None:
177
+ raise ValueError("The 'reward_hebb' rule requires a reward signal to be passed to forward().")
178
+ updates['weight'] = reward_modulated_hebbian_update(input_tensor, r_state, reward, self.learning_rate)
179
+ elif self.rule == 'stdp':
180
+ if getattr(self, 'prev_input', None) is None or self.prev_input.shape != input_tensor.shape:
181
+ self.prev_input = input_tensor.clone().detach()
182
+ self.prev_state = r_state.clone().detach()
183
+ updates['weight'] = torch.zeros_like(self.weight)
184
+ else:
185
+ updates['weight'] = rate_based_stdp_update(input_tensor, self.prev_input, r_state, self.prev_state, self.learning_rate)
186
+ self.prev_input = input_tensor.clone().detach()
187
+ self.prev_state = r_state.clone().detach()
188
+ elif self.rule == 'eligibility':
189
+ current_hebbian = hebbian_update(self.weight, input_tensor, r_state, learning_rate=1.0)
190
+
191
+ if getattr(self, 'eligibility_trace', None) is None or self.eligibility_trace.shape != self.weight.shape:
192
+ self.eligibility_trace = torch.zeros_like(self.weight)
193
+
194
+ self.eligibility_trace = self.eligibility_trace * self.trace_decay + current_hebbian
195
+
196
+ if reward is not None:
197
+ updates['weight'] = eligibility_trace_update(self.eligibility_trace, reward, self.learning_rate)
198
+ else:
199
+ updates['weight'] = torch.zeros_like(self.weight)
200
+ else:
201
+ raise ValueError(f'Unsupported learning rule: {self.rule}')
202
+
203
+ if self.auto_update:
204
+ self.apply_updates(updates)
205
+
206
+ return HebianOutput(
207
+ output_tensor=r_state,
208
+ weight_updates=updates
209
+ )
210
+
211
+ @torch.no_grad()
212
+ def apply_updates(self, updates: Dict[str, torch.Tensor]) -> None:
213
+ '''
214
+ Applies the calculated weight updates to the layer's parameters in-place.
215
+
216
+ Args:
217
+ updates (Dict[str, torch.Tensor]): A dictionary of parameter names and their updates.
218
+ '''
219
+ if 'weight' in updates and self.weight is not None:
220
+ self.weight.add_(updates['weight'])
221
+
222
+ if 'bias' in updates and getattr(self, 'bias', None) is not None:
223
+ self.bias.add_(updates['bias'])