titans-pytorch 0.0.31__tar.gz → 0.0.34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/PKG-INFO +1 -1
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/pyproject.toml +1 -1
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/mac_transformer.py +70 -9
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/train.py +0 -3
- titans_pytorch-0.0.34/train_mac.py +129 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.gitignore +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/LICENSE +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/README.md +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/data/README.md +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/fig1.png +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/fig2.png +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/requirements.txt +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/titans_attn_memory.py +0 -0
| @@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions | |
| 17 17 | 
             
            from axial_positional_embedding import ContinuousAxialPositionalEmbedding
         | 
| 18 18 | 
             
            from rotary_embedding_torch import RotaryEmbedding
         | 
| 19 19 |  | 
| 20 | 
            +
            # proposed neural memory
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from titans_pytorch.titans import NeuralMemory
         | 
| 23 | 
            +
             | 
| 20 24 | 
             
            # constants
         | 
| 21 25 |  | 
| 22 26 | 
             
            LinearNoBias = partial(Linear, bias = False)
         | 
| @@ -46,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len): | |
| 46 50 | 
             
                next_seq_len_mult = round_up_multiple(seq_len, segment_len)
         | 
| 47 51 |  | 
| 48 52 | 
             
                padding = next_seq_len_mult - seq_len
         | 
| 49 | 
            -
                 | 
| 53 | 
            +
                needs_pad = padding > 0
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if needs_pad:
         | 
| 56 | 
            +
                    seq = F.pad(seq, (0, 0, 0, padding))
         | 
| 50 57 |  | 
| 51 58 | 
             
                seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
         | 
| 52 59 |  | 
| 53 60 | 
             
                def inverse(out):
         | 
| 54 61 | 
             
                    out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
         | 
| 55 | 
            -
             | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if needs_pad:
         | 
| 64 | 
            +
                        out = out[:, :-padding]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    return out
         | 
| 56 67 |  | 
| 57 68 | 
             
                return seq, inverse
         | 
| 58 69 |  | 
| @@ -161,7 +172,9 @@ class MemoryAsContextTransformer(Module): | |
| 161 172 | 
             
                    dim_head = 64,
         | 
| 162 173 | 
             
                    heads = 8,
         | 
| 163 174 | 
             
                    ff_mult = 4,
         | 
| 164 | 
            -
                    num_residual_streams = 4
         | 
| 175 | 
            +
                    num_residual_streams = 4,
         | 
| 176 | 
            +
                    neural_memory_kwargs: dict = dict(),
         | 
| 177 | 
            +
                    neural_memory_layers: tuple[int, ...] | None = None,
         | 
| 165 178 | 
             
                ):
         | 
| 166 179 | 
             
                    super().__init__()
         | 
| 167 180 |  | 
| @@ -181,8 +194,25 @@ class MemoryAsContextTransformer(Module): | |
| 181 194 | 
             
                    init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
         | 
| 182 195 |  | 
| 183 196 | 
             
                    self.layers = ModuleList([])
         | 
| 197 | 
            +
                    self.neural_mem_layers = ModuleList([])
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    layers = tuple(range(1, depth + 1))
         | 
| 200 | 
            +
                    neural_memory_layers = set(default(neural_memory_layers, layers))
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    for layer in layers:
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                        # neural memory
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        mem = None
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                        if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
         | 
| 209 | 
            +
                            mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
         | 
| 210 | 
            +
                            mem = init_hyper_conn(dim = dim, branch = mem)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                        self.neural_mem_layers.append(mem)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        # attention and feedforward
         | 
| 184 215 |  | 
| 185 | 
            -
                    for _ in range(depth):
         | 
| 186 216 | 
             
                        attn = SegmentedAttention(
         | 
| 187 217 | 
             
                            dim = dim,
         | 
| 188 218 | 
             
                            dim_head = dim_head,
         | 
| @@ -203,7 +233,14 @@ class MemoryAsContextTransformer(Module): | |
| 203 233 |  | 
| 204 234 | 
             
                    self.to_logits = LinearNoBias(dim, num_tokens)
         | 
| 205 235 |  | 
| 206 | 
            -
                def forward( | 
| 236 | 
            +
                def forward(
         | 
| 237 | 
            +
                    self,
         | 
| 238 | 
            +
                    x,
         | 
| 239 | 
            +
                    return_loss = False
         | 
| 240 | 
            +
                ):
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    if return_loss:
         | 
| 243 | 
            +
                        x, labels = x[:, :-1], x[:, 1:]
         | 
| 207 244 |  | 
| 208 245 | 
             
                    # math
         | 
| 209 246 |  | 
| @@ -221,7 +258,7 @@ class MemoryAsContextTransformer(Module): | |
| 221 258 | 
             
                    x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
         | 
| 222 259 |  | 
| 223 260 | 
             
                    mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
         | 
| 224 | 
            -
                    x =  | 
| 261 | 
            +
                    x = cat((mems, x), dim = -2)
         | 
| 225 262 |  | 
| 226 263 | 
             
                    x = inverse_segment(x)
         | 
| 227 264 |  | 
| @@ -235,8 +272,27 @@ class MemoryAsContextTransformer(Module): | |
| 235 272 |  | 
| 236 273 | 
             
                    x = self.expand_streams(x)
         | 
| 237 274 |  | 
| 238 | 
            -
                    for attn, ff in self.layers:
         | 
| 275 | 
            +
                    for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                        if exists(maybe_neural_mem):
         | 
| 278 | 
            +
                            batch_streams = x.shape[0]
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                            x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                            longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                            longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                            longterm_mems = maybe_neural_mem(longterm_mems)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                            longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                            x = cat((longterm_mems, x), dim = -2)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                            x = inverse_segment(x)
         | 
| 293 | 
            +
             | 
| 239 294 | 
             
                        x = attn(x)
         | 
| 295 | 
            +
             | 
| 240 296 | 
             
                        x = ff(x)
         | 
| 241 297 |  | 
| 242 298 | 
             
                    x = self.reduce_streams(x)
         | 
| @@ -245,7 +301,7 @@ class MemoryAsContextTransformer(Module): | |
| 245 301 |  | 
| 246 302 | 
             
                    x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
         | 
| 247 303 |  | 
| 248 | 
            -
                    x = x[:,  | 
| 304 | 
            +
                    x = x[:, num_longterm_mem_tokens:]
         | 
| 249 305 |  | 
| 250 306 | 
             
                    x = inverse_segment(x)
         | 
| 251 307 |  | 
| @@ -253,4 +309,9 @@ class MemoryAsContextTransformer(Module): | |
| 253 309 |  | 
| 254 310 | 
             
                    x = self.norm(x)
         | 
| 255 311 |  | 
| 256 | 
            -
                     | 
| 312 | 
            +
                    logits = self.to_logits(x)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    if not return_loss:
         | 
| 315 | 
            +
                        return logits
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
         | 
| @@ -63,11 +63,8 @@ def decode_tokens(tokens): | |
| 63 63 | 
             
            titans_neural_memory = NeuralMemory(
         | 
| 64 64 | 
             
                dim = 384,
         | 
| 65 65 | 
             
                chunk_size = 4,
         | 
| 66 | 
            -
                pre_rmsnorm = True,
         | 
| 67 | 
            -
                post_rmsnorm = True,
         | 
| 68 66 | 
             
                dim_head = 64,
         | 
| 69 67 | 
             
                heads = 4,
         | 
| 70 | 
            -
                max_grad_norm = 1.,
         | 
| 71 68 | 
             
                use_accelerated_scan = True,
         | 
| 72 69 | 
             
                default_mlp_kwargs = dict(
         | 
| 73 70 | 
             
                    depth = NEURAL_MEMORY_DEPTH
         | 
| @@ -0,0 +1,129 @@ | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import tqdm
         | 
| 3 | 
            +
            import gzip
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
            from torch.optim import Adam
         | 
| 9 | 
            +
            from torch.nn import functional as F
         | 
| 10 | 
            +
            from torch.utils.data import DataLoader, Dataset
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from titans_pytorch.mac_transformer import MemoryAsContextTransformer
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # constants
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            NUM_BATCHES = int(1e5)
         | 
| 17 | 
            +
            BATCH_SIZE = 4
         | 
| 18 | 
            +
            GRADIENT_ACCUMULATE_EVERY = 4
         | 
| 19 | 
            +
            LEARNING_RATE = 2e-4
         | 
| 20 | 
            +
            VALIDATE_EVERY  = 100
         | 
| 21 | 
            +
            GENERATE_EVERY  = 500
         | 
| 22 | 
            +
            GENERATE_LENGTH = 512
         | 
| 23 | 
            +
            SHOULD_GENERATE = False
         | 
| 24 | 
            +
            SEQ_LEN = 512
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            PROJECT_NAME = 'titans-mac-transformer'
         | 
| 27 | 
            +
            WANDB_ONLINE = False # turn this on to pipe experiment to cloud
         | 
| 28 | 
            +
            GLOBAL_LAYERS = (2, 4)
         | 
| 29 | 
            +
            NEURAL_MEMORY_DEPTH = 2
         | 
| 30 | 
            +
            WINDOW_SIZE = 64
         | 
| 31 | 
            +
            RUN_NAME = 'mac'
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # wandb experiment tracker
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            import wandb
         | 
| 36 | 
            +
            wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
         | 
| 37 | 
            +
            wandb.run.name = RUN_NAME
         | 
| 38 | 
            +
            wandb.run.save()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # helpers
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def cycle(loader):
         | 
| 43 | 
            +
                while True:
         | 
| 44 | 
            +
                    for data in loader:
         | 
| 45 | 
            +
                        yield data
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def decode_token(token):
         | 
| 48 | 
            +
                return str(chr(max(32, token)))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def decode_tokens(tokens):
         | 
| 51 | 
            +
                return ''.join(list(map(decode_token, tokens)))
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # instantiate memory-as-context transformer
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            model = MemoryAsContextTransformer(
         | 
| 56 | 
            +
                num_tokens = 256,
         | 
| 57 | 
            +
                dim = 384,
         | 
| 58 | 
            +
                depth = 8,
         | 
| 59 | 
            +
                segment_len = WINDOW_SIZE,
         | 
| 60 | 
            +
                num_persist_mem_tokens = 16,
         | 
| 61 | 
            +
                num_longterm_mem_tokens = 16,
         | 
| 62 | 
            +
                neural_memory_layers = (3, 4),
         | 
| 63 | 
            +
                neural_memory_kwargs = dict(
         | 
| 64 | 
            +
                    default_mlp_kwargs = dict(
         | 
| 65 | 
            +
                        depth = NEURAL_MEMORY_DEPTH
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
            ).cuda()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # prepare enwik8 data
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            with gzip.open('./data/enwik8.gz') as file:
         | 
| 73 | 
            +
                data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
         | 
| 74 | 
            +
                data_train, data_val = np.split(data, [int(90e6)])
         | 
| 75 | 
            +
                data_train, data_val = map(torch.from_numpy, (data_train, data_val))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            class TextSamplerDataset(Dataset):
         | 
| 78 | 
            +
                def __init__(self, data, seq_len):
         | 
| 79 | 
            +
                    super().__init__()
         | 
| 80 | 
            +
                    self.data = data
         | 
| 81 | 
            +
                    self.seq_len = seq_len
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def __getitem__(self, index):
         | 
| 84 | 
            +
                    rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
         | 
| 85 | 
            +
                    full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
         | 
| 86 | 
            +
                    return full_seq.cuda()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __len__(self):
         | 
| 89 | 
            +
                    return self.data.size(0) // self.seq_len
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
         | 
| 92 | 
            +
            val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
         | 
| 93 | 
            +
            train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
         | 
| 94 | 
            +
            val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # optimizer
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            optim = Adam(model.parameters(), lr=LEARNING_RATE)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            # training
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
         | 
| 103 | 
            +
                model.train()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                for __ in range(GRADIENT_ACCUMULATE_EVERY):
         | 
| 106 | 
            +
                    loss = model(next(train_loader), return_loss = True)
         | 
| 107 | 
            +
                    loss.backward()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                print(f'training loss: {loss.item()}')
         | 
| 110 | 
            +
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
         | 
| 111 | 
            +
                optim.step()
         | 
| 112 | 
            +
                optim.zero_grad()
         | 
| 113 | 
            +
                wandb.log(dict(loss = loss.item()))
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if i % VALIDATE_EVERY == 0:
         | 
| 116 | 
            +
                    model.eval()
         | 
| 117 | 
            +
                    with torch.no_grad():
         | 
| 118 | 
            +
                        loss = model(next(val_loader), return_loss = True)
         | 
| 119 | 
            +
                        print(f'validation loss: {loss.item()}')
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
         | 
| 122 | 
            +
                    model.eval()
         | 
| 123 | 
            +
                    inp = random.choice(val_dataset)[:-1]
         | 
| 124 | 
            +
                    prime = decode_tokens(inp)
         | 
| 125 | 
            +
                    print(f'%s \n\n %s', (prime, '*' * 100))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
         | 
| 128 | 
            +
                    output_str = decode_tokens(sample[0])
         | 
| 129 | 
            +
                    print(output_str)
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         | 
| 
            File without changes
         |