unaiverse 0.1.6__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1846 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2223 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +198 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2714 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +265 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cpython-311-darwin.so +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +305 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +175 -0
  46. unaiverse-0.1.6.dist-info/METADATA +365 -0
  47. unaiverse-0.1.6.dist-info/RECORD +50 -0
  48. unaiverse-0.1.6.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,18 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ from . import cnu
16
+ from . import hl
17
+ from . import networks
18
+ from . import utils
@@ -0,0 +1,17 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ from . import cnus
16
+ from . import layers
17
+ from . import psi
@@ -0,0 +1,536 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import math
16
+ import torch
17
+ from .psi import psi
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class CNUs(torch.nn.Module):
22
+
23
+ def __init__(self, q=1, d=2, m=3, u=4, delta=3,
24
+ gamma_alpha=0.1, tau_alpha=0.5, tau_mu=100, tau_eta=100,
25
+ upd_m="WTA", upd_k="ad_hoc_WTA",
26
+ beta_k=0.001,
27
+ psi_fn="identity",
28
+ scramble=False):
29
+ """
30
+ :param q: number of neurons
31
+ :param d: size of each key
32
+ :param m: number of keys/memory units
33
+ :param u: size of each memory unit
34
+ :param gamma_alpha: softmax temperature (key matching)
35
+ :param tau_alpha: threshold on the attention score of the winning key, to eventually trigger scrambling
36
+ :param tau_mu: number of steps below which a key is considered to be not-used enough
37
+ :param tau_eta: number of steps after which a key is considered old
38
+ :param delta: number of top attention responses to select (top-delta)
39
+ :param upd_m: update memory strategy (None, 'WTA')
40
+ :param upd_k: update key strategy (None, 'ad_hoc_WTA', 'grad_WTA')
41
+ :param beta_k: learning rate for key-update purposes when upd_k is 'ad_hoc_WTA'
42
+ :param psi_fn: function to project the neuron input onto the key space
43
+ :param scramble: triggers the key/memory scrambling routine when upd_k is 'ad_hoc_WTA'
44
+ """
45
+
46
+ super(CNUs, self).__init__()
47
+ assert upd_m in (None, 'WTA'), "Unknown value for upd_m, it must be None or 'WTA'"
48
+ assert upd_k in (None, 'ad_hoc_WTA', 'grad_WTA'), "Unknown value for upd_k, it must be " \
49
+ "None, 'ad_hoc_WTA', or 'grad_WTA'"
50
+ assert upd_m is None or (upd_m == 'WTA' and upd_k is not None), \
51
+ "If upd_m is 'WTA', then upd_k must be ad_hoc_WTA or grad_WTA (it cannot be None)"
52
+ self.q = q
53
+ self.d = d
54
+ self.m = m
55
+ self.u = u
56
+ self.gamma_alpha = gamma_alpha
57
+ self.tau_alpha = tau_alpha
58
+ self.tau_mu = tau_mu
59
+ self.tau_eta = tau_eta
60
+ self.scramble = scramble
61
+ self.delta = min(delta, self.m)
62
+ self.upd_m = upd_m
63
+ self.upd_k = upd_k
64
+ self.beta_k = beta_k
65
+ self.psi_fn = psi_fn
66
+ self.debug = False # Temporarily used
67
+ self.reset_memories = True
68
+
69
+ # Creating keys (self.K) and memories (self.M)
70
+ self.M = torch.nn.Parameter(torch.empty((self.q, self.m, self.u), dtype=torch.float32))
71
+ if self.upd_k == "ad_hoc_WTA":
72
+ self.register_buffer('K', torch.zeros((self.q, self.m, self.d)))
73
+ else:
74
+ self.K = torch.nn.Parameter(torch.empty((self.q, self.m, self.d), dtype=torch.float32))
75
+
76
+ # Buffers for ad_hoc_WTA key updates (average usefulness register buffer "mu" and age "eta")
77
+ if self.upd_k == "ad_hoc_WTA":
78
+ self.register_buffer('mu', torch.zeros(self.q, m, dtype=torch.float))
79
+ self.register_buffer('eta', torch.ones((self.q, m), dtype=torch.float) * self.tau_eta)
80
+ if self.debug:
81
+ self.register_buffer('key_counter', torch.zeros(self.q, m, dtype=torch.float))
82
+ else:
83
+ self.mu = None
84
+ self.eta = None
85
+
86
+ # Scrambling stats
87
+ self.register_buffer('scrambling_count', torch.zeros(self.q, dtype=torch.long))
88
+
89
+ # Initializing memories and keys
90
+ self.reset_parameters()
91
+
92
+ def reset_parameters(self):
93
+ self.__reset_keys()
94
+ if self.reset_memories:
95
+ self.__reset_memories()
96
+ self.__reset_counters()
97
+
98
+ def compute_weights(self, x):
99
+
100
+ # Shortcuts (notice that "self.delta" is called "k" in shortcuts, while "self.delta-1" is called "z")
101
+ q, m, u, d, k = self.q, self.m, self.u, self.d, self.delta
102
+ b = x.shape[0]
103
+ M_qmu = self.M
104
+
105
+ # Ensuring keys are normalized (not needed with ad_hoc_WTA updates)
106
+ if self.upd_k != 'ad_hoc_WTA':
107
+ self.__normalize_keys()
108
+ else:
109
+ x = x.detach() # In ad-hoc WTA, no gradient is propagated to the layers below through key-matching
110
+
111
+ # Mapping the input to the key space using the psi function
112
+ x_bd = psi(x, self.psi_fn, key_size=d, normalize=True)
113
+
114
+ # Finding the top responses and indices for the attention procedure
115
+ top_responses_bqk, top_indices_bqk = self.__top_k_attention(x_bd)
116
+
117
+ # Probabilities
118
+ top_alpha_bqk = torch.softmax((self.gamma_alpha / math.sqrt(d)) * top_responses_bqk, dim=2)
119
+
120
+ if self.debug:
121
+
122
+ # Getting the top-1 indices for the current mini-batch
123
+ top1_indices_qb = top_indices_bqk[..., 0].t()
124
+ self.key_counter.data.scatter_add_(dim=1,
125
+ index=top1_indices_qb,
126
+ src=torch.ones_like(top1_indices_qb, dtype=self.key_counter.dtype))
127
+
128
+ # Updating keys with the ad-hoc scheme (also refreshing top-stuff: responses, indices, alpha)
129
+ if self.training and self.upd_k == 'ad_hoc_WTA':
130
+ top_responses_bqk, top_indices_bqk, top_alpha_bqk = \
131
+ self.__update_keys_and_counters(x_bd, top_responses_bqk, top_indices_bqk, top_alpha_bqk)
132
+
133
+ # Reading memories and blending them
134
+ if self.upd_m is None:
135
+
136
+ # Preparing to read memory units and to blend them
137
+ M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)
138
+
139
+ # Getting top memory units
140
+ top_M_bqku = torch.gather(M_exp_bqmu, dim=2,
141
+ index=top_indices_bqk.view(b, q, k, 1).expand(b, q, k, u))
142
+
143
+ # Mixing memory units by attention scores
144
+ # -> top_alpha_bqk: [b,q,k], that we un-squeeze to [b,q,1,k]
145
+ # -> top_M_bqku: [b,q,k,u]
146
+ # -> W_bqu: matmul([(b,q),1,k], [(b,q),k,u]) = [b,q,1,u] that we squeeze to [b,q,u]
147
+ W_bqu = torch.matmul(top_alpha_bqk.view(b, q, 1, k), top_M_bqku).squeeze(2)
148
+
149
+ elif self.upd_m == 'WTA':
150
+
151
+ # Preparing to read memory units and to blend them
152
+ M_exp_bqmu = M_qmu.view(1, q, m, u).expand(b, q, m, u)
153
+
154
+ # Dealing with top-1 stuff
155
+ top1_M_exp_bq1u = torch.gather(M_exp_bqmu, dim=2,
156
+ index=top_indices_bqk[..., 0:1].view(b, q, 1, 1).expand(b, q, 1, u))
157
+
158
+ # Mixing memory units by attention scores
159
+ # -> top1_alpha_bqk: [b,q,k], that we select to [b,k,1] un-squeeze to [b,q,1,1]
160
+ # -> top1_M_exp_bq1u: [b,q,1,u]
161
+ # -> W_bqu: [b,q,1,1] * [b,q,1,u] = [b,q,1,u] that we squeeze to [b,q,u]
162
+ top1_W_bqu = (top_alpha_bqk[..., 0:1].view(b, q, 1, 1) * top1_M_exp_bq1u).squeeze(2)
163
+
164
+ # Dealing with top-2-and-following stuff
165
+ top2on_M_exp_bqzu = torch.gather(M_exp_bqmu.detach(), dim=2,
166
+ index=top_indices_bqk[..., 1:].view(b, q, k-1, 1).expand(b, q, k-1, u))
167
+ top2on_alpha_bqz = top_alpha_bqk[:, :, 1:]
168
+ if self.upd_k == 'grad_WTA':
169
+ top2on_alpha_bqz = top2on_alpha_bqz.detach()
170
+
171
+ # Mixing memory units by attention scores
172
+ # -> top2on_alpha_bqz: [b,q,k-1], that we un-squeeze to [b,q,1,k-1]
173
+ # -> top2on_M_exp_bqzu: [b,q,k-1,u]
174
+ # -> W_bqu: matmul([(b,q),1,k-1], [(b,q),k-1,u]) = [b,q,1,u] that we squeeze to [b,q,u]
175
+ top2on_W_bqu = torch.matmul(top2on_alpha_bqz.view(b, q, 1, k-1), top2on_M_exp_bqzu).squeeze(2)
176
+
177
+ # Merging top1 and top-2-and-following stuff
178
+ W_bqu = top1_W_bqu + top2on_W_bqu
179
+
180
+ else:
181
+
182
+ # What is going on?
183
+ raise NotImplementedError
184
+
185
+ return W_bqu
186
+
187
+ def forward(self, x):
188
+ raise NotImplementedError
189
+
190
+ @torch.no_grad()
191
+ def __reset_memories(self):
192
+ bound = 1. / math.sqrt(self.u)
193
+ torch.nn.init.uniform_(self.M, -bound, bound)
194
+
195
+ @torch.no_grad()
196
+ def __reset_keys(self):
197
+ bound = 1. / math.sqrt(self.d)
198
+ torch.nn.init.uniform_(self.K, -bound, bound)
199
+ self.K.data = F.normalize(self.K, p=2.0, dim=2, eps=1e-12, out=None)
200
+
201
+ def __reset_counters(self):
202
+ if self.mu is not None:
203
+ self.mu = torch.zeros_like(self.mu)
204
+ if self.eta is not None:
205
+ self.eta.fill_(self.tau_eta)
206
+
207
+ def reset_counter(self):
208
+ if self.debug:
209
+ self.key_counter.data = torch.zeros_like(self.key_counter)
210
+
211
+ def __top_k_attention(self, x_bd):
212
+
213
+ # Matmul([b,d], [d,qm]) = [b,qm], then reshaped (view) to [b,q,m]
214
+ responses_bqm = torch.matmul(x_bd, self.K.view(self.q * self.m, self.d).t()).view(-1, self.q, self.m)
215
+ top_responses_bqk, top_indices_bqk = torch.topk(responses_bqm, k=self.delta, dim=2, largest=True, sorted=True)
216
+ return top_responses_bqk, top_indices_bqk
217
+
218
+ @torch.no_grad()
219
+ def __normalize_keys(self, ids=None):
220
+ """
221
+ :param ids: None or a vector with "self.q" elements, with the indices of the keys to consider for each neuron
222
+ """
223
+ if ids is not None:
224
+ ids_exp_q1d = ids.view(self.q, 1, 1).expand(self.q, 1, self.d)
225
+ keys_q1d = torch.gather(self.K, dim=1, index=ids_exp_q1d)
226
+ keys_q1d = F.normalize(keys_q1d, p=2.0, dim=2, eps=1e-12, out=None)
227
+ self.K.scatter_(dim=1, index=ids_exp_q1d, src=keys_q1d)
228
+ return keys_q1d
229
+ else:
230
+ self.K.data = F.normalize(self.K, p=2.0, dim=2, eps=1e-12, out=None)
231
+ return self.K
232
+
233
+ def __update_keys_and_counters(self, x_bd, top_responses_bqk, top_indices_bqk, top_alphas_bqk):
234
+
235
+ # Saving some shortcuts, notice that "self.delta" is called "k" here
236
+ b, q, d, m, k, u = x_bd.shape[0], self.q, self.d, self.m, self.delta, self.u
237
+ K_qmd = self.K
238
+ mu_qm, eta_qm = self.mu, self.eta
239
+
240
+ # Getting the top-1 indices for the current mini-batch
241
+ top1_indices_qb = top_indices_bqk[..., 0].t()
242
+
243
+ # Temporarily used
244
+ if self.debug:
245
+ K_initial_qmd = self.K.clone()
246
+ M_initial_qmd = self.M.clone()
247
+ mu_initial_qm = self.mu.clone()
248
+ eta_initial_qm = self.eta.clone()
249
+ top1_indices_initial_qb = top1_indices_qb.clone()
250
+ else:
251
+ K_initial_qmd, M_initial_qmd, mu_initial_qm = None, None, None
252
+ eta_initial_qm, top1_indices_initial_qb = None, None
253
+
254
+ # Determining if we need to scramble keys and memories (if scrambling is enabled)
255
+ # (up to one key/memory per neuron, even when the batch size is greater than one)
256
+ if self.scramble:
257
+
258
+ # Computing a boolean mask that tells what neurons should be subject to scrambling (scramble_q),
259
+ # a boolean mask associated with the weak keys,
260
+ # and the indices of the elements of x_bd (batch elements) that should replace the
261
+ # scrambled keys (weak_batch_elements_q)
262
+ scramble_q, weak_keys_mask_qm, weak_batch_elements_q = \
263
+ self.__evaluate_scrambling_conditions(top_responses_bqk)
264
+
265
+ # Temporarily used
266
+ if self.debug:
267
+ self.__debug_pre_scrambling(top_indices_bqk, top_alphas_bqk, scramble_q,
268
+ weak_keys_mask_qm, weak_batch_elements_q, x_bd)
269
+
270
+ # If at least one neuron requires a scrambling operation, we do scramble! (in-place)
271
+ if torch.any(scramble_q):
272
+ self.__scramble(x_bd, scramble_q, weak_keys_mask_qm, weak_batch_elements_q, top1_indices_qb)
273
+
274
+ # Temporarily used
275
+ if self.debug:
276
+ self.__debug_changes_with_respect_to(K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
277
+ top1_indices_initial_qb, top1_indices_qb,
278
+ msg="Right after scrambling...")
279
+
280
+ # Computing variations to apply to winning keys, eventually using an adaptive learning rate
281
+ key_variations_qbd = (self.beta_k * x_bd).view(1, b, d).expand(q, b, d) # Delta to add to each winning key
282
+
283
+ # Updating the winning keys (one winning key per neuron, for each batch example)
284
+ K_qmd.scatter_add_(dim=1,
285
+ index=top1_indices_qb.view(q, b, 1).expand(q, b, d),
286
+ src=key_variations_qbd) # Adding variations to winning keys
287
+
288
+ # Recomputing or updating responses (in the case of batch size > 1, we recompute them all)
289
+ top_responses_bqk, top_indices_bqk, top1_indices_qb = \
290
+ self.__update_top_k_attention(x_bd, top_responses_bqk, top_indices_bqk, top1_indices_qb)
291
+
292
+ # Recomputing the softmax over the (now updated) top responses
293
+ top_alphas_bqk = torch.softmax((self.gamma_alpha / math.sqrt(d)) * top_responses_bqk, dim=2) # [b,q,k]
294
+
295
+ # Resetting ages of winning keys
296
+ eta_qm.scatter_(1, top1_indices_qb, 0.)
297
+
298
+ # Updating counters: usages ("mu") for the winning keys and ages ("eta") for all the keys
299
+ mu_qm.scatter_add_(dim=1,
300
+ index=top1_indices_qb,
301
+ src=torch.ones_like(top1_indices_qb, dtype=mu_qm.dtype)) # Winning keys
302
+ eta_qm += b # All the keys
303
+
304
+ # Temporarily used
305
+ if self.debug:
306
+ self.__debug_changes_with_respect_to(K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
307
+ top1_indices_initial_qb, top1_indices_qb,
308
+ msg="At the end of the whole key-and-counters update procedure...")
309
+
310
+ return top_responses_bqk, top_indices_bqk, top_alphas_bqk
311
+
312
+ def __evaluate_scrambling_conditions(self, top_responses_bqk):
313
+
314
+ # Shortcuts
315
+ mu_qm = self.mu
316
+ eta_qm = self.eta
317
+
318
+ # Computing the max of alphas (we take the smallest of them in case of batch sizes greater than one)
319
+ top1_responses_bq = top_responses_bqk[..., 0] # Max of responses
320
+ max_of_responses_q, weak_batch_elements_q = torch.min(top1_responses_bq, dim=0)
321
+
322
+ # Finding the weak keys, if any (boolean mask)
323
+ weak_keys_mask_qm = torch.logical_and(mu_qm < self.tau_mu, eta_qm >= self.tau_eta)
324
+
325
+ # Determining on what neurons scrambling should be applied (boolean mask)
326
+ scramble_q = torch.logical_and(max_of_responses_q < self.tau_alpha, torch.any(weak_keys_mask_qm, dim=1))
327
+
328
+ return scramble_q, weak_keys_mask_qm, weak_batch_elements_q
329
+
330
+ def __scramble(self, x_bd, scramble_q, weak_keys_mask_qm, weak_batch_elements_q, top1_indices_qb):
331
+
332
+ # Shortcuts
333
+ q, d, u = self.q, self.d, self.u
334
+ K_qmd, M_qmu, mu_qm, eta_qm = self.K, self.M, self.mu, self.eta
335
+
336
+ # Stats
337
+ self.scrambling_count[scramble_q] += 1
338
+
339
+ # Finding the indices of the candidate keys to scramble (one per neuron)
340
+ scramble_candidates_keys_q = torch.max(eta_qm * weak_keys_mask_qm.to(torch.float), dim=1)[1]
341
+
342
+ # Neurons that must and must-not be subject to scrambling operations
343
+ scramble_q = scramble_q.to(torch.float)
344
+ no_scramble_q = torch.logical_not(scramble_q).to(torch.float) # Boolean mask
345
+
346
+ # Scrambling keys
347
+ scramble_candidates_keys_exp_q1d = scramble_candidates_keys_q.view(q, 1, 1).expand(q, 1, d)
348
+ scramble_q11 = scramble_q.view(q, 1, 1)
349
+ no_scramble_q11 = no_scramble_q.view(q, 1, 1)
350
+
351
+ new_keys_q1d = x_bd.gather(dim=0, index=weak_batch_elements_q.view(q, 1).expand(q, d)).view(q, 1, d)
352
+ old_keys_q1d = K_qmd.gather(dim=1, index=scramble_candidates_keys_exp_q1d)
353
+
354
+ K_qmd.scatter_(dim=1,
355
+ index=scramble_candidates_keys_exp_q1d,
356
+ src=new_keys_q1d * scramble_q11 + old_keys_q1d * no_scramble_q11)
357
+
358
+ # Scrambling memories
359
+ with torch.no_grad():
360
+ scramble_candidates_keys_exp_q1u = scramble_candidates_keys_q.view(q, 1, 1).expand(q, 1, u)
361
+ weak_keys_q1 = top1_indices_qb.gather(dim=1, index=weak_batch_elements_q.view(q, 1))
362
+ weak_keys_exp_q1u = weak_keys_q1.view(q, 1, 1).expand(q, 1, u)
363
+
364
+ new_memories_q1u = M_qmu.gather(dim=1, index=weak_keys_exp_q1u)
365
+ old_memories_q1u = M_qmu.gather(dim=1, index=scramble_candidates_keys_exp_q1u)
366
+
367
+ M_qmu.scatter_(
368
+ dim=1,
369
+ index=scramble_candidates_keys_exp_q1u,
370
+ src=new_memories_q1u * scramble_q11 + old_memories_q1u * no_scramble_q11)
371
+
372
+ # Updating indices of the winning keys (in-place!)
373
+ weak_batch_elements_q1 = weak_batch_elements_q.view(q, 1)
374
+ scramble_q1 = scramble_q.view(q, 1)
375
+ no_scramble_q1 = no_scramble_q.view(q, 1)
376
+
377
+ new_top1_indices_q1 = scramble_candidates_keys_q.view(q, 1)
378
+ old_top1_indices_q1 = top1_indices_qb.gather(dim=1, index=weak_batch_elements_q1)
379
+
380
+ top1_indices_qb.scatter_(
381
+ dim=1, index=weak_batch_elements_q1,
382
+ src=(new_top1_indices_q1 * scramble_q1 + old_top1_indices_q1 * no_scramble_q1).to(torch.long))
383
+
384
+ # Resetting to zero the usage counts ("mu") for keys that were scrambled
385
+ scramble_candidates_keys_q1 = scramble_candidates_keys_q.view(q, 1)
386
+
387
+ new_values = 0.
388
+ old_values_q1 = mu_qm.gather(dim=1, index=scramble_candidates_keys_q1)
389
+
390
+ mu_qm.scatter_(
391
+ dim=1,
392
+ index=scramble_candidates_keys_q1,
393
+ src=new_values * scramble_q1 + old_values_q1 * no_scramble_q1)
394
+
395
+ def __update_top_k_attention(self, x_bd, top_responses_bqk, top_indices_bqk, top1_indices_qb):
396
+ q, d, b = self.q, self.d, x_bd.shape[0]
397
+
398
+ if b == 1:
399
+
400
+ # Normalizing the winning keys (recall that b=1 here)
401
+ normalized_winning_keys_q1d = self.__normalize_keys(ids=top1_indices_qb) # Recall that b=1 here
402
+
403
+ # Updating responses (the response with the updated key is recomputed, for each neuron)
404
+ response_winning_bq1 = torch.matmul(x_bd, normalized_winning_keys_q1d.view(q, d).t()).view(b, q, 1)
405
+ top_responses_bqk[:, :, 0] = response_winning_bq1.squeeze()
406
+
407
+ elif b > 1:
408
+
409
+ # Normalizing all the keys (for simplicity - with large batch sizes it is likely easier/faster)
410
+ self.__normalize_keys()
411
+
412
+ # Recomputing all responses, re-determining the top-responses
413
+ top_responses_bqk, top_indices_bqk = self.__top_k_attention(x_bd)
414
+
415
+ # Re-transposing top-1 indices
416
+ top1_indices_qb = top_indices_bqk[..., 0].t()
417
+
418
+ return top_responses_bqk, top_indices_bqk, top1_indices_qb
419
+
420
+ def __str__(self):
421
+ s = "- q = " + str(self.q) + "\n"
422
+ s += "- d = " + str(self.d) + "\n"
423
+ s += "- m = " + str(self.m) + "\n"
424
+ s += "- u = " + str(self.u) + "\n"
425
+ s += "- delta = " + str(self.delta) + "\n"
426
+ s += "- gamma_alpha = " + str(self.gamma_alpha) + "\n"
427
+ s += "- tau_alpha = " + str(self.tau_alpha) + "\n"
428
+ s += "- tau_mu = " + str(self.tau_mu) + "\n"
429
+ s += "- tau_eta = " + str(self.tau_eta) + "\n"
430
+ s += "- upd_m = " + str(self.upd_m) + "\n"
431
+ s += "- upd_k = " + str(self.upd_k) + "\n"
432
+ s += "- beta_k = " + str(self.beta_k) + "\n"
433
+ s += "- psi_fn = " + self.psi_fn + "\n"
434
+ s += "- scramble = " + str(self.scramble)
435
+ return s
436
+
437
+ def __debug_pre_scrambling(self, top_indices_bqk, top_alphas_bqk, scramble_q,
438
+ weak_keys_mask_qm, weak_batch_elements_q, x_bd):
439
+ b, q, k, d, m = top_indices_bqk.shape[0], self.q, self.delta, self.d, self.m
440
+ mu_qm, eta_qm = self.mu, self.eta
441
+
442
+ # Finding the indices of the candidate keys to scramble (one per neuron)
443
+ scramble_candidates_keys_q = torch.max(eta_qm * weak_keys_mask_qm.to(torch.float), dim=1)[1]
444
+
445
+ print("*** __debug_pre_scrambling ***")
446
+ print("No scrambling is applied at all...not yet...")
447
+ s = ""
448
+ torch.set_printoptions(profile='full', linewidth=2000)
449
+ for i in range(0, b):
450
+ if i > 0:
451
+ s += "\n"
452
+ for j in range(0, q):
453
+ if j > 0:
454
+ s += "\n"
455
+ s += "[batch element " + str(i)
456
+ s += ", neuron " + str(j) + "] " if q > 1 else "]"
457
+ s += "\n- tau_alpha, tau_mu, tau_eta: " + str(self.tau_alpha) + ", " \
458
+ + str(self.tau_mu) + ", " + str(self.tau_eta)
459
+ s += "\n- mu: " + str(mu_qm[j, :])
460
+ s += "\n- eta: " + str(eta_qm[j, :])
461
+ s += "\n- top-keys (alphas): "
462
+ for a in range(0, k):
463
+ s += str(top_indices_bqk[i, j, a].item())
464
+ s += " ({0:.3g})".format(top_alphas_bqk[i, j, a].item())
465
+ if a < k - 1:
466
+ s += ", "
467
+ s += "\n- scramble? " + str(scramble_q[j].item())
468
+ if scramble_q[j].item() is True:
469
+ s += "\n - what key? " + str(scramble_candidates_keys_q[j].item())
470
+ s += "\n - with what batch element? " + str(weak_batch_elements_q[j].item())
471
+ s += "\n - data of such element? "
472
+ s += str(x_bd[weak_batch_elements_q[j], 0:min(3, d)])
473
+ if d > 3:
474
+ s += " ~printing only the first 3 components"
475
+ print(s)
476
+ torch.set_printoptions(profile='default')
477
+
478
+ def __debug_changes_with_respect_to(self, K_initial_qmd, M_initial_qmd, mu_initial_qm, eta_initial_qm,
479
+ top1_indices_initial_qb, top1_indices_qb, msg=None):
480
+ q, m, b, u, d = self.q, self.m, top1_indices_initial_qb.shape[1], self.u, self.d
481
+
482
+ changed_keys_qm = torch.greater(torch.max(torch.abs(self.K - K_initial_qmd), dim=2)[0], 1e-5)
483
+ changed_memories_qm = torch.greater(torch.max(torch.abs(self.M - M_initial_qmd), dim=2)[0], 0.)
484
+ changed_top1_indices_qb = torch.greater(torch.abs(top1_indices_initial_qb.to(torch.float) -
485
+ top1_indices_qb.to(torch.float)), 0.)
486
+ changed_mus_qm = torch.greater(torch.abs(self.mu - mu_initial_qm), 0.)
487
+ changed_etas_qm = torch.greater(torch.abs(self.eta - eta_initial_qm), 0.)
488
+
489
+ print("*** __debug_changes_with_respect_to ***")
490
+ if msg is not None:
491
+ print(msg)
492
+ s = ""
493
+ for j in range(0, q):
494
+ if j > 0:
495
+ s += "\n"
496
+ s += "[neuron " + str(j) + "]"
497
+ num_changed_keys = torch.sum(changed_keys_qm[j, :]).item()
498
+ num_changed_memories = torch.sum(changed_memories_qm[j, :]).item()
499
+ num_changed_top1_indices = torch.sum(changed_top1_indices_qb[j, :]).item()
500
+ num_changed_mus = torch.sum(changed_mus_qm[j, :]).item()
501
+ num_changed_etas = torch.sum(changed_etas_qm[j, :]).item()
502
+ s += "\n- #changed keys: " + str(num_changed_keys)
503
+ if num_changed_keys > 0:
504
+ if d > 3:
505
+ s += " ~printing only the first 3 components"
506
+ for t in range(0, m):
507
+ if changed_keys_qm[j, t].item() is True:
508
+ s += "\n - key " + str(t) + ": " + str(self.K[j, t, 0:min(3, d)])
509
+ s += " (it was: " + str(K_initial_qmd[j, t, 0:min(3, d)]) + ")"
510
+ s += "\n- #changed memory units: " + str(num_changed_memories)
511
+ if num_changed_memories > 0:
512
+ if u > 3:
513
+ s += " ~printing only the first 3 components"
514
+ for t in range(0, m):
515
+ if changed_memories_qm[j, t].item() is True:
516
+ s += "\n - mem " + str(t) + ": " + str(self.M[j, t, 0:min(3, u)])
517
+ s += " (it was: " + str(M_initial_qmd[j, t, 0:min(3, u)]) + ")"
518
+ s += "\n- #changed top1 indices: " + str(num_changed_top1_indices)
519
+ if num_changed_top1_indices > 0:
520
+ for i in range(0, b):
521
+ if changed_top1_indices_qb[j, i].item() is True:
522
+ s += "\n - batch element " + str(i) + ": " + str(top1_indices_qb[j, i].item())
523
+ s += " (it was: " + str(top1_indices_initial_qb[j, i].item()) + ")"
524
+ s += "\n- #changed mus: " + str(num_changed_mus)
525
+ if num_changed_mus > 0:
526
+ for t in range(0, m):
527
+ if changed_mus_qm[j, t].item() is True:
528
+ s += "\n - mu " + str(t) + ": " + str(self.mu[j, t].item())
529
+ s += " (it was: " + str(mu_initial_qm[j, t].item()) + ")"
530
+ s += "\n- #changed etas: " + str(num_changed_etas)
531
+ if num_changed_etas > 0:
532
+ for t in range(0, m):
533
+ if changed_etas_qm[j, t].item() is True:
534
+ s += "\n - eta " + str(t) + ": " + str(self.eta[j, t].item())
535
+ s += " (it was: " + str(eta_initial_qm[j, t].item()) + ")"
536
+ print(s)