ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl → 0.2.0.dev20240527__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import _decomp
20
+ from torch import nn
21
+ from torch._prims_common import mask_tensor
22
+ from torch._prims_common.wrappers import out_wrapper
23
+ from torch.nn import functional as F
24
+
25
+
26
+ def triu(a):
27
+ h, w = a.shape[-2:]
28
+ mask = (
29
+ torch.arange(w, device=a.device).unsqueeze(-2)
30
+ - torch.arange(h, device=a.device).unsqueeze(-1)
31
+ ) >= 1
32
+ mask = torch.broadcast_to(mask, a.shape)
33
+ return torch.ops.aten.logical_and(a, mask).contiguous()
34
+
35
+
36
+ # _decomp.decomposition_table[torch.ops.aten.triu.default] = triu
37
+
38
+
39
+ class SelfAttention(nn.Module):
40
+
41
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
42
+ super().__init__()
43
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
44
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
45
+ self.n_heads = n_heads
46
+ self.d_head = d_embed // n_heads
47
+
48
+ def forward(self, x, causal_mask=False):
49
+ input_shape = x.shape
50
+ batch_size, sequence_length, d_embed = input_shape
51
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
52
+
53
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
54
+
55
+ q = q.view(interim_shape).transpose(1, 2)
56
+ k = k.view(interim_shape).transpose(1, 2)
57
+ v = v.view(interim_shape).transpose(1, 2)
58
+
59
+ weight = q @ k.transpose(-1, -2)
60
+ if causal_mask:
61
+ # mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
62
+ mask = triu(torch.ones_like(weight, dtype=torch.bool))
63
+ weight.masked_fill_(mask, -torch.inf)
64
+ weight /= math.sqrt(self.d_head)
65
+ weight = F.softmax(weight, dim=-1)
66
+
67
+ output = weight @ v
68
+ output = output.transpose(1, 2)
69
+ output = output.reshape(input_shape)
70
+ output = self.out_proj(output)
71
+ return output
72
+
73
+
74
+ class CrossAttention(nn.Module):
75
+
76
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
77
+ super().__init__()
78
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
79
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
80
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
81
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
82
+ self.n_heads = n_heads
83
+ self.d_head = d_embed // n_heads
84
+
85
+ def forward(self, x, y):
86
+ input_shape = x.shape
87
+ batch_size, sequence_length, d_embed = input_shape
88
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
89
+
90
+ q = self.q_proj(x)
91
+ k = self.k_proj(y)
92
+ v = self.v_proj(y)
93
+
94
+ q = q.view(interim_shape).transpose(1, 2)
95
+ k = k.view(interim_shape).transpose(1, 2)
96
+ v = v.view(interim_shape).transpose(1, 2)
97
+
98
+ weight = q @ k.transpose(-1, -2)
99
+ weight /= math.sqrt(self.d_head)
100
+ weight = F.softmax(weight, dim=-1)
101
+
102
+ output = weight @ v
103
+ output = output.transpose(1, 2).contiguous()
104
+ output = output.view(input_shape)
105
+ output = self.out_proj(output)
106
+ return output
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch._prims_common import mask_tensor
19
+ from torch._prims_common.wrappers import out_wrapper
20
+
21
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
22
+
23
+
24
+ class CLIPEmbedding(nn.Module):
25
+
26
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
27
+ super().__init__()
28
+ self.token_embedding = nn.Embedding(n_vocab, n_embd)
29
+ self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
30
+
31
+ def forward(self, tokens):
32
+ x = self.token_embedding(tokens)
33
+ x += self.position_value
34
+ return x
35
+
36
+
37
+ class CLIPLayer(nn.Module):
38
+
39
+ def __init__(self, n_head: int, n_embd: int):
40
+ super().__init__()
41
+ self.layernorm_1 = nn.LayerNorm(n_embd)
42
+ self.attention = SelfAttention(n_head, n_embd)
43
+ self.layernorm_2 = nn.LayerNorm(n_embd)
44
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
45
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
46
+
47
+ def forward(self, x):
48
+ residue = x
49
+ x = self.layernorm_1(x)
50
+ x = self.attention(x, causal_mask=True)
51
+ x += residue
52
+
53
+ residue = x
54
+ x = self.layernorm_2(x)
55
+ x = self.linear_1(x)
56
+ x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
57
+ x = self.linear_2(x)
58
+ x += residue
59
+
60
+ return x
61
+
62
+
63
+ class CLIP(nn.Module):
64
+
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.embedding = CLIPEmbedding(49408, 768, 77)
68
+ self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
69
+ self.layernorm = nn.LayerNorm(768)
70
+
71
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
72
+ tokens = tokens.type(torch.long)
73
+
74
+ state = self.embedding(tokens)
75
+ for layer in self.layers:
76
+ state = layer(state)
77
+ output = self.layernorm(state)
78
+ return output
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+
19
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
20
+
21
+
22
+ class AttentionBlock(nn.Module):
23
+
24
+ def __init__(self, channels):
25
+ super().__init__()
26
+ self.groupnorm = nn.GroupNorm(32, channels)
27
+ self.attention = SelfAttention(1, channels)
28
+
29
+ def forward(self, x):
30
+ residue = x
31
+ x = self.groupnorm(x)
32
+
33
+ n, c, h, w = x.shape
34
+ x = x.view((n, c, h * w))
35
+ x = x.transpose(-1, -2)
36
+ x = self.attention(x)
37
+ x = x.transpose(-1, -2)
38
+ x = x.view((n, c, h, w))
39
+
40
+ x += residue
41
+ return x
42
+
43
+
44
+ class ResidualBlock(nn.Module):
45
+
46
+ def __init__(self, in_channels, out_channels):
47
+ super().__init__()
48
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
49
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
50
+
51
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
52
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
53
+
54
+ if in_channels == out_channels:
55
+ self.residual_layer = nn.Identity()
56
+ else:
57
+ self.residual_layer = nn.Conv2d(
58
+ in_channels, out_channels, kernel_size=1, padding=0
59
+ )
60
+
61
+ def forward(self, x):
62
+ residue = x
63
+
64
+ x = self.groupnorm_1(x)
65
+ x = F.silu(x)
66
+ x = self.conv_1(x)
67
+
68
+ x = self.groupnorm_2(x)
69
+ x = F.silu(x)
70
+ x = self.conv_2(x)
71
+
72
+ return x + self.residual_layer(residue)
73
+
74
+
75
+ class Decoder(nn.Sequential):
76
+
77
+ def __init__(self):
78
+ super().__init__(
79
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
80
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
81
+ ResidualBlock(512, 512),
82
+ AttentionBlock(512),
83
+ ResidualBlock(512, 512),
84
+ ResidualBlock(512, 512),
85
+ ResidualBlock(512, 512),
86
+ ResidualBlock(512, 512),
87
+ nn.Upsample(scale_factor=2),
88
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
89
+ ResidualBlock(512, 512),
90
+ ResidualBlock(512, 512),
91
+ ResidualBlock(512, 512),
92
+ nn.Upsample(scale_factor=2),
93
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
94
+ ResidualBlock(512, 256),
95
+ ResidualBlock(256, 256),
96
+ ResidualBlock(256, 256),
97
+ nn.Upsample(scale_factor=2),
98
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
99
+ ResidualBlock(256, 128),
100
+ ResidualBlock(128, 128),
101
+ ResidualBlock(128, 128),
102
+ nn.GroupNorm(32, 128),
103
+ nn.SiLU(),
104
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
105
+ )
106
+
107
+ def forward(self, x):
108
+ x = x / 0.18215
109
+ for module in self:
110
+ x = module(x)
111
+ return x
@@ -0,0 +1,498 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import CrossAttention # NOQA
21
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
22
+
23
+
24
+ class TimeEmbedding(nn.Module):
25
+
26
+ def __init__(self, n_embd):
27
+ super().__init__()
28
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
29
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
30
+
31
+ def forward(self, x):
32
+ x = self.linear_1(x)
33
+ x = F.silu(x)
34
+ x = self.linear_2(x)
35
+ return x
36
+
37
+
38
+ class ResidualBlock(nn.Module):
39
+
40
+ def __init__(self, in_channels, out_channels, n_time=1280):
41
+ super().__init__()
42
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
43
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
44
+ self.linear_time = nn.Linear(n_time, out_channels)
45
+
46
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
47
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
48
+
49
+ if in_channels == out_channels:
50
+ self.residual_layer = nn.Identity()
51
+ else:
52
+ self.residual_layer = nn.Conv2d(
53
+ in_channels, out_channels, kernel_size=1, padding=0
54
+ )
55
+
56
+ def forward(self, feature, time):
57
+ residue = feature
58
+
59
+ feature = self.groupnorm_feature(feature)
60
+ feature = F.silu(feature)
61
+ feature = self.conv_feature(feature)
62
+
63
+ time = F.silu(time)
64
+ time = self.linear_time(time)
65
+
66
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
67
+ merged = self.groupnorm_merged(merged)
68
+ merged = F.silu(merged)
69
+ merged = self.conv_merged(merged)
70
+
71
+ return merged + self.residual_layer(residue)
72
+
73
+
74
+ class AttentionBlock(nn.Module):
75
+
76
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
77
+ super().__init__()
78
+ channels = n_head * n_embd
79
+
80
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
81
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
82
+
83
+ self.layernorm_1 = nn.LayerNorm(channels)
84
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
85
+ self.layernorm_2 = nn.LayerNorm(channels)
86
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
87
+ self.layernorm_3 = nn.LayerNorm(channels)
88
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
89
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
90
+
91
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
92
+
93
+ def forward(self, x, context):
94
+ residue_long = x
95
+
96
+ x = self.groupnorm(x)
97
+ x = self.conv_input(x)
98
+
99
+ n, c, h, w = x.shape
100
+ x = x.view((n, c, h * w)) # (n, c, hw)
101
+ x = x.transpose(-1, -2) # (n, hw, c)
102
+
103
+ residue_short = x
104
+ x = self.layernorm_1(x)
105
+ x = self.attention_1(x)
106
+ x += residue_short
107
+
108
+ residue_short = x
109
+ x = self.layernorm_2(x)
110
+ x = self.attention_2(x, context)
111
+ x += residue_short
112
+
113
+ residue_short = x
114
+ x = self.layernorm_3(x)
115
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
116
+ x = x * F.gelu(gate)
117
+ x = self.linear_geglu_2(x)
118
+ x += residue_short
119
+
120
+ x = x.transpose(-1, -2) # (n, c, hw)
121
+ x = x.view((n, c, h, w)) # (n, c, h, w)
122
+
123
+ return self.conv_output(x) + residue_long
124
+
125
+
126
+ class Upsample(nn.Module):
127
+
128
+ def __init__(self, channels):
129
+ super().__init__()
130
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
131
+
132
+ def forward(self, x):
133
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
134
+ return self.conv(x)
135
+
136
+
137
+ class SwitchSequential(nn.Sequential):
138
+
139
+ def forward(self, x, context, time):
140
+ for layer in self:
141
+ if isinstance(layer, AttentionBlock):
142
+ x = layer(x, context)
143
+ elif isinstance(layer, ResidualBlock):
144
+ x = layer(x, time)
145
+ else:
146
+ x = layer(x)
147
+ return x
148
+
149
+
150
+ class UNet(nn.Module):
151
+
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.encoders = nn.ModuleList(
155
+ [
156
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
157
+ SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
158
+ SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
159
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
160
+ SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
161
+ SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
162
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
163
+ SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
164
+ SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
165
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
166
+ SwitchSequential(ResidualBlock(1280, 1280)),
167
+ SwitchSequential(ResidualBlock(1280, 1280)),
168
+ ]
169
+ )
170
+ self.bottleneck = SwitchSequential(
171
+ ResidualBlock(1280, 1280),
172
+ AttentionBlock(8, 160),
173
+ ResidualBlock(1280, 1280),
174
+ )
175
+
176
+ self.decoders = nn.ModuleList(
177
+ [
178
+ SwitchSequential(ResidualBlock(2560, 1280)),
179
+ SwitchSequential(ResidualBlock(2560, 1280)),
180
+ SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
181
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
182
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
183
+ SwitchSequential(
184
+ ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
185
+ ),
186
+ SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
187
+ SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
188
+ SwitchSequential(
189
+ ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
190
+ ),
191
+ SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
192
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
193
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
194
+ ]
195
+ )
196
+
197
+ def forward(self, x, context, time):
198
+ skip_connections = []
199
+ for layers in self.encoders:
200
+ x = layers(x, context, time)
201
+ skip_connections.append(x)
202
+
203
+ x = self.bottleneck(x, context, time)
204
+
205
+ # print('x shape:')
206
+ # print(list(x.shape))
207
+ # print('time shape:')
208
+ # print(list(time.shape))
209
+
210
+ for layers in self.decoders:
211
+ x = torch.cat((x, skip_connections.pop()), dim=1)
212
+ x = layers(x, context, time)
213
+
214
+ return x
215
+
216
+
217
+ # The encoder component.
218
+ class UNetEncoder(nn.Module):
219
+
220
+ def __init__(self):
221
+ super().__init__()
222
+ self.time_embedding = TimeEmbedding(320)
223
+ self.encoders = nn.ModuleList(
224
+ [
225
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
226
+ SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
227
+ SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
228
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
229
+ SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
230
+ SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
231
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
232
+ SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
233
+ SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
234
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
235
+ SwitchSequential(ResidualBlock(1280, 1280)),
236
+ SwitchSequential(ResidualBlock(1280, 1280)),
237
+ ]
238
+ )
239
+
240
+ def forward(self, x, context, time):
241
+ time_embedding = self.time_embedding(time)
242
+ skip_connections = []
243
+ for layers in self.encoders:
244
+ x = layers(x, context, time_embedding)
245
+ skip_connections.append(x)
246
+
247
+ return x, skip_connections, time_embedding
248
+
249
+
250
+ class UNetBottleNeck(nn.Module):
251
+
252
+ def __init__(self):
253
+ super().__init__()
254
+ self.bottleneck = SwitchSequential(
255
+ ResidualBlock(1280, 1280),
256
+ AttentionBlock(8, 160),
257
+ ResidualBlock(1280, 1280),
258
+ )
259
+
260
+ def forward(self, x, context, time):
261
+ x = self.bottleneck(x, context, time)
262
+ # print('shape')
263
+ # print(list(x.shape))
264
+ return x
265
+
266
+
267
+ # Unet decoder.
268
+ class UNetDecoder1(nn.Module):
269
+
270
+ def __init__(self):
271
+ super().__init__()
272
+ self.decoders = nn.ModuleList(
273
+ [
274
+ SwitchSequential(ResidualBlock(2560, 1280)),
275
+ SwitchSequential(ResidualBlock(2560, 1280)),
276
+ SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
277
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
278
+ ]
279
+ )
280
+
281
+ def forward(self, x, context, time, s9, s10, s11, s12):
282
+ x = torch.cat((x, s12), dim=1)
283
+ x = self.decoders[0](x, context, time)
284
+ x = torch.cat((x, s11), dim=1)
285
+ x = self.decoders[1](x, context, time)
286
+ x = torch.cat((x, s10), dim=1)
287
+ x = self.decoders[2](x, context, time)
288
+ x = torch.cat((x, s9), dim=1)
289
+ x = self.decoders[3](x, context, time)
290
+
291
+ return x
292
+
293
+
294
+ class UNetDecoder2(nn.Module):
295
+
296
+ def __init__(self):
297
+ super().__init__()
298
+ self.decoders = nn.ModuleList(
299
+ [
300
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
301
+ SwitchSequential(
302
+ ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
303
+ ),
304
+ SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
305
+ SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
306
+ ]
307
+ )
308
+
309
+ def forward(self, x, context, time, s5, s6, s7, s8):
310
+ x = torch.cat((x, s8), dim=1)
311
+ x = self.decoders[0](x, context, time)
312
+ x = torch.cat((x, s7), dim=1)
313
+ x = self.decoders[1](x, context, time)
314
+ x = torch.cat((x, s6), dim=1)
315
+ x = self.decoders[2](x, context, time)
316
+ x = torch.cat((x, s5), dim=1)
317
+ x = self.decoders[3](x, context, time)
318
+ return x
319
+
320
+
321
+ class UNetDecoder3(nn.Module):
322
+
323
+ def __init__(self):
324
+ super().__init__()
325
+ self.decoders = nn.ModuleList(
326
+ [
327
+ SwitchSequential(
328
+ ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
329
+ ),
330
+ SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
331
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
332
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
333
+ ]
334
+ )
335
+ self.final = FinalLayer(320, 4)
336
+
337
+ def forward(self, x, context, time, s1, s2, s3, s4):
338
+ x = torch.cat((x, s4), dim=1)
339
+ x = self.decoders[0](x, context, time)
340
+ x = torch.cat((x, s3), dim=1)
341
+ x = self.decoders[1](x, context, time)
342
+ x = torch.cat((x, s2), dim=1)
343
+ x = self.decoders[2](x, context, time)
344
+ x = torch.cat((x, s1), dim=1)
345
+ x = self.decoders[3](x, context, time)
346
+
347
+ x = self.final(x)
348
+ return x
349
+
350
+
351
+ class UNetDecoder(nn.Module):
352
+
353
+ def __init__(self):
354
+ super().__init__()
355
+ self.decoders = nn.ModuleList(
356
+ [
357
+ SwitchSequential(ResidualBlock(2560, 1280)),
358
+ SwitchSequential(ResidualBlock(2560, 1280)),
359
+ SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
360
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
361
+ SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
362
+ SwitchSequential(
363
+ ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
364
+ ),
365
+ SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
366
+ SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
367
+ SwitchSequential(
368
+ ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
369
+ ),
370
+ SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
371
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
372
+ SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
373
+ ]
374
+ )
375
+ self.final = FinalLayer(320, 4)
376
+
377
+ def forward(
378
+ self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
379
+ ):
380
+ x = torch.cat((x, s12), dim=1)
381
+ x = self.decoders[0](x, context, time)
382
+ x = torch.cat((x, s11), dim=1)
383
+ x = self.decoders[1](x, context, time)
384
+ x = torch.cat((x, s10), dim=1)
385
+ x = self.decoders[2](x, context, time)
386
+ x = torch.cat((x, s9), dim=1)
387
+ x = self.decoders[3](x, context, time)
388
+ x = torch.cat((x, s8), dim=1)
389
+ x = self.decoders[4](x, context, time)
390
+ x = torch.cat((x, s7), dim=1)
391
+ x = self.decoders[5](x, context, time)
392
+ x = torch.cat((x, s6), dim=1)
393
+ x = self.decoders[6](x, context, time)
394
+ x = torch.cat((x, s5), dim=1)
395
+ x = self.decoders[7](x, context, time)
396
+ x = torch.cat((x, s4), dim=1)
397
+ x = self.decoders[0](x, context, time)
398
+ x = torch.cat((x, s3), dim=1)
399
+ x = self.decoders[1](x, context, time)
400
+ x = torch.cat((x, s2), dim=1)
401
+ x = self.decoders[2](x, context, time)
402
+ x = torch.cat((x, s1), dim=1)
403
+ x = self.decoders[3](x, context, time)
404
+
405
+ x = self.final(x)
406
+
407
+ return x
408
+
409
+
410
+ class FinalLayer(nn.Module):
411
+
412
+ def __init__(self, in_channels, out_channels):
413
+ super().__init__()
414
+ self.groupnorm = nn.GroupNorm(32, in_channels)
415
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
416
+
417
+ def forward(self, x):
418
+ x = self.groupnorm(x)
419
+ x = F.silu(x)
420
+ x = self.conv(x)
421
+ return x
422
+
423
+
424
+ class Diffusion(nn.Module):
425
+
426
+ def __init__(self):
427
+ super().__init__()
428
+ self.time_embedding = TimeEmbedding(320)
429
+ self.unet = UNet()
430
+ self.final = FinalLayer(320, 4)
431
+
432
+ def forward(self, latent, context, time):
433
+ time = self.time_embedding(time)
434
+ # print('time:')
435
+ # print(list(time.shape))
436
+ output = self.unet(latent, context, time)
437
+ output = self.final(output)
438
+ return output
439
+
440
+
441
+ # Calling code as if Diffusion is splitted into two parts.
442
+ class DiffusionSplitted(nn.Module):
443
+
444
+ def __init__(self):
445
+ super().__init__()
446
+ self.unet_encoder = UNetEncoder()
447
+ self.bottleneck = UNetBottleNeck()
448
+ self.unet_decoder1 = UNetDecoder1()
449
+ self.unet_decoder2 = UNetDecoder2()
450
+ self.unet_decoder3 = UNetDecoder3()
451
+
452
+ def get_skip_connections(self, latent, context, time):
453
+ _, skip_connections, _ = self.unet_encoder(latent, context, time)
454
+ return skip_connections
455
+
456
+ def forward(self, latent, context, time):
457
+ output, skip_connections, time = self.unet_encoder(latent, context, time)
458
+ # print("output shape of unet encoder...")
459
+ # print(list(output.shape))
460
+ # print("output shape of time...")
461
+ # print(list(time.shape))
462
+ output = self.bottleneck(output, context, time)
463
+ # print("output shape of bn")
464
+ # print(list(output.shape))
465
+ output = self.unet_decoder1(
466
+ output,
467
+ context,
468
+ time,
469
+ skip_connections[8],
470
+ skip_connections[9],
471
+ skip_connections[10],
472
+ skip_connections[11],
473
+ )
474
+ # print("output shape of d1:")
475
+ # print(list(output.shape))
476
+
477
+ output = self.unet_decoder2(
478
+ output,
479
+ context,
480
+ time,
481
+ skip_connections[4],
482
+ skip_connections[5],
483
+ skip_connections[6],
484
+ skip_connections[7],
485
+ )
486
+
487
+ # print("output shape of d2:")
488
+ # print(list(output.shape))
489
+ output = self.unet_decoder3(
490
+ output,
491
+ context,
492
+ time,
493
+ skip_connections[0],
494
+ skip_connections[1],
495
+ skip_connections[2],
496
+ skip_connections[3],
497
+ )
498
+ return output
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from ai_edge_torch.generative.examples.stable_diffusion.decoder import AttentionBlock # NOQA
21
+ from ai_edge_torch.generative.examples.stable_diffusion.decoder import ResidualBlock # NOQA
22
+
23
+
24
+ class Encoder(nn.Sequential):
25
+
26
+ def __init__(self):
27
+ super().__init__(
28
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
29
+ ResidualBlock(128, 128),
30
+ ResidualBlock(128, 128),
31
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
32
+ ResidualBlock(128, 256),
33
+ ResidualBlock(256, 256),
34
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
35
+ ResidualBlock(256, 512),
36
+ ResidualBlock(512, 512),
37
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
38
+ ResidualBlock(512, 512),
39
+ ResidualBlock(512, 512),
40
+ ResidualBlock(512, 512),
41
+ AttentionBlock(512),
42
+ ResidualBlock(512, 512),
43
+ nn.GroupNorm(32, 512),
44
+ nn.SiLU(),
45
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
46
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
47
+ )
48
+
49
+ def forward(self, x, noise):
50
+ for module in self:
51
+ if getattr(module, 'stride', None) == (
52
+ 2,
53
+ 2,
54
+ ): # Padding at downsampling should be asymmetric (see #8)
55
+ x = F.pad(x, (0, 1, 0, 1))
56
+ x = module(x)
57
+
58
+ mean, log_variance = torch.chunk(x, 2, dim=1)
59
+ log_variance = torch.clamp(log_variance, -30, 20)
60
+ variance = log_variance.exp()
61
+ stdev = variance.sqrt()
62
+ x = mean + stdev * noise
63
+
64
+ x *= 0.18215
65
+ return x
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+
22
+ def get_time_embedding(timestep):
23
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
24
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
25
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
26
+
27
+
28
+ def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
29
+ betas = (
30
+ np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32)
31
+ ** 2
32
+ )
33
+ alphas = 1.0 - betas
34
+ alphas_cumprod = np.cumprod(alphas, axis=0)
35
+ return alphas_cumprod
36
+
37
+
38
+ def get_file_path(filename, url=None):
39
+ module_location = os.path.dirname(os.path.abspath(__file__))
40
+ parent_location = os.path.dirname(module_location)
41
+ file_location = os.path.join(parent_location, "data", filename)
42
+ return file_location
43
+
44
+
45
+ def move_channel(image, to):
46
+ if to == "first":
47
+ return image.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48
+ elif to == "last":
49
+ return image.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
50
+ else:
51
+ raise ValueError("to must be one of the following: first, last")
52
+
53
+
54
+ def rescale(x, old_range, new_range, clamp=False):
55
+ old_min, old_max = old_range
56
+ new_min, new_max = new_range
57
+ x -= old_min
58
+ x *= (new_max - new_min) / (old_max - old_min)
59
+ x += new_min
60
+ if clamp:
61
+ x = x.clamp(new_min, new_max)
62
+ return x
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing weight loader utilities.
16
+
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import safetensors.torch
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
25
+ from ai_edge_torch.generative.utilities import loader as loading_utils
26
+
27
+
28
+ class TestLoader(unittest.TestCase):
29
+ """Unit tests that check weight loader."""
30
+
31
+ def test_load_safetensors(self):
32
+ with tempfile.TemporaryDirectory() as temp_dir:
33
+ file_path = os.path.join(temp_dir, "test.safetensors")
34
+ test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
35
+ safetensors.torch.save_file(test_data, file_path)
36
+
37
+ loaded_tensors = loading_utils.load_safetensors(file_path)
38
+ self.assertIn("weight", loaded_tensors)
39
+ self.assertIn("bias", loaded_tensors)
40
+
41
+ def test_load_statedict(self):
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ file_path = os.path.join(temp_dir, "test.pt")
44
+ model = torch.nn.Linear(10, 5)
45
+ state_dict = model.state_dict()
46
+ torch.save(state_dict, file_path)
47
+
48
+ loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
49
+ self.assertIn("weight", loaded_tensors)
50
+ self.assertIn("bias", loaded_tensors)
51
+
52
+ def test_model_loader(self):
53
+ with tempfile.TemporaryDirectory() as temp_dir:
54
+ file_path = os.path.join(temp_dir, "test.safetensors")
55
+ test_weights = {
56
+ "lm_head.weight": torch.randn((32000, 2048)),
57
+ "model.embed_tokens.weight": torch.randn((32000, 2048)),
58
+ "model.layers.0.input_layernorm.weight": torch.randn((2048,)),
59
+ "model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
60
+ "model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
61
+ "model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
62
+ "model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
63
+ "model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
64
+ "model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
65
+ "model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
66
+ "model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
67
+ "model.norm.weight": torch.randn((2048,)),
68
+ }
69
+ safetensors.torch.save_file(test_weights, file_path)
70
+ cfg = tiny_llama.get_model_config()
71
+ cfg.num_layers = 1
72
+ model = tiny_llama.TinyLLamma(cfg)
73
+
74
+ loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
75
+ # if returns successfully, it means all the tensors were initiallized.
76
+ loader.load(model, strict=True)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ unittest.main()
@@ -228,14 +228,14 @@ class ModelLoader:
228
228
  q_name = self._names.attn_query_proj.format(idx)
229
229
  k_name = self._names.attn_key_proj.format(idx)
230
230
  v_name = self._names.attn_value_proj.format(idx)
231
- converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
231
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
232
232
  config,
233
233
  state.pop(f"{q_name}.weight"),
234
234
  state.pop(f"{k_name}.weight"),
235
235
  state.pop(f"{v_name}.weight"),
236
236
  )
237
237
  if config.attn_config.qkv_use_bias:
238
- converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
238
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
239
239
  config,
240
240
  state.pop(f"{q_name}.bias"),
241
241
  state.pop(f"{k_name}.bias"),
@@ -243,9 +243,13 @@ class ModelLoader:
243
243
  )
244
244
 
245
245
  o_name = self._names.attn_output_proj.format(idx)
246
- converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
246
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
247
+ f"{o_name}.weight"
248
+ )
247
249
  if config.attn_config.output_proj_use_bias:
248
- converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
250
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
251
+ f"{o_name}.bias"
252
+ )
249
253
 
250
254
  def _map_norm(
251
255
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.1.dev202405131930
3
+ Version: 0.2.0.dev20240527
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -24,8 +24,8 @@ Requires-Python: >=3.9, <3.12
24
24
  Description-Content-Type: text/markdown
25
25
  License-File: LICENSE
26
26
  Requires-Dist: numpy
27
- Requires-Dist: safetensors
28
27
  Requires-Dist: scipy
28
+ Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
30
  Requires-Dist: torch ==2.4.*
31
31
 
@@ -38,6 +38,13 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlY
38
38
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
39
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
40
40
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
41
+ ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
+ ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=8W4X9PKdnMsWGxXbBfm5OX6mX4XhvaMZ2gZw8yCTScY,2410
44
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=beLCtogA32oYT2nlATpyT-1xzkyPF8zi4v3kfHpw6Mc,3239
45
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nnsfgjSeL16U3TVdjTkRycaoWA2ChFeitx2RjGLpwyA,16200
46
+ ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=X6ekByU19KNHNh5OaztZEROv-QwcCwVm1xiJjm2SCoo,2251
47
+ ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=pG_dsV4xIaB7B8MgoRgSXBvLCVqDlF6bNunPN3GIm-s,2046
41
48
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
49
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
43
50
  ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
@@ -65,10 +72,11 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7
65
72
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
66
73
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
67
74
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
+ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
68
76
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
69
77
  ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
70
78
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
71
- ai_edge_torch/generative/utilities/loader.py,sha256=QrGZ3JlEN_tn8j6EdZOxVt_0u3yB5vBrR3KJtNaAwV8,10029
79
+ ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
72
80
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
73
81
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
74
82
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -84,8 +92,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
84
92
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
85
93
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
86
94
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
87
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
88
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA,sha256=lQcAb0esNisYUqkzDRHamW4S9luvrJ4QU75042IAqWc,1750
89
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
90
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
91
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD,,
95
+ ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
96
+ ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/METADATA,sha256=nbZoIm0s6CWdrMkaffTrpz-XooKzTR1q0SQ17rs-AKU,1748
97
+ ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
98
+ ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
99
+ ai_edge_torch_nightly-0.2.0.dev20240527.dist-info/RECORD,,