homa 0.1.5__tar.gz → 0.3.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of homa might be problematic. Click here for more details.

Files changed (189) hide show
  1. homa-0.3.18/PKG-INFO +75 -0
  2. homa-0.3.18/README.md +64 -0
  3. {homa-0.1.5 → homa-0.3.18}/pyproject.toml +1 -1
  4. homa-0.3.18/src/homa/activations/APLU.py +49 -0
  5. homa-0.3.18/src/homa/activations/ActivationFunction.py +6 -0
  6. homa-0.3.18/src/homa/activations/AdaptiveActivationFunction.py +15 -0
  7. homa-0.3.18/src/homa/activations/BaseDLReLU.py +34 -0
  8. homa-0.3.18/src/homa/activations/CaLU.py +13 -0
  9. homa-0.3.18/src/homa/activations/DLReLU.py +6 -0
  10. homa-0.3.18/src/homa/activations/ERF.py +10 -0
  11. homa-0.3.18/src/homa/activations/Elliot.py +10 -0
  12. homa-0.3.18/src/homa/activations/ExpExpish.py +9 -0
  13. homa-0.3.18/src/homa/activations/ExponentialDLReLU.py +6 -0
  14. homa-0.3.18/src/homa/activations/ExponentialSwish.py +10 -0
  15. homa-0.3.18/src/homa/activations/GCU.py +9 -0
  16. homa-0.3.18/src/homa/activations/GaLU.py +11 -0
  17. homa-0.3.18/src/homa/activations/GaussianReLU.py +50 -0
  18. homa-0.3.18/src/homa/activations/GeneralizedSwish.py +10 -0
  19. homa-0.3.18/src/homa/activations/Gish.py +11 -0
  20. homa-0.3.18/src/homa/activations/LaLU.py +11 -0
  21. homa-0.3.18/src/homa/activations/LogLogish.py +10 -0
  22. homa-0.3.18/src/homa/activations/LogSigmoid.py +10 -0
  23. homa-0.3.18/src/homa/activations/Logish.py +10 -0
  24. homa-0.3.18/src/homa/activations/MeLU.py +11 -0
  25. homa-0.3.18/src/homa/activations/MexicanReLU.py +49 -0
  26. homa-0.3.18/src/homa/activations/MinSin.py +10 -0
  27. homa-0.3.18/src/homa/activations/NReLU.py +12 -0
  28. homa-0.3.18/src/homa/activations/NoisyReLU.py +6 -0
  29. homa-0.3.18/src/homa/activations/PLogish.py +6 -0
  30. homa-0.3.18/src/homa/activations/ParametricLogish.py +13 -0
  31. homa-0.3.18/src/homa/activations/Phish.py +11 -0
  32. homa-0.3.18/src/homa/activations/RReLU.py +16 -0
  33. homa-0.3.18/src/homa/activations/RandomizedSlopedReLU.py +7 -0
  34. homa-0.3.18/src/homa/activations/SGELU.py +12 -0
  35. homa-0.3.18/src/homa/activations/SReLU.py +37 -0
  36. homa-0.3.18/src/homa/activations/SelfArctan.py +9 -0
  37. homa-0.3.18/src/homa/activations/ShiftedReLU.py +10 -0
  38. homa-0.3.18/src/homa/activations/SigmoidDerivative.py +10 -0
  39. homa-0.3.18/src/homa/activations/SineReLU.py +11 -0
  40. homa-0.3.18/src/homa/activations/SlopedReLU.py +13 -0
  41. homa-0.3.18/src/homa/activations/SmallGaLU.py +11 -0
  42. homa-0.3.18/src/homa/activations/Smish.py +9 -0
  43. homa-0.3.18/src/homa/activations/SoftsignRReLU.py +17 -0
  44. homa-0.3.18/src/homa/activations/Suish.py +11 -0
  45. homa-0.3.18/src/homa/activations/TBSReLU.py +13 -0
  46. homa-0.3.18/src/homa/activations/TSReLU.py +10 -0
  47. homa-0.3.18/src/homa/activations/TangentBipolarSigmoidReLU.py +6 -0
  48. homa-0.3.18/src/homa/activations/TangentSigmoidReLU.py +6 -0
  49. homa-0.3.18/src/homa/activations/TeLU.py +9 -0
  50. homa-0.3.18/src/homa/activations/TripleStateSwish.py +15 -0
  51. homa-0.3.18/src/homa/activations/WideMeLU.py +15 -0
  52. homa-0.3.18/src/homa/activations/__init__.py +49 -0
  53. homa-0.3.18/src/homa/activations/learnable/AOAF.py +16 -0
  54. homa-0.3.18/src/homa/activations/learnable/AReLU.py +19 -0
  55. homa-0.3.18/src/homa/activations/learnable/DPReLU.py +16 -0
  56. homa-0.3.18/src/homa/activations/learnable/DualLine.py +18 -0
  57. homa-0.3.18/src/homa/activations/learnable/FReLU.py +14 -0
  58. homa-0.3.18/src/homa/activations/learnable/LeLeLU.py +14 -0
  59. homa-0.3.18/src/homa/activations/learnable/PERU.py +16 -0
  60. homa-0.3.18/src/homa/activations/learnable/PiLU.py +18 -0
  61. homa-0.3.18/src/homa/activations/learnable/ShiLU.py +16 -0
  62. homa-0.3.18/src/homa/activations/learnable/StarReLU.py +16 -0
  63. homa-0.3.18/src/homa/activations/learnable/__init__.py +10 -0
  64. homa-0.3.18/src/homa/activations/learnable/concerns/ChannelBased.py +38 -0
  65. homa-0.3.18/src/homa/activations/learnable/concerns/__init__.py +1 -0
  66. homa-0.3.18/src/homa/cli/Commands/Command.py +2 -0
  67. homa-0.3.18/src/homa/cli/Commands/InitCommand.py +34 -0
  68. homa-0.3.18/src/homa/cli/Commands/__init__.py +2 -0
  69. {homa-0.1.5 → homa-0.3.18}/src/homa/cli/HomaCommand.py +4 -0
  70. homa-0.3.18/src/homa/core/__init__.py +0 -0
  71. homa-0.3.18/src/homa/core/concerns/MovesNetworkToDevice.py +13 -0
  72. homa-0.3.18/src/homa/core/concerns/TracksTime.py +7 -0
  73. homa-0.3.18/src/homa/core/concerns/__init__.py +2 -0
  74. {homa-0.1.5 → homa-0.3.18}/src/homa/device.py +4 -0
  75. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/Ensemble.py +6 -6
  76. homa-0.3.18/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +24 -0
  77. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/concerns/PredictsProbabilities.py +6 -2
  78. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +3 -2
  79. homa-0.3.18/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +11 -0
  80. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
  81. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
  82. homa-0.3.18/src/homa/ensemble/concerns/ReportsEnsembleSize.py +11 -0
  83. homa-0.3.18/src/homa/ensemble/concerns/ReportsLogits.py +38 -0
  84. homa-0.3.18/src/homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
  85. homa-0.3.18/src/homa/ensemble/concerns/StoresModels.py +32 -0
  86. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/concerns/__init__.py +3 -3
  87. homa-0.3.18/src/homa/ensemble/utils.py +9 -0
  88. homa-0.3.18/src/homa/graph/GraphAttention.py +13 -0
  89. homa-0.3.18/src/homa/graph/__init__.py +1 -0
  90. homa-0.3.18/src/homa/graph/modules/GraphAttentionHeadModule.py +37 -0
  91. homa-0.3.18/src/homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
  92. homa-0.3.18/src/homa/graph/modules/__init__.py +2 -0
  93. homa-0.3.18/src/homa/loss/LogitNormLoss.py +12 -0
  94. homa-0.3.18/src/homa/loss/Loss.py +5 -0
  95. homa-0.3.18/src/homa/loss/__init__.py +2 -0
  96. homa-0.3.18/src/homa/rl/DQN.py +2 -0
  97. homa-0.3.18/src/homa/rl/DRQN.py +5 -0
  98. homa-0.3.18/src/homa/rl/DiversityIsAllYouNeed.py +96 -0
  99. homa-0.3.18/src/homa/rl/SoftActorCritic.py +70 -0
  100. homa-0.3.18/src/homa/rl/__init__.py +4 -0
  101. homa-0.3.18/src/homa/rl/buffers/Buffer.py +13 -0
  102. homa-0.3.18/src/homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
  103. homa-0.3.18/src/homa/rl/buffers/ImageBuffer.py +5 -0
  104. homa-0.3.18/src/homa/rl/buffers/SoftActorCriticBuffer.py +60 -0
  105. homa-0.3.18/src/homa/rl/buffers/__init__.py +4 -0
  106. homa-0.3.18/src/homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
  107. homa-0.3.18/src/homa/rl/buffers/concerns/ResetsCollection.py +9 -0
  108. homa-0.3.18/src/homa/rl/buffers/concerns/__init__.py +2 -0
  109. homa-0.3.18/src/homa/rl/diayn/Actor.py +54 -0
  110. homa-0.3.18/src/homa/rl/diayn/Critic.py +41 -0
  111. homa-0.3.18/src/homa/rl/diayn/Discriminator.py +45 -0
  112. homa-0.3.18/src/homa/rl/diayn/__init__.py +3 -0
  113. homa-0.3.18/src/homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
  114. homa-0.3.18/src/homa/rl/diayn/modules/CriticModule.py +28 -0
  115. homa-0.3.18/src/homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
  116. homa-0.3.18/src/homa/rl/diayn/modules/__init__.py +3 -0
  117. homa-0.3.18/src/homa/rl/sac/SoftActor.py +70 -0
  118. homa-0.3.18/src/homa/rl/sac/SoftCritic.py +97 -0
  119. homa-0.3.18/src/homa/rl/sac/__init__.py +2 -0
  120. homa-0.3.18/src/homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
  121. homa-0.3.18/src/homa/rl/sac/modules/SoftActorModule.py +35 -0
  122. homa-0.3.18/src/homa/rl/sac/modules/SoftCriticModule.py +30 -0
  123. homa-0.3.18/src/homa/rl/sac/modules/__init__.py +3 -0
  124. homa-0.3.18/src/homa/rl/utils.py +7 -0
  125. homa-0.1.5/src/homa/vision/ClassificationModel.py → homa-0.3.18/src/homa/vision/Classifier.py +1 -1
  126. homa-0.3.18/src/homa/vision/Resnet.py +13 -0
  127. homa-0.3.18/src/homa/vision/StochasticClassifier.py +29 -0
  128. homa-0.3.18/src/homa/vision/StochasticSwin.py +11 -0
  129. homa-0.3.18/src/homa/vision/Swin.py +25 -0
  130. homa-0.3.18/src/homa/vision/__init__.py +5 -0
  131. homa-0.3.18/src/homa/vision/concerns/HasLabels.py +13 -0
  132. homa-0.1.5/src/homa/vision/concerns/ReportsLogits.py → homa-0.3.18/src/homa/vision/concerns/HasLogits.py +4 -1
  133. homa-0.1.5/src/homa/vision/concerns/Predicts.py → homa-0.3.18/src/homa/vision/concerns/HasProbabilities.py +2 -2
  134. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/concerns/ReportsAccuracy.py +2 -4
  135. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/concerns/Trainable.py +4 -3
  136. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/concerns/__init__.py +3 -2
  137. homa-0.3.18/src/homa/vision/modules/SwinModule.py +31 -0
  138. homa-0.3.18/src/homa/vision/modules/__init__.py +2 -0
  139. homa-0.3.18/src/homa/vision/utils.py +12 -0
  140. homa-0.3.18/src/homa.egg-info/PKG-INFO +75 -0
  141. homa-0.3.18/src/homa.egg-info/SOURCES.txt +154 -0
  142. homa-0.1.5/PKG-INFO +0 -21
  143. homa-0.1.5/README.md +0 -10
  144. homa-0.1.5/src/homa/activations/__init__.py +0 -2
  145. homa-0.1.5/src/homa/activations/classes/APLU.py +0 -48
  146. homa-0.1.5/src/homa/activations/classes/GALU.py +0 -51
  147. homa-0.1.5/src/homa/activations/classes/MELU.py +0 -50
  148. homa-0.1.5/src/homa/activations/classes/PDELU.py +0 -39
  149. homa-0.1.5/src/homa/activations/classes/SReLU.py +0 -49
  150. homa-0.1.5/src/homa/activations/classes/SmallGALU.py +0 -39
  151. homa-0.1.5/src/homa/activations/classes/StochasticActivation.py +0 -20
  152. homa-0.1.5/src/homa/activations/classes/WideMELU.py +0 -61
  153. homa-0.1.5/src/homa/activations/classes/__init__.py +0 -8
  154. homa-0.1.5/src/homa/activations/utils.py +0 -27
  155. homa-0.1.5/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -20
  156. homa-0.1.5/src/homa/ensemble/concerns/HasNetwork.py +0 -5
  157. homa-0.1.5/src/homa/ensemble/concerns/HasStateDicts.py +0 -8
  158. homa-0.1.5/src/homa/ensemble/concerns/RecordsStateDictionaries.py +0 -23
  159. homa-0.1.5/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -10
  160. homa-0.1.5/src/homa/ensemble/concerns/ReportsLogits.py +0 -13
  161. homa-0.1.5/src/homa/ensemble/concerns/ReportsSize.py +0 -11
  162. homa-0.1.5/src/homa/torch/__init__.py +0 -1
  163. homa-0.1.5/src/homa/torch/helpers.py +0 -6
  164. homa-0.1.5/src/homa/vision/Resnet.py +0 -13
  165. homa-0.1.5/src/homa/vision/StochasticResnet.py +0 -8
  166. homa-0.1.5/src/homa/vision/__init__.py +0 -3
  167. homa-0.1.5/src/homa/vision/modules/StochasticResnetModule.py +0 -9
  168. homa-0.1.5/src/homa/vision/modules/__init__.py +0 -2
  169. homa-0.1.5/src/homa/vision/utils.py +0 -21
  170. homa-0.1.5/src/homa.egg-info/PKG-INFO +0 -21
  171. homa-0.1.5/src/homa.egg-info/SOURCES.txt +0 -61
  172. homa-0.1.5/tests/test_ensemble.py +0 -28
  173. homa-0.1.5/tests/test_resnet.py +0 -34
  174. homa-0.1.5/tests/test_stochastic_resnet.py +0 -20
  175. {homa-0.1.5 → homa-0.3.18}/setup.cfg +0 -0
  176. {homa-0.1.5 → homa-0.3.18}/src/homa/__init__.py +0 -0
  177. {homa-0.1.5 → homa-0.3.18}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
  178. {homa-0.1.5 → homa-0.3.18}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
  179. {homa-0.1.5 → homa-0.3.18}/src/homa/cli/namespaces/__init__.py +0 -0
  180. {homa-0.1.5 → homa-0.3.18}/src/homa/ensemble/__init__.py +0 -0
  181. {homa-0.1.5 → homa-0.3.18}/src/homa/settings.py +0 -0
  182. {homa-0.1.5 → homa-0.3.18}/src/homa/utils.py +0 -0
  183. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/Model.py +0 -0
  184. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
  185. {homa-0.1.5 → homa-0.3.18}/src/homa/vision/modules/ResnetModule.py +0 -0
  186. {homa-0.1.5 → homa-0.3.18}/src/homa.egg-info/dependency_links.txt +0 -0
  187. {homa-0.1.5 → homa-0.3.18}/src/homa.egg-info/entry_points.txt +0 -0
  188. {homa-0.1.5 → homa-0.3.18}/src/homa.egg-info/requires.txt +0 -0
  189. {homa-0.1.5 → homa-0.3.18}/src/homa.egg-info/top_level.txt +0 -0
homa-0.3.18/PKG-INFO ADDED
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.4
2
+ Name: homa
3
+ Version: 0.3.18
4
+ Summary: A curated list of machine learning and deep learning helpers.
5
+ Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
+ Requires-Python: >=3.7
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: numpy
9
+ Requires-Dist: torch
10
+ Requires-Dist: fire
11
+
12
+ # Core
13
+
14
+ ### Device Management
15
+
16
+ ```py
17
+ from homa import cpu, mps, cuda, device
18
+
19
+ torch.tensor([1, 2, 3, 4, 5]).to(cpu())
20
+ torch.tensor([1, 2, 3, 4, 5]).to(cuda())
21
+ torch.tensor([1, 2, 3, 4, 5]).to(mps())
22
+ torch.tensor([1, 2, 3, 4, 5]).to(device())
23
+ ```
24
+
25
+ # Vision
26
+
27
+ ## Resnet
28
+
29
+ This is the standard ResNet50 module.
30
+
31
+ You can train the model with a `DataLoader` object.
32
+
33
+ ```py
34
+ from homa.vision import Resnet
35
+
36
+ model = Resnet(num_classes=10, lr=0.001)
37
+ for epoch in range(10):
38
+ model.train(train_dataloader)
39
+ ```
40
+
41
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
42
+
43
+ ```py
44
+ from homa.vision import Resnet
45
+
46
+ model = Resnet(num_classes=10, lr=0.001)
47
+ for epoch in range(10):
48
+ for x, y in train_dataloader:
49
+ model.train(x, y)
50
+ ```
51
+
52
+ ## StochasticResnet
53
+
54
+ This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
55
+
56
+ You can train the model with a `DataLoader` object.
57
+
58
+ ```py
59
+ from homa.vision import StochasticResnet
60
+
61
+ model = StochasticResnet(num_classes=10, lr=0.001)
62
+ for epoch in range(10):
63
+ model.train(train_dataloader)
64
+ ```
65
+
66
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
67
+
68
+ ```py
69
+ from homa.vision import StochasticResnet
70
+
71
+ model = StochasticResnet(num_classes=10, lr=0.001)
72
+ for epoch in range(10):
73
+ for x, y in train_dataloader:
74
+ model.train(x, y)
75
+ ```
homa-0.3.18/README.md ADDED
@@ -0,0 +1,64 @@
1
+ # Core
2
+
3
+ ### Device Management
4
+
5
+ ```py
6
+ from homa import cpu, mps, cuda, device
7
+
8
+ torch.tensor([1, 2, 3, 4, 5]).to(cpu())
9
+ torch.tensor([1, 2, 3, 4, 5]).to(cuda())
10
+ torch.tensor([1, 2, 3, 4, 5]).to(mps())
11
+ torch.tensor([1, 2, 3, 4, 5]).to(device())
12
+ ```
13
+
14
+ # Vision
15
+
16
+ ## Resnet
17
+
18
+ This is the standard ResNet50 module.
19
+
20
+ You can train the model with a `DataLoader` object.
21
+
22
+ ```py
23
+ from homa.vision import Resnet
24
+
25
+ model = Resnet(num_classes=10, lr=0.001)
26
+ for epoch in range(10):
27
+ model.train(train_dataloader)
28
+ ```
29
+
30
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
31
+
32
+ ```py
33
+ from homa.vision import Resnet
34
+
35
+ model = Resnet(num_classes=10, lr=0.001)
36
+ for epoch in range(10):
37
+ for x, y in train_dataloader:
38
+ model.train(x, y)
39
+ ```
40
+
41
+ ## StochasticResnet
42
+
43
+ This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
44
+
45
+ You can train the model with a `DataLoader` object.
46
+
47
+ ```py
48
+ from homa.vision import StochasticResnet
49
+
50
+ model = StochasticResnet(num_classes=10, lr=0.001)
51
+ for epoch in range(10):
52
+ model.train(train_dataloader)
53
+ ```
54
+
55
+ Similarly you can manually take care of decomposition of data from the `DataLoader`.
56
+
57
+ ```py
58
+ from homa.vision import StochasticResnet
59
+
60
+ model = StochasticResnet(num_classes=10, lr=0.001)
61
+ for epoch in range(10):
62
+ for x, y in train_dataloader:
63
+ model.train(x, y)
64
+ ```
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "homa"
7
- version = "0.1.5"
7
+ version = "0.3.18"
8
8
  description = "A curated list of machine learning and deep learning helpers."
9
9
  authors = [
10
10
  { name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
@@ -0,0 +1,49 @@
1
+ import torch
2
+
3
+
4
+ class APLU(torch.nn.Module):
5
+ def __init__(
6
+ self, channels: int | None = None, n: int = 2, init_b: str = "linspace"
7
+ ):
8
+ super().__init__()
9
+ self.n = n
10
+ self.init_b = init_b
11
+ if channels is None:
12
+ self.register_parameter("a", None)
13
+ self.register_parameter("b", None)
14
+ else:
15
+ self._init_params(channels, device=None, dtype=None)
16
+
17
+ def _init_params(self, channels, device, dtype):
18
+ a = torch.zeros(channels, self.n, device=device, dtype=dtype)
19
+ if self.init_b == "linspace":
20
+ b = (
21
+ torch.linspace(-1.0, 1.0, steps=self.n, device=device, dtype=dtype)
22
+ .expand(channels, -1)
23
+ .contiguous()
24
+ )
25
+ else:
26
+ b = torch.empty(channels, self.n, device=device, dtype=dtype).uniform_(
27
+ -1.0, 1.0
28
+ )
29
+ self.a = torch.nn.Parameter(a)
30
+ self.b = torch.nn.Parameter(b)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ if self.a is None or self.b is None:
34
+ self._init_params(x.shape[1], device=x.device, dtype=x.dtype)
35
+
36
+ y = F.relu(x)
37
+ x_exp = x.unsqueeze(-1)
38
+ expand_shape = (
39
+ (
40
+ 1,
41
+ x.shape[1],
42
+ )
43
+ + (1,) * (x.dim() - 2)
44
+ + (self.n,)
45
+ )
46
+ a = self.a.view(*expand_shape)
47
+ b = self.b.view(*expand_shape)
48
+ hinges = (-x_exp + b).clamp_max(0.0)
49
+ return y + (a * hinges).sum(dim=-1)
@@ -0,0 +1,6 @@
1
+ import torch
2
+
3
+
4
+ class ActivationFunction(torch.nn.Module):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
@@ -0,0 +1,15 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class AdaptiveActivationFunction(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def __repr__(self):
10
+ arguments_text = ""
11
+
12
+ if hasattr(self, "num_channels"):
13
+ arguments_text = f"channels={self.num_channels}"
14
+
15
+ return f"{__class__.__name__}({arguments_text})"
@@ -0,0 +1,34 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class BaseDLReLU(ActivationFunction):
6
+ def __init__(self, a: float = 0.01, init_mse: float = 1.0, mode: str = "linear"):
7
+ super().__init__()
8
+ assert 0.0 < a < 1.0, "a must be in (0,1)"
9
+ assert mode in ("linear", "exp")
10
+ self.a = float(a)
11
+ self.mode = mode
12
+ self.register_buffer("prev_mse", torch.tensor(float(init_mse)))
13
+
14
+ @torch.no_grad()
15
+ def set_prev_mse(self, mse_value):
16
+ if isinstance(mse_value, torch.Tensor):
17
+ mse_value = float(mse_value.detach().cpu().item())
18
+ self.prev_mse.fill_(mse_value)
19
+
20
+ @torch.no_grad()
21
+ def update_from_loss(self, loss_tensor: torch.Tensor):
22
+ self.set_prev_mse(loss_tensor)
23
+
24
+ def forward(self, z: torch.Tensor, mse_prev: torch.Tensor | float | None = None):
25
+ b_t = (
26
+ self.prev_mse
27
+ if mse_prev is None
28
+ else (torch.as_tensor(mse_prev, device=z.device, dtype=z.dtype))
29
+ )
30
+ if self.mode == "linear":
31
+ slope = self.a * b_t
32
+ else:
33
+ slope = self.a * torch.exp(-b_t)
34
+ return torch.where(z >= 0, z, slope * z)
@@ -0,0 +1,13 @@
1
+ import torch
2
+ import math
3
+ from .ActivationFunction import ActivationFunction
4
+
5
+
6
+ class CaLU(ActivationFunction):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+
10
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
11
+ a = torch.arctan(x) / math.pi
12
+ b = 0.5
13
+ return x * (a + b)
@@ -0,0 +1,6 @@
1
+ from .BaseDLReLU import BaseDLReLU
2
+
3
+
4
+ class DLReLU(BaseDLReLU):
5
+ def __init__(self, a: float = 0.01, init_mse: float = 1.0):
6
+ super().__init__(a=a, init_mse=init_mse, mode="linear")
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+
4
+ class ERF(torch.nn.Module):
5
+ def __init__(self, alpha=1.0):
6
+ super().__init__()
7
+ self.alpha = torch.nn.Parameter(torch.tensor(alpha))
8
+
9
+ def forward(self, x: torch.Tensor):
10
+ return x * torch.erf(self.alpha * x)
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class Elliot(ActivationFunction):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return 0.5 + torch.div(0.5 * x, 1 + torch.abs(x))
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ class ExpExpish(torch.nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
9
+ return x * torch.exp(-torch.exp(-x))
@@ -0,0 +1,6 @@
1
+ from .BaseDLReLU import BaseDLReLU
2
+
3
+
4
+ class ExponentialDLReLU(BaseDLReLU):
5
+ def __init__(self, a: float = 0.01, init_mse: float = 1.0):
6
+ super().__init__(a=a, init_mse=init_mse, mode="exp")
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class ExponentialSwish(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return torch.exp(-x) * torch.sigmoid(x)
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ class GCU(torch.nn.Module):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def forward(self, x: torch.Tensor):
9
+ return x * torch.cos(x)
@@ -0,0 +1,11 @@
1
+ from .GaussianReLU import GaussianReLU
2
+
3
+
4
+ class GaLU(GaussianReLU):
5
+ def __init__(
6
+ self,
7
+ channels: int | None = None,
8
+ max_input: float = 1.0,
9
+ ):
10
+ self.hats = [(2.0, 2.0), (1.0, 1.0), (3.0, 1.0)]
11
+ super().__init__(self.hats, channels=channels, max_input=max_input)
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from typing import Sequence, Tuple
3
+
4
+
5
+ class GaussianReLU(torch.nn.Module):
6
+ def __init__(
7
+ self,
8
+ alphas_lambdas: Sequence[Tuple[float, float]],
9
+ channels: int | None = None,
10
+ max_input: float = 1.0,
11
+ ):
12
+ super().__init__()
13
+ self.M = float(max_input)
14
+ self.register_buffer(
15
+ "alphas", torch.tensor([a for a, _ in alphas_lambdas], dtype=torch.float32)
16
+ )
17
+ self.register_buffer(
18
+ "lambdas", torch.tensor([l for _, l in alphas_lambdas], dtype=torch.float32)
19
+ )
20
+ self.K = len(alphas_lambdas)
21
+
22
+ if channels is None:
23
+ self.register_parameter("c0", None) # per-channel (PReLU slope)
24
+ self.register_parameter("c", None) # (C, K) coefficients
25
+ else:
26
+ self._init_params(channels, None, None)
27
+
28
+ def _init_params(self, C: int, device, dtype):
29
+ self.c0 = torch.nn.Parameter(torch.zeros(C, device=device, dtype=dtype))
30
+ self.c = torch.nn.Parameter(torch.zeros(C, self.K, device=device, dtype=dtype))
31
+
32
+ def _expand_param(p: torch.Tensor, x: torch.Tensor, add_K: bool = False):
33
+ shape = (
34
+ (1, x.shape[1]) + (1,) * (x.dim() - 2) + ((p.shape[-1],) if add_K else ())
35
+ )
36
+ return p.view(shape)
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ if self.c0 is None or self.c is None:
40
+ self._init_params(x.shape[1], x.device, x.dtype)
41
+ c0 = self._expand_param(self.c0, x)
42
+ y = torch.nn.functional.relu(x) - c0 * torch.nn.functional.relu(-x)
43
+ a = self.alphas.to(x.device, x.dtype).view(*((1,) * x.dim()), -1)
44
+ l = self.lambdas.to(x.device, x.dtype).view(*((1,) * x.dim()), -1)
45
+ xE = x.unsqueeze(-1)
46
+ term1 = (l * self.M - (xE - a * self.M).abs()).clamp_min(0.0)
47
+ term2 = ((xE - a * self.M - 2 * l * self.M).abs() - l * self.M).clamp_max(0.0)
48
+ hats = term1 + term2
49
+ c = self._expand_param(self.c, x, add_K=True) # (1,C,...,K)
50
+ return y + (c * hats).sum(dim=-1)
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class GeneralizedSwish(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x):
10
+ return x * torch.sigmoid(torch.exp(-x))
@@ -0,0 +1,11 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class Gish(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x):
10
+ a = -torch.exp(x)
11
+ return x * torch.log(2 - torch.exp(a))
@@ -0,0 +1,11 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class LaLU(ActivationFunction):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ phi_laplace = torch.where(x >= 0, 1 - 0.5 * torch.exp(-x), 0.5 * torch.exp(x))
11
+ return x * phi_laplace
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+
4
+ class LogLogish(torch.nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x: torch.Tensor):
9
+ a = -torch.exp(x)
10
+ return x * (1 - torch.exp(a))
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class LogSigmoid(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor):
10
+ return torch.log(torch.sigmoid(x))
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class Logish(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return x * torch.log1p(torch.sigmoid(x))
@@ -0,0 +1,11 @@
1
+ from .MexicanReLU import MexicanReLU
2
+
3
+
4
+ class MeLU(MexicanReLU):
5
+ def __init__(self, channels: int | None = None, max_input: float = 1.0):
6
+ self.hats = [
7
+ (2.0, 2.0),
8
+ (1.0, 1.0),
9
+ (3.0, 1.0),
10
+ ]
11
+ super().__init__(self.hats, channels=channels, max_input=max_input)
@@ -0,0 +1,49 @@
1
+ import torch
2
+ from typing import Sequence, Tuple
3
+
4
+
5
+ class MexicanReLU(torch.nn.Module):
6
+ def __init__(
7
+ self,
8
+ alphas_lambdas: Sequence[Tuple[float, float]],
9
+ channels: int | None = None,
10
+ max_input: float = 1.0,
11
+ ):
12
+ super().__init__()
13
+ self.M = float(max_input)
14
+ self.register_buffer(
15
+ "alphas", torch.tensor([a for a, _ in alphas_lambdas], dtype=torch.float32)
16
+ )
17
+ self.register_buffer(
18
+ "lambdas", torch.tensor([l for _, l in alphas_lambdas], dtype=torch.float32)
19
+ )
20
+ self.K = len(alphas_lambdas)
21
+
22
+ if channels is None:
23
+ self.register_parameter("c0", None) # PReLU negative slope (per-channel)
24
+ self.register_parameter("c", None) # (C, K) coefficients
25
+ else:
26
+ self._init_params(channels, device=None, dtype=None)
27
+
28
+ def _init_params(self, C: int, device, dtype):
29
+ self.c0 = torch.nn.Parameter(torch.zeros(C, device=device, dtype=dtype))
30
+ self.c = torch.nn.Parameter(torch.zeros(C, self.K, device=device, dtype=dtype))
31
+
32
+ def _expand_param(p: torch.Tensor, x: torch.Tensor, n_extra: int = 0):
33
+ shape = (
34
+ (1, x.shape[1]) + (1,) * (x.dim() - 2) + ((p.shape[-1],) if n_extra else ())
35
+ )
36
+ return p.view(shape)
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ if self.c0 is None or self.c is None:
40
+ self._init_params(x.shape[1], x.device, x.dtype)
41
+ c0 = self._expand_param(self.c0, x)
42
+ y = F.relu(x) - c0 * F.relu(-x)
43
+ xE = x.unsqueeze(-1)
44
+ cE = self._expand_param(self.c, x, n_extra=1)
45
+ aE = self.alphas.to(x.device, x.dtype).view(*((1,) * x.dim()), -1) # (..., K)
46
+ lE = self.lambdas.to(x.device, x.dtype).view(*((1,) * x.dim()), -1) # (..., K)
47
+ hats = (lE * self.M - (xE - aE * self.M).abs()).clamp_min(0.0)
48
+ y = y + (cE * hats).sum(dim=-1)
49
+ return y
@@ -0,0 +1,10 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class MinSin(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return torch.min(x, torch.sin(x))
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class NReLU(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor):
10
+ sigma = x.std(unbiased=False)
11
+ a = torch.randn_like(x) * sigma
12
+ return torch.where(x >= 0, x + a, 0)
@@ -0,0 +1,6 @@
1
+ from .NReLU import NReLU
2
+
3
+
4
+ class NoisyReLU(NReLU):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
@@ -0,0 +1,6 @@
1
+ from .ParametricLogish import ParametricLogish
2
+
3
+
4
+ class PLogish(ParametricLogish):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ class ParametricLogish(torch.nn.Module):
5
+ def __init__(self, alpha: float = 1.0, beta: float = 10.0):
6
+ super().__init__()
7
+ self.alpha = alpha
8
+ self.beta = beta
9
+
10
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
11
+ a = torch.sigmoid(self.beta * x)
12
+ b = torch.log(1 + a)
13
+ return self.alpha * x * b
@@ -0,0 +1,11 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class Phish(ActivationFunction):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ a = torch.nn.functional.gelu(x)
11
+ return x * torch.tanh(a)
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from .ActivationFunction import ActivationFunction
3
+
4
+
5
+ class RReLU(ActivationFunction):
6
+ def __init__(self, lower: int = 3, upper: int = 8):
7
+ super(RReLU, self).__init__()
8
+ self.lower = lower
9
+ self.upper = upper
10
+
11
+ def forward(self, x):
12
+ if self.training:
13
+ a = torch.empty_like(x).uniform_(self.lower, self.upper)
14
+ else:
15
+ a = (self.lower + self.upper) / 2.0
16
+ return torch.where(x >= 0, x, x / a)
@@ -0,0 +1,7 @@
1
+ import random
2
+ from .SlopedReLU import SlopedReLU
3
+
4
+
5
+ class RandomizedSlopedReLU(SlopedReLU):
6
+ def __init__(self):
7
+ super().__init__(alpha=random.uniform(1, 10))
@@ -0,0 +1,12 @@
1
+ import torch
2
+ import math
3
+ from .ActivationFunction import ActivationFunction
4
+
5
+
6
+ class SGELU(ActivationFunction):
7
+ def __init__(self, alpha: float = 0.1, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.alpha = alpha
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ return self.alpha * x * torch.erf(x / math.sqrt(2))