RRAEsTorch 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +0,0 @@
1
- from .wrappers import vmap_wrap, norm_wrap
@@ -1,237 +0,0 @@
1
- import torch
2
- import dataclasses
3
- from dataclasses import dataclass
4
- from torch.func import vmap
5
-
6
- @dataclass(frozen=True)
7
- class NormParams:
8
- pass
9
-
10
- @dataclass(frozen=True)
11
- class MeanStdParams(NormParams):
12
- mean: float
13
- std: float
14
-
15
- @dataclass(frozen=True)
16
- class MinMaxParams(NormParams):
17
- min: float
18
- max: float
19
-
20
- def find_norm_funcs(
21
- array=None,
22
- norm_typ="None",
23
- params=None,
24
- ):
25
- """ Function that finds norm and inv_norm functions
26
-
27
- The functions can be defined either based on an array (in_train),
28
- In this case, the required parameters (e.g., mean, min, etc.) are
29
- found using in_train. Otherwise, the parameters can be given
30
- explicitly using params_in (check below for some examples).
31
-
32
- Parameters
33
- ----------
34
- in_train : input array that is to be normalized
35
- norm_in: Type of normalization, "minmax" and "meanstd" supported
36
- params_in: Normalization parameters if they are to be given manually
37
- instead of being computed from in_train.
38
- These for example can be:
39
- For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
40
- For min/max normalization, params_out = {"min": -1, "max": 3,}
41
- Returns
42
- -------
43
- A new subclass of base_cls with methods that normalize the input when called.
44
- """
45
- if norm_typ != "None":
46
- assert (params is not None) or (
47
- array is not None
48
- ), "Either params or in_train must be provided to set norm parameters"
49
-
50
- assert not (
51
- params is not None and array is not None
52
- ), "Only one of params or in_train must be provided to set norm parameters"
53
-
54
-
55
- match norm_typ:
56
- case "minmax":
57
- if params is None:
58
- params = MinMaxParams(min=torch.min(array), max=torch.max(array))
59
- else:
60
- params = params
61
- norm_fn = lambda self, x: (x - params.min) / (params.max - params.min)
62
- inv_norm_fn = lambda self, x: x * (params.max - params.min) + params.min
63
- case "meanstd":
64
- if params is None:
65
- params = MeanStdParams(mean=torch.mean(array), std=torch.std(array))
66
- else:
67
- params = params
68
- norm_fn = lambda self, x: (x - params.mean) / params.std
69
- inv_norm_fn = lambda self, x: x * params.std + params.mean
70
- case "None":
71
- if params is None:
72
- params = NormParams()
73
- else:
74
- params = params
75
- norm_fn = lambda self, x: x
76
- inv_norm_fn = lambda self, x: x
77
- case _:
78
- raise NotImplementedError(f"norm_in specified {norm_typ} is not implemented.")
79
-
80
- return norm_fn, inv_norm_fn, params
81
-
82
- def norm_in_wrap(base_cls, array=None, norm_typ="None", params=None, methods_to_wrap=["__call__"]):
83
- """ Wrapper that normalizes the input of a function of a subclass of eqx.Module
84
-
85
- The parameters of normalization can be either computed based on an array, or given
86
- by the user.
87
-
88
- Parameters
89
- ----------
90
- base_cls : Base class, subclass of eqx.Module of which functions will
91
- be modified.
92
- methods_to_wrap: Name of the methods in base_cls to be wrapped
93
- array : The array from which norm parameters are to be found
94
- norm_typ: Type of normalization, "meanstd" and "minmax" are supported
95
- params: Parameters of normalization if these are to be given manually
96
- These for example can be:
97
- For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
98
- For min/max normalization, params_out = {"min": -1, "max": 3,}
99
- Returns
100
- -------
101
- A new subclass of base_cls with methods that normalize the input when called.
102
- """
103
- norm_in, inv_norm_in, params_in = find_norm_funcs(array, norm_typ, params)
104
- def norm_in_decorator(fn):
105
- def wrapped(self, x, *args, **kwargs):
106
- result = fn(self, norm_in(self, x), *args, **kwargs)
107
- return result
108
- return wrapped, {"norm_in": (callable, norm_in), "inv_norm_in": (callable, inv_norm_in), "params_in": (dict, params_in)}
109
- return make_wrapped(base_cls, norm_in_decorator, methods_to_wrap)
110
-
111
- def inv_norm_out_wrap(base_cls, array=None, norm_typ="None", params=None, methods_to_wrap=["__call__"]):
112
- """ Wrapper that de-normalizes the output of a function of a subclass of eqx.Module
113
-
114
- The parameters of normalization can be either computed based on an array, or given
115
- by the user.
116
-
117
- Parameters
118
- ----------
119
- base_cls : Base class, subclass of eqx.Module of which functions will
120
- be modified.
121
- methods_to_wrap: Name of the methods in base_cls to be wrapped
122
- array : The array from which norm parameters are to be found
123
- norm_typ: Type of normalization, "meanstd" and "minmax" are supported
124
- params: Parameters of normalization if these are to be given manually
125
- These for example can be:
126
- For mean/std normalization, params_out = {"mean": 0.12, "std": 0.23}
127
- For min/max normalization, params_out = {"min": -1, "max": 3,}
128
- Returns
129
- -------
130
- A new subclass of base_cls with de-normalized methods.
131
- The given methods accept an additional keyword argument "keep_normalized"
132
- If this argument is passed as True, functions retrieve original behavior
133
- """
134
- norm_out, inv_norm_out, params_out = find_norm_funcs(array, norm_typ, params)
135
-
136
- def norm_out_decorator(fn):
137
- def wrapped(self, x, *args, keep_normalized=False, **kwargs):
138
- if keep_normalized:
139
- result = fn(self, x, *args, **kwargs)
140
- else:
141
- result = inv_norm_out(self, fn(self, x, *args, **kwargs))
142
- return result
143
- return wrapped, {"norm_out": (callable, norm_out), "inv_norm_out": (callable, inv_norm_out), "params_out": (dict, params_out)}
144
- return make_wrapped(base_cls, norm_out_decorator, methods_to_wrap)
145
-
146
- def norm_wrap(base_cls, array_in=None, norm_typ_in="None", params_in=None, array_out=None, norm_typ_out="None", params_out=None, methods_to_wrap_in=["__call__"], methods_to_wrap_out=["__call__"]):
147
- """ Wrapper that normalizes functions of a subclass of eqx.Module
148
-
149
- Parameters
150
- ----------
151
- base_cls : Base class, subclass of eqx.Module of which functions will
152
- be modified.
153
- methods_to_wrap: Name of the methods in base_cls to be wrapped
154
- ... : Other parameters are explained in norm_in_wrap and norm_out_wrap
155
-
156
- Returns
157
- -------
158
- A new subclass of base_cls with normalized methods.
159
- """
160
- after_in = norm_in_wrap(base_cls, array_in, norm_typ_in, params_in, methods_to_wrap_in)
161
- after_out = inv_norm_out_wrap(after_in, array_out, norm_typ_out, params_out, methods_to_wrap_out)
162
- return after_out
163
-
164
- def vmap_wrap(base_cls, map_axis, count=1, methods_to_wrap=["__call__"]):
165
- """ Wrapper that vectorizes functions of a subclass of eqx.Module
166
-
167
- Parameters
168
- ----------
169
- base_cls : Base class, subclass of eqx.Module of which functions will
170
- be modified.
171
- map_axis : Axis along which to vectorize the functions
172
- methods_to_wrap: Name of the methods in base_cls to be wrapped
173
- count: How many times to vectorize the functions
174
-
175
- Returns
176
- -------
177
- A new subclass of base_cls with vectorized methods.
178
- The given methods accept an additional keyword argument "no_map"
179
- If this argument is passed as True, functions retrieve original behavior
180
- """
181
- def vmap_decorator(fn):
182
- def wrapped(self, x, *args, no_map=False, **kwargs):
183
- if (map_axis is None) or no_map:
184
- return fn(self, x, *args, **kwargs)
185
- f = lambda x: fn(self, x, *args, **kwargs)
186
- for _ in range(count):
187
- f = vmap(f, in_dims=(map_axis,), out_dims=map_axis)
188
- out = f(x)
189
- return out
190
- return wrapped, {}
191
- return make_wrapped(base_cls, vmap_decorator, methods_to_wrap)
192
-
193
-
194
-
195
- def make_wrapped(base_cls, decorator, methods_to_wrap=["__call__"]):
196
- """
197
- Create a subclass of base_cls with specified methods wrapped by decorator.
198
-
199
- Parameters
200
- ----------
201
- base_cls: Original class to wrap.
202
- methods_to_wrap: List of method names (strings) to decorate.
203
- decorator: The wanted modification to the methods given above
204
-
205
- Returns
206
- -------
207
- A new subclass of base_cls with decorated methods.
208
- """
209
- attrs = {}
210
- annotations = {}
211
- seen_fields = set()
212
-
213
- for method_name in methods_to_wrap:
214
- if not hasattr(base_cls, method_name):
215
- raise AttributeError(f"Method {method_name} not found in {base_cls}")
216
-
217
- original = getattr(base_cls, method_name)
218
- wrapped_method, extra_fields = decorator(original)
219
- attrs[method_name] = wrapped_method
220
-
221
- # Sort field names to avoid ordering issues
222
- for field_name in extra_fields:
223
- if field_name in seen_fields:
224
- continue
225
-
226
- field_type, default_value = extra_fields[field_name]
227
- annotations[field_name] = field_type
228
-
229
- attrs[field_name] = dataclasses.field(default=default_value)
230
-
231
- seen_fields.add(field_name)
232
-
233
- if annotations:
234
- attrs["__annotations__"] = annotations
235
-
236
- Wrapped = type(base_cls.__name__, (base_cls,), attrs)
237
- return Wrapped
@@ -1,26 +0,0 @@
1
- RRAEsTorch/__init__.py,sha256=f234R6usRCqIgmBmiXyZNIHa7VrDe5E-KZO0Y6Ek5AQ,33
2
- RRAEsTorch/config.py,sha256=bQPwc_2KTvhglH_WIRSb5_6CpUQQj9AGpfqBp8_kuys,2931
3
- RRAEsTorch/AE_base/AE_base.py,sha256=Eeo_I7p5P-357rnOmCuFxosJgmBg4KPyMA8n70sTV7U,3368
4
- RRAEsTorch/AE_base/__init__.py,sha256=95YfMgEWzIFAkm--Ci-a9YPSGfCs2PDAK2sbfScT7oo,24
5
- RRAEsTorch/AE_classes/AE_classes.py,sha256=oDpDzQasPbtK2L9vDLiG4VQdKH02VRCagOYT1-FAldo,18063
6
- RRAEsTorch/AE_classes/__init__.py,sha256=inM2_YPJG8T-lwx-CUg-zL2EMltmROQAlNZeZmnvVGA,27
7
- RRAEsTorch/tests/test_AE_classes_CNN.py,sha256=bEE9JnTo84t9w0a4kw1W74L51eLGjBB8trrlAG938RE,3182
8
- RRAEsTorch/tests/test_AE_classes_MLP.py,sha256=Cr1_uP7lag6RPQC1UhN2O7RFW4BEx1cd0Z-Y6VgrWRg,2718
9
- RRAEsTorch/tests/test_fitting_CNN.py,sha256=8i6oUZFS-DpaTh1VsRsd1rGG_Me7R3Kf1vO4lPSxYns,2797
10
- RRAEsTorch/tests/test_fitting_MLP.py,sha256=6Ggy5iAJmqMEJPDoon7fYkwFS8LWmosTOFBQuAQzszc,3353
11
- RRAEsTorch/tests/test_mains.py,sha256=ivTXP7NypSlgmB9VR5g0yq5VEuPZJGOibDqBMjOxHow,1021
12
- RRAEsTorch/tests/test_save.py,sha256=_zYcDZtz_HWzHHWT6uJVq7ynsFmlH5v6nDE-loZQ4zo,1997
13
- RRAEsTorch/tests/test_stable_SVD.py,sha256=OimHPqw4f22qndyRzwJfNvTzzjP2CM-yHtfXCqkMBuA,1230
14
- RRAEsTorch/tests/test_wrappers.py,sha256=Ike4IfMUx2Qic3f3_cBikgFPEU1WW5TuH1jT_r2NgvY,2215
15
- RRAEsTorch/trackers/__init__.py,sha256=3c9qcUMZiUfVr93rxFp6l11lIDthyK3PCY_-P-sNX3I,25
16
- RRAEsTorch/trackers/trackers.py,sha256=Pn1ejMxMjAtvgDazFFwa3qiZhogG5GtXj4UIIFiBpuY,9127
17
- RRAEsTorch/training_classes/__init__.py,sha256=K_Id4yhw640jp2JN15-0E4wJi4sPadi1fFRgovMV3kw,101
18
- RRAEsTorch/training_classes/training_classes.py,sha256=HU8Ksz1-2WwOMuwyGiWdkQ_vrrgBEwaeQT4avs4jd2E,37870
19
- RRAEsTorch/utilities/__init__.py,sha256=NtlizCcRW4qcsULXxWfjPk265rLJst0-GqWLRah2yDY,26
20
- RRAEsTorch/utilities/utilities.py,sha256=JfLkAPEC8fzwgM32LEcXVe0tA4C7UBgsrkuh6noUA_4,53372
21
- RRAEsTorch/wrappers/__init__.py,sha256=txiLh4ylnuvPlapagz7DiAslmjllOzTqwCDL2dFr6dM,44
22
- RRAEsTorch/wrappers/wrappers.py,sha256=9Rmq2RS_EkZvsg96SKrt1HFIP35sF0xyPI0goV0ujOs,9659
23
- rraestorch-0.1.6.dist-info/METADATA,sha256=BfnB-vhx0m-d79hCd8UgZT0c8GPHFRwVn8x-M8k_h6E,3028
24
- rraestorch-0.1.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
25
- rraestorch-0.1.6.dist-info/licenses/LICENSE,sha256=QQKXj7kEw2yUJxPDzHUnEPRO7YweIU6AAlwPLnhtinA,1090
26
- rraestorch-0.1.6.dist-info/RECORD,,