titans-pytorch 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
15
15
 
16
16
  from tensordict import TensorDict
17
17
 
18
- from titans_pytorch.associative_scan import AssocScan
18
+ from assoc_scan import AssocScan
19
19
 
20
20
  from titans_pytorch.memory_models import(
21
21
  MemoryMLP,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.6
3
+ Version: 0.4.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: accelerated-scan>=0.2.0
37
+ Requires-Dist: assoc-scan
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
2
+ titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
3
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
4
+ titans_pytorch/neural_memory.py,sha256=EhHptv-9q3PUTJwX9kKAdYMfWueM-JB_kZ3SmRoAdjM,33356
5
+ titans_pytorch-0.4.7.dist-info/METADATA,sha256=MP0qHzoAM0AZuWg0gL2VOnmpx9HXdHwo5xx2CL0ugso,6797
6
+ titans_pytorch-0.4.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.4.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.4.7.dist-info/RECORD,,
@@ -1,178 +0,0 @@
1
- from __future__ import annotations
2
- import math
3
- from typing import Callable
4
-
5
- import torch
6
- from torch import Tensor
7
- from torch.nn import Module
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, repeat, reduce, pack, unpack
11
-
12
- # taken from S5-pytorch repository
13
- # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
14
-
15
- # helper functions
16
-
17
- def exists(v):
18
- return v is not None
19
-
20
- def default(*args):
21
- for arg in args:
22
- if exists(arg):
23
- return arg
24
- return None
25
-
26
- def pad_at_dim(t, pad, dim = -1, value = 0.):
27
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
28
- zeros = ((0, 0) * dims_from_right)
29
- return F.pad(t, (*zeros, *pad), value = value)
30
-
31
- def pack_one_with_inverse(t, pattern):
32
- packed, packed_shape = pack([t], pattern)
33
-
34
- def inverse(out, inv_pattern = None):
35
- inv_pattern = default(inv_pattern, pattern)
36
- return unpack(out, packed_shape, inv_pattern)[0]
37
-
38
- return packed, inverse
39
-
40
- # the operator that is needed
41
-
42
- @torch.jit.script
43
- def binary_operator(
44
- a: tuple[Tensor, Tensor],
45
- b: tuple[Tensor, Tensor]
46
- ):
47
- a_i, kv_i = a
48
- a_j, kv_j = b
49
- return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
50
-
51
- # Pytorch impl. of jax.lax.associative_scan
52
- # made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
53
-
54
- def associative_scan(
55
- operator: Callable,
56
- elems: tuple[Tensor, Tensor]
57
- ):
58
- num_elems = int(elems[0].shape[1])
59
-
60
- if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
61
- raise ValueError('Array inputs to associative_scan must have the same '
62
- 'first dimension. (saw: {})'
63
- .format([elem.shape for elem in elems]))
64
-
65
- def _scan(elems):
66
- """Perform scan on `elems`."""
67
- num_elems = elems[0].shape[1]
68
-
69
- if num_elems < 2:
70
- return elems
71
-
72
- # Combine adjacent pairs of elements.
73
-
74
- reduced_elems = operator(
75
- [elem[:, :-1:2] for elem in elems],
76
- [elem[:, 1::2] for elem in elems])
77
-
78
- # Recursively compute scan for partially reduced tensors.
79
-
80
- odd_elems = _scan(reduced_elems)
81
-
82
- if num_elems % 2 == 0:
83
- even_elems = operator(
84
- [e[:, :-1] for e in odd_elems],
85
- [e[:, 2::2] for e in elems])
86
- else:
87
- even_elems = operator(
88
- odd_elems,
89
- [e[:, 2::2] for e in elems])
90
-
91
- # The first element of a scan is the same as the first element
92
- # of the original `elems`.
93
-
94
- even_elems = [
95
- torch.cat([elem[:, :1], result], dim=1)
96
- for (elem, result) in zip(elems, even_elems)]
97
-
98
- return list(map(_interleave, even_elems, odd_elems))
99
-
100
- return _scan(elems)
101
-
102
- def _interleave(a, b):
103
- a_axis_len, b_axis_len = a.shape[1], b.shape[1]
104
- output_axis_len = a_axis_len + b_axis_len
105
-
106
- if (a_axis_len == (b_axis_len + 1)):
107
- b = pad_at_dim(b, (0, 1), dim = 1)
108
-
109
- stacked = torch.stack([a, b], dim=2)
110
- interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
111
-
112
- return interleaved[:, :output_axis_len]
113
-
114
- # associative scan wrapper around naive and accelerated version
115
-
116
- class AssocScan(Module):
117
- def __init__(
118
- self,
119
- use_accelerated = False
120
- ):
121
- super().__init__()
122
- self.use_accelerated = use_accelerated
123
-
124
- def forward(
125
- self,
126
- gates,
127
- inputs,
128
- prev = None,
129
- remove_prev = None
130
- ):
131
- remove_prev = default(remove_prev, exists(prev))
132
-
133
- inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
134
- gates, _ = pack_one_with_inverse(gates, 'b n *')
135
-
136
- if exists(prev):
137
- prev, _ = pack_one_with_inverse(prev, 'b *')
138
-
139
- if exists(prev):
140
- inputs, _ = pack([prev, inputs], 'b * d')
141
- gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
142
-
143
- if not self.use_accelerated:
144
- _, out = associative_scan(binary_operator, (gates, inputs))
145
-
146
- if remove_prev:
147
- out = out[:, 1:]
148
-
149
- return inverse_pack_weight_shape(out)
150
-
151
- from accelerated_scan.triton import scan as triton_scan
152
- from accelerated_scan.warp import scan as warp_scan
153
-
154
- scan = triton_scan if gates.is_cuda else warp_scan
155
-
156
- def accelerate_scan_fn(gates, inputs):
157
- gates = gates.expand_as(inputs)
158
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
159
-
160
- seq_len = gates.shape[-1]
161
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
162
-
163
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
164
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
165
-
166
- outputs = scan(gates.contiguous(), inputs.contiguous())
167
-
168
- outputs = outputs[..., :seq_len]
169
- outputs = rearrange(outputs, 'b d n -> b n d')
170
-
171
- return outputs
172
-
173
- out = accelerate_scan_fn(gates, inputs)
174
-
175
- if remove_prev:
176
- out = out[:, 1:]
177
-
178
- return inverse_pack_weight_shape(out)
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
6
- titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
7
- titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.6.dist-info/RECORD,,