unaiverse 0.1.6__cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unaiverse might be problematic. Click here for more details.
- unaiverse/__init__.py +19 -0
- unaiverse/agent.py +2008 -0
- unaiverse/agent_basics.py +1846 -0
- unaiverse/clock.py +191 -0
- unaiverse/dataprops.py +1209 -0
- unaiverse/hsm.py +1880 -0
- unaiverse/modules/__init__.py +18 -0
- unaiverse/modules/cnu/__init__.py +17 -0
- unaiverse/modules/cnu/cnus.py +536 -0
- unaiverse/modules/cnu/layers.py +261 -0
- unaiverse/modules/cnu/psi.py +60 -0
- unaiverse/modules/hl/__init__.py +15 -0
- unaiverse/modules/hl/hl_utils.py +411 -0
- unaiverse/modules/networks.py +1509 -0
- unaiverse/modules/utils.py +680 -0
- unaiverse/networking/__init__.py +16 -0
- unaiverse/networking/node/__init__.py +18 -0
- unaiverse/networking/node/connpool.py +1261 -0
- unaiverse/networking/node/node.py +2223 -0
- unaiverse/networking/node/profile.py +446 -0
- unaiverse/networking/node/tokens.py +79 -0
- unaiverse/networking/p2p/__init__.py +198 -0
- unaiverse/networking/p2p/go.mod +127 -0
- unaiverse/networking/p2p/go.sum +548 -0
- unaiverse/networking/p2p/golibp2p.py +18 -0
- unaiverse/networking/p2p/golibp2p.pyi +135 -0
- unaiverse/networking/p2p/lib.go +2714 -0
- unaiverse/networking/p2p/lib.go.sha256 +1 -0
- unaiverse/networking/p2p/lib_types.py +312 -0
- unaiverse/networking/p2p/message_pb2.py +63 -0
- unaiverse/networking/p2p/messages.py +265 -0
- unaiverse/networking/p2p/mylogger.py +77 -0
- unaiverse/networking/p2p/p2p.py +929 -0
- unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
- unaiverse/networking/p2p/unailib.cpython-311-aarch64-linux-gnu.so +0 -0
- unaiverse/streamlib/__init__.py +15 -0
- unaiverse/streamlib/streamlib.py +210 -0
- unaiverse/streams.py +770 -0
- unaiverse/utils/__init__.py +16 -0
- unaiverse/utils/ask_lone_wolf.json +27 -0
- unaiverse/utils/lone_wolf.json +19 -0
- unaiverse/utils/misc.py +305 -0
- unaiverse/utils/sandbox.py +293 -0
- unaiverse/utils/server.py +435 -0
- unaiverse/world.py +175 -0
- unaiverse-0.1.6.dist-info/METADATA +365 -0
- unaiverse-0.1.6.dist-info/RECORD +50 -0
- unaiverse-0.1.6.dist-info/WHEEL +7 -0
- unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
- unaiverse-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
from . import cnu
|
|
16
|
+
from . import hl
|
|
17
|
+
from . import networks
|
|
18
|
+
from . import utils
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
from . import cnus
|
|
16
|
+
from . import layers
|
|
17
|
+
from . import psi
|
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
import math
|
|
16
|
+
import torch
|
|
17
|
+
from .psi import psi
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CNUs(torch.nn.Module):
|
|
22
|
+
|
|
23
|
+
def __init__(self, q=1, d=2, m=3, u=4, delta=3,
|
|
24
|
+
gamma_alpha=0.1, tau_alpha=0.5, tau_mu=100, tau_eta=100,
|
|
25
|
+
upd_m="WTA", upd_k="ad_hoc_WTA",
|
|
26
|
+
beta_k=0.001,
|
|
27
|
+
psi_fn="identity",
|
|
28
|
+
scramble=False):
|
|
29
|
+
"""
|
|
30
|
+
:param q: number of neurons
|
|
31
|
+
:param d: size of each key
|
|
32
|
+
:param m: number of keys/memory units
|
|
33
|
+
:param u: size of each memory unit
|
|
34
|
+
:param gamma_alpha: softmax temperature (key matching)
|
|
35
|
+
:param tau_alpha: threshold on the attention score of the winning key, to eventually trigger scrambling
|
|
36
|
+
:param tau_mu: number of steps below which a key is considered to be not-used enough
|
|
37
|
+
:param tau_eta: number of steps after which a key is considered old
|
|
38
|
+
:param delta: number of top attention responses to select (top-delta)
|
|
39
|
+
:param upd_m: update memory strategy (None, 'WTA')
|
|
40
|
+
:param upd_k: update key strategy (None, 'ad_hoc_WTA', 'grad_WTA')
|
|
41
|
+
:param beta_k: learning rate for key-update purposes when upd_k is 'ad_hoc_WTA'
|
|
42
|
+
:param psi_fn: function to project the neuron input onto the key space
|
|
43
|
+
:param scramble: triggers the key/memory scrambling routine when upd_k is 'ad_hoc_WTA'
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
super(CNUs, self).__init__()
|
|
47
|
+
assert upd_m in (None, 'WTA'), "Unknown value for upd_m, it must be None or 'WTA'"
|
|
48
|
+
assert upd_k in (None, 'ad_hoc_WTA', 'grad_WTA'), "Unknown value for upd_k, it must be " \
|
|
49
|
+
"None, 'ad_hoc_WTA', or 'grad_WTA'"
|
|
50
|
+
assert upd_m is None or (upd_m == 'WTA' and upd_k is not None), \
|
|
51
|
+
"If upd_m is 'WTA', then upd_k must be ad_hoc_WTA or grad_WTA (it cannot be None)"
|
|
52
|
+
self.q = q
|
|
53
|
+
self.d = d
|
|
54
|
+
self.m = m
|
|
55
|
+
self.u = u
|
|
56
|
+
self.gamma_alpha = gamma_alpha
|
|
57
|
+
self.tau_alpha = tau_alpha
|
|
58
|
+
self.tau_mu = tau_mu
|
|
59
|
+
self.tau_eta = tau_eta
|
|
60
|
+
self.scramble = scramble
|
|
61
|
+
self.delta = min(delta, self.m)
|
|
62
|
+
self.upd_m = upd_m
|
|
63
|
+
self.upd_k = upd_k
|
|
64
|
+
self.beta_k = beta_k
|
|
65
|
+
self.psi_fn = psi_fn
|
|
66
|
+
self.debug = False # Temporarily used
|
|
67
|
+
self.reset_memories = True
|
|
68
|
+
|
|
69
|
+
# Creating keys (self.K) and memories (self.M)
|
|
70
|
+
self.M = torch.nn.Parameter(torch.empty((self.q, self.m, self.u), dtype=torch.float32))
|
|
71
|
+
if self.upd_k == "ad_hoc_WTA":
|
|
72
|
+
self.register_buffer('K', torch.zeros((self.q, self.m, self.d)))
|
|
73
|
+
else:
|
|
74
|
+
self.K = torch.nn.Parameter(torch.empty((self.q, self.m, self.d), dtype=torch.float32))
|
|
75
|
+
|
|
76
|
+
# Buffers for ad_hoc_WTA key updates (average usefulness register buffer "mu" and age "eta")
|
|
77
|
+
if self.upd_k == "ad_hoc_WTA":
|
|
78
|
+
self.register_buffer('mu', torch.zeros(self.q, m, dtype=torch.float))
|
|
79
|
+
self.register_buffer('eta', torch.ones((self.q, m), dtype=torch.float) * self.tau_eta)
|
|
80
|
+
if self.debug:
|
|
81
|
+
self.register_buffer('key_counter', torch.zeros(self.q, m, dtype=torch.float))
|
|
82
|
+
else:
|
|
83
|
+
self.mu = None
|
|
84
|
+
self.eta = None
|
|
85
|
+
|
|
86
|
+
# Scrambling stats
|
|
87
|
+
self.register_buffer('scrambling_count', torch.zeros(self.q, dtype=torch.long))
|
|
88
|
+
|
|
89
|
+
# Initializing memories and keys
|
|
90
|
+
self.reset_parameters()
|
|
91
|
+
|
|
92
|
+
def reset_parameters(self):
|
|
93
|
+
self.__reset_keys()
|
|
94
|
+
if self.reset_memories:
|
|
95
|
+
self.__reset_memories()
|
|
96
|
+
self.__reset_counters()
|
|
97
|
+
|
|
98
|
+
def compute_weights(self, x):
|
|
99
|
+
|
|
100
|
+
# Shortcuts (notice that "self.delta" is called "k" in shortcuts, while "self.delta-1" is called "z")
|
|
101
|
+
q, m, u, d, k = self.q, self.m, self.u, self.d, self.delta
|
|
102
|
+
b = x.shape[0]
|
|
103
|
+
M_qmu = self.M
|
|
104
|
+
|
|
105
|
+
# Ensuring keys are normalized (not needed with ad_hoc_WTA updates)
|
|
106
|
+
if self.upd_k != 'ad_hoc_WTA':
|
|
107
|
+
self.__normalize_keys()
|
|
108
|
+
else:
|
|
109
|
+
x = x.detach() # In ad-hoc WTA, no gradient is propagated to the layers below through key-matching
|
|
110
|
+
|
|
111
|
+
# Mapping the input to the key space using the psi function
|
|
112
|
+
x_bd = psi(x, self.psi_fn, key_size=d, normalize=True)
|
|
113
|
+
|
|
114
|
+
# Finding the top responses and indices for the attention procedure
|
|
115
|
+
top_responses_bqk, top_indices_bqk = self.__top_k_attention(x_bd)
|
|
116
|
+
|
|
117
|
+
# Probabilities
|
|
118
|
+
top_alpha_bqk = torch.softmax((self.gamma_alpha / math.sqrt(d)) * top_responses_bqk, dim=2)
|
|
119
|
+
|
|
120
|
+
if self.debug:
|
|
121
|
+
|
|
122
|
+
# Getting the top-1 indices for the current mini-batch
|
|
123
|
+
top1_indices_qb = top_indices_bqk[..., 0].t()
|
|
124
|
+
self.key_counter.data.scatter_add_(dim=1,
|
|
125
|
+
index=top1_indices_qb,
|
|
126
|
+
src=torch.ones_like(top1_indices_qb, dtype=self.key_counter.dtype))
|
|
127
|
+
|
|
128
|
+
# Updating keys with the ad-hoc scheme (also refreshing top-stuff: responses, indices, alpha)
|
|
129
|
+
if self.training and self.upd_k == 'ad_hoc_WTA':
|
|
130
|
+
top_responses_bqk, top_indices_bqk, top_alpha_bqk = \
|
|
131
|
+
self.__update_keys_and_counters(x_bd, top_responses_bqk, top_indices_bqk, top_alpha_bqk)
|
|
132
|
+
|
|
133
|
+
# Reading memories and blending them
|
|
134
|
+
if self.upd_m is None:
|
|
135
|
+
|
|
136
|
+
# Preparing to read memory units and to blend them
|
|
137
|
+
M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)
|
|
138
|
+
|
|
139
|
+
# Getting top memory units
|
|
140
|
+
top_M_bqku = torch.gather(M_exp_bqmu, dim=2,
|
|
141
|
+
index=top_indices_bqk.view(b, q, k, 1).expand(b, q, k, u))
|
|
142
|
+
|
|
143
|
+
# Mixing memory units by attention scores
|
|
144
|
+
# -> top_alpha_bqk: [b,q,k], that we un-squeeze to [b,q,1,k]
|
|
145
|
+
# -> top_M_bqku: [b,q,k,u]
|
|
146
|
+
# -> W_bqu: matmul([(b,q),1,k], [(b,q),k,u]) = [b,q,1,u] that we squeeze to [b,q,u]
|
|
147
|
+
W_bqu = torch.matmul(top_alpha_bqk.view(b, q, 1, k), top_M_bqku).squeeze(2)
|
|
148
|
+
|
|
149
|
+
elif self.upd_m == 'WTA':
|
|
150
|
+
|
|
151
|
+
# Preparing to read memory units and to blend them
|
|
152
|
+
M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)
|
|
153
|
+
|
|
154
|
+
# Dealing with top-1 stuff
|
|
155
|
+
top1_M_exp_bq1u = torch.gather(M_exp_bqmu, dim=2,
|
|
156
|
+
index=top_indices_bqk[..., 0:1].view(b, q, 1, 1).expand(b, q, 1, u))
|
|
157
|
+
|
|
158
|
+
# Mixing memory units by attention scores
|
|
159
|
+
# -> top1_alpha_bqk: [b,q,k], that we select to [b,k,1] un-squeeze to [b,q,1,1]
|
|
160
|
+
# -> top1_M_exp_bq1u: [b,q,1,u]
|
|
161
|
+
# -> W_bqu: [b,q,1,1] * [b,q,1,u] = [b,q,1,u] that we squeeze to [b,q,u]
|
|
162
|
+
top1_W_bqu = (top_alpha_bqk[..., 0:1].view(b, q, 1, 1) * top1_M_exp_bq1u).squeeze(2)
|
|
163
|
+
|
|
164
|
+
# Dealing with top-2-and-following stuff
|
|
165
|
+
top2on_M_exp_bqzu = torch.gather(M_exp_bqmu.detach(), dim=2,
|
|
166
|
+
index=top_indices_bqk[..., 1:].view(b, q, k-1, 1).expand(b, q, k-1, u))
|
|
167
|
+
top2on_alpha_bqz = top_alpha_bqk[:, :, 1:]
|
|
168
|
+
if self.upd_k == 'grad_WTA':
|
|
169
|
+
top2on_alpha_bqz = top2on_alpha_bqz.detach()
|
|
170
|
+
|
|
171
|
+
# Mixing memory units by attention scores
|
|
172
|
+
# -> top2on_alpha_bqz: [b,q,k-1], that we un-squeeze to [b,q,1,k-1]
|
|
173
|
+
# -> top2on_M_exp_bqzu: [b,q,k-1,u]
|
|
174
|
+
# -> W_bqu: matmul([(b,q),1,k-1], [(b,q),k-1,u]) = [b,q,1,u] that we squeeze to [b,q,u]
|
|
175
|
+
top2on_W_bqu = torch.matmul(top2on_alpha_bqz.view(b, q, 1, k-1), top2on_M_exp_bqzu).squeeze(2)
|
|
176
|
+
|
|
177
|
+
# Merging top1 and top-2-and-following stuff
|
|
178
|
+
W_bqu = top1_W_bqu + top2on_W_bqu
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
|
|
182
|
+
# What is going on?
|
|
183
|
+
raise NotImplementedError
|
|
184
|
+
|
|
185
|
+
return W_bqu
|
|
186
|
+
|
|
187
|
+
def forward(self, x):
|
|
188
|
+
raise NotImplementedError
|
|
189
|
+
|
|
190
|
+
@torch.no_grad()
|
|
191
|
+
def __reset_memories(self):
|
|
192
|
+
bound = 1. / math.sqrt(self.u)
|
|
193
|
+
torch.nn.init.uniform_(self.M, -bound, bound)
|
|
194
|
+
|
|
195
|
+
@torch.no_grad()
|
|
196
|
+
def __reset_keys(self):
|
|
197
|
+
bound = 1. / math.sqrt(self.d)
|
|
198
|
+
torch.nn.init.uniform_(self.K, -bound, bound)
|
|
199
|
+
self.K.data = F.normalize(self.K, p=2.0, dim=2, eps=1e-12, out=None)
|
|
200
|
+
|
|
201
|
+
def __reset_counters(self):
|
|
202
|
+
if self.mu is not None:
|
|
203
|
+
self.mu = torch.zeros_like(self.mu)
|
|
204
|
+
if self.eta is not None:
|
|
205
|
+
self.eta.fill_(self.tau_eta)
|
|
206
|
+
|
|
207
|
+
def reset_counter(self):
|
|
208
|
+
if self.debug:
|
|
209
|
+
self.key_counter.data = torch.zeros_like(self.key_counter)
|
|
210
|
+
|
|
211
|
+
def __top_k_attention(self, x_bd):
|
|
212
|
+
|
|
213
|
+
# Matmul([b,d], [d,qm]) = [b,qm], then reshaped (view) to [b,q,m]
|
|
214
|
+
responses_bqm = torch.matmul(x_bd, self.K.view(self.q * self.m, self.d).t()).view(-1, self.q, self.m)
|
|
215
|
+
top_responses_bqk, top_indices_bqk = torch.topk(responses_bqm, k=self.delta, dim=2, largest=True, sorted=True)
|
|
216
|
+
return top_responses_bqk, top_indices_bqk
|
|
217
|
+
|
|
218
|
+
@torch.no_grad()
|
|
219
|
+
def __normalize_keys(self, ids=None):
|
|
220
|
+
"""
|
|
221
|
+
:param ids: None or a vector with "self.q" elements, with the indices of the keys to consider for each neuron
|
|
222
|
+
"""
|
|
223
|
+
if ids is not None:
|
|
224
|
+
ids_exp_q1d = ids.view(self.q, 1, 1).expand(self.q, 1, self.d)
|
|
225
|
+
keys_q1d = torch.gather(self.K, dim=1, index=ids_exp_q1d)
|
|
226
|
+
keys_q1d = F.normalize(keys_q1d, p=2.0, dim=2, eps=1e-12, out=None)
|
|
227
|
+
self.K.scatter_(dim=1, index=ids_exp_q1d, src=keys_q1d)
|
|
228
|
+
return keys_q1d
|
|
229
|
+
else:
|
|
230
|
+
self.K.data = F.normalize(self.K, p=2.0, dim=2, eps=1e-12, out=None)
|
|
231
|
+
return self.K
|
|
232
|
+
|
|
233
|
+
def __update_keys_and_counters(self, x_bd, top_responses_bqk, top_indices_bqk, top_alphas_bqk):
|
|
234
|
+
|
|
235
|
+
# Saving some shortcuts, notice that "self.delta" is called "k" here
|
|
236
|
+
b, q, d, m, k, u = x_bd.shape[0], self.q, self.d, self.m, self.delta, self.u
|
|
237
|
+
K_qmd = self.K
|
|
238
|
+
mu_qm, eta_qm = self.mu, self.eta
|
|
239
|
+
|
|
240
|
+
# Getting the top-1 indices for the current mini-batch
|
|
241
|
+
top1_indices_qb = top_indices_bqk[..., 0].t()
|
|
242
|
+
|
|
243
|
+
# Temporarily used
|
|
244
|
+
if self.debug:
|
|
245
|
+
K_initial_qmd = self.K.clone()
|
|
246
|
+
M_initial_qmd = self.M.clone()
|
|
247
|
+
mu_initial_qm = self.mu.clone()
|
|
248
|
+
eta_initial_qm = self.eta.clone()
|
|
249
|
+
top1_indices_initial_qb = top1_indices_qb.clone()
|
|
250
|
+
else:
|
|
251
|
+
K_initial_qmd, M_initial_qmd, mu_initial_qm = None, None, None
|
|
252
|
+
eta_initial_qm, top1_indices_initial_qb = None, None
|
|
253
|
+
|
|
254
|
+
# Determining if we need to scramble keys and memories (if scrambling is enabled)
|
|
255
|
+
# (up to one key/memory per neuron, even when the batch size is greater than one)
|
|
256
|
+
if self.scramble:
|
|
257
|
+
|
|
258
|
+
# Computing a boolean mask that tells what neurons should be subject to scrambling (scramble_q),
|
|
259
|
+
# a boolean mask associated with the weak keys,
|
|
260
|
+
# and the indices of the elements of x_bd (batch elements) that should replace the
|
|
261
|
+
# scrambled keys (weak_batch_elements_q)
|
|
262
|
+
scramble_q, weak_keys_mask_qm, weak_batch_elements_q = \
|
|
263
|
+
self.__evaluate_scrambling_conditions(top_responses_bqk)
|
|
264
|
+
|
|
265
|
+
# Temporarily used
|
|
266
|
+
if self.debug:
|
|
267
|
+
self.__debug_pre_scrambling(top_indices_bqk, top_alphas_bqk, scramble_q,
|
|
268
|
+
weak_keys_mask_qm, weak_batch_elements_q, x_bd)
|
|
269
|
+
|
|
270
|
+
# If at least one neuron requires a scrambling operation, we do scramble! (in-place)
|
|
271
|
+
if torch.any(scramble_q):
|
|
272
|
+
self.__scramble(x_bd, scramble_q, weak_keys_mask_qm, weak_batch_elements_q, top1_indices_qb)
|
|
273
|
+
|
|
274
|
+
# Temporarily used
|
|
275
|
+
if self.debug:
|
|
276
|
+
self.__debug_changes_with_respect_to(K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
|
|
277
|
+
top1_indices_initial_qb, top1_indices_qb,
|
|
278
|
+
msg="Right after scrambling...")
|
|
279
|
+
|
|
280
|
+
# Computing variations to apply to winning keys, eventually using an adaptive learning rate
|
|
281
|
+
key_variations_qbd = (self.beta_k * x_bd).view(1, b, d).expand(q, b, d) # Delta to add to each winning key
|
|
282
|
+
|
|
283
|
+
# Updating the winning keys (one winning key per neuron, for each batch example)
|
|
284
|
+
K_qmd.scatter_add_(dim=1,
|
|
285
|
+
index=top1_indices_qb.view(q, b, 1).expand(q, b, d),
|
|
286
|
+
src=key_variations_qbd) # Adding variations to winning keys
|
|
287
|
+
|
|
288
|
+
# Recomputing or updating responses (in the case of batch size > 1, we recompute them all)
|
|
289
|
+
top_responses_bqk, top_indices_bqk, top1_indices_qb = \
|
|
290
|
+
self.__update_top_k_attention(x_bd, top_responses_bqk, top_indices_bqk, top1_indices_qb)
|
|
291
|
+
|
|
292
|
+
# Recomputing the softmax over the (now updated) top responses
|
|
293
|
+
top_alphas_bqk = torch.softmax((self.gamma_alpha / math.sqrt(d)) * top_responses_bqk, dim=2) # [b,q,k]
|
|
294
|
+
|
|
295
|
+
# Resetting ages of winning keys
|
|
296
|
+
eta_qm.scatter_(1, top1_indices_qb, 0.)
|
|
297
|
+
|
|
298
|
+
# Updating counters: usages ("mu") for the winning keys and ages ("eta") for all the keys
|
|
299
|
+
mu_qm.scatter_add_(dim=1,
|
|
300
|
+
index=top1_indices_qb,
|
|
301
|
+
src=torch.ones_like(top1_indices_qb, dtype=mu_qm.dtype)) # Winning keys
|
|
302
|
+
eta_qm += b # All the keys
|
|
303
|
+
|
|
304
|
+
# Temporarily used
|
|
305
|
+
if self.debug:
|
|
306
|
+
self.__debug_changes_with_respect_to(K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
|
|
307
|
+
top1_indices_initial_qb, top1_indices_qb,
|
|
308
|
+
msg="At the end of the whole key-and-counters update procedure...")
|
|
309
|
+
|
|
310
|
+
return top_responses_bqk, top_indices_bqk, top_alphas_bqk
|
|
311
|
+
|
|
312
|
+
def __evaluate_scrambling_conditions(self, top_responses_bqk):
|
|
313
|
+
|
|
314
|
+
# Shortcuts
|
|
315
|
+
mu_qm = self.mu
|
|
316
|
+
eta_qm = self.eta
|
|
317
|
+
|
|
318
|
+
# Computing the max of alphas (we take the smallest of them in case of batch sizes greater than one)
|
|
319
|
+
top1_responses_bq = top_responses_bqk[..., 0] # Max of responses
|
|
320
|
+
max_of_responses_q, weak_batch_elements_q = torch.min(top1_responses_bq, dim=0)
|
|
321
|
+
|
|
322
|
+
# Finding the weak keys, if any (boolean mask)
|
|
323
|
+
weak_keys_mask_qm = torch.logical_and(mu_qm < self.tau_mu, eta_qm >= self.tau_eta)
|
|
324
|
+
|
|
325
|
+
# Determining on what neurons scrambling should be applied (boolean mask)
|
|
326
|
+
scramble_q = torch.logical_and(max_of_responses_q < self.tau_alpha, torch.any(weak_keys_mask_qm, dim=1))
|
|
327
|
+
|
|
328
|
+
return scramble_q, weak_keys_mask_qm, weak_batch_elements_q
|
|
329
|
+
|
|
330
|
+
def __scramble(self, x_bd, scramble_q, weak_keys_mask_qm, weak_batch_elements_q, top1_indices_qb):
|
|
331
|
+
|
|
332
|
+
# Shortcuts
|
|
333
|
+
q, d, u = self.q, self.d, self.u
|
|
334
|
+
K_qmd, M_qmu, mu_qm, eta_qm = self.K, self.M, self.mu, self.eta
|
|
335
|
+
|
|
336
|
+
# Stats
|
|
337
|
+
self.scrambling_count[scramble_q] += 1
|
|
338
|
+
|
|
339
|
+
# Finding the indices of the candidate keys to scramble (one per neuron)
|
|
340
|
+
scramble_candidates_keys_q = torch.max(eta_qm * weak_keys_mask_qm.to(torch.float), dim=1)[1]
|
|
341
|
+
|
|
342
|
+
# Neurons that must and must-not be subject to scrambling operations
|
|
343
|
+
scramble_q = scramble_q.to(torch.float)
|
|
344
|
+
no_scramble_q = torch.logical_not(scramble_q).to(torch.float) # Boolean mask
|
|
345
|
+
|
|
346
|
+
# Scrambling keys
|
|
347
|
+
scramble_candidates_keys_exp_q1d = scramble_candidates_keys_q.view(q, 1, 1).expand(q, 1, d)
|
|
348
|
+
scramble_q11 = scramble_q.view(q, 1, 1)
|
|
349
|
+
no_scramble_q11 = no_scramble_q.view(q, 1, 1)
|
|
350
|
+
|
|
351
|
+
new_keys_q1d = x_bd.gather(dim=0, index=weak_batch_elements_q.view(q, 1).expand(q, d)).view(q, 1, d)
|
|
352
|
+
old_keys_q1d = K_qmd.gather(dim=1, index=scramble_candidates_keys_exp_q1d)
|
|
353
|
+
|
|
354
|
+
K_qmd.scatter_(dim=1,
|
|
355
|
+
index=scramble_candidates_keys_exp_q1d,
|
|
356
|
+
src=new_keys_q1d * scramble_q11 + old_keys_q1d * no_scramble_q11)
|
|
357
|
+
|
|
358
|
+
# Scrambling memories
|
|
359
|
+
with torch.no_grad():
|
|
360
|
+
scramble_candidates_keys_exp_q1u = scramble_candidates_keys_q.view(q, 1, 1).expand(q, 1, u)
|
|
361
|
+
weak_keys_q1 = top1_indices_qb.gather(dim=1, index=weak_batch_elements_q.view(q, 1))
|
|
362
|
+
weak_keys_exp_q1u = weak_keys_q1.view(q, 1, 1).expand(q, 1, u)
|
|
363
|
+
|
|
364
|
+
new_memories_q1u = M_qmu.gather(dim=1, index=weak_keys_exp_q1u)
|
|
365
|
+
old_memories_q1u = M_qmu.gather(dim=1, index=scramble_candidates_keys_exp_q1u)
|
|
366
|
+
|
|
367
|
+
M_qmu.scatter_(
|
|
368
|
+
dim=1,
|
|
369
|
+
index=scramble_candidates_keys_exp_q1u,
|
|
370
|
+
src=new_memories_q1u * scramble_q11 + old_memories_q1u * no_scramble_q11)
|
|
371
|
+
|
|
372
|
+
# Updating indices of the winning keys (in-place!)
|
|
373
|
+
weak_batch_elements_q1 = weak_batch_elements_q.view(q, 1)
|
|
374
|
+
scramble_q1 = scramble_q.view(q, 1)
|
|
375
|
+
no_scramble_q1 = no_scramble_q.view(q, 1)
|
|
376
|
+
|
|
377
|
+
new_top1_indices_q1 = scramble_candidates_keys_q.view(q, 1)
|
|
378
|
+
old_top1_indices_q1 = top1_indices_qb.gather(dim=1, index=weak_batch_elements_q1)
|
|
379
|
+
|
|
380
|
+
top1_indices_qb.scatter_(
|
|
381
|
+
dim=1, index=weak_batch_elements_q1,
|
|
382
|
+
src=(new_top1_indices_q1 * scramble_q1 + old_top1_indices_q1 * no_scramble_q1).to(torch.long))
|
|
383
|
+
|
|
384
|
+
# Resetting to zero the usage counts ("mu") for keys that were scrambled
|
|
385
|
+
scramble_candidates_keys_q1 = scramble_candidates_keys_q.view(q, 1)
|
|
386
|
+
|
|
387
|
+
new_values = 0.
|
|
388
|
+
old_values_q1 = mu_qm.gather(dim=1, index=scramble_candidates_keys_q1)
|
|
389
|
+
|
|
390
|
+
mu_qm.scatter_(
|
|
391
|
+
dim=1,
|
|
392
|
+
index=scramble_candidates_keys_q1,
|
|
393
|
+
src=new_values * scramble_q1 + old_values_q1 * no_scramble_q1)
|
|
394
|
+
|
|
395
|
+
def __update_top_k_attention(self, x_bd, top_responses_bqk, top_indices_bqk, top1_indices_qb):
|
|
396
|
+
q, d, b = self.q, self.d, x_bd.shape[0]
|
|
397
|
+
|
|
398
|
+
if b == 1:
|
|
399
|
+
|
|
400
|
+
# Normalizing the winning keys (recall that b=1 here)
|
|
401
|
+
normalized_winning_keys_q1d = self.__normalize_keys(ids=top1_indices_qb) # Recall that b=1 here
|
|
402
|
+
|
|
403
|
+
# Updating responses (the response with the updated key is recomputed, for each neuron)
|
|
404
|
+
response_winning_bq1 = torch.matmul(x_bd, normalized_winning_keys_q1d.view(q, d).t()).view(b, q, 1)
|
|
405
|
+
top_responses_bqk[:, :, 0] = response_winning_bq1.squeeze()
|
|
406
|
+
|
|
407
|
+
elif b > 1:
|
|
408
|
+
|
|
409
|
+
# Normalizing all the keys (for simplicity - with large batch sizes it is likely easier/faster)
|
|
410
|
+
self.__normalize_keys()
|
|
411
|
+
|
|
412
|
+
# Recomputing all responses, re-determining the top-responses
|
|
413
|
+
top_responses_bqk, top_indices_bqk = self.__top_k_attention(x_bd)
|
|
414
|
+
|
|
415
|
+
# Re-transposing top-1 indices
|
|
416
|
+
top1_indices_qb = top_indices_bqk[..., 0].t()
|
|
417
|
+
|
|
418
|
+
return top_responses_bqk, top_indices_bqk, top1_indices_qb
|
|
419
|
+
|
|
420
|
+
def __str__(self):
|
|
421
|
+
s = "- q = " + str(self.q) + "\n"
|
|
422
|
+
s += "- d = " + str(self.d) + "\n"
|
|
423
|
+
s += "- m = " + str(self.m) + "\n"
|
|
424
|
+
s += "- u = " + str(self.u) + "\n"
|
|
425
|
+
s += "- delta = " + str(self.delta) + "\n"
|
|
426
|
+
s += "- gamma_alpha = " + str(self.gamma_alpha) + "\n"
|
|
427
|
+
s += "- tau_alpha = " + str(self.tau_alpha) + "\n"
|
|
428
|
+
s += "- tau_mu = " + str(self.tau_mu) + "\n"
|
|
429
|
+
s += "- tau_eta = " + str(self.tau_eta) + "\n"
|
|
430
|
+
s += "- upd_m = " + str(self.upd_m) + "\n"
|
|
431
|
+
s += "- upd_k = " + str(self.upd_k) + "\n"
|
|
432
|
+
s += "- beta_k = " + str(self.beta_k) + "\n"
|
|
433
|
+
s += "- psi_fn = " + self.psi_fn + "\n"
|
|
434
|
+
s += "- scramble = " + str(self.scramble)
|
|
435
|
+
return s
|
|
436
|
+
|
|
437
|
+
def __debug_pre_scrambling(self, top_indices_bqk, top_alphas_bqk, scramble_q,
|
|
438
|
+
weak_keys_mask_qm, weak_batch_elements_q, x_bd):
|
|
439
|
+
b, q, k, d, m = top_indices_bqk.shape[0], self.q, self.delta, self.d, self.m
|
|
440
|
+
mu_qm, eta_qm = self.mu, self.eta
|
|
441
|
+
|
|
442
|
+
# Finding the indices of the candidate keys to scramble (one per neuron)
|
|
443
|
+
scramble_candidates_keys_q = torch.max(eta_qm * weak_keys_mask_qm.to(torch.float), dim=1)[1]
|
|
444
|
+
|
|
445
|
+
print("*** __debug_pre_scrambling ***")
|
|
446
|
+
print("No scrambling is applied at all...not yet...")
|
|
447
|
+
s = ""
|
|
448
|
+
torch.set_printoptions(profile='full', linewidth=2000)
|
|
449
|
+
for i in range(0, b):
|
|
450
|
+
if i > 0:
|
|
451
|
+
s += "\n"
|
|
452
|
+
for j in range(0, q):
|
|
453
|
+
if j > 0:
|
|
454
|
+
s += "\n"
|
|
455
|
+
s += "[batch element " + str(i)
|
|
456
|
+
s += ", neuron " + str(j) + "] " if q > 1 else "]"
|
|
457
|
+
s += "\n- tau_alpha, tau_mu, tau_eta: " + str(self.tau_alpha) + ", " \
|
|
458
|
+
+ str(self.tau_mu) + ", " + str(self.tau_eta)
|
|
459
|
+
s += "\n- mu: " + str(mu_qm[j, :])
|
|
460
|
+
s += "\n- eta: " + str(eta_qm[j, :])
|
|
461
|
+
s += "\n- top-keys (alphas): "
|
|
462
|
+
for a in range(0, k):
|
|
463
|
+
s += str(top_indices_bqk[i, j, a].item())
|
|
464
|
+
s += " ({0:.3g})".format(top_alphas_bqk[i, j, a].item())
|
|
465
|
+
if a < k - 1:
|
|
466
|
+
s += ", "
|
|
467
|
+
s += "\n- scramble? " + str(scramble_q[j].item())
|
|
468
|
+
if scramble_q[j].item() is True:
|
|
469
|
+
s += "\n - what key? " + str(scramble_candidates_keys_q[j].item())
|
|
470
|
+
s += "\n - with what batch element? " + str(weak_batch_elements_q[j].item())
|
|
471
|
+
s += "\n - data of such element? "
|
|
472
|
+
s += str(x_bd[weak_batch_elements_q[j], 0:min(3, d)])
|
|
473
|
+
if d > 3:
|
|
474
|
+
s += " ~printing only the first 3 components"
|
|
475
|
+
print(s)
|
|
476
|
+
torch.set_printoptions(profile='default')
|
|
477
|
+
|
|
478
|
+
def __debug_changes_with_respect_to(self, K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
|
|
479
|
+
top1_indices_initial_qb, top1_indices_qb, msg=None):
|
|
480
|
+
q, m, b, u, d = self.q, self.m, top1_indices_initial_qb.shape[1], self.u, self.d
|
|
481
|
+
|
|
482
|
+
changed_keys_qm = torch.greater(torch.max(torch.abs(self.K - K_initial_qmd), dim=2)[0], 1e-5)
|
|
483
|
+
changed_memories_qm = torch.greater(torch.max(torch.abs(self.M - M_initial_qmd), dim=2)[0], 0.)
|
|
484
|
+
changed_top1_indices_qb = torch.greater(torch.abs(top1_indices_initial_qb.to(torch.float) -
|
|
485
|
+
top1_indices_qb.to(torch.float)), 0.)
|
|
486
|
+
changed_mus_qm = torch.greater(torch.abs(self.mu - mu_initial_qm), 0.)
|
|
487
|
+
changed_etas_qm = torch.greater(torch.abs(self.eta - eta_initial_qm), 0.)
|
|
488
|
+
|
|
489
|
+
print("*** __debug_changes_with_respect_to ***")
|
|
490
|
+
if msg is not None:
|
|
491
|
+
print(msg)
|
|
492
|
+
s = ""
|
|
493
|
+
for j in range(0, q):
|
|
494
|
+
if j > 0:
|
|
495
|
+
s += "\n"
|
|
496
|
+
s += "[neuron " + str(j) + "]"
|
|
497
|
+
num_changed_keys = torch.sum(changed_keys_qm[j, :]).item()
|
|
498
|
+
num_changed_memories = torch.sum(changed_memories_qm[j, :]).item()
|
|
499
|
+
num_changed_top1_indices = torch.sum(changed_top1_indices_qb[j, :]).item()
|
|
500
|
+
num_changed_mus = torch.sum(changed_mus_qm[j, :]).item()
|
|
501
|
+
num_changed_etas = torch.sum(changed_etas_qm[j, :]).item()
|
|
502
|
+
s += "\n- #changed keys: " + str(num_changed_keys)
|
|
503
|
+
if num_changed_keys > 0:
|
|
504
|
+
if d > 3:
|
|
505
|
+
s += " ~printing only the first 3 components"
|
|
506
|
+
for t in range(0, m):
|
|
507
|
+
if changed_keys_qm[j, t].item() is True:
|
|
508
|
+
s += "\n - key " + str(t) + ": " + str(self.K[j, t, 0:min(3, d)])
|
|
509
|
+
s += " (it was: " + str(K_initial_qmd[j, t, 0:min(3, d)]) + ")"
|
|
510
|
+
s += "\n- #changed memory units: " + str(num_changed_memories)
|
|
511
|
+
if num_changed_memories > 0:
|
|
512
|
+
if u > 3:
|
|
513
|
+
s += " ~printing only the first 3 components"
|
|
514
|
+
for t in range(0, m):
|
|
515
|
+
if changed_memories_qm[j, t].item() is True:
|
|
516
|
+
s += "\n - mem " + str(t) + ": " + str(self.M[j, t, 0:min(3, u)])
|
|
517
|
+
s += " (it was: " + str(M_initial_qmd[j, t, 0:min(3, u)]) + ")"
|
|
518
|
+
s += "\n- #changed top1 indices: " + str(num_changed_top1_indices)
|
|
519
|
+
if num_changed_top1_indices > 0:
|
|
520
|
+
for i in range(0, b):
|
|
521
|
+
if changed_top1_indices_qb[j, i].item() is True:
|
|
522
|
+
s += "\n - batch element " + str(i) + ": " + str(top1_indices_qb[j, i].item())
|
|
523
|
+
s += " (it was: " + str(top1_indices_initial_qb[j, i].item()) + ")"
|
|
524
|
+
s += "\n- #changed mus: " + str(num_changed_mus)
|
|
525
|
+
if num_changed_mus > 0:
|
|
526
|
+
for t in range(0, m):
|
|
527
|
+
if changed_mus_qm[j, t].item() is True:
|
|
528
|
+
s += "\n - mu " + str(t) + ": " + str(self.mu[j, t].item())
|
|
529
|
+
s += " (it was: " + str(mu_initial_qm[j, t].item()) + ")"
|
|
530
|
+
s += "\n- #changed etas: " + str(num_changed_etas)
|
|
531
|
+
if num_changed_etas > 0:
|
|
532
|
+
for t in range(0, m):
|
|
533
|
+
if changed_etas_qm[j, t].item() is True:
|
|
534
|
+
s += "\n - eta " + str(t) + ": " + str(self.eta[j, t].item())
|
|
535
|
+
s += " (it was: " + str(eta_initial_qm[j, t].item()) + ")"
|
|
536
|
+
print(s)
|