ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +34 -24
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -47
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/METADATA +5 -5
- ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +133 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/WHEEL +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/RECORD +0 -132
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,15 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Optional, Tuple
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
from torch import nn
|
|
21
|
-
import torch.nn.functional as F
|
|
22
|
-
|
|
23
19
|
import ai_edge_torch.generative.layers.builder as builder
|
|
24
20
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
25
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
26
22
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
|
27
23
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
28
24
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
|
+
import torch
|
|
26
|
+
from torch import nn
|
|
27
|
+
import torch.nn.functional as F
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
def _embed_rope(
|
|
@@ -140,7 +139,9 @@ class CausalSelfAttention(nn.Module):
|
|
|
140
139
|
shape = (config.num_heads + 2 * config.num_query_groups) * self.head_dim
|
|
141
140
|
# Key, query, value projections for all heads.
|
|
142
141
|
self.qkv_projection = nn.Linear(dim, shape, bias=config.qkv_use_bias)
|
|
143
|
-
self.output_projection = nn.Linear(
|
|
142
|
+
self.output_projection = nn.Linear(
|
|
143
|
+
dim, dim, bias=config.output_proj_use_bias
|
|
144
|
+
)
|
|
144
145
|
self.config = config
|
|
145
146
|
self.kv_cache = None
|
|
146
147
|
self.batch_size = batch_size
|
|
@@ -181,9 +182,10 @@ class CausalSelfAttention(nn.Module):
|
|
|
181
182
|
"""
|
|
182
183
|
# Batch size, sequence length, embedding dimensionality.
|
|
183
184
|
B, T, E = x.size()
|
|
184
|
-
assert (
|
|
185
|
-
|
|
186
|
-
|
|
185
|
+
assert B == self.batch_size, (
|
|
186
|
+
"batch size of input tensor must match with the batch size specified in"
|
|
187
|
+
" the model configuration."
|
|
188
|
+
)
|
|
187
189
|
|
|
188
190
|
qkv = self.qkv_projection(x)
|
|
189
191
|
|
|
@@ -279,9 +281,15 @@ class CrossAttention(nn.Module):
|
|
|
279
281
|
self.config = config
|
|
280
282
|
self.head_dim = query_dim // config.num_heads
|
|
281
283
|
self.n_heads = config.num_heads
|
|
282
|
-
self.q_projection = nn.Linear(
|
|
283
|
-
|
|
284
|
-
|
|
284
|
+
self.q_projection = nn.Linear(
|
|
285
|
+
query_dim, query_dim, bias=config.qkv_use_bias
|
|
286
|
+
)
|
|
287
|
+
self.k_projection = nn.Linear(
|
|
288
|
+
cross_dim, query_dim, bias=config.qkv_use_bias
|
|
289
|
+
)
|
|
290
|
+
self.v_projection = nn.Linear(
|
|
291
|
+
cross_dim, query_dim, bias=config.qkv_use_bias
|
|
292
|
+
)
|
|
285
293
|
self.output_projection = nn.Linear(
|
|
286
294
|
query_dim, query_dim, bias=config.output_proj_use_bias
|
|
287
295
|
)
|
|
@@ -13,13 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Builder class for individual components.
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
import torch.nn.functional as F
|
|
19
|
-
|
|
20
16
|
import ai_edge_torch.generative.layers.feed_forward as feed_forward
|
|
21
17
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
22
18
|
import ai_edge_torch.generative.layers.normalization as normalization
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn
|
|
21
|
+
import torch.nn.functional as F
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class GeGLU(nn.Module):
|
|
@@ -14,16 +14,17 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# `nn.Module` which implements a KV cache.
|
|
16
16
|
|
|
17
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
17
18
|
import torch
|
|
18
19
|
from torch import nn
|
|
19
20
|
import torch_xla
|
|
20
21
|
|
|
21
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
22
|
-
|
|
23
22
|
|
|
24
23
|
class KVCache(nn.Module):
|
|
25
24
|
|
|
26
|
-
def __init__(
|
|
25
|
+
def __init__(
|
|
26
|
+
self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
|
|
27
|
+
):
|
|
27
28
|
"""Initializes the KVCache layer.
|
|
28
29
|
|
|
29
30
|
Args:
|
|
@@ -124,9 +124,13 @@ class ModelConfig:
|
|
|
124
124
|
default_factory=NormalizationConfig
|
|
125
125
|
)
|
|
126
126
|
# The normalization applied to feed forward's input.
|
|
127
|
-
pre_ff_norm_config: NormalizationConfig = field(
|
|
127
|
+
pre_ff_norm_config: NormalizationConfig = field(
|
|
128
|
+
default_factory=NormalizationConfig
|
|
129
|
+
)
|
|
128
130
|
# The normalization applied before LM head.
|
|
129
|
-
final_norm_config: NormalizationConfig = field(
|
|
131
|
+
final_norm_config: NormalizationConfig = field(
|
|
132
|
+
default_factory=NormalizationConfig
|
|
133
|
+
)
|
|
130
134
|
|
|
131
135
|
# If set to True, only pre_attention_norm is applied to the input and the
|
|
132
136
|
# decode's output is computed as `output = input + attn_out + ff_out` where
|
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
import torch
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
def apply_rope(
|
|
19
|
+
def apply_rope(
|
|
20
|
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
21
|
+
) -> torch.Tensor:
|
|
20
22
|
"""Computes rotary positional embedding.
|
|
21
23
|
|
|
22
24
|
Args:
|
|
@@ -17,11 +17,10 @@
|
|
|
17
17
|
import math
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
20
21
|
import torch
|
|
21
22
|
import torch.nn.functional as F
|
|
22
23
|
|
|
23
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
24
|
-
|
|
25
24
|
|
|
26
25
|
def scaled_dot_product_attention(
|
|
27
26
|
q: torch.Tensor,
|
|
@@ -15,15 +15,14 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import List, Optional, Tuple
|
|
17
17
|
|
|
18
|
-
import torch
|
|
19
|
-
from torch import nn
|
|
20
|
-
|
|
21
18
|
from ai_edge_torch.generative.layers.attention import CrossAttention
|
|
22
19
|
from ai_edge_torch.generative.layers.attention import SelfAttention
|
|
23
20
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
24
21
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
25
22
|
import ai_edge_torch.generative.layers.unet.builder as unet_builder
|
|
26
23
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
24
|
+
import torch
|
|
25
|
+
from torch import nn
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class ResidualBlock2D(nn.Module):
|
|
@@ -41,7 +40,11 @@ class ResidualBlock2D(nn.Module):
|
|
|
41
40
|
config.in_channels, config.normalization_config
|
|
42
41
|
)
|
|
43
42
|
self.conv_1 = nn.Conv2d(
|
|
44
|
-
config.in_channels,
|
|
43
|
+
config.in_channels,
|
|
44
|
+
config.out_channels,
|
|
45
|
+
kernel_size=3,
|
|
46
|
+
stride=1,
|
|
47
|
+
padding=1,
|
|
45
48
|
)
|
|
46
49
|
if config.time_embedding_channels is not None:
|
|
47
50
|
self.time_emb_proj = nn.Linear(
|
|
@@ -53,14 +56,22 @@ class ResidualBlock2D(nn.Module):
|
|
|
53
56
|
config.out_channels, config.normalization_config
|
|
54
57
|
)
|
|
55
58
|
self.conv_2 = nn.Conv2d(
|
|
56
|
-
config.out_channels,
|
|
59
|
+
config.out_channels,
|
|
60
|
+
config.out_channels,
|
|
61
|
+
kernel_size=3,
|
|
62
|
+
stride=1,
|
|
63
|
+
padding=1,
|
|
57
64
|
)
|
|
58
65
|
self.act_fn = layers_builder.get_activation(config.activation_config)
|
|
59
66
|
if config.in_channels == config.out_channels:
|
|
60
67
|
self.residual_layer = nn.Identity()
|
|
61
68
|
else:
|
|
62
69
|
self.residual_layer = nn.Conv2d(
|
|
63
|
-
config.in_channels,
|
|
70
|
+
config.in_channels,
|
|
71
|
+
config.out_channels,
|
|
72
|
+
kernel_size=1,
|
|
73
|
+
stride=1,
|
|
74
|
+
padding=0,
|
|
64
75
|
)
|
|
65
76
|
|
|
66
77
|
def forward(
|
|
@@ -105,7 +116,9 @@ class AttentionBlock2D(nn.Module):
|
|
|
105
116
|
"""
|
|
106
117
|
super().__init__()
|
|
107
118
|
self.config = config
|
|
108
|
-
self.norm = layers_builder.build_norm(
|
|
119
|
+
self.norm = layers_builder.build_norm(
|
|
120
|
+
config.dim, config.normalization_config
|
|
121
|
+
)
|
|
109
122
|
self.attention = SelfAttention(
|
|
110
123
|
config.attention_batch_size,
|
|
111
124
|
config.dim,
|
|
@@ -125,7 +138,10 @@ class AttentionBlock2D(nn.Module):
|
|
|
125
138
|
"""
|
|
126
139
|
residual = input_tensor
|
|
127
140
|
B, C, H, W = input_tensor.shape
|
|
128
|
-
if
|
|
141
|
+
if (
|
|
142
|
+
self.config.normalization_config.type
|
|
143
|
+
== layers_cfg.NormalizationType.GROUP_NORM
|
|
144
|
+
):
|
|
129
145
|
x = self.norm(input_tensor)
|
|
130
146
|
x = x.view(B, C, H * W)
|
|
131
147
|
x = x.transpose(-1, -2)
|
|
@@ -156,7 +172,9 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
156
172
|
"""
|
|
157
173
|
super().__init__()
|
|
158
174
|
self.config = config
|
|
159
|
-
self.norm = layers_builder.build_norm(
|
|
175
|
+
self.norm = layers_builder.build_norm(
|
|
176
|
+
config.query_dim, config.normalization_config
|
|
177
|
+
)
|
|
160
178
|
self.attention = CrossAttention(
|
|
161
179
|
config.attention_batch_size,
|
|
162
180
|
config.query_dim,
|
|
@@ -180,7 +198,10 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
180
198
|
"""
|
|
181
199
|
residual = input_tensor
|
|
182
200
|
B, C, H, W = input_tensor.shape
|
|
183
|
-
if
|
|
201
|
+
if (
|
|
202
|
+
self.config.normalization_config.type
|
|
203
|
+
== layers_cfg.NormalizationType.GROUP_NORM
|
|
204
|
+
):
|
|
184
205
|
x = self.norm(input_tensor)
|
|
185
206
|
x = x.view(B, C, H * W)
|
|
186
207
|
x = x.transpose(-1, -2)
|
|
@@ -209,7 +230,9 @@ class FeedForwardBlock2D(nn.Module):
|
|
|
209
230
|
super().__init__()
|
|
210
231
|
self.config = config
|
|
211
232
|
self.act = layers_builder.get_activation(config.activation_config)
|
|
212
|
-
self.norm = layers_builder.build_norm(
|
|
233
|
+
self.norm = layers_builder.build_norm(
|
|
234
|
+
config.dim, config.normalization_config
|
|
235
|
+
)
|
|
213
236
|
if config.activation_config.type == layers_cfg.ActivationType.GE_GLU:
|
|
214
237
|
self.w1 = nn.Identity()
|
|
215
238
|
self.w2 = nn.Linear(config.hidden_dim, config.dim)
|
|
@@ -220,7 +243,10 @@ class FeedForwardBlock2D(nn.Module):
|
|
|
220
243
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
221
244
|
residual = input_tensor
|
|
222
245
|
B, C, H, W = input_tensor.shape
|
|
223
|
-
if
|
|
246
|
+
if (
|
|
247
|
+
self.config.normalization_config.type
|
|
248
|
+
== layers_cfg.NormalizationType.GROUP_NORM
|
|
249
|
+
):
|
|
224
250
|
x = self.norm(input_tensor)
|
|
225
251
|
x = x.view(B, C, H * W)
|
|
226
252
|
x = x.transpose(-1, -2)
|
|
@@ -287,7 +313,9 @@ class TransformerBlock2D(nn.Module):
|
|
|
287
313
|
padding=0,
|
|
288
314
|
)
|
|
289
315
|
self.self_attention = AttentionBlock2D(config.attention_block_config)
|
|
290
|
-
self.cross_attention = CrossAttentionBlock2D(
|
|
316
|
+
self.cross_attention = CrossAttentionBlock2D(
|
|
317
|
+
config.cross_attention_block_config
|
|
318
|
+
)
|
|
291
319
|
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
|
292
320
|
self.conv_out = nn.Conv2d(
|
|
293
321
|
config.attention_block_config.dim,
|
|
@@ -371,7 +399,9 @@ class DownEncoderBlock2D(nn.Module):
|
|
|
371
399
|
if config.transformer_block_config:
|
|
372
400
|
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
373
401
|
self.resnets = nn.ModuleList(resnets)
|
|
374
|
-
self.transformers =
|
|
402
|
+
self.transformers = (
|
|
403
|
+
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
404
|
+
)
|
|
375
405
|
if config.add_downsample:
|
|
376
406
|
self.downsampler = unet_builder.build_downsampling(config.sampling_config)
|
|
377
407
|
else:
|
|
@@ -467,12 +497,18 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
467
497
|
if config.transformer_block_config:
|
|
468
498
|
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
469
499
|
self.resnets = nn.ModuleList(resnets)
|
|
470
|
-
self.transformers =
|
|
500
|
+
self.transformers = (
|
|
501
|
+
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
502
|
+
)
|
|
471
503
|
if config.add_upsample:
|
|
472
504
|
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
473
505
|
if config.upsample_conv:
|
|
474
506
|
self.upsample_conv = nn.Conv2d(
|
|
475
|
-
config.out_channels,
|
|
507
|
+
config.out_channels,
|
|
508
|
+
config.out_channels,
|
|
509
|
+
kernel_size=3,
|
|
510
|
+
stride=1,
|
|
511
|
+
padding=1,
|
|
476
512
|
)
|
|
477
513
|
else:
|
|
478
514
|
self.upsampler = None
|
|
@@ -548,9 +584,13 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
|
548
584
|
transformers = []
|
|
549
585
|
for i in range(config.num_layers):
|
|
550
586
|
res_skip_channels = (
|
|
551
|
-
config.in_channels
|
|
587
|
+
config.in_channels
|
|
588
|
+
if (i == config.num_layers - 1)
|
|
589
|
+
else config.out_channels
|
|
590
|
+
)
|
|
591
|
+
resnet_in_channels = (
|
|
592
|
+
config.prev_out_channels if i == 0 else config.out_channels
|
|
552
593
|
)
|
|
553
|
-
resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
|
|
554
594
|
resnets.append(
|
|
555
595
|
ResidualBlock2D(
|
|
556
596
|
unet_cfg.ResidualBlock2DConfig(
|
|
@@ -565,12 +605,18 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
|
565
605
|
if config.transformer_block_config:
|
|
566
606
|
transformers.append(TransformerBlock2D(config.transformer_block_config))
|
|
567
607
|
self.resnets = nn.ModuleList(resnets)
|
|
568
|
-
self.transformers =
|
|
608
|
+
self.transformers = (
|
|
609
|
+
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
610
|
+
)
|
|
569
611
|
if config.add_upsample:
|
|
570
612
|
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
571
613
|
if config.upsample_conv:
|
|
572
614
|
self.upsample_conv = nn.Conv2d(
|
|
573
|
-
config.out_channels,
|
|
615
|
+
config.out_channels,
|
|
616
|
+
config.out_channels,
|
|
617
|
+
kernel_size=3,
|
|
618
|
+
stride=1,
|
|
619
|
+
padding=1,
|
|
574
620
|
)
|
|
575
621
|
else:
|
|
576
622
|
self.upsampler = None
|
|
@@ -678,7 +724,9 @@ class MidBlock2D(nn.Module):
|
|
|
678
724
|
)
|
|
679
725
|
self.resnets = nn.ModuleList(resnets)
|
|
680
726
|
self.attentions = nn.ModuleList(attentions) if len(attentions) > 0 else None
|
|
681
|
-
self.transformers =
|
|
727
|
+
self.transformers = (
|
|
728
|
+
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
|
729
|
+
)
|
|
682
730
|
|
|
683
731
|
def forward(
|
|
684
732
|
self,
|
|
@@ -14,9 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Builder utils for individual components.
|
|
16
16
|
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
17
|
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
18
|
+
from torch import nn
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def build_upsampling(config: unet_config.UpSamplingConfig):
|
|
@@ -30,10 +29,14 @@ def build_upsampling(config: unet_config.UpSamplingConfig):
|
|
|
30
29
|
|
|
31
30
|
def build_downsampling(config: unet_config.DownSamplingConfig):
|
|
32
31
|
if config.mode == unet_config.SamplingType.AVERAGE:
|
|
33
|
-
return nn.AvgPool2d(
|
|
32
|
+
return nn.AvgPool2d(
|
|
33
|
+
config.kernel_size, config.stride, padding=config.padding
|
|
34
|
+
)
|
|
34
35
|
elif config.mode == unet_config.SamplingType.CONVOLUTION:
|
|
35
36
|
out_channels = (
|
|
36
|
-
config.in_channels
|
|
37
|
+
config.in_channels
|
|
38
|
+
if config.out_channels is None
|
|
39
|
+
else config.out_channels
|
|
37
40
|
)
|
|
38
41
|
padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
|
|
39
42
|
return nn.Conv2d(
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
import json
|
|
17
17
|
|
|
18
18
|
from ai_edge_quantizer import quantizer
|
|
19
|
-
|
|
20
19
|
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
20
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
22
21
|
|
|
@@ -44,7 +43,9 @@ def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
|
|
|
44
43
|
raise ValueError('Unimplemented number of bits')
|
|
45
44
|
|
|
46
45
|
|
|
47
|
-
def _get_dtype_from_dtype(
|
|
46
|
+
def _get_dtype_from_dtype(
|
|
47
|
+
dtype: quant_attrs.Dtype,
|
|
48
|
+
) -> quantizer.qtyping.TensorDataType:
|
|
48
49
|
if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
|
|
49
50
|
return quantizer.qtyping.TensorDataType.FLOAT
|
|
50
51
|
else:
|
|
@@ -59,7 +60,9 @@ def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
|
|
|
59
60
|
raise ValueError('Unimplemented execution mode')
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
def _get_channelwise_from_granularity(
|
|
63
|
+
def _get_channelwise_from_granularity(
|
|
64
|
+
granularity: quant_attrs.Granularity,
|
|
65
|
+
) -> bool:
|
|
63
66
|
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
|
64
67
|
return True
|
|
65
68
|
elif granularity == quant_attrs.Granularity.NONE:
|
|
@@ -87,7 +90,9 @@ def _set_quant_config(
|
|
|
87
90
|
weight_tensor_config=_TensorQuantConfig(
|
|
88
91
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
|
89
92
|
symmetric=True,
|
|
90
|
-
channel_wise=_get_channelwise_from_granularity(
|
|
93
|
+
channel_wise=_get_channelwise_from_granularity(
|
|
94
|
+
layer_recipe.granularity
|
|
95
|
+
),
|
|
91
96
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
92
97
|
),
|
|
93
98
|
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
|
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
import torch
|
|
18
|
-
|
|
19
16
|
import ai_edge_torch
|
|
20
17
|
from ai_edge_torch.generative.examples.gemma import gemma
|
|
21
18
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
def main():
|
|
@@ -41,6 +41,16 @@ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
|
|
45
|
+
return quant_recipe.LayerQuantRecipe(
|
|
46
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
|
+
weight_dtype=quant_attrs.Dtype.INT8,
|
|
48
|
+
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
+
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
50
|
+
granularity=quant_attrs.Granularity.CHANNELWISE,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
44
54
|
def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
45
55
|
return quant_recipe.LayerQuantRecipe(
|
|
46
56
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
@@ -40,6 +40,14 @@ def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
def full_int8_weight_only_recipe() -> quant_config.QuantConfig:
|
|
44
|
+
return quant_config.QuantConfig(
|
|
45
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
46
|
+
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
43
51
|
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
52
|
return quant_config.QuantConfig(
|
|
45
53
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
@@ -18,11 +18,10 @@ import os
|
|
|
18
18
|
import tempfile
|
|
19
19
|
import unittest
|
|
20
20
|
|
|
21
|
-
import safetensors.torch
|
|
22
|
-
import torch
|
|
23
|
-
|
|
24
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
25
22
|
from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
23
|
+
import safetensors.torch
|
|
24
|
+
import torch
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class TestLoader(unittest.TestCase):
|
|
@@ -59,7 +58,9 @@ class TestLoader(unittest.TestCase):
|
|
|
59
58
|
"model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
|
|
60
59
|
"model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
|
|
61
60
|
"model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
|
|
62
|
-
"model.layers.0.post_attention_layernorm.weight": torch.randn((
|
|
61
|
+
"model.layers.0.post_attention_layernorm.weight": torch.randn((
|
|
62
|
+
2048,
|
|
63
|
+
)),
|
|
63
64
|
"model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
|
|
64
65
|
"model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
|
|
65
66
|
"model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
|
|
@@ -16,20 +16,23 @@
|
|
|
16
16
|
|
|
17
17
|
import unittest
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
|
-
import torch
|
|
21
|
-
|
|
22
19
|
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
23
20
|
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
24
21
|
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
|
|
25
22
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
26
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class TestExternalKVLayers(unittest.TestCase):
|
|
30
29
|
|
|
31
|
-
def _get_test_config(
|
|
32
|
-
|
|
30
|
+
def _get_test_config(
|
|
31
|
+
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
|
32
|
+
):
|
|
33
|
+
attn_config = cfg.AttentionConfig(
|
|
34
|
+
num_heads=1, num_query_groups=num_query_groups
|
|
35
|
+
)
|
|
33
36
|
config = cfg.ModelConfig(
|
|
34
37
|
kv_cache_max_len=kv_cache_max_len,
|
|
35
38
|
embedding_dim=head_dim,
|
|
@@ -56,23 +59,31 @@ class TestExternalKVLayers(unittest.TestCase):
|
|
|
56
59
|
entry = kv.caches[0]
|
|
57
60
|
# single-slice update
|
|
58
61
|
input_pos = torch.tensor([1])
|
|
59
|
-
k_slice = v_slice = torch.full(
|
|
62
|
+
k_slice = v_slice = torch.full(
|
|
63
|
+
(1, 1, NUM_QG, HEAD_DIM), 5, dtype=torch.float
|
|
64
|
+
)
|
|
60
65
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
|
61
66
|
self.assertEqual(
|
|
62
|
-
updated_entry.k_cache.numpy().flatten().tolist(),
|
|
67
|
+
updated_entry.k_cache.numpy().flatten().tolist(),
|
|
68
|
+
[0, 0, 5, 5, 0, 0, 0, 0],
|
|
63
69
|
)
|
|
64
70
|
self.assertEqual(
|
|
65
|
-
updated_entry.v_cache.numpy().flatten().tolist(),
|
|
71
|
+
updated_entry.v_cache.numpy().flatten().tolist(),
|
|
72
|
+
[0, 0, 5, 5, 0, 0, 0, 0],
|
|
66
73
|
)
|
|
67
74
|
# multi-slice update
|
|
68
75
|
input_pos = torch.tensor([0, 3])
|
|
69
|
-
k_slice = v_slice = torch.full(
|
|
76
|
+
k_slice = v_slice = torch.full(
|
|
77
|
+
(1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float
|
|
78
|
+
)
|
|
70
79
|
updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
|
|
71
80
|
self.assertEqual(
|
|
72
|
-
updated_entry.k_cache.numpy().flatten().tolist(),
|
|
81
|
+
updated_entry.k_cache.numpy().flatten().tolist(),
|
|
82
|
+
[7, 7, 0, 0, 0, 0, 7, 7],
|
|
73
83
|
)
|
|
74
84
|
self.assertEqual(
|
|
75
|
-
updated_entry.v_cache.numpy().flatten().tolist(),
|
|
85
|
+
updated_entry.v_cache.numpy().flatten().tolist(),
|
|
86
|
+
[7, 7, 0, 0, 0, 0, 7, 7],
|
|
76
87
|
)
|
|
77
88
|
|
|
78
89
|
def test_serialization(self):
|
|
@@ -18,15 +18,14 @@ import os
|
|
|
18
18
|
import tempfile
|
|
19
19
|
import unittest
|
|
20
20
|
|
|
21
|
-
import numpy as np
|
|
22
|
-
import torch
|
|
23
|
-
|
|
24
21
|
import ai_edge_torch
|
|
25
22
|
from ai_edge_torch.generative.examples.gemma import gemma
|
|
26
23
|
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
27
24
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
28
25
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
29
26
|
from ai_edge_torch.testing import model_coverage
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
class TestModelConversion(unittest.TestCase):
|