RRAEsTorch 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +14 -12
- RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
- RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
- RRAEsTorch/tests/test_fitting_CNN.py +14 -14
- RRAEsTorch/tests/test_fitting_MLP.py +11 -13
- RRAEsTorch/tests/test_save.py +11 -11
- RRAEsTorch/training_classes/training_classes.py +55 -115
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/METADATA +1 -1
- rraestorch-0.1.7.dist-info/RECORD +22 -0
- RRAEsTorch/tests/test_wrappers.py +0 -56
- RRAEsTorch/utilities/utilities.py +0 -1561
- RRAEsTorch/wrappers/__init__.py +0 -1
- RRAEsTorch/wrappers/wrappers.py +0 -237
- rraestorch-0.1.6.dist-info/RECORD +0 -26
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.6.dist-info → rraestorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
RRAEsTorch/wrappers/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .wrappers import vmap_wrap, norm_wrap
|
RRAEsTorch/wrappers/wrappers.py
DELETED
|
@@ -1,237 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import dataclasses
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from torch.func import vmap
|
|
5
|
-
|
|
6
|
-
@dataclass(frozen=True)
|
|
7
|
-
class NormParams:
|
|
8
|
-
pass
|
|
9
|
-
|
|
10
|
-
@dataclass(frozen=True)
|
|
11
|
-
class MeanStdParams(NormParams):
|
|
12
|
-
mean: float
|
|
13
|
-
std: float
|
|
14
|
-
|
|
15
|
-
@dataclass(frozen=True)
|
|
16
|
-
class MinMaxParams(NormParams):
|
|
17
|
-
min: float
|
|
18
|
-
max: float
|
|
19
|
-
|
|
20
|
-
def find_norm_funcs(
|
|
21
|
-
array=None,
|
|
22
|
-
norm_typ="None",
|
|
23
|
-
params=None,
|
|
24
|
-
):
|
|
25
|
-
""" Function that finds norm and inv_norm functions
|
|
26
|
-
|
|
27
|
-
The functions can be defined either based on an array (in_train),
|
|
28
|
-
In this case, the required parameters (e.g., mean, min, etc.) are
|
|
29
|
-
found using in_train. Otherwise, the parameters can be given
|
|
30
|
-
explicitly using params_in (check below for some examples).
|
|
31
|
-
|
|
32
|
-
Parameters
|
|
33
|
-
----------
|
|
34
|
-
in_train : input array that is to be normalized
|
|
35
|
-
norm_in: Type of normalization, "minmax" and "meanstd" supported
|
|
36
|
-
params_in: Normalization parameters if they are to be given manually
|
|
37
|
-
instead of being computed from in_train.
|
|
38
|
-
These for example can be:
|
|
39
|
-
For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
|
|
40
|
-
For min/max normalization, params_out = {"min": -1, "max": 3,}
|
|
41
|
-
Returns
|
|
42
|
-
-------
|
|
43
|
-
A new subclass of base_cls with methods that normalize the input when called.
|
|
44
|
-
"""
|
|
45
|
-
if norm_typ != "None":
|
|
46
|
-
assert (params is not None) or (
|
|
47
|
-
array is not None
|
|
48
|
-
), "Either params or in_train must be provided to set norm parameters"
|
|
49
|
-
|
|
50
|
-
assert not (
|
|
51
|
-
params is not None and array is not None
|
|
52
|
-
), "Only one of params or in_train must be provided to set norm parameters"
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
match norm_typ:
|
|
56
|
-
case "minmax":
|
|
57
|
-
if params is None:
|
|
58
|
-
params = MinMaxParams(min=torch.min(array), max=torch.max(array))
|
|
59
|
-
else:
|
|
60
|
-
params = params
|
|
61
|
-
norm_fn = lambda self, x: (x - params.min) / (params.max - params.min)
|
|
62
|
-
inv_norm_fn = lambda self, x: x * (params.max - params.min) + params.min
|
|
63
|
-
case "meanstd":
|
|
64
|
-
if params is None:
|
|
65
|
-
params = MeanStdParams(mean=torch.mean(array), std=torch.std(array))
|
|
66
|
-
else:
|
|
67
|
-
params = params
|
|
68
|
-
norm_fn = lambda self, x: (x - params.mean) / params.std
|
|
69
|
-
inv_norm_fn = lambda self, x: x * params.std + params.mean
|
|
70
|
-
case "None":
|
|
71
|
-
if params is None:
|
|
72
|
-
params = NormParams()
|
|
73
|
-
else:
|
|
74
|
-
params = params
|
|
75
|
-
norm_fn = lambda self, x: x
|
|
76
|
-
inv_norm_fn = lambda self, x: x
|
|
77
|
-
case _:
|
|
78
|
-
raise NotImplementedError(f"norm_in specified {norm_typ} is not implemented.")
|
|
79
|
-
|
|
80
|
-
return norm_fn, inv_norm_fn, params
|
|
81
|
-
|
|
82
|
-
def norm_in_wrap(base_cls, array=None, norm_typ="None", params=None, methods_to_wrap=["__call__"]):
|
|
83
|
-
""" Wrapper that normalizes the input of a function of a subclass of eqx.Module
|
|
84
|
-
|
|
85
|
-
The parameters of normalization can be either computed based on an array, or given
|
|
86
|
-
by the user.
|
|
87
|
-
|
|
88
|
-
Parameters
|
|
89
|
-
----------
|
|
90
|
-
base_cls : Base class, subclass of eqx.Module of which functions will
|
|
91
|
-
be modified.
|
|
92
|
-
methods_to_wrap: Name of the methods in base_cls to be wrapped
|
|
93
|
-
array : The array from which norm parameters are to be found
|
|
94
|
-
norm_typ: Type of normalization, "meanstd" and "minmax" are supported
|
|
95
|
-
params: Parameters of normalization if these are to be given manually
|
|
96
|
-
These for example can be:
|
|
97
|
-
For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
|
|
98
|
-
For min/max normalization, params_out = {"min": -1, "max": 3,}
|
|
99
|
-
Returns
|
|
100
|
-
-------
|
|
101
|
-
A new subclass of base_cls with methods that normalize the input when called.
|
|
102
|
-
"""
|
|
103
|
-
norm_in, inv_norm_in, params_in = find_norm_funcs(array, norm_typ, params)
|
|
104
|
-
def norm_in_decorator(fn):
|
|
105
|
-
def wrapped(self, x, *args, **kwargs):
|
|
106
|
-
result = fn(self, norm_in(self, x), *args, **kwargs)
|
|
107
|
-
return result
|
|
108
|
-
return wrapped, {"norm_in": (callable, norm_in), "inv_norm_in": (callable, inv_norm_in), "params_in": (dict, params_in)}
|
|
109
|
-
return make_wrapped(base_cls, norm_in_decorator, methods_to_wrap)
|
|
110
|
-
|
|
111
|
-
def inv_norm_out_wrap(base_cls, array=None, norm_typ="None", params=None, methods_to_wrap=["__call__"]):
|
|
112
|
-
""" Wrapper that de-normalizes the output of a function of a subclass of eqx.Module
|
|
113
|
-
|
|
114
|
-
The parameters of normalization can be either computed based on an array, or given
|
|
115
|
-
by the user.
|
|
116
|
-
|
|
117
|
-
Parameters
|
|
118
|
-
----------
|
|
119
|
-
base_cls : Base class, subclass of eqx.Module of which functions will
|
|
120
|
-
be modified.
|
|
121
|
-
methods_to_wrap: Name of the methods in base_cls to be wrapped
|
|
122
|
-
array : The array from which norm parameters are to be found
|
|
123
|
-
norm_typ: Type of normalization, "meanstd" and "minmax" are supported
|
|
124
|
-
params: Parameters of normalization if these are to be given manually
|
|
125
|
-
These for example can be:
|
|
126
|
-
For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
|
|
127
|
-
For min/max normalization, params_out = {"min": -1, "max": 3,}
|
|
128
|
-
Returns
|
|
129
|
-
-------
|
|
130
|
-
A new subclass of base_cls with de-normalized methods.
|
|
131
|
-
The given methods accept an additional keyword argument "keep_normalized"
|
|
132
|
-
If this argument is passed as True, functions retrieve original behavior
|
|
133
|
-
"""
|
|
134
|
-
norm_out, inv_norm_out, params_out = find_norm_funcs(array, norm_typ, params)
|
|
135
|
-
|
|
136
|
-
def norm_out_decorator(fn):
|
|
137
|
-
def wrapped(self, x, *args, keep_normalized=False, **kwargs):
|
|
138
|
-
if keep_normalized:
|
|
139
|
-
result = fn(self, x, *args, **kwargs)
|
|
140
|
-
else:
|
|
141
|
-
result = inv_norm_out(self, fn(self, x, *args, **kwargs))
|
|
142
|
-
return result
|
|
143
|
-
return wrapped, {"norm_out": (callable, norm_out), "inv_norm_out": (callable, inv_norm_out), "params_out": (dict, params_out)}
|
|
144
|
-
return make_wrapped(base_cls, norm_out_decorator, methods_to_wrap)
|
|
145
|
-
|
|
146
|
-
def norm_wrap(base_cls, array_in=None, norm_typ_in="None", params_in=None, array_out=None, norm_typ_out="None", params_out=None, methods_to_wrap_in=["__call__"], methods_to_wrap_out=["__call__"]):
|
|
147
|
-
""" Wrapper that normalizes functions of a subclass of eqx.Module
|
|
148
|
-
|
|
149
|
-
Parameters
|
|
150
|
-
----------
|
|
151
|
-
base_cls : Base class, subclass of eqx.Module of which functions will
|
|
152
|
-
be modified.
|
|
153
|
-
methods_to_wrap: Name of the methods in base_cls to be wrapped
|
|
154
|
-
... : Other parameters are explained in norm_in_wrap and norm_out_wrap
|
|
155
|
-
|
|
156
|
-
Returns
|
|
157
|
-
-------
|
|
158
|
-
A new subclass of base_cls with normalized methods.
|
|
159
|
-
"""
|
|
160
|
-
after_in = norm_in_wrap(base_cls, array_in, norm_typ_in, params_in, methods_to_wrap_in)
|
|
161
|
-
after_out = inv_norm_out_wrap(after_in, array_out, norm_typ_out, params_out, methods_to_wrap_out)
|
|
162
|
-
return after_out
|
|
163
|
-
|
|
164
|
-
def vmap_wrap(base_cls, map_axis, count=1, methods_to_wrap=["__call__"]):
|
|
165
|
-
""" Wrapper that vectorizes functions of a subclass of eqx.Module
|
|
166
|
-
|
|
167
|
-
Parameters
|
|
168
|
-
----------
|
|
169
|
-
base_cls : Base class, subclass of eqx.Module of which functions will
|
|
170
|
-
be modified.
|
|
171
|
-
map_axis : Axis along which to vectorize the functions
|
|
172
|
-
methods_to_wrap: Name of the methods in base_cls to be wrapped
|
|
173
|
-
count: How many times to vectorize the functions
|
|
174
|
-
|
|
175
|
-
Returns
|
|
176
|
-
-------
|
|
177
|
-
A new subclass of base_cls with vectorized methods.
|
|
178
|
-
The given methods accept an additional keyword argument "no_map"
|
|
179
|
-
If this argument is passed as True, functions retrieve original behavior
|
|
180
|
-
"""
|
|
181
|
-
def vmap_decorator(fn):
|
|
182
|
-
def wrapped(self, x, *args, no_map=False, **kwargs):
|
|
183
|
-
if (map_axis is None) or no_map:
|
|
184
|
-
return fn(self, x, *args, **kwargs)
|
|
185
|
-
f = lambda x: fn(self, x, *args, **kwargs)
|
|
186
|
-
for _ in range(count):
|
|
187
|
-
f = vmap(f, in_dims=(map_axis,), out_dims=map_axis)
|
|
188
|
-
out = f(x)
|
|
189
|
-
return out
|
|
190
|
-
return wrapped, {}
|
|
191
|
-
return make_wrapped(base_cls, vmap_decorator, methods_to_wrap)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
def make_wrapped(base_cls, decorator, methods_to_wrap=["__call__"]):
|
|
196
|
-
"""
|
|
197
|
-
Create a subclass of base_cls with specified methods wrapped by decorator.
|
|
198
|
-
|
|
199
|
-
Parameters
|
|
200
|
-
----------
|
|
201
|
-
base_cls: Original class to wrap.
|
|
202
|
-
methods_to_wrap: List of method names (strings) to decorate.
|
|
203
|
-
decorator: The wanted modification to the methods given above
|
|
204
|
-
|
|
205
|
-
Returns
|
|
206
|
-
-------
|
|
207
|
-
A new subclass of base_cls with decorated methods.
|
|
208
|
-
"""
|
|
209
|
-
attrs = {}
|
|
210
|
-
annotations = {}
|
|
211
|
-
seen_fields = set()
|
|
212
|
-
|
|
213
|
-
for method_name in methods_to_wrap:
|
|
214
|
-
if not hasattr(base_cls, method_name):
|
|
215
|
-
raise AttributeError(f"Method {method_name} not found in {base_cls}")
|
|
216
|
-
|
|
217
|
-
original = getattr(base_cls, method_name)
|
|
218
|
-
wrapped_method, extra_fields = decorator(original)
|
|
219
|
-
attrs[method_name] = wrapped_method
|
|
220
|
-
|
|
221
|
-
# Sort field names to avoid ordering issues
|
|
222
|
-
for field_name in extra_fields:
|
|
223
|
-
if field_name in seen_fields:
|
|
224
|
-
continue
|
|
225
|
-
|
|
226
|
-
field_type, default_value = extra_fields[field_name]
|
|
227
|
-
annotations[field_name] = field_type
|
|
228
|
-
|
|
229
|
-
attrs[field_name] = dataclasses.field(default=default_value)
|
|
230
|
-
|
|
231
|
-
seen_fields.add(field_name)
|
|
232
|
-
|
|
233
|
-
if annotations:
|
|
234
|
-
attrs["__annotations__"] = annotations
|
|
235
|
-
|
|
236
|
-
Wrapped = type(base_cls.__name__, (base_cls,), attrs)
|
|
237
|
-
return Wrapped
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
|
|
2
|
-
RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
|
|
3
|
-
RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
|
|
4
|
-
RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
|
|
5
|
-
RRAEsTorch/AE_classes/AE_classes.py,sha256=oDpDzQasPbtK2L9vDLiG4VQdKH02VRCagOYT1-FAldo,18063
|
|
6
|
-
RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
|
|
7
|
-
RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=bEE9JnTo84t9w0a4kw1W74L51eLGjBB8trrlAG938RE,3182
|
|
8
|
-
RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Cr1_uP7lag6RPQC1UhN2O7RFW4BEx1cd0Z-Y6VgrWRg,2718
|
|
9
|
-
RRAEsTorch/tests/test_fitting_CNN.py,sha256=8i6oUZFS-DpaTh1VsRsd1rGG_Me7R3Kf1vO4lPSxYns,2797
|
|
10
|
-
RRAEsTorch/tests/test_fitting_MLP.py,sha256=6Ggy5iAJmqMEJPDoon7fYkwFS8LWmosTOFBQuAQzszc,3353
|
|
11
|
-
RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
|
|
12
|
-
RRAEsTorch/tests/test_save.py,sha256=_zYcDZtz_HWzHHWT6uJVq7ynsFmlH5v6nDE-loZQ4zo,1997
|
|
13
|
-
RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
|
|
14
|
-
RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2NgvY,2215
|
|
15
|
-
RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
|
|
16
|
-
RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
|
|
17
|
-
RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
|
|
18
|
-
RRAEsTorch/training_classes/training_classes.py,sha256=HU8Ksz1-2WwOMuwyGiWdkQ_vrrgBEwaeQT4avs4jd2E,37870
|
|
19
|
-
RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
|
|
20
|
-
RRAEsTorch/utilities/utilities.py,sha256=JfLkAPEC8fzwgM32LEcXVe0tA4C7UBgsrkuh6noUA_4,53372
|
|
21
|
-
RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
|
|
22
|
-
RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
|
|
23
|
-
rraestorch-0.1.6.dist-info/METADATA,sha256=BfnB-vhx0m-d79hCd8UgZT0c8GPHFRwVn8x-M8k_h6E,3028
|
|
24
|
-
rraestorch-0.1.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
25
|
-
rraestorch-0.1.6.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
|
|
26
|
-
rraestorch-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|