homa 0.2.0__py3-none-any.whl → 0.2.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. homa/activations/APLU.py +49 -0
  2. homa/activations/ActivationFunction.py +6 -0
  3. homa/activations/AdaptiveActivationFunction.py +15 -0
  4. homa/activations/BaseDLReLU.py +34 -0
  5. homa/activations/CaLU.py +13 -0
  6. homa/activations/DLReLU.py +6 -0
  7. homa/activations/ERF.py +10 -0
  8. homa/activations/Elliot.py +10 -0
  9. homa/activations/ExpExpish.py +9 -0
  10. homa/activations/ExponentialDLReLU.py +6 -0
  11. homa/activations/ExponentialSwish.py +10 -0
  12. homa/activations/GCU.py +9 -0
  13. homa/activations/GaLU.py +11 -0
  14. homa/activations/GaussianReLU.py +50 -0
  15. homa/activations/GeneralizedSwish.py +10 -0
  16. homa/activations/Gish.py +11 -0
  17. homa/activations/LaLU.py +11 -0
  18. homa/activations/LogLogish.py +10 -0
  19. homa/activations/LogSigmoid.py +10 -0
  20. homa/activations/Logish.py +10 -0
  21. homa/activations/MeLU.py +11 -0
  22. homa/activations/MexicanReLU.py +49 -0
  23. homa/activations/MinSin.py +10 -0
  24. homa/activations/NReLU.py +12 -0
  25. homa/activations/NoisyReLU.py +6 -0
  26. homa/activations/PLogish.py +6 -0
  27. homa/activations/ParametricLogish.py +13 -0
  28. homa/activations/Phish.py +11 -0
  29. homa/activations/RReLU.py +16 -0
  30. homa/activations/RandomizedSlopedReLU.py +7 -0
  31. homa/activations/SGELU.py +12 -0
  32. homa/activations/SReLU.py +37 -0
  33. homa/activations/SelfArctan.py +9 -0
  34. homa/activations/ShiftedReLU.py +10 -0
  35. homa/activations/SigmoidDerivative.py +10 -0
  36. homa/activations/SineReLU.py +11 -0
  37. homa/activations/SlopedReLU.py +13 -0
  38. homa/activations/SmallGaLU.py +11 -0
  39. homa/activations/Smish.py +9 -0
  40. homa/activations/SoftsignRReLU.py +17 -0
  41. homa/activations/Suish.py +11 -0
  42. homa/activations/TBSReLU.py +13 -0
  43. homa/activations/TSReLU.py +10 -0
  44. homa/activations/TangentBipolarSigmoidReLU.py +6 -0
  45. homa/activations/TangentSigmoidReLU.py +6 -0
  46. homa/activations/TeLU.py +9 -0
  47. homa/activations/TripleStateSwish.py +15 -0
  48. homa/activations/WideMeLU.py +15 -0
  49. homa/activations/__init__.py +49 -2
  50. homa/activations/learnable/AOAF.py +16 -0
  51. homa/activations/learnable/AReLU.py +19 -0
  52. homa/activations/learnable/DPReLU.py +16 -0
  53. homa/activations/learnable/DualLine.py +18 -0
  54. homa/activations/learnable/FReLU.py +14 -0
  55. homa/activations/learnable/LeLeLU.py +14 -0
  56. homa/activations/learnable/PERU.py +16 -0
  57. homa/activations/learnable/PiLU.py +18 -0
  58. homa/activations/learnable/ShiLU.py +16 -0
  59. homa/activations/learnable/StarReLU.py +16 -0
  60. homa/activations/learnable/__init__.py +10 -0
  61. homa/activations/learnable/concerns/ChannelBased.py +38 -0
  62. homa/activations/learnable/concerns/__init__.py +1 -0
  63. homa/cli/Commands/Command.py +2 -0
  64. homa/cli/Commands/InitCommand.py +34 -0
  65. homa/cli/Commands/__init__.py +2 -0
  66. homa/cli/HomaCommand.py +4 -0
  67. homa/ensemble/concerns/StoresModels.py +9 -3
  68. homa/vision/{ClassificationModel.py → Classifier.py} +1 -1
  69. homa/vision/Resnet.py +2 -2
  70. homa/vision/StochasticClassifier.py +23 -22
  71. homa/vision/StochasticSwin.py +3 -1
  72. homa/vision/Swin.py +4 -3
  73. homa/vision/__init__.py +2 -1
  74. homa/vision/utils.py +12 -0
  75. {homa-0.2.0.dist-info → homa-0.2.95.dist-info}/METADATA +1 -1
  76. homa-0.2.95.dist-info/RECORD +113 -0
  77. homa/activations/classes/APLU.py +0 -86
  78. homa/activations/classes/GALU.py +0 -67
  79. homa/activations/classes/MELU.py +0 -70
  80. homa/activations/classes/PDELU.py +0 -54
  81. homa/activations/classes/SReLU.py +0 -69
  82. homa/activations/classes/SmallGALU.py +0 -58
  83. homa/activations/classes/WideMELU.py +0 -90
  84. homa/activations/classes/__init__.py +0 -7
  85. homa/activations/utils.py +0 -27
  86. homa/vision/StochasticResnet.py +0 -9
  87. homa-0.2.0.dist-info/RECORD +0 -58
  88. {homa-0.2.0.dist-info → homa-0.2.95.dist-info}/WHEEL +0 -0
  89. {homa-0.2.0.dist-info → homa-0.2.95.dist-info}/entry_points.txt +0 -0
  90. {homa-0.2.0.dist-info → homa-0.2.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,113 @@
1
+ homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
2
+ homa/device.py,sha256=9kKXfpYfnEk2cFQWPfcJrVloHgC_SSbP4I8IRY9TYk4,343
3
+ homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
4
+ homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
5
+ homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
6
+ homa/activations/ActivationFunction.py,sha256=XUw7Pa5E-CPG6rPL8Us_pDH7xCZqY0c2P9xtnJMyX44,141
7
+ homa/activations/AdaptiveActivationFunction.py,sha256=p_bqAq7527UOhVm47kdUtgdC1DlApxgiLOA4ZPBFdCE,386
8
+ homa/activations/BaseDLReLU.py,sha256=iRmDhhbFaO8N9G8u5M01s8-y-09t7poP96oA6uQkVq8,1186
9
+ homa/activations/CaLU.py,sha256=n0drKwp4GstHql69p4S58KeVctdaQ1B5oK_AIoI_okk,331
10
+ homa/activations/DLReLU.py,sha256=Q8l2zpR5q_tSgfgbz90uDXbXMbBT3b_7BWKw6JbtpQE,191
11
+ homa/activations/ERF.py,sha256=tDgHbo7UNFU93XPlcQCBRRxPMksr-FOE19mlsqfzmU8,252
12
+ homa/activations/Elliot.py,sha256=RDxERH9vFh6FYwtZXKHMDmLVG2ia1UfOoW18Gm2_8hM,298
13
+ homa/activations/ExpExpish.py,sha256=iq_uOmmV9EIz2eKowEy7SCeW-OMgGcEeMcivTnPc-Y0,202
14
+ homa/activations/ExponentialDLReLU.py,sha256=aVtah3c4sokB-aSdbVa5F_uq06IyXHwovnHtXlKGYlw,199
15
+ homa/activations/ExponentialSwish.py,sha256=nJtGu1TRHa2GSQ35w2MN0HEWzFogVvA9R2pGEkFvJX4,266
16
+ homa/activations/GCU.py,sha256=hXwty6WPovnhPGAxQDd4bIixujdoMOORN-77imVri7s,199
17
+ homa/activations/GaLU.py,sha256=5QHnHsUsLAy28s-LTxtwRN-t1hO1tg9xtWmkzE1T7Ck,308
18
+ homa/activations/GaussianReLU.py,sha256=ufNeVnod6dxkPLmdd9ye-xt0SIWap2dehX14_YxSZVM,2051
19
+ homa/activations/GeneralizedSwish.py,sha256=zv6CX83cOTVnN0yoCIXIvIgkjXLnmm_T_LsvyoN7lOY,236
20
+ homa/activations/Gish.py,sha256=Xohk5tTmeGTmQ4PXtHF5sPBDikmoNTjdEJzy2KPDmOI,249
21
+ homa/activations/LaLU.py,sha256=UiulXzSTmnoU_Gp8qKigFoL6efonqbldUlsBBlm9mB8,356
22
+ homa/activations/LogLogish.py,sha256=lfVRNhnDGbYYakTsUmePmr5azkzz_NQwEy6NvSSD-Do,205
23
+ homa/activations/LogSigmoid.py,sha256=PUvr84dRRd6L-VZ_9UeWAN9lhUFr2Otj8VrAIQ3eOEM,239
24
+ homa/activations/Logish.py,sha256=CnL-10b76C2EaDm56N4n2GYCaYJUKl_k7H82UBJI5to,257
25
+ homa/activations/MeLU.py,sha256=f13h2AAQCwp9soR3RWbMAA4Bl38oqRdBAsdzh6Bf4k8,321
26
+ homa/activations/MexicanReLU.py,sha256=vfDa1lWI-PgY4ztDY34aeBMaJ2rOyAYt5ifZBG0DS0c,1946
27
+ homa/activations/MinSin.py,sha256=JzQsmuffRAGGcD40nlz2ZnOGhQMEU0JYBIeFHIC1qcE,250
28
+ homa/activations/NReLU.py,sha256=mX4B2OXw28M8zyd6RpkaSoOCGZuB3FaksO4oFyr3YD8,314
29
+ homa/activations/NoisyReLU.py,sha256=2YFkOS_h8EijvCugYQTfq_gQg5uNEkcypcm0iDgEHIg,134
30
+ homa/activations/PLogish.py,sha256=ia_V0xewAS6mmX3G8JNQTxWl66ea8P3MuRLQkEPv-I8,165
31
+ homa/activations/ParametricLogish.py,sha256=grCGG61xTDytA4iOK3kS70V9m2bYoiSmuilJV3U8vIw,360
32
+ homa/activations/Phish.py,sha256=CLAV1fLHRAq-GxBut-_FsSYJRMlk5sOFVlcXs3G3w9c,280
33
+ homa/activations/RReLU.py,sha256=ILpkmoWk8WatXusrPqSLu15xMWQALwRQVZjhzwmw1PM,476
34
+ homa/activations/RandomizedSlopedReLU.py,sha256=O20XX3vRRmkERxwhLSNgue-fn0qSRoF7rlIN1LSWlyI,169
35
+ homa/activations/SGELU.py,sha256=AaNmXRoFQ68Xsgt4sSWMZxnSCTR5OD5ZEuqxxg1mvfg,358
36
+ homa/activations/SReLU.py,sha256=xyChK3G2HPpM7C8icQNfMzrOm142boDLY31n9yXqPtg,1472
37
+ homa/activations/SelfArctan.py,sha256=Sq3yWGXjxdP32J-rSZ38BQ5S_XErr5H1ZyPsMF1VKfI,193
38
+ homa/activations/ShiftedReLU.py,sha256=JVsf2F6C13PRICjnVOSVEsx9IoQ9rcM2TFn55DguZQs,229
39
+ homa/activations/SigmoidDerivative.py,sha256=4PPT-QX4MW9ySKU4Qv9K-y--lxlqFxvKVviC2S3e6Z0,274
40
+ homa/activations/SineReLU.py,sha256=gzYF1ZEZAFYmUuABWJf18LIer1oPAS38i_5NLcIhP-I,357
41
+ homa/activations/SlopedReLU.py,sha256=j6YfM4msg6It-ANbpMzEaXkiHvgEdhFNFbB6NkY6KpE,421
42
+ homa/activations/SmallGaLU.py,sha256=ERrK-g3QMZTNFDzUyiSLAovymEpV5h1x1696CN5K4Zg,289
43
+ homa/activations/Smish.py,sha256=hsr5FS4KywsCmsuFUKP-4pKoXkJK0hhRVDleq_CFGX0,198
44
+ homa/activations/SoftsignRReLU.py,sha256=bBSjYDLUVKxXPyaJExYXndEO3oORnP3M6NKoU-hiCCQ,564
45
+ homa/activations/Suish.py,sha256=I459CV24NV1JlLbko4oHUOh98fxoLaM-2SH71pVMcwA,279
46
+ homa/activations/TBSReLU.py,sha256=ZfYY_M6msDimJAOHr1HyrG1HHnWiJ7hnZ5hWjCFPecU,320
47
+ homa/activations/TSReLU.py,sha256=gbU0Q7zhf3X6oWvKUSry6sVdRhuaxIQt8keFH3WsxV8,256
48
+ homa/activations/TangentBipolarSigmoidReLU.py,sha256=YtrFHkFbEbx7aeIpIRc9TLCxhpveUFSgAnvTKaKLZ4E,156
49
+ homa/activations/TangentSigmoidReLU.py,sha256=C47UK6ADWsG2ueaZe9FUt-sPBzeuBLkiNjpkDZOCYGc,146
50
+ homa/activations/TeLU.py,sha256=qU5x0EskjQs6d5rCtbL91C6cMAm8vjDnjQNMX0LcEt8,180
51
+ homa/activations/TripleStateSwish.py,sha256=UG5BGY29wUEJaryClB2rDM90s0jt5vMJF9Kv-5M4Rgo,507
52
+ homa/activations/WideMeLU.py,sha256=ieJjTjnK9JJtApPFGpmTynu3G8YlyH5jw6qnhkJkStI,421
53
+ homa/activations/__init__.py,sha256=2GHNqrOp6WoLAtFFJcSj6j4GP-W8-YAYRZGX9vZbcmU,1659
54
+ homa/activations/learnable/AOAF.py,sha256=1ArhgpI6PfCRePgvFq8VqKDQ9rDMHZb0bm6g4Tiz13s,510
55
+ homa/activations/learnable/AReLU.py,sha256=Pfyv_7EEwGgW4_UyKc8CiSg7lhTcO7LZ7uIUeVQWLpA,737
56
+ homa/activations/learnable/DPReLU.py,sha256=xQhYTJ0-mfRGdld950xoTh8c9O08WIY50K0FjPtVVFs,507
57
+ homa/activations/learnable/DualLine.py,sha256=cgqyE7dVqXflT8ulCuOyKQQa09FYSj8vJkeVUEOaeIU,600
58
+ homa/activations/learnable/FReLU.py,sha256=qQ8GjjWWGeoE6qW9tw49mZPs29app0QK1AFOuMc5ASU,413
59
+ homa/activations/learnable/LeLeLU.py,sha256=ya2m60QRcpVlTwMejJTgMTxM3RRHF0RgNe72_EdD1-U,425
60
+ homa/activations/learnable/PERU.py,sha256=y2OxRLIA1HTUnFyRHs0zgLhLMJhQz9Q4F6QrqBSkQ00,513
61
+ homa/activations/learnable/PiLU.py,sha256=w7LkBBs_hr07pvizUie5Z49UkHg3O8LHA-wFK4hbnjE,612
62
+ homa/activations/learnable/ShiLU.py,sha256=35VC1pCAWMaxHKWYBeXd2DrXn1tepvQaT7a-KwoNdHY,475
63
+ homa/activations/learnable/StarReLU.py,sha256=hrscp-A0HnIvebFPLGr86K5Uf_U--EWtpNDqdNgonA0,485
64
+ homa/activations/learnable/__init__.py,sha256=yDzcgM_n5sNEU0kz9I0aVgGihpw_2RvtkCCylaTCPEQ,260
65
+ homa/activations/learnable/concerns/ChannelBased.py,sha256=pSKnWOKVOdb0GoiBobSSUANaZPGNwT9rxBnJUpZ9Eac,1206
66
+ homa/activations/learnable/concerns/__init__.py,sha256=CubRRYQEQMAK2-igsYKD8tcyesPOYoZYF_IlHzRZXi4,39
67
+ homa/cli/HomaCommand.py,sha256=w-Dg6dFpoXbQx2tvWSLdND2pdhqB2cPSORyi4MfY8XY,307
68
+ homa/cli/Commands/Command.py,sha256=DnmsEwpaxdQaLjzyYBO7qtIQTLwYzyhJS31YazA1IHg,24
69
+ homa/cli/Commands/InitCommand.py,sha256=3whh2mWLuevXpUyRpDEMbo_KNeAIdO2aLMFnC2nz_0c,1159
70
+ homa/cli/Commands/__init__.py,sha256=PYKkcG06R5LqLnp2x8otuimzRpL4oMbziL3xEMkCffc,66
71
+ homa/cli/namespaces/CacheNamespace.py,sha256=QXGljzj287stzTx0y_MXnqvCgPLqd7WjSPop2WDe14E,784
72
+ homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33IFCiqYro,590
73
+ homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
74
+ homa/ensemble/Ensemble.py,sha256=GNkXEV7Nli8lHSTQ3qTTCTeSBwST1PLZS5wxpKpeC5U,290
75
+ homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
76
+ homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
77
+ homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
78
+ homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
79
+ homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
80
+ homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
81
+ homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
82
+ homa/ensemble/concerns/ReportsLogits.py,sha256=vTGuC9NR4rno3Mkbm0MhL8f7YopuCErGyjIorxamKTM,461
83
+ homa/ensemble/concerns/ReportsSize.py,sha256=S7lo_Wu6rDnuqyAcv6AI6jspaBhcpfsirpp9RVD8c20,238
84
+ homa/ensemble/concerns/StoresModels.py,sha256=tfql0sr_Y27cHEJxZkc9AUQYlQRe0HtbN4JD940lKqY,1001
85
+ homa/ensemble/concerns/__init__.py,sha256=X0F_b2Jsv0XpiNhYwJsl-dfPsBOdEeW53LQPE4xQD0w,479
86
+ homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
87
+ homa/loss/Loss.py,sha256=COUr_idShYgAP8xKCxcaXbyUyAoJg7IOON0ARTQykmQ,21
88
+ homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
89
+ homa/torch/__init__.py,sha256=HTxCVaw1TLgpHMH8guB3hHYQ80cX6_fSEoPT_hz2Y8w,23
90
+ homa/torch/helpers.py,sha256=CLbTCXRrroM0n4PfM-K_xFavs4dCZJEu_L7hdgb1DCI,134
91
+ homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
92
+ homa/vision/Model.py,sha256=JIeVpHJwirHfsDfYYbLsu0kt7bGf4nhMQGIOagUDKw4,22
93
+ homa/vision/Resnet.py,sha256=Uitf58bEzIKkZd-F4FTvJ8nmhoFHlzZjJTvBPXEt2Iw,513
94
+ homa/vision/StochasticClassifier.py,sha256=6-o0TaH4iWXiPFefL7DOdLr3ZrTnjnJ9PIgQLlygN8w,497
95
+ homa/vision/StochasticSwin.py,sha256=FggzfaVYrP4fnjAFcdMpDozwQHc7CQhl3iRw78oQh0o,425
96
+ homa/vision/Swin.py,sha256=W3XbfUTrjaIhMH8fI_whPP6XO9fVA2R34LlGfQ1hoyo,508
97
+ homa/vision/__init__.py,sha256=w5OkcmdU6Ik5wHIJzeV1Z2UElQtvCsUZks1Q-xViSVg,153
98
+ homa/vision/utils.py,sha256=WB2b7eMDaf6UO3SuS7cB6IJk-9NRQesLavuzWUZRZyg,389
99
+ homa/vision/concerns/HasLabels.py,sha256=fM6nHLeQaEaWDlV6R8NQ5hgOSiwspPxOIwj-nvYXbP0,321
100
+ homa/vision/concerns/HasLogits.py,sha256=oStX4NCV7zwxI7Vj23M8wQSlY1xoSmAYJ_6cBNJpVCk,290
101
+ homa/vision/concerns/HasProbabilities.py,sha256=m1_ObS2BNYO-WVCNVMiHXzC3XAsyb88_0N4BWVDwCw0,221
102
+ homa/vision/concerns/ReportsAccuracy.py,sha256=DD0YTr5i8JMllIJTQn88Dn711yjZ2uiecaTi7WqpOEw,986
103
+ homa/vision/concerns/ReportsMetrics.py,sha256=93Hw_JBUbwfkrJNJA1xFSQ4cqRwzbSv4nPU524PGF6I,169
104
+ homa/vision/concerns/Trainable.py,sha256=SRCW3XpG9_DQgubyqhALlYDHwAWNzVVFjshUv1ecuEQ,988
105
+ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioURsg,234
106
+ homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
107
+ homa/vision/modules/SwinModule.py,sha256=h7wq1YdKoN6-7C3FVFA0bpkAET_30002iTRbjZxziFQ,714
108
+ homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
109
+ homa-0.2.95.dist-info/METADATA,sha256=Tt_dtrzp2O9_bhBkhZAjId_k_kRQI6z9ze6aQJhId_s,1760
110
+ homa-0.2.95.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
111
+ homa-0.2.95.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
112
+ homa-0.2.95.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
113
+ homa-0.2.95.dist-info/RECORD,,
@@ -1,86 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.parameter import Parameter, UninitializedParameter
4
- import torch.nn.functional as F
5
-
6
-
7
- class APLU(nn.Module):
8
- def __init__(self, max_input: float = 1.0):
9
- super().__init__()
10
- self.max_input = float(max_input)
11
- self.alpha = UninitializedParameter()
12
- self.beta = UninitializedParameter()
13
- self.gamma = UninitializedParameter()
14
- self.xi = UninitializedParameter()
15
- self.psi = UninitializedParameter()
16
- self.mu = UninitializedParameter()
17
- self._num_channels = None
18
-
19
- def _initialize_parameters(self, x: torch.Tensor):
20
- if x.ndim < 2:
21
- raise ValueError(
22
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
23
- )
24
-
25
- channels = int(x.shape[1])
26
- self._num_channels = channels
27
- param_shape = [1] * x.ndim
28
- param_shape[1] = channels
29
-
30
- with torch.no_grad():
31
- self.alpha = Parameter(
32
- torch.zeros(param_shape, dtype=x.dtype, device=x.device)
33
- )
34
- self.beta = Parameter(
35
- torch.zeros(param_shape, dtype=x.dtype, device=x.device)
36
- )
37
- self.gamma = Parameter(
38
- torch.zeros(param_shape, dtype=x.dtype, device=x.device)
39
- )
40
- self.xi = Parameter(
41
- torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
42
- 0.0, self.max_input
43
- )
44
- )
45
- self.psi = Parameter(
46
- torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
47
- 0.0, self.max_input
48
- )
49
- )
50
- self.mu = Parameter(
51
- torch.empty(param_shape, dtype=x.dtype, device=x.device).uniform_(
52
- 0.0, self.max_input
53
- )
54
- )
55
-
56
- def reset_parameters(self):
57
- if isinstance(self.alpha, UninitializedParameter):
58
- return
59
-
60
- with torch.no_grad():
61
- self.alpha.zero_()
62
- self.beta.zero_()
63
- self.gamma.zero_()
64
- self.xi.uniform_(0.0, self.max_input)
65
- self.psi.uniform_(0.0, self.max_input)
66
- self.mu.uniform_(0.0, self.max_input)
67
-
68
- def forward(self, x: torch.Tensor):
69
- if isinstance(self.alpha, UninitializedParameter):
70
- self._initialize_parameters(x)
71
-
72
- if x.ndim < 2:
73
- raise ValueError(
74
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
75
- )
76
- if self._num_channels is not None and x.shape[1] != self._num_channels:
77
- raise RuntimeError(
78
- f"APLU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
79
- "Create a new APLU for a different channel size."
80
- )
81
-
82
- a = F.relu(x)
83
- b = self.alpha * F.relu(-x + self.xi)
84
- c = self.beta * F.relu(-x + self.psi)
85
- d = self.gamma * F.relu(-x + self.mu)
86
- return a + b + c + d
@@ -1,67 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.parameter import Parameter, UninitializedParameter
4
- import torch.nn.functional as F
5
-
6
-
7
- class GALU(nn.Module):
8
- def __init__(self, max_input: float = 1.0):
9
- super().__init__()
10
- if max_input <= 0:
11
- raise ValueError("max_input must be positive.")
12
- self.max_input = float(max_input)
13
- self.alpha: torch.Tensor = UninitializedParameter()
14
- self.beta: torch.Tensor = UninitializedParameter()
15
- self.gamma: torch.Tensor = UninitializedParameter()
16
- self.delta: torch.Tensor = UninitializedParameter()
17
-
18
- def _initialize_parameters(self, x: torch.Tensor):
19
- if x.ndim < 2:
20
- raise ValueError(
21
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
22
- )
23
- param_shape = [1] * x.ndim
24
- param_shape[1] = int(x.shape[1])
25
- zeros = torch.zeros(param_shape, dtype=x.dtype, device=x.device)
26
- with torch.no_grad():
27
- for name in ("alpha", "beta", "gamma", "delta"):
28
- setattr(self, name, Parameter(zeros.clone()))
29
-
30
- def reset_parameters(self):
31
- for name in ("alpha", "beta", "gamma", "delta"):
32
- p = getattr(self, name)
33
- if not isinstance(p, UninitializedParameter):
34
- with torch.no_grad():
35
- p.zero_()
36
-
37
- def forward(self, x: torch.Tensor):
38
- if isinstance(self.alpha, UninitializedParameter):
39
- self._initialize_parameters(x)
40
-
41
- if x.ndim < 2:
42
- raise ValueError(
43
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
44
- )
45
- if not isinstance(self.alpha, UninitializedParameter) and x.shape[1] != self.alpha.shape[1]:
46
- raise RuntimeError(
47
- f"GALU was initialized with C={self.alpha.shape[1]} but got C={x.shape[1]}. "
48
- "Create a new GALU for a different channel size."
49
- )
50
-
51
- x_norm = x / self.max_input
52
- zero = x.new_zeros(1)
53
- part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
54
- part_beta = self.beta * (
55
- F.relu(1.0 - torch.abs(x_norm - 1.0))
56
- + torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
57
- )
58
- part_gamma = self.gamma * (
59
- F.relu(0.5 - torch.abs(x_norm - 0.5))
60
- + torch.minimum(torch.abs(x_norm - 1.5) - 0.5, zero)
61
- )
62
- part_delta = self.delta * (
63
- F.relu(0.5 - torch.abs(x_norm - 2.5))
64
- + torch.minimum(torch.abs(x_norm - 3.5) - 0.5, zero)
65
- )
66
- z = part_prelu + part_beta + part_gamma + part_delta
67
- return z * self.max_input
@@ -1,70 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class MELU(nn.Module):
7
- def __init__(self, maxInput: float = 1.0):
8
- super().__init__()
9
- self.maxInput = float(maxInput)
10
- self._num_channels = None
11
- self.register_parameter("alpha", None)
12
- self.register_parameter("beta", None)
13
- self.register_parameter("gamma", None)
14
- self.register_parameter("delta", None)
15
- self.register_parameter("xi", None)
16
- self.register_parameter("psi", None)
17
-
18
- def _ensure_parameters(self, x: torch.Tensor):
19
- if x.dim() != 4:
20
- raise ValueError(
21
- f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
22
- )
23
- c = int(x.shape[1])
24
- if self._num_channels is None:
25
- self._num_channels = c
26
- elif c != self._num_channels:
27
- raise RuntimeError(
28
- f"MELU was initialized with C={self._num_channels} but got C={c}. "
29
- "Create a new MELU for a different channel size."
30
- )
31
-
32
- if self.alpha is None:
33
- shape = (1, c, 1, 1)
34
- device, dtype = x.device, x.dtype
35
- for name in ("alpha", "beta", "gamma", "delta", "xi", "psi"):
36
- setattr(
37
- self,
38
- name,
39
- nn.Parameter(torch.zeros(shape, dtype=dtype, device=device)),
40
- )
41
-
42
- def reset_parameters(self):
43
- for p in (self.alpha, self.beta, self.gamma, self.delta, self.xi, self.psi):
44
- if p is not None:
45
- with torch.no_grad():
46
- p.zero_()
47
-
48
- def forward(self, X: torch.Tensor) -> torch.Tensor:
49
- self._ensure_parameters(X)
50
-
51
- X_norm = X / self.maxInput
52
- Y = torch.roll(X_norm, shifts=-1, dims=1)
53
-
54
- term1 = F.relu(X_norm)
55
- term2 = self.alpha * torch.clamp(X_norm, max=0)
56
-
57
- dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
58
- dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
59
- dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
60
- dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
61
- dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
62
-
63
- term3 = self.beta * torch.sqrt(F.relu(2 - dist_sq_beta))
64
- term4 = self.gamma * torch.sqrt(F.relu(1 - dist_sq_gamma))
65
- term5 = self.delta * torch.sqrt(F.relu(1 - dist_sq_delta))
66
- term6 = self.xi * torch.sqrt(F.relu(1 - dist_sq_xi))
67
- term7 = self.psi * torch.sqrt(F.relu(1 - dist_sq_psi))
68
-
69
- Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7
70
- return Z_norm * self.maxInput
@@ -1,54 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class PDELU(nn.Module):
7
- def __init__(self, theta: float = 0.5):
8
- super().__init__()
9
- if theta == 1.0:
10
- raise ValueError(
11
- "theta cannot be 1.0, as it would cause a division by zero."
12
- )
13
- self.theta = float(theta)
14
- self._power_val = 1.0 / (1.0 - self.theta)
15
- self.register_parameter("alpha", None)
16
- self._num_channels = None
17
-
18
- def _ensure_parameters(self, x: torch.Tensor):
19
- if x.ndim < 2:
20
- raise ValueError(
21
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
22
- )
23
-
24
- c = int(x.shape[1])
25
- if self._num_channels is None:
26
- self._num_channels = c
27
- elif c != self._num_channels:
28
- raise RuntimeError(
29
- f"PDELU was initialized with C={self._num_channels} but got C={c}. "
30
- "Create a new PDELU for a different channel size."
31
- )
32
-
33
- if self.alpha is None:
34
- param_shape = [1] * x.ndim
35
- param_shape[1] = c
36
- self.alpha = nn.Parameter(
37
- torch.full(param_shape, 0.1, dtype=x.dtype, device=x.device)
38
- )
39
-
40
- def reset_parameters(self):
41
- if self.alpha is not None:
42
- with torch.no_grad():
43
- self.alpha.fill_(0.1)
44
-
45
- def forward(self, x: torch.Tensor):
46
- self._ensure_parameters(x)
47
-
48
- positive_part = F.relu(x)
49
- inner_term = F.relu(1.0 + (1.0 - self.theta) * x)
50
- powered_term = torch.pow(inner_term, self._power_val)
51
- subtracted_term = powered_term - 1.0
52
- zero = torch.zeros(1, dtype=x.dtype, device=x.device)
53
- negative_part = self.alpha * torch.minimum(subtracted_term, zero)
54
- return positive_part + negative_part
@@ -1,69 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class SReLU(nn.Module):
6
- def __init__(
7
- self,
8
- alpha_init: float = 0.0,
9
- beta_init: float = 0.0,
10
- gamma_init: float = 1.0,
11
- delta_init: float = 1.0,
12
- ):
13
- super().__init__()
14
- self.alpha_init_val = float(alpha_init)
15
- self.beta_init_val = float(beta_init)
16
- self.gamma_init_val = float(gamma_init)
17
- self.delta_init_val = float(delta_init)
18
- self._num_channels = None
19
- self.register_parameter("alpha", None)
20
- self.register_parameter("beta", None)
21
- self.register_parameter("gamma", None)
22
- self.register_parameter("delta", None)
23
-
24
- def _ensure_parameters(self, x: torch.Tensor):
25
- if x.dim() != 4:
26
- raise ValueError(
27
- f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
28
- )
29
- c = int(x.shape[1])
30
- if self._num_channels is None:
31
- self._num_channels = c
32
- elif c != self._num_channels:
33
- raise RuntimeError(
34
- f"SReLU was initialized with C={self._num_channels} but got C={c}. "
35
- "Create a new SReLU for different channel sizes."
36
- )
37
-
38
- if self.alpha is None:
39
- shape = (1, c, 1, 1)
40
- device, dtype = x.device, x.dtype
41
- self.alpha = nn.Parameter(
42
- torch.full(shape, self.alpha_init_val, dtype=dtype, device=device)
43
- )
44
- self.beta = nn.Parameter(
45
- torch.full(shape, self.beta_init_val, dtype=dtype, device=device)
46
- )
47
- self.gamma = nn.Parameter(
48
- torch.full(shape, self.gamma_init_val, dtype=dtype, device=device)
49
- )
50
- self.delta = nn.Parameter(
51
- torch.full(shape, self.delta_init_val, dtype=dtype, device=device)
52
- )
53
-
54
- def reset_parameters(self):
55
- if self.alpha is not None:
56
- with torch.no_grad():
57
- self.alpha.fill_(self.alpha_init_val)
58
- self.beta.fill_(self.beta_init_val)
59
- self.gamma.fill_(self.gamma_init_val)
60
- self.delta.fill_(self.delta_init_val)
61
-
62
- def forward(self, x: torch.Tensor) -> torch.Tensor:
63
- self._ensure_parameters(x)
64
-
65
- start = self.beta + self.alpha * (x - self.beta)
66
- finish = self.delta + self.gamma * (x - self.delta)
67
- out = torch.where(x < self.beta, start, x)
68
- out = torch.where(x > self.delta, finish, out)
69
- return out
@@ -1,58 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.parameter import Parameter
4
- import torch.nn.functional as F
5
-
6
-
7
- class SmallGALU(nn.Module):
8
- def __init__(self, max_input: float = 1.0):
9
- super().__init__()
10
- if max_input <= 0:
11
- raise ValueError("max_input must be positive.")
12
- self.max_input = float(max_input)
13
- self.register_parameter("alpha", None)
14
- self.register_parameter("beta", None)
15
- self._num_channels = None
16
-
17
- def _initialize_parameters(self, x: torch.Tensor):
18
- if x.ndim < 2:
19
- raise ValueError(
20
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
21
- )
22
- self._num_channels = int(x.shape[1])
23
- param_shape = [1] * x.ndim
24
- param_shape[1] = self._num_channels
25
- device = x.device
26
- dtype = x.dtype
27
- self.alpha = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
28
- self.beta = Parameter(torch.zeros(param_shape, dtype=dtype, device=device))
29
-
30
- def reset_parameters(self):
31
- if self.alpha is not None:
32
- with torch.no_grad():
33
- self.alpha.zero_()
34
- self.beta.zero_()
35
-
36
- def forward(self, x: torch.Tensor):
37
- if self.alpha is None:
38
- self._initialize_parameters(x)
39
- else:
40
- if x.ndim < 2:
41
- raise ValueError(
42
- f"Input tensor must have at least 2 dimensions (N, C), but got shape {tuple(x.shape)}"
43
- )
44
- if x.shape[1] != self._num_channels:
45
- raise RuntimeError(
46
- f"SmallGALU was initialized with C={self._num_channels} but got C={x.shape[1]}. "
47
- "Create a new SmallGALU for a different channel size."
48
- )
49
-
50
- x_norm = x / self.max_input
51
- zero = torch.zeros(1, dtype=x.dtype, device=x.device)
52
- part_prelu = F.relu(x_norm) + self.alpha * torch.minimum(x_norm, zero)
53
- part_beta = self.beta * (
54
- F.relu(1.0 - torch.abs(x_norm - 1.0))
55
- + torch.minimum(torch.abs(x_norm - 3.0) - 1.0, zero)
56
- )
57
- z = part_prelu + part_beta
58
- return z * self.max_input
@@ -1,90 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class WideMELU(nn.Module):
7
- def __init__(self, maxInput: float = 1.0):
8
- super().__init__()
9
- self.maxInput = float(maxInput)
10
- self._num_channels = None
11
- self.register_parameter("alpha", None)
12
- self.register_parameter("beta", None)
13
- self.register_parameter("gamma", None)
14
- self.register_parameter("delta", None)
15
- self.register_parameter("xi", None)
16
- self.register_parameter("psi", None)
17
- self.register_parameter("theta", None)
18
- self.register_parameter("lam", None)
19
-
20
- def _ensure_parameters(self, x: torch.Tensor):
21
- if x.dim() != 4:
22
- raise ValueError(
23
- f"Expected 4D input (N, C, H, W), got {x.dim()}D with shape {tuple(x.shape)}"
24
- )
25
-
26
- c = int(x.shape[1])
27
- if self._num_channels is None:
28
- self._num_channels = c
29
- elif c != self._num_channels:
30
- raise RuntimeError(
31
- f"WideMELU was initialized with C={self._num_channels} but got C={c}. "
32
- "Create a new WideMELU for different channel sizes."
33
- )
34
-
35
- if self.alpha is None:
36
- shape = (1, c, 1, 1)
37
- device, dtype = x.device, x.dtype
38
- for name in (
39
- "alpha",
40
- "beta",
41
- "gamma",
42
- "delta",
43
- "xi",
44
- "psi",
45
- "theta",
46
- "lam",
47
- ):
48
- param = nn.Parameter(torch.zeros(shape, dtype=dtype, device=device))
49
- setattr(self, name, param)
50
-
51
- def reset_parameters(self):
52
- params = (
53
- self.alpha,
54
- self.beta,
55
- self.gamma,
56
- self.delta,
57
- self.xi,
58
- self.psi,
59
- self.theta,
60
- self.lam,
61
- )
62
- for p in params:
63
- if p is not None:
64
- with torch.no_grad():
65
- p.zero_()
66
-
67
- def forward(self, x: torch.Tensor) -> torch.Tensor:
68
- self._ensure_parameters(x)
69
-
70
- X_norm = x / self.maxInput
71
- Y = torch.roll(X_norm, shifts=-1, dims=1)
72
-
73
- term1 = F.relu(X_norm)
74
- term2 = self.alpha * torch.clamp(X_norm, max=0)
75
- dist_sq_beta = (X_norm - 2) ** 2 + (Y - 2) ** 2
76
- dist_sq_gamma = (X_norm - 1) ** 2 + (Y - 1) ** 2
77
- dist_sq_delta = (X_norm - 1) ** 2 + (Y - 3) ** 2
78
- dist_sq_xi = (X_norm - 3) ** 2 + (Y - 1) ** 2
79
- dist_sq_psi = (X_norm - 3) ** 2 + (Y - 3) ** 2
80
- dist_sq_theta = (X_norm - 1) ** 2 + (Y - 2) ** 2
81
- dist_sq_lambda = (X_norm - 3) ** 2 + (Y - 2) ** 2
82
- term3 = self.beta * torch.sqrt(F.relu(2 - dist_sq_beta))
83
- term4 = self.gamma * torch.sqrt(F.relu(1 - dist_sq_gamma))
84
- term5 = self.delta * torch.sqrt(F.relu(1 - dist_sq_delta))
85
- term6 = self.xi * torch.sqrt(F.relu(1 - dist_sq_xi))
86
- term7 = self.psi * torch.sqrt(F.relu(1 - dist_sq_psi))
87
- term8 = self.theta * torch.sqrt(F.relu(1 - dist_sq_theta))
88
- term9 = self.lam * torch.sqrt(F.relu(1 - dist_sq_lambda))
89
- Z_norm = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8 + term9
90
- return Z_norm * self.maxInput
@@ -1,7 +0,0 @@
1
- from .APLU import APLU
2
- from .GALU import GALU
3
- from .SmallGALU import SmallGALU
4
- from .MELU import MELU
5
- from .WideMELU import WideMELU
6
- from .PDELU import PDELU
7
- from .SReLU import SReLU
homa/activations/utils.py DELETED
@@ -1,27 +0,0 @@
1
- import torch
2
-
3
-
4
- def negative_part(x):
5
- return torch.minimum(x, torch.zeros_like(x))
6
-
7
-
8
- def positive_part(x):
9
- return torch.maximum(x, torch.zeros_like(x))
10
-
11
-
12
- def as_channel_parameters(parameter: torch.Tensor, x: torch.Tensor):
13
- shape = [1] * x.dim()
14
- shape[1] = -1
15
- return parameter.view(*shape)
16
-
17
-
18
- def device_compatibility_check(model, x: torch.Tensor):
19
- for p in model.parameters():
20
- if p.device != x.device or p.dtype != x.dtype:
21
- p.data = p.data.to(device=x.device, dtype=x.dtype)
22
-
23
-
24
- def phi_hat(x, a, lam):
25
- term_pos = torch.maximum(lam - torch.abs(x - a), torch.zeros_like(x))
26
- term_neg = torch.minimum(torch.abs(x - (a + 2 * lam)) - lam, torch.zeros_like(x))
27
- return term_pos + term_neg
@@ -1,9 +0,0 @@
1
- import torch
2
- from .Resnet import Resnet
3
- from .StochasticClassifier import StochasticClassifier
4
-
5
-
6
- class StochasticResnet(Resnet, StochasticClassifier):
7
- def __init__(self, *args, **kwargs):
8
- super().__init__(*args, **kwargs)
9
- self.replace_activations(torch.nn.ReLU)