codon-model 0.0.3b2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. {codon_model-0.0.3b2/codon_model.egg-info → codon_model-0.0.4}/PKG-INFO +1 -1
  2. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/__init__.py +1 -1
  3. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/base.py +29 -3
  4. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/__init__.py +11 -0
  5. codon_model-0.0.4/codon/block/bio/__init__.py +9 -0
  6. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/lora.py +2 -0
  7. codon_model-0.0.3b2/codon/exp/block/manifold_conv.py → codon_model-0.0.4/codon/block/manifold.py +255 -134
  8. codon_model-0.0.4/codon/exp/block/bio.py +494 -0
  9. codon_model-0.0.4/codon/exp/block/manifold.py +88 -0
  10. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/kit/train/vision.py +1 -1
  11. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/__init__.py +4 -0
  12. codon_model-0.0.4/codon/ops/manifold/__init__.py +133 -0
  13. codon_model-0.0.4/codon/ops/manifold/conv.py +217 -0
  14. codon_model-0.0.3b2/codon/exp/ops/manifold_triton.py → codon_model-0.0.4/codon/ops/manifold/linear.py +6 -5
  15. codon_model-0.0.4/codon/utils/eval/__init__.py +24 -0
  16. codon_model-0.0.4/codon/utils/eval/activation.py +127 -0
  17. codon_model-0.0.4/codon/utils/eval/base.py +210 -0
  18. codon_model-0.0.4/codon/utils/eval/boundary.py +157 -0
  19. codon_model-0.0.4/codon/utils/eval/cka.py +191 -0
  20. codon_model-0.0.4/codon/utils/eval/confusion.py +77 -0
  21. codon_model-0.0.4/codon/utils/eval/gradcam.py +121 -0
  22. codon_model-0.0.4/codon/utils/eval/layer_rsa.py +130 -0
  23. codon_model-0.0.4/codon/utils/eval/rsa.py +103 -0
  24. codon_model-0.0.4/codon/utils/eval/selectivity.py +149 -0
  25. codon_model-0.0.4/codon/utils/eval/similarity.py +65 -0
  26. codon_model-0.0.4/codon/utils/eval/tsne.py +109 -0
  27. codon_model-0.0.4/codon/utils/info.py +137 -0
  28. codon_model-0.0.4/codon/utils/layer/lora.py +13 -0
  29. codon_model-0.0.4/codon/utils/layer/manifold.py +70 -0
  30. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/seed.py +8 -4
  31. codon_model-0.0.3b2/codon/utils/token.py → codon_model-0.0.4/codon/utils/tokens.py +12 -2
  32. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/transforms.py +13 -6
  33. {codon_model-0.0.3b2 → codon_model-0.0.4/codon_model.egg-info}/PKG-INFO +1 -1
  34. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/SOURCES.txt +26 -8
  35. {codon_model-0.0.3b2 → codon_model-0.0.4}/test/test_motifv1_train.py +1 -1
  36. codon_model-0.0.3b2/codon/exp/block/manifold.py +0 -332
  37. codon_model-0.0.3b2/codon/exp/ops/manifold.py +0 -63
  38. {codon_model-0.0.3b2 → codon_model-0.0.4}/LICENSE +0 -0
  39. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/attention.py +0 -0
  40. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/bio/hebian.py +0 -0
  41. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/bio/predictive.py +0 -0
  42. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/codebook.py +0 -0
  43. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/conv.py +0 -0
  44. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/embedding.py +0 -0
  45. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/film.py +0 -0
  46. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/fusion.py +0 -0
  47. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/mlp.py +0 -0
  48. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/moe.py +0 -0
  49. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/pixelshuffle.py +0 -0
  50. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/block/transformer.py +0 -0
  51. {codon_model-0.0.3b2/codon/block/bio → codon_model-0.0.4/codon/exp}/__init__.py +0 -0
  52. {codon_model-0.0.3b2/codon/exp → codon_model-0.0.4/codon/exp/block}/__init__.py +0 -0
  53. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/exp/block/moe.py +0 -0
  54. {codon_model-0.0.3b2/codon/exp/block → codon_model-0.0.4/codon/exp/ops}/__init__.py +0 -0
  55. {codon_model-0.0.3b2/codon/exp/ops → codon_model-0.0.4/codon/kit}/__init__.py +0 -0
  56. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/kit/train/__init__.py +0 -0
  57. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/__init__.py +0 -0
  58. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/patch_disc.py +0 -0
  59. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/resnet.py +0 -0
  60. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/model/tcn.py +0 -0
  61. {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/__init__.py +0 -0
  62. {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/base.py +0 -0
  63. {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/motif_a1.py +0 -0
  64. {codon_model-0.0.3b2/codon/model → codon_model-0.0.4/codon}/motif/motif_v1.py +0 -0
  65. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/attention.py +0 -0
  66. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/bio.py +0 -0
  67. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/ops/pixelshuffle.py +0 -0
  68. {codon_model-0.0.3b2/codon/kit → codon_model-0.0.4/codon/utils}/__init__.py +0 -0
  69. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/__init__.py +0 -0
  70. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/base.py +0 -0
  71. {codon_model-0.0.3b2/codon/utils → codon_model-0.0.4/codon/utils/dataset/conflux}/__init__.py +0 -0
  72. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/base.py +0 -0
  73. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/reader.py +0 -0
  74. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/conflux/writer.py +0 -0
  75. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/corpus.py +0 -0
  76. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/dataviewer.py +0 -0
  77. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/flatdata.py +0 -0
  78. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/dataset/image.py +0 -0
  79. {codon_model-0.0.3b2/codon/utils/dataset/conflux → codon_model-0.0.4/codon/utils/layer}/__init__.py +0 -0
  80. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/mask.py +0 -0
  81. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/safecode.py +0 -0
  82. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/split.py +0 -0
  83. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon/utils/theta.py +0 -0
  84. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/dependency_links.txt +0 -0
  85. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/requires.txt +0 -0
  86. {codon_model-0.0.3b2 → codon_model-0.0.4}/codon_model.egg-info/top_level.txt +0 -0
  87. {codon_model-0.0.3b2 → codon_model-0.0.4}/setup.cfg +0 -0
  88. {codon_model-0.0.3b2 → codon_model-0.0.4}/setup.py +0 -0
  89. {codon_model-0.0.3b2 → codon_model-0.0.4}/test/test_conflux_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: codon-model
3
- Version: 0.0.3b2
3
+ Version: 0.0.4
4
4
  Summary: Codon model package
5
5
  Author: CodonTeam
6
6
  Requires-Python: >=3.8
@@ -1,5 +1,5 @@
1
1
  from typing import Optional
2
2
 
3
- __version__ = '0.0.3b2'
3
+ __version__ = '0.0.4'
4
4
 
5
5
  __seed__: Optional[int] = None
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
3
4
 
4
5
  from typing import Callable, Any, Iterator, Union
5
6
 
@@ -132,7 +133,7 @@ class BasicModel(nn.Module):
132
133
 
133
134
  return total
134
135
 
135
- def load_pretrained(self, path: str) -> None:
136
+ def load_pretrained(self, path: str) -> 'BasicModel':
136
137
  '''
137
138
  Load a pretrained model from a file.
138
139
 
@@ -141,7 +142,7 @@ class BasicModel(nn.Module):
141
142
  '''
142
143
  if path.endswith('.safetensors'):
143
144
  safe_load_model(self, path)
144
- return
145
+ return self
145
146
 
146
147
  state_dict = torch.load(path, map_location=self.device)
147
148
 
@@ -152,8 +153,10 @@ class BasicModel(nn.Module):
152
153
  state_dict = state_dict['state_dict']
153
154
 
154
155
  self.load_state_dict(state_dict)
156
+
157
+ return self
155
158
 
156
- def save_pretrained(self, path: str) -> None:
159
+ def save_pretrained(self, path: str) -> 'BasicModel':
157
160
  '''
158
161
  Save the model to a file.
159
162
 
@@ -165,3 +168,26 @@ class BasicModel(nn.Module):
165
168
  else:
166
169
  state_dict = self.state_dict()
167
170
  torch.save(state_dict, path)
171
+ return self
172
+
173
+ def freeze(self) -> 'BasicModel':
174
+ '''
175
+ Freeze all parameters in the model by setting requires_grad to False.
176
+
177
+ Returns:
178
+ BasicModel: The model itself for method chaining.
179
+ '''
180
+ for param in self.parameters():
181
+ param.requires_grad = False
182
+ return self
183
+
184
+ def unfreeze(self) -> 'BasicModel':
185
+ '''
186
+ Unfreeze all parameters in the model by setting requires_grad to True.
187
+
188
+ Returns:
189
+ BasicModel: The model itself for method chaining.
190
+ '''
191
+ for param in self.parameters():
192
+ param.requires_grad = True
193
+ return self
@@ -31,6 +31,11 @@ from .transformer import (
31
31
  TransformerMoEDecoder,
32
32
  _TransformerDecoder,
33
33
  )
34
+ from .manifold import (
35
+ MainfoldLoss,
36
+ BasicManifoldLinear, RiemannianManifoldLinear,
37
+ BasicManifoldConv2d, RiemannianManifoldConv2d
38
+ )
34
39
 
35
40
  __all__ = [
36
41
  # attention
@@ -80,4 +85,10 @@ __all__ = [
80
85
  'TransformerDecoderOutput',
81
86
  'TransformerDenseDecoder',
82
87
  'TransformerMoEDecoder',
88
+ # manifold
89
+ 'MainfoldLoss',
90
+ 'BasicManifoldLinear',
91
+ 'RiemannianManifoldLinear',
92
+ 'BasicManifoldConv2d',
93
+ 'RiemannianManifoldConv2d'
83
94
  ]
@@ -0,0 +1,9 @@
1
+ from .hebian import HebianOutput, Hebian
2
+ from .predictive import PredictiveCodingOutput, PredictiveCoding
3
+
4
+ __all__ = [
5
+ 'HebianOutput',
6
+ 'Hebian',
7
+ 'PredictiveCodingOutput',
8
+ 'PredictiveCoding'
9
+ ]
@@ -50,6 +50,8 @@ class BasicLoRA(BasicModel):
50
50
  super().__init__()
51
51
  self.gradient_checkpointing = gradient_checkpointing
52
52
  self.original_layer = original_layer
53
+ self.lora_dropout_p = lora_dropout
54
+ self.merge_weights = merge_weights
53
55
 
54
56
  # Freeze original layer
55
57
  for p in self.original_layer.parameters():
@@ -1,11 +1,231 @@
1
+ from codon.base import *
2
+
1
3
  import torch.nn.functional as F
2
4
 
3
- from codon.base import *
5
+ from typing import Tuple
6
+ from dataclasses import dataclass
7
+
8
+ from codon.ops.manifold import riemannian_manifold_linear, riemannian_manifold_conv2d
9
+
10
+
11
+ @dataclass
12
+ class MainfoldLoss:
13
+ '''
14
+ Dataclass for storing manifold-related loss components.
15
+
16
+ Attributes:
17
+ cosine (torch.Tensor): The cosine similarity loss.
18
+ laplacian (torch.Tensor): The Laplacian regularization loss.
19
+ '''
20
+ cosine: torch.Tensor
21
+ laplacian: torch.Tensor
22
+
23
+ def factor_loss(self, factor_cos: float = 0.013, factor_lap: float = 0.012) -> torch.Tensor:
24
+ '''
25
+ Calculates the weighted sum of cosine and Laplacian losses.
26
+
27
+ Args:
28
+ factor_cos (float): The weight factor for the cosine loss.
29
+ factor_lap (float): The weight factor for the Laplacian loss.
30
+
31
+ Returns:
32
+ torch.Tensor: The calculated total loss value.
33
+ '''
34
+ return self.cosine * factor_cos + self.laplacian * factor_lap
35
+
36
+
37
+ class BasicManifoldLinear(BasicModel):
38
+ '''
39
+ Base class for manifold-based neural network layers.
40
+
41
+ Attributes:
42
+ in_features (int): Size of each input sample.
43
+ out_features (int): Size of each output sample.
44
+ k_neighbors (int): Number of nearest neighbors to consider for Laplacian loss.
45
+ weight (nn.Parameter): The learnable weights of the layer.
46
+ '''
47
+
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ k_neighbors: int = 2,
53
+ ) -> None:
54
+ '''
55
+ Initializes the BasicManifoldLinear layer.
4
56
 
5
- import math
6
- from typing import Tuple, Union
57
+ Args:
58
+ in_features (int): Size of each input sample.
59
+ out_features (int): Size of each output sample.
60
+ k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
61
+ '''
62
+ super().__init__()
7
63
 
8
- from .manifold import MainfoldLoss
64
+ self.in_features = in_features
65
+ self.out_features = out_features
66
+ self.k_neighbors = min(k_neighbors, out_features - 1)
67
+
68
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
69
+
70
+ @property
71
+ def loss_cosine(self) -> torch.Tensor:
72
+ '''
73
+ Calculates the cosine similarity penalty loss among the weight vectors.
74
+
75
+ Returns:
76
+ torch.Tensor: The computed cosine penalty loss.
77
+ '''
78
+ w_norm = F.normalize(self.weight, p=2, dim=1)
79
+
80
+ C = torch.matmul(w_norm, w_norm.T)
81
+ I = torch.eye(self.out_features, device=C.device)
82
+
83
+ return torch.sum((C * (1 - I)) ** 2) / (self.out_features * (self.out_features - 1))
84
+
85
+ @property
86
+ def loss_laplacian(self) -> torch.Tensor:
87
+ '''
88
+ Calculates the Laplacian regularization loss based on k-nearest neighbors.
89
+
90
+ Returns:
91
+ torch.Tensor: The computed Laplacian regularization loss.
92
+ '''
93
+ w_norm = F.normalize(self.weight, p=2, dim=1)
94
+
95
+ C = torch.matmul(w_norm, w_norm.T)
96
+ I = torch.eye(self.out_features, device=C.device)
97
+
98
+ _, topk_idx = torch.topk(C, self.k_neighbors + 1, dim=1)
99
+
100
+ A = torch.zeros_like(C)
101
+ A.scatter_(1, topk_idx, 1.0)
102
+ A = A - I
103
+ A = torch.max(A, A.T)
104
+
105
+ return torch.sum(A * (1.0 - C)) / torch.sum(A)
106
+
107
+ def compute_loss(self) -> MainfoldLoss:
108
+ '''
109
+ Computes both the cosine and Laplacian losses and returns them in a MainfoldLoss object.
110
+
111
+ Returns:
112
+ MainfoldLoss: An object containing the computed cosine and Laplacian losses.
113
+ '''
114
+ w_norm = F.normalize(self.weight, p=2, dim=1)
115
+
116
+ C = torch.matmul(w_norm, w_norm.T)
117
+ I = torch.eye(self.out_features, device=C.device)
118
+
119
+ loss_cos = torch.sum((C * (1 - I)) ** 2) / (self.out_features * (self.out_features - 1))
120
+
121
+ _, topk_idx = torch.topk(C, self.k_neighbors + 1, dim=1)
122
+
123
+ A = torch.zeros_like(C)
124
+ A.scatter_(1, topk_idx, 1.0)
125
+ A = A - I
126
+ A = torch.max(A, A.T)
127
+
128
+ loss_lap = torch.sum(A * (1.0 - C)) / torch.sum(A)
129
+
130
+ return MainfoldLoss(cosine=loss_cos, laplacian=loss_lap)
131
+
132
+ def extra_repr(self) -> str:
133
+ '''
134
+ Sets the extra representation of the module for printing.
135
+ '''
136
+ return f'in_features={self.in_features}, out_features={self.out_features}, k_neighbors={self.k_neighbors}'
137
+
138
+
139
+ class RiemannianManifoldLinear(BasicManifoldLinear):
140
+ '''
141
+ A linear layer projecting data onto a Riemannian manifold (hypersphere).
142
+
143
+ Attributes:
144
+ kappa (nn.Parameter): Concentration parameter for the von Mises-Fisher (vMF) distribution.
145
+ lambda_rate (nn.Parameter): Gravitational attraction coefficient.
146
+ scale (nn.Parameter): Vector amplifier for the hyperspherical network.
147
+ bias (nn.Parameter): Manifold bias vector.
148
+ '''
149
+
150
+ def __init__(
151
+ self,
152
+ in_features: int,
153
+ out_features: int,
154
+ kappa_init: float = 2.0,
155
+ lambda_init: float = 0.1,
156
+ scale_init: float = 15.0,
157
+ k_neighbors: int = 2,
158
+ rule: str = 'near'
159
+ ) -> None:
160
+ '''
161
+ Initializes the RiemannianManifoldLinear layer.
162
+
163
+ Args:
164
+ in_features (int): Size of each input sample.
165
+ out_features (int): Size of each output sample.
166
+ kappa_init (float): Initial value for the vMF concentration parameter.
167
+ lambda_init (float): Initial value for the gravitational attraction coefficient.
168
+ scale_init (float): Initial value for the vector amplifier scale.
169
+ k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
170
+ rule (str): Attraction rule, either 'near' or 'far'.
171
+ '''
172
+ super().__init__(
173
+ in_features=in_features,
174
+ out_features=out_features,
175
+ k_neighbors=k_neighbors
176
+ )
177
+
178
+ self.kappa_init = kappa_init
179
+ self.lambda_init = lambda_init
180
+ self.scale_init = scale_init
181
+ self.rule = rule.lower()
182
+
183
+ if not self.rule in ['far', 'near']:
184
+ raise ValueError(f"Invalid rule: {self.rule}, must be 'far' or 'near'")
185
+
186
+ # Concentration parameter for the vMF distribution
187
+ self.kappa = nn.Parameter(torch.tensor(float(kappa_init)))
188
+
189
+ # Gravitational attraction coefficient
190
+ self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
191
+
192
+ # Vector amplifier for the hyperspherical network
193
+ self.scale = nn.Parameter(torch.ones(out_features) * scale_init)
194
+
195
+ # Manifold bias vector
196
+ self.bias = nn.Parameter(torch.zeros(out_features))
197
+
198
+ self.reset_parameters()
199
+
200
+ def reset_parameters(self) -> None:
201
+ '''
202
+ Resets the parameters of the layer.
203
+ '''
204
+ nn.init.normal_(self.weight, 0, 0.01)
205
+
206
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
207
+ '''
208
+ Defines the computation performed at every call.
209
+
210
+ Args:
211
+ input_tensor (torch.Tensor): The input data with shape (batch_size, in_features).
212
+
213
+ Returns:
214
+ torch.Tensor: The output data with shape (batch_size, out_features).
215
+ '''
216
+ return riemannian_manifold_linear(
217
+ input_tensor=input_tensor,
218
+ weight=self.weight,
219
+ kappa=self.kappa,
220
+ lambda_rate=self.lambda_rate,
221
+ scale=self.scale,
222
+ bias=self.bias,
223
+ rule=self.rule
224
+ )
225
+
226
+ def extra_repr(self) -> str:
227
+ main_str = super().extra_repr()
228
+ return f'{main_str}, rule={self.rule}, kappa={self.kappa.item():.4f}, lambda={self.lambda_rate.item():.4f}'
9
229
 
10
230
 
11
231
  class BasicManifoldConv2d(BasicModel):
@@ -126,6 +346,16 @@ class BasicManifoldConv2d(BasicModel):
126
346
  loss_lap = torch.sum(A * (1.0 - C)) / torch.sum(A)
127
347
 
128
348
  return MainfoldLoss(cosine=loss_cos, laplacian=loss_lap)
349
+
350
+ def extra_repr(self) -> str:
351
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
352
+ ', stride={stride}')
353
+ if self.padding != 0:
354
+ s += ', padding={padding}'
355
+ if self.dilation != 1:
356
+ s += ', dilation={dilation}'
357
+ s += f', rule={self.rule}, use_norm={self.use_norm}'
358
+ return s.format(**self.__dict__)
129
359
 
130
360
 
131
361
  class RiemannianManifoldConv2d(BasicManifoldConv2d):
@@ -138,6 +368,7 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
138
368
  scale (nn.Parameter): Vector amplifier for the hyperspherical network.
139
369
  bias (nn.Parameter): Manifold bias vector.
140
370
  weight_ones (torch.Tensor): Fixed all-ones kernel for computing patch norm rapidly.
371
+ use_norm (bool): Whether to scale the output by the input patch norm.
141
372
  '''
142
373
 
143
374
  def __init__(
@@ -152,7 +383,9 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
152
383
  lambda_init: float = 0.1,
153
384
  scale_init: float = 15.0,
154
385
  k_neighbors: int = 2,
155
- rule: str = 'near'
386
+ rule: str = 'near',
387
+ use_norm_gate: bool = False,
388
+ use_norm: bool = False
156
389
  ) -> None:
157
390
  '''
158
391
  Initializes the RiemannianManifoldConv2d layer.
@@ -169,6 +402,7 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
169
402
  scale_init (float): Initial value for the vector amplifier scale.
170
403
  k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
171
404
  rule (str): Attraction rule, either 'near' or 'far'.
405
+ use_norm (bool): Whether to scale the output by the input patch norm. Default: True.
172
406
  '''
173
407
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, k_neighbors)
174
408
 
@@ -177,6 +411,8 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
177
411
  self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
178
412
  self.scale = nn.Parameter(torch.ones(out_channels) * scale_init)
179
413
  self.bias = nn.Parameter(torch.zeros(out_channels))
414
+ self.use_norm_gate = use_norm_gate
415
+ self.use_norm = use_norm
180
416
 
181
417
  # All-ones kernel for ultra-fast calculation of patch norm
182
418
  weight_ones = torch.ones(1, in_channels, *self.kernel_size)
@@ -199,132 +435,17 @@ class RiemannianManifoldConv2d(BasicManifoldConv2d):
199
435
  Returns:
200
436
  torch.Tensor: The output manifold projection tensor.
201
437
  '''
202
- # 1. Weight normalization
203
- w_flat = self.weight.view(self.out_channels, -1)
204
- w_norm_flat = F.normalize(w_flat, p=2, dim=1)
205
- w_norm = w_norm_flat.view_as(self.weight)
206
-
207
- # 2. Ultra-fast calculation of the norm for each sliding patch of the input image
208
- # x_sq: [batch, 1, H_out, W_out]
209
- x_sq = F.conv2d(input_tensor ** 2, self.weight_ones, stride=self.stride, padding=self.padding, dilation=self.dilation)
210
- x_norm_val = torch.sqrt(torch.clamp(x_sq, min=1e-6))
211
-
212
- # 3. Calculate Cosine Feature Map
213
- # cosine: [batch, out_channels, H_out, W_out]
214
- conv_proj = F.conv2d(input_tensor, w_norm, stride=self.stride, padding=self.padding, dilation=self.dilation)
215
- cosine = conv_proj / (x_norm_val + 1e-6)
216
- cosine = torch.clamp(cosine, -1.0 + 1e-6, 1.0 - 1e-6)
217
-
218
- # 4. vMF gravitational field calculation (applied pixel-wise)
219
- theta = torch.acos(cosine)
220
- exp_val = torch.exp(self.kappa * (cosine - 1.0))
221
- attraction = exp_val if self.rule == 'near' else 1.0 - exp_val
222
-
223
- # 5. Riemannian geodesic pullback
224
- safe_lambda = torch.clamp(self.lambda_rate, 1e-6, 1.0 - 1e-4)
225
- effective_theta = theta * (1.0 - safe_lambda * attraction)
226
-
227
- # 6. Reconstruct the output (note the shape broadcasting)
228
- scale_view = self.scale.view(1, -1, 1, 1)
229
- bias_view = self.bias.view(1, -1, 1, 1)
230
-
231
- output = scale_view * torch.cos(effective_theta) + bias_view
232
-
233
- return output
234
-
235
-
236
- class EuclideanManifoldConv2d(BasicManifoldConv2d):
237
- '''
238
- A 2D convolutional layer simulating a manifold structure in Euclidean space.
239
-
240
- Attributes:
241
- tau (nn.Parameter): Temperature or radius parameter for the basin of attraction.
242
- lambda_rate (nn.Parameter): Gravitational strength parameter.
243
- bias (nn.Parameter): Translation bias vector.
244
- weight_ones (torch.Tensor): Fixed all-ones kernel for computing patch norm rapidly.
245
- '''
246
-
247
- def __init__(
248
- self,
249
- in_channels: int,
250
- out_channels: int,
251
- kernel_size: int,
252
- stride: int = 1,
253
- padding: int = 0,
254
- dilation: int = 1,
255
- tau_init: float = 5.0,
256
- lambda_init: float = 0.5,
257
- k_neighbors: int = 2,
258
- rule: str = 'near'
259
- ) -> None:
260
- '''
261
- Initializes the EuclideanManifoldConv2d layer.
262
-
263
- Args:
264
- in_channels (int): Number of channels in the input image.
265
- out_channels (int): Number of channels produced by the convolution.
266
- kernel_size (int): Size of the convolving kernel.
267
- stride (int): Stride of the convolution. Default: 1.
268
- padding (int): Zero-padding added to both sides of the input. Default: 0.
269
- dilation (int): Spacing between kernel elements. Default: 1.
270
- tau_init (float): Initial value for the basin temperature/radius.
271
- lambda_init (float): Initial value for the gravitational strength.
272
- k_neighbors (int): Number of nearest neighbors for the Laplacian graph.
273
- rule (str): Attraction rule, either 'near' or 'far'.
274
- '''
275
- super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, k_neighbors)
276
-
277
- self.rule = rule.lower()
278
- self.tau = nn.Parameter(torch.tensor(float(tau_init)))
279
- self.lambda_rate = nn.Parameter(torch.tensor(float(lambda_init)))
280
- self.bias = nn.Parameter(torch.Tensor(out_channels))
281
-
282
- weight_ones = torch.ones(1, in_channels, *self.kernel_size)
283
- self.register_buffer('weight_ones', weight_ones)
284
-
285
- self.reset_parameters()
286
-
287
- def reset_parameters(self) -> None:
288
- '''
289
- Resets the parameters of the layer using Kaiming uniform initialization.
290
- '''
291
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
292
- if self.bias is not None:
293
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
294
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
295
- nn.init.uniform_(self.bias, -bound, bound)
296
-
297
- def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
298
- '''
299
- Defines the computation performed at every call.
300
-
301
- Args:
302
- input_tensor (torch.Tensor): The input data tensor.
303
-
304
- Returns:
305
- torch.Tensor: The output manifold projection tensor.
306
- '''
307
- # 1. Base physical projection
308
- # base_proj: [batch, out_channels, H_out, W_out]
309
- base_proj = F.conv2d(input_tensor, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
310
-
311
- # 2. Ultra-fast algebraic expansion of the squared L2 distance for local patches
312
- # ||patch - W||^2 = ||patch||^2 + ||W||^2 - 2<patch, W>
313
- x_sq = F.conv2d(input_tensor ** 2, self.weight_ones, stride=self.stride, padding=self.padding, dilation=self.dilation)
314
- w_sq = torch.sum(self.weight ** 2, dim=(1,2,3)).view(1, -1, 1, 1)
315
-
316
- dist_sq = x_sq + w_sq - 2 * base_proj
317
- dist_sq = torch.clamp(dist_sq, min=1e-6)
318
-
319
- # 3. Compute the attraction index
320
- exp_val = torch.exp(-dist_sq / (self.tau ** 2 + 1e-8))
321
- attraction = exp_val if self.rule == 'near' else 1.0 - exp_val
322
-
323
- # 4. Gravitational correction
324
- safe_lambda = torch.clamp(self.lambda_rate, 1e-6, 1.0 - 1e-4)
325
- correction = safe_lambda * attraction * (w_sq - base_proj)
326
-
327
- # 5. Combine outputs
328
- output = base_proj + correction + self.bias.view(1, -1, 1, 1)
329
-
330
- return output
438
+ return riemannian_manifold_conv2d(
439
+ input_tensor=input_tensor,
440
+ weight=self.weight,
441
+ weight_ones=self.weight_ones,
442
+ kappa=self.kappa,
443
+ lambda_rate=self.lambda_rate,
444
+ scale=self.scale,
445
+ bias=self.bias,
446
+ stride=self.stride,
447
+ padding=self.padding,
448
+ dilation=self.dilation,
449
+ rule=self.rule,
450
+ use_norm=self.use_norm
451
+ )