codon-model 0.0.3b2__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {codon_model-0.0.3b2/codon_model.egg-info → codon_model-0.0.4}/PKG-INFO +1 -1
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/__init__.py +1 -1
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/base.py +29 -3
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/__init__.py +11 -0
- codon_model-0.0.4/codon/block/bio/__init__.py +9 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/lora.py +2 -0
- codon_model-0.0.3b2/codon/exp/block/manifold_conv.py → codon_model-0.0.4/codon/block/manifold.py +255 -134
- codon_model-0.0.4/codon/exp/block/bio.py +494 -0
- codon_model-0.0.4/codon/exp/block/manifold.py +88 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/kit/train/vision.py +1 -1
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/__init__.py +4 -0
- codon_model-0.0.4/codon/ops/manifold/__init__.py +133 -0
- codon_model-0.0.4/codon/ops/manifold/conv.py +217 -0
- codon_model-0.0.3b2/codon/exp/ops/manifold_triton.py → codon_model-0.0.4/codon/ops/manifold/linear.py +6 -5
- codon_model-0.0.4/codon/utils/eval/__init__.py +24 -0
- codon_model-0.0.4/codon/utils/eval/activation.py +127 -0
- codon_model-0.0.4/codon/utils/eval/base.py +210 -0
- codon_model-0.0.4/codon/utils/eval/boundary.py +157 -0
- codon_model-0.0.4/codon/utils/eval/cka.py +191 -0
- codon_model-0.0.4/codon/utils/eval/confusion.py +77 -0
- codon_model-0.0.4/codon/utils/eval/gradcam.py +121 -0
- codon_model-0.0.4/codon/utils/eval/layer_rsa.py +130 -0
- codon_model-0.0.4/codon/utils/eval/rsa.py +103 -0
- codon_model-0.0.4/codon/utils/eval/selectivity.py +149 -0
- codon_model-0.0.4/codon/utils/eval/similarity.py +65 -0
- codon_model-0.0.4/codon/utils/eval/tsne.py +109 -0
- codon_model-0.0.4/codon/utils/info.py +137 -0
- codon_model-0.0.4/codon/utils/layer/lora.py +13 -0
- codon_model-0.0.4/codon/utils/layer/manifold.py +70 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/seed.py +8 -4
- codon_model-0.0.3b2/codon/utils/token.py → codon_model-0.0.4/codon/utils/tokens.py +12 -2
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/transforms.py +13 -6
- {codon_model-0.0.3b2 → codon_model-0.0.4/codon_model.egg-info}/PKG-INFO +1 -1
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/SOURCES.txt +26 -8
- {codon_model-0.0.3b2 → codon_model-0.0.4}/test/test_motifv1_train.py +1 -1
- codon_model-0.0.3b2/codon/exp/block/manifold.py +0 -332
- codon_model-0.0.3b2/codon/exp/ops/manifold.py +0 -63
- {codon_model-0.0.3b2 → codon_model-0.0.4}/LICENSE +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/attention.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/bio/hebian.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/bio/predictive.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/codebook.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/conv.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/embedding.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/film.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/fusion.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/mlp.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/moe.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/pixelshuffle.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/transformer.py +0 -0
- {codon_model-0.0.3b2/codon/block/bio → codon_model-0.0.4/codon/exp}/__init__.py +0 -0
- {codon_model-0.0.3b2/codon/exp → codon_model-0.0.4/codon/exp/block}/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/exp/block/moe.py +0 -0
- {codon_model-0.0.3b2/codon/exp/block → codon_model-0.0.4/codon/exp/ops}/__init__.py +0 -0
- {codon_model-0.0.3b2/codon/exp/ops → codon_model-0.0.4/codon/kit}/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/kit/train/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/patch_disc.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/resnet.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/tcn.py +0 -0
- {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/__init__.py +0 -0
- {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/base.py +0 -0
- {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/motif_a1.py +0 -0
- {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/motif_v1.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/attention.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/bio.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/pixelshuffle.py +0 -0
- {codon_model-0.0.3b2/codon/kit → codon_model-0.0.4/codon/utils}/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/base.py +0 -0
- {codon_model-0.0.3b2/codon/utils → codon_model-0.0.4/codon/utils/dataset/conflux}/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/base.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/reader.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/writer.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/corpus.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/dataviewer.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/flatdata.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/image.py +0 -0
- {codon_model-0.0.3b2/codon/utils/dataset/conflux → codon_model-0.0.4/codon/utils/layer}/__init__.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/mask.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/safecode.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/split.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/theta.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/dependency_links.txt +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/requires.txt +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/top_level.txt +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/setup.cfg +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/setup.py +0 -0
- {codon_model-0.0.3b2 → codon_model-0.0.4}/test/test_conflux_dataset.py +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
3
4
|
|
|
4
5
|
from typing import Callable, Any, Iterator, Union
|
|
5
6
|
|
|
@@ -132,7 +133,7 @@ class BasicModel(nn.Module):
|
|
|
132
133
|
|
|
133
134
|
return total
|
|
134
135
|
|
|
135
|
-
def load_pretrained(self, path: str) ->
|
|
136
|
+
def load_pretrained(self, path: str) -> 'BasicModel':
|
|
136
137
|
'''
|
|
137
138
|
Load a pretrained model from a file.
|
|
138
139
|
|
|
@@ -141,7 +142,7 @@ class BasicModel(nn.Module):
|
|
|
141
142
|
'''
|
|
142
143
|
if path.endswith('.safetensors'):
|
|
143
144
|
safe_load_model(self, path)
|
|
144
|
-
return
|
|
145
|
+
return self
|
|
145
146
|
|
|
146
147
|
state_dict = torch.load(path, map_location=self.device)
|
|
147
148
|
|
|
@@ -152,8 +153,10 @@ class BasicModel(nn.Module):
|
|
|
152
153
|
state_dict = state_dict['state_dict']
|
|
153
154
|
|
|
154
155
|
self.load_state_dict(state_dict)
|
|
156
|
+
|
|
157
|
+
return self
|
|
155
158
|
|
|
156
|
-
def save_pretrained(self, path: str) ->
|
|
159
|
+
def save_pretrained(self, path: str) -> 'BasicModel':
|
|
157
160
|
'''
|
|
158
161
|
Save the model to a file.
|
|
159
162
|
|
|
@@ -165,3 +168,26 @@ class BasicModel(nn.Module):
|
|
|
165
168
|
else:
|
|
166
169
|
state_dict = self.state_dict()
|
|
167
170
|
torch.save(state_dict, path)
|
|
171
|
+
return self
|
|
172
|
+
|
|
173
|
+
def freeze(self) -> 'BasicModel':
|
|
174
|
+
'''
|
|
175
|
+
Freeze all parameters in the model by setting requires_grad to False.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
BasicModel: The model itself for method chaining.
|
|
179
|
+
'''
|
|
180
|
+
for param in self.parameters():
|
|
181
|
+
param.requires_grad = False
|
|
182
|
+
return self
|
|
183
|
+
|
|
184
|
+
def unfreeze(self) -> 'BasicModel':
|
|
185
|
+
'''
|
|
186
|
+
Unfreeze all parameters in the model by setting requires_grad to True.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
BasicModel: The model itself for method chaining.
|
|
190
|
+
'''
|
|
191
|
+
for param in self.parameters():
|
|
192
|
+
param.requires_grad = True
|
|
193
|
+
return self
|
|
@@ -31,6 +31,11 @@ from .transformer import (
|
|
|
31
31
|
TransformerMoEDecoder,
|
|
32
32
|
_TransformerDecoder,
|
|
33
33
|
)
|
|
34
|
+
from .manifold import (
|
|
35
|
+
MainfoldLoss,
|
|
36
|
+
BasicManifoldLinear, RiemannianManifoldLinear,
|
|
37
|
+
BasicManifoldConv2d, RiemannianManifoldConv2d
|
|
38
|
+
)
|
|
34
39
|
|
|
35
40
|
__all__ = [
|
|
36
41
|
# attention
|
|
@@ -80,4 +85,10 @@ __all__ = [
|
|
|
80
85
|
'TransformerDecoderOutput',
|
|
81
86
|
'TransformerDenseDecoder',
|
|
82
87
|
'TransformerMoEDecoder',
|
|
88
|
+
# manifold
|
|
89
|
+
'MainfoldLoss',
|
|
90
|
+
'BasicManifoldLinear',
|
|
91
|
+
'RiemannianManifoldLinear',
|
|
92
|
+
'BasicManifoldConv2d',
|
|
93
|
+
'RiemannianManifoldConv2d'
|
|
83
94
|
]
|
|
@@ -50,6 +50,8 @@ class BasicLoRA(BasicModel):
|
|
|
50
50
|
super().__init__()
|
|
51
51
|
self.gradient_checkpointing = gradient_checkpointing
|
|
52
52
|
self.original_layer = original_layer
|
|
53
|
+
self.lora_dropout_p = lora_dropout
|
|
54
|
+
self.merge_weights = merge_weights
|
|
53
55
|
|
|
54
56
|
# Freeze original layer
|
|
55
57
|
for p in self.original_layer.parameters():
|
codon_model-0.0.3b2/codon/exp/block/manifold_conv.py → codon_model-0.0.4/codon/block/manifold.py
RENAMED
|
@@ -1,11 +1,231 @@
|
|
|
1
|
+
from codon.base import *
|
|
2
|
+
|
|
1
3
|
import torch.nn.functional as F
|
|
2
4
|
|
|
3
|
-
from
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from codon.ops.manifold import riemannian_manifold_linear, riemannian_manifold_conv2d
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MainfoldLoss:
|
|
13
|
+
'''
|
|
14
|
+
Dataclass for storing manifold-related loss components.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
cosine (torch.Tensor): The cosine similarity loss.
|
|
18
|
+
laplacian (torch.Tensor): The Laplacian regularization loss.
|
|
19
|
+
'''
|
|
20
|
+
cosine: torch.Tensor
|
|
21
|
+
laplacian: torch.Tensor
|
|
22
|
+
|
|
23
|
+
def factor_loss(self, factor_cos: float = 0.013, factor_lap: float = 0.012) -> torch.Tensor:
|
|
24
|
+
'''
|
|
25
|
+
Calculates the weighted sum of cosine and Laplacian losses.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
factor_cos (float): The weight factor for the cosine loss.
|
|
29
|
+
factor_lap (float): The weight factor for the Laplacian loss.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: The calculated total loss value.
|
|
33
|
+
'''
|
|
34
|
+
return self.cosine * factor_cos + self.laplacian * factor_lap
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BasicManifoldLinear(BasicModel):
|
|
38
|
+
'''
|
|
39
|
+
Base class for manifold-based neural network layers.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
in_features (int): Size of each input sample.
|
|
43
|
+
out_features (int): Size of each output sample.
|
|
44
|
+
k_neighbors (int): Number of nearest neighbors to consider for Laplacian loss.
|
|
45
|
+
weight (nn.Parameter): The learnable weights of the layer.
|
|
46
|
+
'''
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
in_features: int,
|
|
51
|
+
out_features: int,
|
|
52
|
+
k_neighbors: int = 2,
|
|
53
|
+
) -> None:
|
|
54
|
+
'''
|
|
55
|
+
Initializes the BasicManifoldLinear layer.
|
|
4
56
|
|
|
5
|
-
|
|
6
|
-
|
|
57
|
+
Args:
|
|
58
|
+
in_features (int): Size of each input sample.
|
|
59
|
+
out_features (int): Size of each output sample.
|
|
60
|
+
k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
|
|
61
|
+
'''
|
|
62
|
+
super().__init__()
|
|
7
63
|
|
|
8
|
-
|
|
64
|
+
self.in_features = in_features
|
|
65
|
+
self.out_features = out_features
|
|
66
|
+
self.k_neighbors = min(k_neighbors, out_features - 1)
|
|
67
|
+
|
|
68
|
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def loss_cosine(self) -> torch.Tensor:
|
|
72
|
+
'''
|
|
73
|
+
Calculates the cosine similarity penalty loss among the weight vectors.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
torch.Tensor: The computed cosine penalty loss.
|
|
77
|
+
'''
|
|
78
|
+
w_norm = F.normalize(self.weight, p=2, dim=1)
|
|
79
|
+
|
|
80
|
+
C = torch.matmul(w_norm, w_norm.T)
|
|
81
|
+
I = torch.eye(self.out_features, device=C.device)
|
|
82
|
+
|
|
83
|
+
return torch.sum((C * (1 - I)) ** 2) / (self.out_features * (self.out_features - 1))
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def loss_laplacian(self) -> torch.Tensor:
|
|
87
|
+
'''
|
|
88
|
+
Calculates the Laplacian regularization loss based on k-nearest neighbors.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
torch.Tensor: The computed Laplacian regularization loss.
|
|
92
|
+
'''
|
|
93
|
+
w_norm = F.normalize(self.weight, p=2, dim=1)
|
|
94
|
+
|
|
95
|
+
C = torch.matmul(w_norm, w_norm.T)
|
|
96
|
+
I = torch.eye(self.out_features, device=C.device)
|
|
97
|
+
|
|
98
|
+
_, topk_idx = torch.topk(C, self.k_neighbors + 1, dim=1)
|
|
99
|
+
|
|
100
|
+
A = torch.zeros_like(C)
|
|
101
|
+
A.scatter_(1, topk_idx, 1.0)
|
|
102
|
+
A = A - I
|
|
103
|
+
A = torch.max(A, A.T)
|
|
104
|
+
|
|
105
|
+
return torch.sum(A * (1.0 - C)) / torch.sum(A)
|
|
106
|
+
|
|
107
|
+
def compute_loss(self) -> MainfoldLoss:
|
|
108
|
+
'''
|
|
109
|
+
Computes both the cosine and Laplacian losses and returns them in a MainfoldLoss object.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
MainfoldLoss: An object containing the computed cosine and Laplacian losses.
|
|
113
|
+
'''
|
|
114
|
+
w_norm = F.normalize(self.weight, p=2, dim=1)
|
|
115
|
+
|
|
116
|
+
C = torch.matmul(w_norm, w_norm.T)
|
|
117
|
+
I = torch.eye(self.out_features, device=C.device)
|
|
118
|
+
|
|
119
|
+
loss_cos = torch.sum((C * (1 - I)) ** 2) / (self.out_features * (self.out_features - 1))
|
|
120
|
+
|
|
121
|
+
_, topk_idx = torch.topk(C, self.k_neighbors + 1, dim=1)
|
|
122
|
+
|
|
123
|
+
A = torch.zeros_like(C)
|
|
124
|
+
A.scatter_(1, topk_idx, 1.0)
|
|
125
|
+
A = A - I
|
|
126
|
+
A = torch.max(A, A.T)
|
|
127
|
+
|
|
128
|
+
loss_lap = torch.sum(A * (1.0 - C)) / torch.sum(A)
|
|
129
|
+
|
|
130
|
+
return MainfoldLoss(cosine=loss_cos, laplacian=loss_lap)
|
|
131
|
+
|
|
132
|
+
def extra_repr(self) -> str:
|
|
133
|
+
'''
|
|
134
|
+
Sets the extra representation of the module for printing.
|
|
135
|
+
'''
|
|
136
|
+
return f'in_features={self.in_features}, out_features={self.out_features}, k_neighbors={self.k_neighbors}'
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class RiemannianManifoldLinear(BasicManifoldLinear):
|
|
140
|
+
'''
|
|
141
|
+
A linear layer projecting data onto a Riemannian manifold (hypersphere).
|
|
142
|
+
|
|
143
|
+
Attributes:
|
|
144
|
+
kappa (nn.Parameter): Concentration parameter for the von Mises-Fisher (vMF) distribution.
|
|
145
|
+
lambda_rate (nn.Parameter): Gravitational attraction coefficient.
|
|
146
|
+
scale (nn.Parameter): Vector amplifier for the hyperspherical network.
|
|
147
|
+
bias (nn.Parameter): Manifold bias vector.
|
|
148
|
+
'''
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
in_features: int,
|
|
153
|
+
out_features: int,
|
|
154
|
+
kappa_init: float = 2.0,
|
|
155
|
+
lambda_init: float = 0.1,
|
|
156
|
+
scale_init: float = 15.0,
|
|
157
|
+
k_neighbors: int = 2,
|
|
158
|
+
rule: str = 'near'
|
|
159
|
+
) -> None:
|
|
160
|
+
'''
|
|
161
|
+
Initializes the RiemannianManifoldLinear layer.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
in_features (int): Size of each input sample.
|
|
165
|
+
out_features (int): Size of each output sample.
|
|
166
|
+
kappa_init (float): Initial value for the vMF concentration parameter.
|
|
167
|
+
lambda_init (float): Initial value for the gravitational attraction coefficient.
|
|
168
|
+
scale_init (float): Initial value for the vector amplifier scale.
|
|
169
|
+
k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
|
|
170
|
+
rule (str): Attraction rule, either 'near' or 'far'.
|
|
171
|
+
'''
|
|
172
|
+
super().__init__(
|
|
173
|
+
in_features=in_features,
|
|
174
|
+
out_features=out_features,
|
|
175
|
+
k_neighbors=k_neighbors
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self.kappa_init = kappa_init
|
|
179
|
+
self.lambda_init = lambda_init
|
|
180
|
+
self.scale_init = scale_init
|
|
181
|
+
self.rule = rule.lower()
|
|
182
|
+
|
|
183
|
+
if not self.rule in ['far', 'near']:
|
|
184
|
+
raise ValueError(f"Invalid rule: {self.rule}, must be 'far' or 'near'")
|
|
185
|
+
|
|
186
|
+
# Concentration parameter for the vMF distribution
|
|
187
|
+
self.kappa = nn.Parameter(torch.tensor(float(kappa_init)))
|
|
188
|
+
|
|
189
|
+
# Gravitational attraction coefficient
|
|
190
|
+
self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
|
|
191
|
+
|
|
192
|
+
# Vector amplifier for the hyperspherical network
|
|
193
|
+
self.scale = nn.Parameter(torch.ones(out_features) * scale_init)
|
|
194
|
+
|
|
195
|
+
# Manifold bias vector
|
|
196
|
+
self.bias = nn.Parameter(torch.zeros(out_features))
|
|
197
|
+
|
|
198
|
+
self.reset_parameters()
|
|
199
|
+
|
|
200
|
+
def reset_parameters(self) -> None:
|
|
201
|
+
'''
|
|
202
|
+
Resets the parameters of the layer.
|
|
203
|
+
'''
|
|
204
|
+
nn.init.normal_(self.weight, 0, 0.01)
|
|
205
|
+
|
|
206
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
207
|
+
'''
|
|
208
|
+
Defines the computation performed at every call.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
input_tensor (torch.Tensor): The input data with shape (batch_size, in_features).
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
torch.Tensor: The output data with shape (batch_size, out_features).
|
|
215
|
+
'''
|
|
216
|
+
return riemannian_manifold_linear(
|
|
217
|
+
input_tensor=input_tensor,
|
|
218
|
+
weight=self.weight,
|
|
219
|
+
kappa=self.kappa,
|
|
220
|
+
lambda_rate=self.lambda_rate,
|
|
221
|
+
scale=self.scale,
|
|
222
|
+
bias=self.bias,
|
|
223
|
+
rule=self.rule
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def extra_repr(self) -> str:
|
|
227
|
+
main_str = super().extra_repr()
|
|
228
|
+
return f'{main_str}, rule={self.rule}, kappa={self.kappa.item():.4f}, lambda={self.lambda_rate.item():.4f}'
|
|
9
229
|
|
|
10
230
|
|
|
11
231
|
class BasicManifoldConv2d(BasicModel):
|
|
@@ -126,6 +346,16 @@ class BasicManifoldConv2d(BasicModel):
|
|
|
126
346
|
loss_lap = torch.sum(A * (1.0 - C)) / torch.sum(A)
|
|
127
347
|
|
|
128
348
|
return MainfoldLoss(cosine=loss_cos, laplacian=loss_lap)
|
|
349
|
+
|
|
350
|
+
def extra_repr(self) -> str:
|
|
351
|
+
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
|
352
|
+
', stride={stride}')
|
|
353
|
+
if self.padding != 0:
|
|
354
|
+
s += ', padding={padding}'
|
|
355
|
+
if self.dilation != 1:
|
|
356
|
+
s += ', dilation={dilation}'
|
|
357
|
+
s += f', rule={self.rule}, use_norm={self.use_norm}'
|
|
358
|
+
return s.format(**self.__dict__)
|
|
129
359
|
|
|
130
360
|
|
|
131
361
|
class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
@@ -138,6 +368,7 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
|
138
368
|
scale (nn.Parameter): Vector amplifier for the hyperspherical network.
|
|
139
369
|
bias (nn.Parameter): Manifold bias vector.
|
|
140
370
|
weight_ones (torch.Tensor): Fixed all-ones kernel for computing patch norm rapidly.
|
|
371
|
+
use_norm (bool): Whether to scale the output by the input patch norm.
|
|
141
372
|
'''
|
|
142
373
|
|
|
143
374
|
def __init__(
|
|
@@ -152,7 +383,9 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
|
152
383
|
lambda_init: float = 0.1,
|
|
153
384
|
scale_init: float = 15.0,
|
|
154
385
|
k_neighbors: int = 2,
|
|
155
|
-
rule: str = 'near'
|
|
386
|
+
rule: str = 'near',
|
|
387
|
+
use_norm_gate: bool = False,
|
|
388
|
+
use_norm: bool = False
|
|
156
389
|
) -> None:
|
|
157
390
|
'''
|
|
158
391
|
Initializes the RiemannianManifoldConv2d layer.
|
|
@@ -169,6 +402,7 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
|
169
402
|
scale_init (float): Initial value for the vector amplifier scale.
|
|
170
403
|
k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
|
|
171
404
|
rule (str): Attraction rule, either 'near' or 'far'.
|
|
405
|
+
use_norm (bool): Whether to scale the output by the input patch norm. Default: True.
|
|
172
406
|
'''
|
|
173
407
|
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, k_neighbors)
|
|
174
408
|
|
|
@@ -177,6 +411,8 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
|
177
411
|
self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
|
|
178
412
|
self.scale = nn.Parameter(torch.ones(out_channels) * scale_init)
|
|
179
413
|
self.bias = nn.Parameter(torch.zeros(out_channels))
|
|
414
|
+
self.use_norm_gate = use_norm_gate
|
|
415
|
+
self.use_norm = use_norm
|
|
180
416
|
|
|
181
417
|
# All-ones kernel for ultra-fast calculation of patch norm
|
|
182
418
|
weight_ones = torch.ones(1, in_channels, *self.kernel_size)
|
|
@@ -199,132 +435,17 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
|
|
|
199
435
|
Returns:
|
|
200
436
|
torch.Tensor: The output manifold projection tensor.
|
|
201
437
|
'''
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
cosine = torch.clamp(cosine, -1.0 + 1e-6, 1.0 - 1e-6)
|
|
217
|
-
|
|
218
|
-
# 4. vMF gravitational field calculation (applied pixel-wise)
|
|
219
|
-
theta = torch.acos(cosine)
|
|
220
|
-
exp_val = torch.exp(self.kappa * (cosine - 1.0))
|
|
221
|
-
attraction = exp_val if self.rule == 'near' else 1.0 - exp_val
|
|
222
|
-
|
|
223
|
-
# 5. Riemannian geodesic pullback
|
|
224
|
-
safe_lambda = torch.clamp(self.lambda_rate, 1e-6, 1.0 - 1e-4)
|
|
225
|
-
effective_theta = theta * (1.0 - safe_lambda * attraction)
|
|
226
|
-
|
|
227
|
-
# 6. Reconstruct the output (note the shape broadcasting)
|
|
228
|
-
scale_view = self.scale.view(1, -1, 1, 1)
|
|
229
|
-
bias_view = self.bias.view(1, -1, 1, 1)
|
|
230
|
-
|
|
231
|
-
output = scale_view * torch.cos(effective_theta) + bias_view
|
|
232
|
-
|
|
233
|
-
return output
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
class EuclideanManifoldConv2d(BasicManifoldConv2d):
|
|
237
|
-
'''
|
|
238
|
-
A 2D convolutional layer simulating a manifold structure in Euclidean space.
|
|
239
|
-
|
|
240
|
-
Attributes:
|
|
241
|
-
tau (nn.Parameter): Temperature or radius parameter for the basin of attraction.
|
|
242
|
-
lambda_rate (nn.Parameter): Gravitational strength parameter.
|
|
243
|
-
bias (nn.Parameter): Translation bias vector.
|
|
244
|
-
weight_ones (torch.Tensor): Fixed all-ones kernel for computing patch norm rapidly.
|
|
245
|
-
'''
|
|
246
|
-
|
|
247
|
-
def __init__(
|
|
248
|
-
self,
|
|
249
|
-
in_channels: int,
|
|
250
|
-
out_channels: int,
|
|
251
|
-
kernel_size: int,
|
|
252
|
-
stride: int = 1,
|
|
253
|
-
padding: int = 0,
|
|
254
|
-
dilation: int = 1,
|
|
255
|
-
tau_init: float = 5.0,
|
|
256
|
-
lambda_init: float = 0.5,
|
|
257
|
-
k_neighbors: int = 2,
|
|
258
|
-
rule: str = 'near'
|
|
259
|
-
) -> None:
|
|
260
|
-
'''
|
|
261
|
-
Initializes the EuclideanManifoldConv2d layer.
|
|
262
|
-
|
|
263
|
-
Args:
|
|
264
|
-
in_channels (int): Number of channels in the input image.
|
|
265
|
-
out_channels (int): Number of channels produced by the convolution.
|
|
266
|
-
kernel_size (int): Size of the convolving kernel.
|
|
267
|
-
stride (int): Stride of the convolution. Default: 1.
|
|
268
|
-
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
|
269
|
-
dilation (int): Spacing between kernel elements. Default: 1.
|
|
270
|
-
tau_init (float): Initial value for the basin temperature/radius.
|
|
271
|
-
lambda_init (float): Initial value for the gravitational strength.
|
|
272
|
-
k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
|
|
273
|
-
rule (str): Attraction rule, either 'near' or 'far'.
|
|
274
|
-
'''
|
|
275
|
-
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, k_neighbors)
|
|
276
|
-
|
|
277
|
-
self.rule = rule.lower()
|
|
278
|
-
self.tau = nn.Parameter(torch.tensor(float(tau_init)))
|
|
279
|
-
self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
|
|
280
|
-
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
|
281
|
-
|
|
282
|
-
weight_ones = torch.ones(1, in_channels, *self.kernel_size)
|
|
283
|
-
self.register_buffer('weight_ones', weight_ones)
|
|
284
|
-
|
|
285
|
-
self.reset_parameters()
|
|
286
|
-
|
|
287
|
-
def reset_parameters(self) -> None:
|
|
288
|
-
'''
|
|
289
|
-
Resets the parameters of the layer using Kaiming uniform initialization.
|
|
290
|
-
'''
|
|
291
|
-
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
292
|
-
if self.bias is not None:
|
|
293
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
294
|
-
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
295
|
-
nn.init.uniform_(self.bias, -bound, bound)
|
|
296
|
-
|
|
297
|
-
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
298
|
-
'''
|
|
299
|
-
Defines the computation performed at every call.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
input_tensor (torch.Tensor): The input data tensor.
|
|
303
|
-
|
|
304
|
-
Returns:
|
|
305
|
-
torch.Tensor: The output manifold projection tensor.
|
|
306
|
-
'''
|
|
307
|
-
# 1. Base physical projection
|
|
308
|
-
# base_proj: [batch, out_channels, H_out, W_out]
|
|
309
|
-
base_proj = F.conv2d(input_tensor, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
|
|
310
|
-
|
|
311
|
-
# 2. Ultra-fast algebraic expansion of the squared L2 distance for local patches
|
|
312
|
-
# ||patch - W||^2 = ||patch||^2 + ||W||^2 - 2<patch, W>
|
|
313
|
-
x_sq = F.conv2d(input_tensor ** 2, self.weight_ones, stride=self.stride, padding=self.padding, dilation=self.dilation)
|
|
314
|
-
w_sq = torch.sum(self.weight ** 2, dim=(1,2,3)).view(1, -1, 1, 1)
|
|
315
|
-
|
|
316
|
-
dist_sq = x_sq + w_sq - 2 * base_proj
|
|
317
|
-
dist_sq = torch.clamp(dist_sq, min=1e-6)
|
|
318
|
-
|
|
319
|
-
# 3. Compute the attraction index
|
|
320
|
-
exp_val = torch.exp(-dist_sq / (self.tau ** 2 + 1e-8))
|
|
321
|
-
attraction = exp_val if self.rule == 'near' else 1.0 - exp_val
|
|
322
|
-
|
|
323
|
-
# 4. Gravitational correction
|
|
324
|
-
safe_lambda = torch.clamp(self.lambda_rate, 1e-6, 1.0 - 1e-4)
|
|
325
|
-
correction = safe_lambda * attraction * (w_sq - base_proj)
|
|
326
|
-
|
|
327
|
-
# 5. Combine outputs
|
|
328
|
-
output = base_proj + correction + self.bias.view(1, -1, 1, 1)
|
|
329
|
-
|
|
330
|
-
return output
|
|
438
|
+
return riemannian_manifold_conv2d(
|
|
439
|
+
input_tensor=input_tensor,
|
|
440
|
+
weight=self.weight,
|
|
441
|
+
weight_ones=self.weight_ones,
|
|
442
|
+
kappa=self.kappa,
|
|
443
|
+
lambda_rate=self.lambda_rate,
|
|
444
|
+
scale=self.scale,
|
|
445
|
+
bias=self.bias,
|
|
446
|
+
stride=self.stride,
|
|
447
|
+
padding=self.padding,
|
|
448
|
+
dilation=self.dilation,
|
|
449
|
+
rule=self.rule,
|
|
450
|
+
use_norm=self.use_norm
|
|
451
|
+
)
|