ASAC-pytorch 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, tensor
5
+ from torch.nn import Module, Linear
6
+ import torch.nn.functional as F
7
+
8
+ from einops import einsum
9
+ from einops.layers.torch import Rearrange
10
+
11
+ from x_transformers import Decoder
12
+
13
+ from x_mlps_pytorch import MLP
14
+
15
+ from vector_quantize_pytorch import VectorQuantize
16
+
17
+ from ema_pytorch import EMA
18
+
19
+ from torch_einops_utils import pack_with_inverse
20
+
21
+ # helpers
22
+
23
+ def exists(v):
24
+ return v is not None
25
+
26
+ def default(v, d):
27
+ return v if exists(v) else d
28
+
29
+ # attention
30
+
31
+ class Attention(Module):
32
+ def __init__(
33
+ self,
34
+ dim,
35
+ dim_head = 64,
36
+ heads = 8,
37
+ attn_schema: Module | None = None,
38
+ attn_add_residual = True # they had to add a residual for stability
39
+ ):
40
+ super().__init__()
41
+ self.scale = dim_head ** -0.5
42
+ dim_inner = dim_head * heads
43
+
44
+ self.to_qkv = Linear(dim, dim_inner * 3, bias = False)
45
+ self.combine_heads = Linear(dim_inner, dim, bias = False)
46
+
47
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
48
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
49
+
50
+ self.attn_schema = attn_schema
51
+ self.attn_add_residual = attn_add_residual and attn_schema
52
+
53
+ self.register_buffer('zero', tensor(0.), persistent = False)
54
+
55
+ def forward(
56
+ self,
57
+ tokens, # (b h w d)
58
+ ):
59
+ tokens, inverse_pack = pack_with_inverse(tokens, 'b * d')
60
+
61
+ q, k, v = self.to_qkv(tokens).chunk(3, dim = -1)
62
+ q, k, v = (self.split_heads(t) for t in (q, k, v))
63
+
64
+ q = q * self.scale
65
+
66
+ sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
67
+
68
+ orig_sim = sim
69
+
70
+ # the proposal
71
+
72
+ aux_loss = self.zero
73
+
74
+ if exists(self.attn_schema):
75
+ sim, indices, aux_loss = self.attn_schema(orig_sim)
76
+
77
+ if self.attn_add_residual:
78
+ sim = sim + orig_sim
79
+
80
+ # attend
81
+
82
+ attn = sim.softmax(dim = -1)
83
+
84
+ # aggregate and combine out
85
+
86
+ out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
87
+
88
+ out = self.merge_heads(out)
89
+ attended = self.combine_heads(out)
90
+
91
+ # bring back the packed dimensions
92
+
93
+ attended = inverse_pack(attended)
94
+
95
+ return attended, indices, aux_loss
96
+
97
+ # attention autoencoder
98
+
99
+ class AttentionSchema(Module):
100
+ def __init__(
101
+ self,
102
+ dim,
103
+ dim_bottleneck,
104
+ kl_div_loss = True,
105
+ detach_target = True,
106
+ **vq_kwargs
107
+ ):
108
+ super().__init__()
109
+ self.encoder = MLP(dim, dim_bottleneck, activation = nn.LeakyReLU())
110
+
111
+ self.vq = VectorQuantize(dim_bottleneck, **vq_kwargs)
112
+
113
+ self.decoder = MLP(dim_bottleneck, dim, activation = nn.LeakyReLU())
114
+
115
+ self.kl_div_loss = kl_div_loss
116
+ self.detach_target = detach_target
117
+
118
+ def forward(
119
+ self,
120
+ attn_sim,
121
+ return_loss = None
122
+ ):
123
+ return_loss = default(return_loss, self.training)
124
+
125
+ attn_features, inverse_pack = pack_with_inverse(attn_sim, 'b *')
126
+
127
+ encoded = self.encoder(attn_features)
128
+
129
+ quantized, indices, commit_loss = self.vq(encoded)
130
+
131
+ decoded = self.decoder(quantized)
132
+
133
+ recon = inverse_pack(decoded)
134
+
135
+ # loss, mse as in paper or reverse kl
136
+
137
+ if return_loss:
138
+ if self.detach_target:
139
+ attn_sim = attn_sim.detach()
140
+
141
+ if self.kl_div_loss:
142
+ recon_loss = F.kl_div(
143
+ attn_sim.log_softmax(dim = -1),
144
+ recon,
145
+ reduction = 'batchmean'
146
+ )
147
+ else:
148
+ recon_loss = F.mse_loss(recon, attn_sim)
149
+
150
+ # total
151
+
152
+ total_loss = recon_loss + commit_loss
153
+
154
+ loss_breakdown = (recon_loss, commit_loss)
155
+
156
+ return recon, indices, total_loss
157
+
158
+ # class
159
+
160
+ class ASAC(Module):
161
+ def __init__(self):
162
+ super().__init__()
163
+
164
+ def forward(self, x):
165
+ return x
@@ -0,0 +1 @@
1
+ from ASAC.ASAC import ASAC
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ASAC-pytorch
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Implementation of Attention Schema-based Attention Control (ASAC)
5
5
  Project-URL: Homepage, https://pypi.org/project/ASAC/
6
6
  Project-URL: Repository, https://codeberg.org/lucidrains/ASAC
@@ -36,8 +36,12 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.10
37
37
  Requires-Dist: einops>=0.8.1
38
38
  Requires-Dist: einx>=0.3.0
39
+ Requires-Dist: ema-pytorch
39
40
  Requires-Dist: torch-einops-utils>=0.1.2
40
41
  Requires-Dist: torch>=2.5
42
+ Requires-Dist: vector-quantize-pytorch
43
+ Requires-Dist: x-mlps-pytorch
44
+ Requires-Dist: x-transformers
41
45
  Provides-Extra: examples
42
46
  Provides-Extra: test
43
47
  Requires-Dist: pytest; extra == 'test'
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ASAC-pytorch"
3
- version = "0.0.1"
3
+ version = "0.0.2"
4
4
  description = "Implementation of Attention Schema-based Attention Control (ASAC)"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,8 +26,12 @@ classifiers=[
26
26
  dependencies = [
27
27
  "einx>=0.3.0",
28
28
  "einops>=0.8.1",
29
+ "ema-pytorch",
29
30
  "torch>=2.5",
30
31
  "torch-einops-utils>=0.1.2",
32
+ "vector-quantize-pytorch",
33
+ "x-transformers",
34
+ "x-mlps-pytorch"
31
35
  ]
32
36
 
33
37
  [project.urls]
File without changes
File without changes
File without changes
File without changes
File without changes