unaiverse 0.1.6__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unaiverse might be problematic. Click here for more details.
- unaiverse/__init__.py +19 -0
- unaiverse/agent.py +2008 -0
- unaiverse/agent_basics.py +1846 -0
- unaiverse/clock.py +191 -0
- unaiverse/dataprops.py +1209 -0
- unaiverse/hsm.py +1880 -0
- unaiverse/modules/__init__.py +18 -0
- unaiverse/modules/cnu/__init__.py +17 -0
- unaiverse/modules/cnu/cnus.py +536 -0
- unaiverse/modules/cnu/layers.py +261 -0
- unaiverse/modules/cnu/psi.py +60 -0
- unaiverse/modules/hl/__init__.py +15 -0
- unaiverse/modules/hl/hl_utils.py +411 -0
- unaiverse/modules/networks.py +1509 -0
- unaiverse/modules/utils.py +680 -0
- unaiverse/networking/__init__.py +16 -0
- unaiverse/networking/node/__init__.py +18 -0
- unaiverse/networking/node/connpool.py +1261 -0
- unaiverse/networking/node/node.py +2223 -0
- unaiverse/networking/node/profile.py +446 -0
- unaiverse/networking/node/tokens.py +79 -0
- unaiverse/networking/p2p/__init__.py +198 -0
- unaiverse/networking/p2p/go.mod +127 -0
- unaiverse/networking/p2p/go.sum +548 -0
- unaiverse/networking/p2p/golibp2p.py +18 -0
- unaiverse/networking/p2p/golibp2p.pyi +135 -0
- unaiverse/networking/p2p/lib.go +2714 -0
- unaiverse/networking/p2p/lib.go.sha256 +1 -0
- unaiverse/networking/p2p/lib_types.py +312 -0
- unaiverse/networking/p2p/message_pb2.py +63 -0
- unaiverse/networking/p2p/messages.py +265 -0
- unaiverse/networking/p2p/mylogger.py +77 -0
- unaiverse/networking/p2p/p2p.py +929 -0
- unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
- unaiverse/networking/p2p/unailib.cpython-314t-darwin.so +0 -0
- unaiverse/streamlib/__init__.py +15 -0
- unaiverse/streamlib/streamlib.py +210 -0
- unaiverse/streams.py +770 -0
- unaiverse/utils/__init__.py +16 -0
- unaiverse/utils/ask_lone_wolf.json +27 -0
- unaiverse/utils/lone_wolf.json +19 -0
- unaiverse/utils/misc.py +305 -0
- unaiverse/utils/sandbox.py +293 -0
- unaiverse/utils/server.py +435 -0
- unaiverse/world.py +175 -0
- unaiverse-0.1.6.dist-info/METADATA +365 -0
- unaiverse-0.1.6.dist-info/RECORD +50 -0
- unaiverse-0.1.6.dist-info/WHEEL +6 -0
- unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
- unaiverse-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from typing import Iterable, Dict, Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _euler_step(o: torch.Tensor | Dict[str, torch.Tensor], do: torch.Tensor | Dict[str, torch.Tensor],
|
|
21
|
+
step_size: float, decay: float | None = None, in_place: bool = False) \
|
|
22
|
+
-> torch.Tensor | Dict[str, torch.Tensor] | None:
|
|
23
|
+
"""Euler step, vanilla case.
|
|
24
|
+
|
|
25
|
+
Params:
|
|
26
|
+
o: list or dict of data to update (warning: it will be updated here).
|
|
27
|
+
do: list or dict of derivatives w.r.t. time of the data we want to update.
|
|
28
|
+
step_size: the step size of the Euler method.
|
|
29
|
+
decay: the weight-decay-like scalar coefficient that tunes the strength of the weight-decay regularization.
|
|
30
|
+
in_place: whether to overwrite the input data or to return a new list with new elements.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A list or dict of Tensors with the same size of the input 'o'. It could be a new list with new Tensors
|
|
34
|
+
(if in_place is False) or 'o' itself, updated in-place (if in_place is True).
|
|
35
|
+
"""
|
|
36
|
+
assert type(o) is type(do), f'Input should either be two lists or two dicts, got {type(o)} and {type(do)}.'
|
|
37
|
+
|
|
38
|
+
if isinstance(o, dict):
|
|
39
|
+
assert set(o.keys()) == set(do.keys()), 'Dictionaries should have the same keys.'
|
|
40
|
+
oo = dict.fromkeys(o)
|
|
41
|
+
for k in o.keys():
|
|
42
|
+
if not in_place:
|
|
43
|
+
if decay is None or decay == 0.:
|
|
44
|
+
oo[k] = o[k] + step_size * do[k]
|
|
45
|
+
else:
|
|
46
|
+
oo[k] = (1. - decay) * o[k] + step_size * do[k]
|
|
47
|
+
else:
|
|
48
|
+
if decay is None or decay == 0.:
|
|
49
|
+
o[k].add_(do[k], alpha=step_size)
|
|
50
|
+
else:
|
|
51
|
+
o[k].mul_(1. - decay).add_(do[k], alpha=step_size)
|
|
52
|
+
elif isinstance(o, torch.Tensor):
|
|
53
|
+
if not in_place:
|
|
54
|
+
if decay is None or decay == 0.:
|
|
55
|
+
oo = o + step_size * do
|
|
56
|
+
else:
|
|
57
|
+
oo = (1. - decay) * o + step_size * do
|
|
58
|
+
else:
|
|
59
|
+
oo = None
|
|
60
|
+
if decay is None or decay == 0.:
|
|
61
|
+
o.add_(do, alpha=step_size)
|
|
62
|
+
else:
|
|
63
|
+
o.mul_(1. - decay).add_(do, alpha=step_size)
|
|
64
|
+
else:
|
|
65
|
+
raise Exception(f'Input to this function should be either tensor or dict, got {type(o)}.')
|
|
66
|
+
|
|
67
|
+
if not in_place:
|
|
68
|
+
return oo
|
|
69
|
+
else:
|
|
70
|
+
return o
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _init(val: float | str, data_shape: torch.Size, device, dtype: torch.dtype, non_negative: bool = False) \
|
|
74
|
+
-> torch.Tensor:
|
|
75
|
+
"""Initialize a tensor to a constant value or to random values, or to zeros (and possibly others).
|
|
76
|
+
|
|
77
|
+
Params:
|
|
78
|
+
val: a float value or a string in ['random', 'zeros'].
|
|
79
|
+
data_shape: the shape of the target tensor.
|
|
80
|
+
device: the device where the tensor will be stored.
|
|
81
|
+
non_negative: whether to create something non-negative.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
An initialized tensor.
|
|
85
|
+
"""
|
|
86
|
+
assert type(val) is float or val in ['zeros', 'random', 'ones', 'alternating'], (
|
|
87
|
+
'Invalid initialization: ' + str(val))
|
|
88
|
+
|
|
89
|
+
if isinstance(val, float):
|
|
90
|
+
t = torch.full(data_shape, val, device=device, dtype=dtype)
|
|
91
|
+
if non_negative:
|
|
92
|
+
t = torch.abs(t)
|
|
93
|
+
return t
|
|
94
|
+
elif val == 'random':
|
|
95
|
+
t = torch.randn(data_shape, device=device, dtype=dtype)
|
|
96
|
+
if non_negative:
|
|
97
|
+
t = torch.abs(t)
|
|
98
|
+
return t
|
|
99
|
+
elif val == 'zeros':
|
|
100
|
+
return torch.zeros(data_shape, device=device, dtype=dtype)
|
|
101
|
+
elif val == 'ones':
|
|
102
|
+
return torch.ones(data_shape, device=device, dtype=dtype)
|
|
103
|
+
elif val == 'alternating':
|
|
104
|
+
|
|
105
|
+
# Initialize the state as alternating pairs of (0,1) (to be used with BlockSkewSymmetric)
|
|
106
|
+
# data_shape is (batch_size, xi_shape)
|
|
107
|
+
assert len(data_shape) == 2, (f"xi should be initialized as (batch_size, xi_shape), "
|
|
108
|
+
f"got xi with {len(data_shape)} dimensions.")
|
|
109
|
+
order = data_shape[1] // 2
|
|
110
|
+
batch_size = data_shape[0]
|
|
111
|
+
return torch.tensor([[0., 1.]], device=device, dtype=dtype).repeat(batch_size, order)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _init_state_and_costate(model: nn.Module, batch_size: int = 1) -> (
|
|
115
|
+
Dict)[str, Dict[str, torch.Tensor | Dict[str, torch.Tensor]]]:
|
|
116
|
+
"""Initialize the state and costate dictionaries (keys are 'xi', 'w_xi', 'w_y').
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
# Getting device
|
|
120
|
+
device = next(model.parameters()).device
|
|
121
|
+
dtype = next(model.parameters()).dtype
|
|
122
|
+
|
|
123
|
+
# Creating state and costate
|
|
124
|
+
x = dict(xi=model.h_init, w={})
|
|
125
|
+
p = dict(xi=torch.zeros_like(x['xi'], device=device, dtype=dtype), w={})
|
|
126
|
+
|
|
127
|
+
# Initialize state and costate for the weights of the state network
|
|
128
|
+
x['w'] = {par_name: par for par_name, par in dict(model.named_parameters()).items() if par.requires_grad}
|
|
129
|
+
p['w'] = {par_name: torch.zeros_like(par, device=device, dtype=dtype) for par_name, par in x['w'].items()}
|
|
130
|
+
|
|
131
|
+
return {'x': x, 'p': p}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _get_grad(a: torch.Tensor | Dict[str, torch.Tensor]) -> torch.Tensor | Dict[str, torch.Tensor]:
|
|
135
|
+
"""Collects gradients from a list of Tensors.
|
|
136
|
+
|
|
137
|
+
Params:
|
|
138
|
+
a: list or dict of Tensors.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
List or dict of references to the 'grad' fields of each component of the input list.
|
|
142
|
+
"""
|
|
143
|
+
if isinstance(a, dict):
|
|
144
|
+
g = {_k_a: a[_k_a].grad if a[_k_a].grad is not None else torch.zeros_like(a[_k_a]) for _k_a in a.keys()}
|
|
145
|
+
elif isinstance(a, torch.Tensor):
|
|
146
|
+
g = a.grad if a.grad is not None else torch.zeros_like(a)
|
|
147
|
+
else:
|
|
148
|
+
raise Exception(f'Input to this function should be either list or dict, got {type(a)}.')
|
|
149
|
+
|
|
150
|
+
return g
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _apply1(a: torch.Tensor | Dict[str, torch.Tensor], op: torch.func) \
|
|
154
|
+
-> torch.Tensor | Dict[str, torch.Tensor]:
|
|
155
|
+
"""Apply an operation to each element on a list or dict of Tensors.
|
|
156
|
+
|
|
157
|
+
Params:
|
|
158
|
+
a: list or dict of Tensors.
|
|
159
|
+
op: operation to be applied to the elements in list_ten (it could be a lambda expression).
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
List or dict of Tensors with the result of the operation (same size of each list-argument).
|
|
163
|
+
"""
|
|
164
|
+
if isinstance(a, dict):
|
|
165
|
+
oo = {_k_a: op(a[_k_a]) for _k_a in a.keys()}
|
|
166
|
+
elif isinstance(a, torch.Tensor):
|
|
167
|
+
oo = op(a)
|
|
168
|
+
else:
|
|
169
|
+
raise Exception(f'Input to this function should be either list or dict, got {type(a)}.')
|
|
170
|
+
|
|
171
|
+
return oo
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _apply2(a: torch.Tensor | Dict[str, torch.Tensor], b: torch.Tensor | Dict[str, torch.Tensor],
|
|
175
|
+
op: torch.func) -> torch.Tensor | Dict[str, torch.Tensor]:
|
|
176
|
+
"""Apply an operation involving each pair of elements stored into two lists or dicts of Tensors.
|
|
177
|
+
|
|
178
|
+
Params:
|
|
179
|
+
a: first list or dict of Tensors.
|
|
180
|
+
b: second list or dict of Tensors.
|
|
181
|
+
op: operation to be applied to the elements in both the lists (it could be a lambda expression).
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
List or dict of Tensors with the result of the operation (same size of each list-argument).
|
|
185
|
+
"""
|
|
186
|
+
assert type(a) is type(b), f'Type of the inputs to this function should match, got {type(a)} and {type(b)} instead.'
|
|
187
|
+
|
|
188
|
+
if isinstance(a, dict):
|
|
189
|
+
assert set(a.keys()) == set(b.keys()), 'Dictionaries should have the same keys.'
|
|
190
|
+
oo = {_k_a: op(a[_k_a], b[_k_a]) for _k_a in a.keys()}
|
|
191
|
+
elif isinstance(a, torch.Tensor):
|
|
192
|
+
oo = op(a, b)
|
|
193
|
+
else:
|
|
194
|
+
raise Exception(f'Input to this function should be either list or dict, got {type(a)}.')
|
|
195
|
+
|
|
196
|
+
return oo
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _copy_inplace(a: torch.Tensor | Dict[str, torch.Tensor], b: torch.Tensor | Dict[str, torch.Tensor],
|
|
200
|
+
detach: bool = False) -> None:
|
|
201
|
+
"""Copies 'in-place' the values of a list or dict of Tensors (b) into another one (a).
|
|
202
|
+
|
|
203
|
+
Params:
|
|
204
|
+
a: list or dict of Tensors.
|
|
205
|
+
b: another list or dict of Tensors, same sizes of the one above.
|
|
206
|
+
"""
|
|
207
|
+
assert type(a) is type(b), f'Type of the inputs to this function should match, got {type(a)} and {type(b)} instead.'
|
|
208
|
+
|
|
209
|
+
if detach:
|
|
210
|
+
b = _detach(b)
|
|
211
|
+
if isinstance(a, torch.Tensor):
|
|
212
|
+
a.copy_(b)
|
|
213
|
+
elif isinstance(a, dict):
|
|
214
|
+
for _k_b in b.keys():
|
|
215
|
+
a[_k_b].copy_(b[_k_b])
|
|
216
|
+
else:
|
|
217
|
+
raise Exception(f'Inputs to this function should be either lists or dicts, got {type(a)} and {type(b)}.')
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _copy(a: torch.Tensor | Dict[str, torch.Tensor], detach: bool = False) \
|
|
221
|
+
-> torch.Tensor | Dict[str, torch.Tensor]:
|
|
222
|
+
"""Copies the values of a list or dict of Tensors into another one.
|
|
223
|
+
|
|
224
|
+
Params:
|
|
225
|
+
a: list or dict of Tensors.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
b: another list or dict of Tensors, same sizes of the one above.
|
|
229
|
+
"""
|
|
230
|
+
if detach:
|
|
231
|
+
a = _detach(a)
|
|
232
|
+
if isinstance(a, torch.Tensor):
|
|
233
|
+
b = a.clone()
|
|
234
|
+
elif isinstance(a, dict):
|
|
235
|
+
b = {k: v.clone() for k, v in a.items()}
|
|
236
|
+
else:
|
|
237
|
+
raise Exception(f'Input to this function should be either list or dict, got {type(a)}.')
|
|
238
|
+
|
|
239
|
+
return b
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _zero_grad(tensors: torch.Tensor | Dict[str, torch.Tensor], set_to_none: bool = False) -> None:
|
|
243
|
+
"""Zeroes the gradient field of a list or dict of Tensors.
|
|
244
|
+
|
|
245
|
+
Params:
|
|
246
|
+
tensors: a Tensor or a list or dict of Tensors with requires_grad activated.
|
|
247
|
+
set_to_none: forces all 'grad' fields to be set to None.
|
|
248
|
+
"""
|
|
249
|
+
if isinstance(tensors, dict):
|
|
250
|
+
list_ten = list(tensors.values())
|
|
251
|
+
elif isinstance(tensors, torch.Tensor):
|
|
252
|
+
list_ten = [tensors, ]
|
|
253
|
+
else:
|
|
254
|
+
raise Exception(f'Input to this function should be either list or dict, or a Tensor, got {type(tensors)}.')
|
|
255
|
+
for a in list_ten:
|
|
256
|
+
if a.grad is not None:
|
|
257
|
+
if set_to_none:
|
|
258
|
+
a.grad = None
|
|
259
|
+
else:
|
|
260
|
+
if a.grad.grad_fn is not None:
|
|
261
|
+
a.grad.detach_()
|
|
262
|
+
else:
|
|
263
|
+
a.grad.requires_grad_(False)
|
|
264
|
+
a.grad.zero_()
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _zero(tensors: torch.Tensor | Dict[str, torch.Tensor], detach: bool = False) \
|
|
268
|
+
-> torch.Tensor | Dict[str, torch.Tensor]:
|
|
269
|
+
"""Returns a zeroed copy of a list or dict of tensors.
|
|
270
|
+
|
|
271
|
+
Params:
|
|
272
|
+
tensors: a list or dict of Tensors.
|
|
273
|
+
"""
|
|
274
|
+
if detach:
|
|
275
|
+
tensors = _detach(tensors)
|
|
276
|
+
if isinstance(tensors, torch.Tensor):
|
|
277
|
+
b = torch.zeros_like(tensors)
|
|
278
|
+
elif isinstance(tensors, dict):
|
|
279
|
+
b = {k: torch.zeros_like(v) for k, v in tensors.items()}
|
|
280
|
+
else:
|
|
281
|
+
raise Exception(f'Input to this function should be either list or dict, got {type(tensors)}.')
|
|
282
|
+
|
|
283
|
+
return b
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _zero_inplace(tensors: torch.Tensor | Dict[str, torch.Tensor], detach: bool = False) -> None:
|
|
287
|
+
"""Zeroes a list or dict of Tensors (inplace).
|
|
288
|
+
|
|
289
|
+
Params:
|
|
290
|
+
tensors: a list or dict of Tensors.
|
|
291
|
+
"""
|
|
292
|
+
if detach:
|
|
293
|
+
tensors = _detach(tensors)
|
|
294
|
+
if isinstance(tensors, dict):
|
|
295
|
+
for a in tensors.values():
|
|
296
|
+
a.zero_()
|
|
297
|
+
elif isinstance(tensors, torch.Tensor):
|
|
298
|
+
tensors.zero_()
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError('Unsupported type.')
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _detach(tensors: torch.Tensor | Dict[str, torch.Tensor]) -> torch.Tensor | Dict[str, torch.Tensor]:
|
|
304
|
+
"""Detaches a list or dict of Tensors (not-in-place).
|
|
305
|
+
|
|
306
|
+
Params:
|
|
307
|
+
a: a list or dict of Tensors.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A list or dict of detached Tensors.
|
|
311
|
+
"""
|
|
312
|
+
if isinstance(tensors, dict):
|
|
313
|
+
oo = dict.fromkeys(tensors)
|
|
314
|
+
for k, a in tensors.items():
|
|
315
|
+
oo[k] = a.detach()
|
|
316
|
+
elif isinstance(tensors, torch.Tensor):
|
|
317
|
+
oo = tensors.detach()
|
|
318
|
+
else:
|
|
319
|
+
raise ValueError('Unsupported type.')
|
|
320
|
+
return oo
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class HL:
|
|
324
|
+
def __init__(self, models: torch.nn.Module | Iterable[Dict[str, torch.nn.Module | Any]], *,
|
|
325
|
+
gamma=1., flip=-1., theta=0.1, beta=1., reset_neuron_costate=False, reset_weight_costate=False,
|
|
326
|
+
local=True):
|
|
327
|
+
"""
|
|
328
|
+
Args:
|
|
329
|
+
models (list of dict): List of parameter groups, each containing:
|
|
330
|
+
- 'params': the Model or list of parameters
|
|
331
|
+
- 'gamma', 'beta', 'theta', etc.: Hyperparameters for the group.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
# Set defaults
|
|
335
|
+
defaults = dict(params=None, gamma=gamma, flip=flip, theta=theta, beta=beta,
|
|
336
|
+
reset_neuron_costate=reset_neuron_costate, reset_weight_costate=reset_weight_costate,
|
|
337
|
+
local=local)
|
|
338
|
+
|
|
339
|
+
# Ensure models is a list of dicts and assign the specified values
|
|
340
|
+
if isinstance(models, torch.nn.Module):
|
|
341
|
+
models = [{**defaults, 'params': models}]
|
|
342
|
+
|
|
343
|
+
self.param_groups = []
|
|
344
|
+
for group in models:
|
|
345
|
+
assert 'params' in group, "Each parameter group must contain a 'params' key storing the model."
|
|
346
|
+
self.param_groups.append({**defaults, **group})
|
|
347
|
+
|
|
348
|
+
# Store the optimizer state for each model in a list of dicts, not to be confused with the state of the model
|
|
349
|
+
self.state = [_init_state_and_costate(group['params']) for group in self.param_groups]
|
|
350
|
+
|
|
351
|
+
@torch.no_grad()
|
|
352
|
+
def step(self):
|
|
353
|
+
"""Perform one optimization step for all parameter groups."""
|
|
354
|
+
|
|
355
|
+
for group, state in zip(self.param_groups, self.state):
|
|
356
|
+
model = group['params']
|
|
357
|
+
delta = model.delta
|
|
358
|
+
|
|
359
|
+
# Copy the state (of the model) just to track it during the optimization and get the costate
|
|
360
|
+
# the locality of these operations is handled by the model
|
|
361
|
+
state['x']['xi'] = model.h
|
|
362
|
+
dp_xi = _get_grad(model.h)
|
|
363
|
+
_euler_step(state['p']['xi'], dp_xi, step_size=-delta * group['flip'],
|
|
364
|
+
decay=-group['flip'] * group['theta'], in_place=True)
|
|
365
|
+
|
|
366
|
+
# Copy the weights from the network just to track it during the optimization and get the costates
|
|
367
|
+
dp_w = {}
|
|
368
|
+
for name, param in model.named_parameters():
|
|
369
|
+
state['x']['w'][name] = param
|
|
370
|
+
dp_w[name] = _get_grad(param)
|
|
371
|
+
|
|
372
|
+
if group['local']:
|
|
373
|
+
|
|
374
|
+
# Local HL uses the old costates to update the weights
|
|
375
|
+
d_w = state['p']['w']
|
|
376
|
+
_euler_step(state['x']['w'], d_w, step_size=-delta*group['beta'], decay=None, in_place=True)
|
|
377
|
+
_euler_step(state['p']['w'], dp_w, step_size=-delta*group['flip'],
|
|
378
|
+
decay=-group['flip']*group['theta'], in_place=True)
|
|
379
|
+
else:
|
|
380
|
+
|
|
381
|
+
# Non-local HL updates the costates before updating the weights
|
|
382
|
+
d_w = _euler_step(state['p']['w'], dp_w, step_size=-delta * group['flip'],
|
|
383
|
+
decay=-group['flip'] * group['theta'], in_place=True)
|
|
384
|
+
_euler_step(state['x']['w'], d_w, step_size=-delta * group['beta'], decay=None, in_place=True)
|
|
385
|
+
|
|
386
|
+
def compute_hamiltonian(self, *potential_terms: torch.Tensor) -> torch.Tensor:
|
|
387
|
+
"""Computes the Hamiltonian for all models."""
|
|
388
|
+
|
|
389
|
+
# The number of potential terms provided should be equal to the number of models
|
|
390
|
+
assert len(potential_terms) == len(self.param_groups), f"A potential term for each model is expected."
|
|
391
|
+
|
|
392
|
+
ham = torch.tensor(0., dtype=potential_terms[0].dtype, device=potential_terms[0].device)
|
|
393
|
+
for group, state, potential_term in zip(self.param_groups, self.state, potential_terms):
|
|
394
|
+
model = group['params']
|
|
395
|
+
ham += group['gamma'] * potential_term + torch.dot(model.dh.view(-1), state['p']['xi'].view(-1)).real
|
|
396
|
+
return ham
|
|
397
|
+
|
|
398
|
+
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
399
|
+
"""Zeroes the gradients and resets co-states if needed."""
|
|
400
|
+
|
|
401
|
+
for group, state in zip(self.param_groups, self.state):
|
|
402
|
+
model = group['params']
|
|
403
|
+
_zero_grad(model.h, set_to_none)
|
|
404
|
+
for param in model.parameters():
|
|
405
|
+
_zero_grad(param, set_to_none)
|
|
406
|
+
|
|
407
|
+
# Eventually reset costates
|
|
408
|
+
if group['reset_neuron_costate']:
|
|
409
|
+
_zero_inplace(state['p']['xi'], detach=True)
|
|
410
|
+
if group['reset_weight_costate']:
|
|
411
|
+
_zero_inplace(state['p']['w'], detach=True)
|