titans-pytorch 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +27 -16
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.8.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.8.dist-info/RECORD +8 -0
- titans_pytorch-0.1.6.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.8.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.8.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -44,6 +44,9 @@ def default(v, d):
|
|
|
44
44
|
def identity(t):
|
|
45
45
|
return t
|
|
46
46
|
|
|
47
|
+
def pair(v):
|
|
48
|
+
return (v, v) if not isinstance(v, tuple) else v
|
|
49
|
+
|
|
47
50
|
def round_down_multiple(seq, mult):
|
|
48
51
|
return seq // mult * mult
|
|
49
52
|
|
|
@@ -161,14 +164,16 @@ class GatedResidualMemoryMLP(Module):
|
|
|
161
164
|
def __init__(
|
|
162
165
|
self,
|
|
163
166
|
dim,
|
|
164
|
-
depth
|
|
167
|
+
depth,
|
|
168
|
+
expansion_factor = 2.
|
|
165
169
|
):
|
|
166
170
|
super().__init__()
|
|
167
|
-
|
|
171
|
+
dim_hidden = int(dim * expansion_factor)
|
|
168
172
|
|
|
169
173
|
self.weights = ParameterList([
|
|
170
174
|
ParameterList([
|
|
171
|
-
Parameter(torch.randn(dim,
|
|
175
|
+
Parameter(torch.randn(dim, dim_hidden)),
|
|
176
|
+
Parameter(torch.randn(dim_hidden, dim)),
|
|
172
177
|
Parameter(torch.randn(dim * 2, dim)),
|
|
173
178
|
]) for _ in range(depth)
|
|
174
179
|
])
|
|
@@ -182,16 +187,17 @@ class GatedResidualMemoryMLP(Module):
|
|
|
182
187
|
self,
|
|
183
188
|
x
|
|
184
189
|
):
|
|
185
|
-
for
|
|
190
|
+
for weight1, weight2, to_gates in self.weights:
|
|
186
191
|
res = x
|
|
187
192
|
|
|
188
|
-
|
|
189
|
-
|
|
193
|
+
hidden = x @ weight1
|
|
194
|
+
hidden = F.silu(hidden)
|
|
195
|
+
branch_out = hidden @ weight2
|
|
190
196
|
|
|
191
197
|
# gated residual
|
|
192
198
|
|
|
193
|
-
gates = cat((
|
|
194
|
-
x = res.lerp(
|
|
199
|
+
gates = cat((branch_out, res), dim = -1) @ to_gates
|
|
200
|
+
x = res.lerp(branch_out, gates.sigmoid())
|
|
195
201
|
|
|
196
202
|
return x @ self.final_proj
|
|
197
203
|
|
|
@@ -287,7 +293,7 @@ class NeuralMemory(Module):
|
|
|
287
293
|
def __init__(
|
|
288
294
|
self,
|
|
289
295
|
dim,
|
|
290
|
-
chunk_size = 1,
|
|
296
|
+
chunk_size: int | tuple[int, int] = 1,
|
|
291
297
|
dim_head = None,
|
|
292
298
|
heads = 1,
|
|
293
299
|
model: Module | None = None,
|
|
@@ -310,6 +316,8 @@ class NeuralMemory(Module):
|
|
|
310
316
|
super().__init__()
|
|
311
317
|
dim_head = default(dim_head, dim)
|
|
312
318
|
|
|
319
|
+
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
|
320
|
+
|
|
313
321
|
# norms
|
|
314
322
|
|
|
315
323
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
@@ -377,6 +385,10 @@ class NeuralMemory(Module):
|
|
|
377
385
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
|
378
386
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
379
387
|
|
|
388
|
+
# `chunk_size` refers to chunk size used for storing to memory model weights
|
|
389
|
+
|
|
390
|
+
chunk_size = self.store_chunk_size
|
|
391
|
+
|
|
380
392
|
# whether to use averaging of chunks, or attention pooling
|
|
381
393
|
|
|
382
394
|
if not attn_pool_chunks:
|
|
@@ -448,11 +460,11 @@ class NeuralMemory(Module):
|
|
|
448
460
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
449
461
|
return_aux_kv_loss = False
|
|
450
462
|
):
|
|
451
|
-
seq_len = seq.shape[-2]
|
|
463
|
+
seq_len, chunk_size = seq.shape[-2], self.store_chunk_size
|
|
452
464
|
|
|
453
465
|
# handle edge case
|
|
454
466
|
|
|
455
|
-
if seq_len <
|
|
467
|
+
if seq_len < chunk_size:
|
|
456
468
|
past_weight, _ = past_state
|
|
457
469
|
return TensorDict(past_weight).clone().zero_(), self.zero
|
|
458
470
|
|
|
@@ -461,8 +473,7 @@ class NeuralMemory(Module):
|
|
|
461
473
|
# curtail sequence by multiple of the chunk size
|
|
462
474
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
463
475
|
|
|
464
|
-
|
|
465
|
-
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
476
|
+
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
|
466
477
|
|
|
467
478
|
seq = seq[:, :round_down_seq_len]
|
|
468
479
|
|
|
@@ -594,12 +605,12 @@ class NeuralMemory(Module):
|
|
|
594
605
|
seq,
|
|
595
606
|
past_weights: dict[str, Tensor] | None = None,
|
|
596
607
|
):
|
|
597
|
-
chunk_size = self.
|
|
608
|
+
chunk_size = self.retrieve_chunk_size
|
|
598
609
|
batch, seq_len = seq.shape[:2]
|
|
599
610
|
|
|
600
611
|
seq = self.retrieve_norm(seq)
|
|
601
612
|
|
|
602
|
-
if seq_len <
|
|
613
|
+
if seq_len < chunk_size:
|
|
603
614
|
return self.init_empty_memory_embed(batch, seq_len)
|
|
604
615
|
|
|
605
616
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -671,7 +682,7 @@ class NeuralMemory(Module):
|
|
|
671
682
|
):
|
|
672
683
|
batch, seq_len = seq.shape[:2]
|
|
673
684
|
|
|
674
|
-
if seq_len < self.
|
|
685
|
+
if seq_len < self.retrieve_chunk_size:
|
|
675
686
|
out = self.init_empty_memory_embed(batch, seq_len)
|
|
676
687
|
|
|
677
688
|
if not return_aux_kv_loss:
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
+
titans_pytorch/titans.py,sha256=qRUw-Lad_dkMqV7ASMNoGLgxYwGD-maAadetAd_qmc8,21031
|
|
5
|
+
titans_pytorch-0.1.8.dist-info/METADATA,sha256=0-m6h7GERineU8N9_2cW6nCuXs96twFEwVYkHVuuuLM,4747
|
|
6
|
+
titans_pytorch-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.8.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
-
titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
|
|
5
|
-
titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
|
|
6
|
-
titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|