ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
- ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
- ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
|
@@ -33,17 +33,17 @@ def convert_phi2_to_tflite(
|
|
|
33
33
|
quantize: bool = True,
|
|
34
34
|
):
|
|
35
35
|
"""An example method for converting a Phi-2 model to multi-signature
|
|
36
|
-
tflite model.
|
|
37
36
|
|
|
37
|
+
tflite model.
|
|
38
38
|
Args:
|
|
39
|
-
checkpoint_path (str): The filepath to the model checkpoint, or
|
|
40
|
-
|
|
39
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
40
|
+
holding the checkpoint.
|
|
41
41
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
42
42
|
Defaults to 512.
|
|
43
43
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
44
44
|
including both prefill and decode. Defaults to 1024.
|
|
45
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
46
|
-
|
|
45
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
46
|
+
to True.
|
|
47
47
|
"""
|
|
48
48
|
pytorch_model = phi2.build_model(
|
|
49
49
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -68,7 +68,9 @@ class Phi2(nn.Module):
|
|
|
68
68
|
)
|
|
69
69
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
70
70
|
size=config.kv_cache_max,
|
|
71
|
-
dim=int(
|
|
71
|
+
dim=int(
|
|
72
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
73
|
+
),
|
|
72
74
|
base=10_000,
|
|
73
75
|
condense_ratio=1,
|
|
74
76
|
dtype=torch.float32,
|
|
@@ -118,6 +120,7 @@ class Phi2(nn.Module):
|
|
|
118
120
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
119
121
|
attn_config = cfg.AttentionConfig(
|
|
120
122
|
num_heads=32,
|
|
123
|
+
head_dim=80,
|
|
121
124
|
num_query_groups=32,
|
|
122
125
|
rotary_percentage=0.4,
|
|
123
126
|
qkv_use_bias=True,
|
|
@@ -21,7 +21,7 @@ import os
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
|
|
23
23
|
import ai_edge_torch
|
|
24
|
-
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
|
|
24
|
+
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
|
|
25
25
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
26
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
27
27
|
import torch
|
|
@@ -33,8 +33,7 @@ def convert_tiny_llama_to_tflite(
|
|
|
33
33
|
kv_cache_max_len: int = 1024,
|
|
34
34
|
quantize: bool = True,
|
|
35
35
|
):
|
|
36
|
-
"""An example
|
|
37
|
-
tflite model.
|
|
36
|
+
"""An example for converting TinyLlama model to multi-signature tflite model.
|
|
38
37
|
|
|
39
38
|
Args:
|
|
40
39
|
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
@@ -43,8 +42,8 @@ def convert_tiny_llama_to_tflite(
|
|
|
43
42
|
Defaults to 512.
|
|
44
43
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
45
44
|
including both prefill and decode. Defaults to 1024.
|
|
46
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
47
|
-
|
|
45
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
46
|
+
to True.
|
|
48
47
|
"""
|
|
49
48
|
pytorch_model = tiny_llama.build_model(
|
|
50
49
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -70,7 +70,9 @@ class TinyLLamma(nn.Module):
|
|
|
70
70
|
)
|
|
71
71
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
72
72
|
size=config.kv_cache_max,
|
|
73
|
-
dim=int(
|
|
73
|
+
dim=int(
|
|
74
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
75
|
+
),
|
|
74
76
|
base=10_000,
|
|
75
77
|
condense_ratio=1,
|
|
76
78
|
dtype=torch.float32,
|
|
@@ -121,6 +123,7 @@ class TinyLLamma(nn.Module):
|
|
|
121
123
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
122
124
|
attn_config = cfg.AttentionConfig(
|
|
123
125
|
num_heads=32,
|
|
126
|
+
head_dim=64,
|
|
124
127
|
num_query_groups=4,
|
|
125
128
|
rotary_percentage=1.0,
|
|
126
129
|
)
|
|
@@ -28,17 +28,17 @@ def convert_gemma_to_tflite(
|
|
|
28
28
|
kv_cache_max_len: int = 1024,
|
|
29
29
|
quantize: bool = True,
|
|
30
30
|
):
|
|
31
|
-
"""
|
|
32
|
-
tflite model.
|
|
31
|
+
"""Converts a Gemma 2B model to multi-signature tflite model.
|
|
33
32
|
|
|
34
33
|
Args:
|
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
34
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
35
|
+
holding the checkpoint.
|
|
36
36
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
37
37
|
Defaults to 512.
|
|
38
38
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
39
39
|
including both prefill and decode. Defaults to 1024.
|
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
41
|
-
|
|
40
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
41
|
+
to True.
|
|
42
42
|
"""
|
|
43
43
|
pytorch_model = gemma.build_2b_model(
|
|
44
44
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -68,7 +68,9 @@ class Gemma(nn.Module):
|
|
|
68
68
|
)
|
|
69
69
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
70
70
|
size=config.kv_cache_max,
|
|
71
|
-
dim=int(
|
|
71
|
+
dim=int(
|
|
72
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
73
|
+
),
|
|
72
74
|
base=10_000,
|
|
73
75
|
condense_ratio=1,
|
|
74
76
|
dtype=torch.float32,
|
|
@@ -113,6 +115,7 @@ class Gemma(nn.Module):
|
|
|
113
115
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
114
116
|
attn_config = cfg.AttentionConfig(
|
|
115
117
|
num_heads=8,
|
|
118
|
+
head_dim=256,
|
|
116
119
|
num_query_groups=1,
|
|
117
120
|
rotary_percentage=1.0,
|
|
118
121
|
)
|
|
@@ -28,17 +28,17 @@ def convert_phi2_to_tflite(
|
|
|
28
28
|
kv_cache_max_len: int = 1024,
|
|
29
29
|
quantize: bool = True,
|
|
30
30
|
):
|
|
31
|
-
"""
|
|
32
|
-
tflite model.
|
|
31
|
+
"""Converts a Phi-2 model to multi-signature tflite model.
|
|
33
32
|
|
|
34
33
|
Args:
|
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
34
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
35
|
+
holding the checkpoint.
|
|
36
36
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
37
37
|
Defaults to 512.
|
|
38
38
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
39
39
|
including both prefill and decode. Defaults to 1024.
|
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
41
|
-
|
|
40
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
41
|
+
to True.
|
|
42
42
|
"""
|
|
43
43
|
pytorch_model = phi2.build_model(
|
|
44
44
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -63,7 +63,9 @@ class Phi2(nn.Module):
|
|
|
63
63
|
)
|
|
64
64
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
65
65
|
size=config.kv_cache_max,
|
|
66
|
-
dim=int(
|
|
66
|
+
dim=int(
|
|
67
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
68
|
+
),
|
|
67
69
|
base=10_000,
|
|
68
70
|
condense_ratio=1,
|
|
69
71
|
dtype=torch.float32,
|
|
@@ -107,6 +109,7 @@ class Phi2(nn.Module):
|
|
|
107
109
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
108
110
|
attn_config = cfg.AttentionConfig(
|
|
109
111
|
num_heads=32,
|
|
112
|
+
head_dim=80,
|
|
110
113
|
num_query_groups=32,
|
|
111
114
|
rotary_percentage=0.4,
|
|
112
115
|
qkv_use_bias=True,
|
|
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
49
49
|
|
|
50
50
|
class CLIP(nn.Module):
|
|
51
51
|
"""CLIP text encoder
|
|
52
|
+
|
|
52
53
|
For details, see https://arxiv.org/abs/2103.00020
|
|
53
54
|
"""
|
|
54
55
|
|
|
@@ -92,6 +93,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
92
93
|
|
|
93
94
|
attn_config = cfg.AttentionConfig(
|
|
94
95
|
num_heads=num_heads,
|
|
96
|
+
head_dim=embedding_dim // num_heads,
|
|
95
97
|
num_query_groups=num_query_groups,
|
|
96
98
|
rotary_percentage=0.0,
|
|
97
99
|
qkv_use_bias=True,
|
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
17
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
18
|
-
|
|
18
|
+
from ai_edge_torch.generative.layers.unet import blocks_2d
|
|
19
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
20
|
-
|
|
20
|
+
from ai_edge_torch.generative.utilities import stable_diffusion_loader
|
|
21
21
|
import torch
|
|
22
22
|
from torch import nn
|
|
23
23
|
|
|
@@ -288,6 +288,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
288
288
|
normalization_config=norm_config,
|
|
289
289
|
attention_config=layers_cfg.AttentionConfig(
|
|
290
290
|
num_heads=1,
|
|
291
|
+
head_dim=block_out_channels[-1],
|
|
291
292
|
num_query_groups=1,
|
|
292
293
|
qkv_use_bias=True,
|
|
293
294
|
output_proj_use_bias=True,
|
|
@@ -15,9 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
17
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
18
|
-
|
|
18
|
+
from ai_edge_torch.generative.layers.unet import blocks_2d
|
|
19
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
20
|
-
|
|
20
|
+
from ai_edge_torch.generative.utilities import stable_diffusion_loader
|
|
21
21
|
import torch
|
|
22
22
|
from torch import nn
|
|
23
23
|
|
|
@@ -195,6 +195,31 @@ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
|
|
|
195
195
|
)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
def build_attention_config(
|
|
199
|
+
num_heads,
|
|
200
|
+
dim,
|
|
201
|
+
num_query_groups,
|
|
202
|
+
rotary_percentage=0.0,
|
|
203
|
+
qkv_transpose_before_split=True,
|
|
204
|
+
qkv_use_bias=False,
|
|
205
|
+
output_proj_use_bias=True,
|
|
206
|
+
enable_kv_cache=False,
|
|
207
|
+
qkv_fused_interleaved=False,
|
|
208
|
+
):
|
|
209
|
+
|
|
210
|
+
return layers_cfg.AttentionConfig(
|
|
211
|
+
num_heads=num_heads,
|
|
212
|
+
head_dim=dim // num_heads,
|
|
213
|
+
num_query_groups=num_query_groups,
|
|
214
|
+
rotary_percentage=rotary_percentage,
|
|
215
|
+
qkv_transpose_before_split=qkv_transpose_before_split,
|
|
216
|
+
qkv_use_bias=qkv_use_bias,
|
|
217
|
+
output_proj_use_bias=output_proj_use_bias,
|
|
218
|
+
enable_kv_cache=enable_kv_cache,
|
|
219
|
+
qkv_fused_interleaved=qkv_fused_interleaved,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
198
223
|
class TimeEmbedding(nn.Module):
|
|
199
224
|
|
|
200
225
|
def __init__(self, in_dim, out_dim):
|
|
@@ -267,17 +292,6 @@ class Diffusion(nn.Module):
|
|
|
267
292
|
config.in_channels, block_out_channels[0], kernel_size=3, padding=1
|
|
268
293
|
)
|
|
269
294
|
|
|
270
|
-
attention_config = layers_cfg.AttentionConfig(
|
|
271
|
-
num_heads=config.transformer_num_attention_heads,
|
|
272
|
-
num_query_groups=config.transformer_num_attention_heads,
|
|
273
|
-
rotary_percentage=0.0,
|
|
274
|
-
qkv_transpose_before_split=True,
|
|
275
|
-
qkv_use_bias=False,
|
|
276
|
-
output_proj_use_bias=True,
|
|
277
|
-
enable_kv_cache=False,
|
|
278
|
-
qkv_fused_interleaved=False,
|
|
279
|
-
)
|
|
280
|
-
|
|
281
295
|
# Down encoders.
|
|
282
296
|
down_encoders = []
|
|
283
297
|
output_channel = block_out_channels[0]
|
|
@@ -312,7 +326,11 @@ class Diffusion(nn.Module):
|
|
|
312
326
|
dim=output_channel,
|
|
313
327
|
attention_batch_size=config.transformer_batch_size,
|
|
314
328
|
normalization_config=config.transformer_norm_config,
|
|
315
|
-
attention_config=
|
|
329
|
+
attention_config=build_attention_config(
|
|
330
|
+
num_heads=config.transformer_num_attention_heads,
|
|
331
|
+
dim=output_channel,
|
|
332
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
333
|
+
),
|
|
316
334
|
enable_hlfb=False,
|
|
317
335
|
),
|
|
318
336
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
@@ -320,7 +338,11 @@ class Diffusion(nn.Module):
|
|
|
320
338
|
cross_dim=config.transformer_cross_attention_dim,
|
|
321
339
|
attention_batch_size=config.transformer_batch_size,
|
|
322
340
|
normalization_config=config.transformer_norm_config,
|
|
323
|
-
attention_config=
|
|
341
|
+
attention_config=build_attention_config(
|
|
342
|
+
num_heads=config.transformer_num_attention_heads,
|
|
343
|
+
dim=output_channel,
|
|
344
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
345
|
+
),
|
|
324
346
|
enable_hlfb=False,
|
|
325
347
|
),
|
|
326
348
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
@@ -374,7 +396,11 @@ class Diffusion(nn.Module):
|
|
|
374
396
|
dim=mid_block_channels,
|
|
375
397
|
attention_batch_size=config.transformer_batch_size,
|
|
376
398
|
normalization_config=config.transformer_norm_config,
|
|
377
|
-
attention_config=
|
|
399
|
+
attention_config=build_attention_config(
|
|
400
|
+
num_heads=config.transformer_num_attention_heads,
|
|
401
|
+
dim=mid_block_channels,
|
|
402
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
403
|
+
),
|
|
378
404
|
enable_hlfb=False,
|
|
379
405
|
),
|
|
380
406
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
@@ -382,7 +408,11 @@ class Diffusion(nn.Module):
|
|
|
382
408
|
cross_dim=config.transformer_cross_attention_dim,
|
|
383
409
|
attention_batch_size=config.transformer_batch_size,
|
|
384
410
|
normalization_config=config.transformer_norm_config,
|
|
385
|
-
attention_config=
|
|
411
|
+
attention_config=build_attention_config(
|
|
412
|
+
num_heads=config.transformer_num_attention_heads,
|
|
413
|
+
dim=mid_block_channels,
|
|
414
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
415
|
+
),
|
|
386
416
|
enable_hlfb=False,
|
|
387
417
|
),
|
|
388
418
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
@@ -437,7 +467,11 @@ class Diffusion(nn.Module):
|
|
|
437
467
|
dim=output_channel,
|
|
438
468
|
attention_batch_size=config.transformer_batch_size,
|
|
439
469
|
normalization_config=config.transformer_norm_config,
|
|
440
|
-
attention_config=
|
|
470
|
+
attention_config=build_attention_config(
|
|
471
|
+
num_heads=config.transformer_num_attention_heads,
|
|
472
|
+
dim=output_channel,
|
|
473
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
474
|
+
),
|
|
441
475
|
enable_hlfb=False,
|
|
442
476
|
),
|
|
443
477
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
@@ -445,7 +479,11 @@ class Diffusion(nn.Module):
|
|
|
445
479
|
cross_dim=config.transformer_cross_attention_dim,
|
|
446
480
|
attention_batch_size=config.transformer_batch_size,
|
|
447
481
|
normalization_config=config.transformer_norm_config,
|
|
448
|
-
attention_config=
|
|
482
|
+
attention_config=build_attention_config(
|
|
483
|
+
num_heads=config.transformer_num_attention_heads,
|
|
484
|
+
dim=output_channel,
|
|
485
|
+
num_query_groups=config.transformer_num_attention_heads,
|
|
486
|
+
),
|
|
449
487
|
enable_hlfb=False,
|
|
450
488
|
),
|
|
451
489
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
@@ -543,7 +581,6 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
|
|
|
543
581
|
|
|
544
582
|
Retruns:
|
|
545
583
|
The configuration of diffusion model of Stable Diffusion v1.5.
|
|
546
|
-
|
|
547
584
|
"""
|
|
548
585
|
in_channels = 4
|
|
549
586
|
out_channels = 4
|
|
@@ -127,7 +127,9 @@ def run_tflite_pipeline(
|
|
|
127
127
|
input_image: Optional[Image.Image] = None,
|
|
128
128
|
):
|
|
129
129
|
"""Run stable diffusion pipeline with tflite model.
|
|
130
|
+
|
|
130
131
|
model:
|
|
132
|
+
|
|
131
133
|
StableDiffsuion model.
|
|
132
134
|
prompt:
|
|
133
135
|
The prompt to guide the image generation.
|
|
@@ -136,27 +138,36 @@ def run_tflite_pipeline(
|
|
|
136
138
|
uncond_prompt:
|
|
137
139
|
The prompt not to guide the image generation.
|
|
138
140
|
cfg_scale:
|
|
139
|
-
Guidance scale of classifier-free guidance. Higher guidance scale encourages
|
|
140
|
-
|
|
141
|
+
Guidance scale of classifier-free guidance. Higher guidance scale encourages
|
|
142
|
+
to generate
|
|
143
|
+
images that are closely linked to the text `prompt`, usually at the expense
|
|
144
|
+
of lower
|
|
141
145
|
image quality.
|
|
142
146
|
height:
|
|
143
147
|
The height in pixels of the generated image.
|
|
144
148
|
width:
|
|
145
149
|
The width in pixels of the generated image.
|
|
146
150
|
sampler:
|
|
147
|
-
A sampler to be used to denoise the encoded image latents. Can be one of
|
|
151
|
+
A sampler to be used to denoise the encoded image latents. Can be one of
|
|
152
|
+
`k_lms, `k_euler`,
|
|
148
153
|
or `k_euler_ancestral`.
|
|
149
154
|
n_inference_steps:
|
|
150
|
-
The number of denoising steps. More denoising steps usually lead to a higher
|
|
155
|
+
The number of denoising steps. More denoising steps usually lead to a higher
|
|
156
|
+
quality image at the
|
|
151
157
|
expense of slower inference. This parameter will be modulated by `strength`.
|
|
152
158
|
seed:
|
|
153
159
|
A seed to make generation deterministic.
|
|
154
160
|
strength:
|
|
155
|
-
Conceptually, indicates how much to transform the reference `input_image`.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
161
|
+
Conceptually, indicates how much to transform the reference `input_image`.
|
|
162
|
+
Must be between 0 and 1.
|
|
163
|
+
`input_image` will be used as a starting point, adding more noise to it the
|
|
164
|
+
larger the `strength`.
|
|
165
|
+
The number of denoising steps depends on the amount of noise initially
|
|
166
|
+
added. When `strength` is 1,
|
|
167
|
+
added noise will be maximum and the denoising process will run for the full
|
|
168
|
+
number of iterations
|
|
169
|
+
specified in `n_inference_steps`. A value of 1, therefore, essentially
|
|
170
|
+
ignores `input_image`.
|
|
160
171
|
input_image:
|
|
161
172
|
Image which is served as the starting point for the image generation.
|
|
162
173
|
"""
|
|
@@ -28,6 +28,7 @@ class SamplerInterface(abc.ABC):
|
|
|
28
28
|
@abc.abstractmethod
|
|
29
29
|
def set_strength(self, strength: float = 1) -> None:
|
|
30
30
|
"""Set the strength of initial step.
|
|
31
|
+
|
|
31
32
|
Conceptually, indicates how much to transform the reference `input_images`.
|
|
32
33
|
"""
|
|
33
34
|
return NotImplemented
|
|
@@ -17,14 +17,13 @@
|
|
|
17
17
|
import copy
|
|
18
18
|
import os
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Optional
|
|
21
21
|
|
|
22
22
|
from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
|
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
24
24
|
import ai_edge_torch.generative.layers.builder as builder
|
|
25
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
26
26
|
import ai_edge_torch.generative.utilities.t5_loader as loading_utils
|
|
27
|
-
import numpy as np
|
|
28
27
|
import torch
|
|
29
28
|
import torch.nn as nn
|
|
30
29
|
|
|
@@ -371,6 +370,7 @@ class T5Decoder(nn.Module):
|
|
|
371
370
|
def get_model_config_t5() -> cfg.ModelConfig:
|
|
372
371
|
attn_config = cfg.AttentionConfig(
|
|
373
372
|
num_heads=12,
|
|
373
|
+
head_dim=64,
|
|
374
374
|
num_query_groups=12,
|
|
375
375
|
qkv_use_bias=False,
|
|
376
376
|
relative_attention_num_buckets=32,
|
|
@@ -37,10 +37,10 @@ class EncoderDecoderBlock(nn.Module):
|
|
|
37
37
|
"""Initialize an instance of the EncoderDecoderBlock.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
-
config (cfg.ModelConfig): the configuration object
|
|
41
|
-
|
|
42
|
-
has_relative_attention_bias (bool): whether the
|
|
43
|
-
|
|
40
|
+
config (cfg.ModelConfig): the configuration object for this transformer
|
|
41
|
+
block.
|
|
42
|
+
has_relative_attention_bias (bool): whether the self attention block has
|
|
43
|
+
relative bias.
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
46
|
super().__init__()
|
|
@@ -143,8 +143,10 @@ class T5Attention(CrossAttention):
|
|
|
143
143
|
Args:
|
|
144
144
|
dim (int): causal attention's input/output dimmension.
|
|
145
145
|
config (cfg.AttentionConfig): attention specific configurations.
|
|
146
|
-
norm_config (cfg.NormalizationConfig): normalization configure before
|
|
147
|
-
|
|
146
|
+
norm_config (cfg.NormalizationConfig): normalization configure before
|
|
147
|
+
attention.
|
|
148
|
+
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
|
149
|
+
enabled.
|
|
148
150
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
|
149
151
|
has_relative_attention_bias (bool): whether we compute relative bias.
|
|
150
152
|
"""
|
|
@@ -185,7 +187,7 @@ class T5Attention(CrossAttention):
|
|
|
185
187
|
) # batch size, sequence length, embedding dimensionality (n_embd)
|
|
186
188
|
query_states = self.q_projection(x)
|
|
187
189
|
query_states = query_states.reshape(
|
|
188
|
-
B, T, -1, self.head_dim
|
|
190
|
+
B, T, -1, self.config.head_dim
|
|
189
191
|
) # (B, T, nh_q, hs)
|
|
190
192
|
|
|
191
193
|
if key_value_states is not None:
|
|
@@ -198,13 +200,13 @@ class T5Attention(CrossAttention):
|
|
|
198
200
|
) # batch size, sequence length, embedding dimensionality (n_embd)
|
|
199
201
|
key_states = self.k_projection(key_value_states)
|
|
200
202
|
value_states = self.v_projection(key_value_states)
|
|
201
|
-
key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
|
|
202
|
-
value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
|
|
203
|
+
key_states = key_states.reshape(kvB, kvT, -1, self.config.head_dim)
|
|
204
|
+
value_states = value_states.reshape(kvB, kvT, -1, self.config.head_dim)
|
|
203
205
|
else:
|
|
204
206
|
key_states = self.k_projection(x)
|
|
205
207
|
value_states = self.v_projection(x)
|
|
206
|
-
key_states = key_states.reshape(B, T, -1, self.head_dim)
|
|
207
|
-
value_states = value_states.reshape(B, T, -1, self.head_dim)
|
|
208
|
+
key_states = key_states.reshape(B, T, -1, self.config.head_dim)
|
|
209
|
+
value_states = value_states.reshape(B, T, -1, self.config.head_dim)
|
|
208
210
|
|
|
209
211
|
if key_value_states is None and self.kv_cache is not None:
|
|
210
212
|
key_states, value_states = self.kv_cache.update_cache(
|
|
@@ -221,7 +223,7 @@ class T5Attention(CrossAttention):
|
|
|
221
223
|
0
|
|
222
224
|
) # shape (1, num_heads, query_length, key_length)
|
|
223
225
|
else:
|
|
224
|
-
# position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
|
|
226
|
+
# position_bias = torch.zeros(B, self.n_heads, T, self.config.head_dim, dtype=torch.float32)
|
|
225
227
|
position_bias = torch.zeros_like(mask, dtype=torch.float32)
|
|
226
228
|
|
|
227
229
|
mask = mask + position_bias
|
|
@@ -229,7 +231,7 @@ class T5Attention(CrossAttention):
|
|
|
229
231
|
query_states,
|
|
230
232
|
key_states,
|
|
231
233
|
value_states,
|
|
232
|
-
self.head_dim,
|
|
234
|
+
self.config.head_dim,
|
|
233
235
|
mask=mask,
|
|
234
236
|
scale=1.0,
|
|
235
237
|
)
|
|
@@ -43,7 +43,9 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
|
43
43
|
)
|
|
44
44
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
45
45
|
size=config.max_seq_len,
|
|
46
|
-
dim=int(
|
|
46
|
+
dim=int(
|
|
47
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
48
|
+
),
|
|
47
49
|
base=10_000,
|
|
48
50
|
condense_ratio=1,
|
|
49
51
|
dtype=torch.float32,
|
|
@@ -72,6 +74,7 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
|
72
74
|
def get_model_config() -> cfg.ModelConfig:
|
|
73
75
|
attn_config = cfg.AttentionConfig(
|
|
74
76
|
num_heads=32,
|
|
77
|
+
head_dim=4,
|
|
75
78
|
num_query_groups=4,
|
|
76
79
|
rotary_percentage=1.0,
|
|
77
80
|
enable_kv_cache=False,
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from typing import Tuple
|
|
18
18
|
|
|
19
19
|
import ai_edge_torch
|
|
20
|
+
from ai_edge_torch import lowertools
|
|
20
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
21
22
|
import ai_edge_torch.generative.layers.builder as builder
|
|
22
23
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
@@ -24,7 +25,6 @@ from ai_edge_torch.generative.layers.experimental.attention import TransformerBl
|
|
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
25
26
|
import torch
|
|
26
27
|
import torch.nn as nn
|
|
27
|
-
import torch_xla
|
|
28
28
|
|
|
29
29
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
30
30
|
|
|
@@ -46,7 +46,9 @@ class ToyModelWithExternalKV(torch.nn.Module):
|
|
|
46
46
|
)
|
|
47
47
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
48
48
|
size=config.max_seq_len,
|
|
49
|
-
dim=int(
|
|
49
|
+
dim=int(
|
|
50
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
51
|
+
),
|
|
50
52
|
base=10_000,
|
|
51
53
|
condense_ratio=1,
|
|
52
54
|
dtype=torch.float32,
|
|
@@ -84,13 +86,12 @@ class ToyModelWithExternalKV(torch.nn.Module):
|
|
|
84
86
|
|
|
85
87
|
def _export_stablehlo_mlir(model, args):
|
|
86
88
|
ep = torch.export.export(model, args)
|
|
87
|
-
|
|
88
|
-
return stablehlo_gm.get_stablehlo_text()
|
|
89
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
def get_model_config() -> cfg.ModelConfig:
|
|
92
93
|
attn_config = cfg.AttentionConfig(
|
|
93
|
-
num_heads=32, num_query_groups=4, rotary_percentage=1.0
|
|
94
|
+
num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
|
|
94
95
|
)
|
|
95
96
|
ff_config = cfg.FeedForwardConfig(
|
|
96
97
|
type=cfg.FeedForwardType.GATED,
|
|
@@ -13,17 +13,16 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# A toy example which has basic transformer block (w/ KV-Cache).
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Tuple
|
|
17
17
|
|
|
18
18
|
import ai_edge_torch
|
|
19
|
+
from ai_edge_torch import lowertools
|
|
19
20
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
21
22
|
import ai_edge_torch.generative.layers.builder as builder
|
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
-
import numpy as np
|
|
24
24
|
import torch
|
|
25
25
|
import torch.nn as nn
|
|
26
|
-
import torch_xla
|
|
27
26
|
|
|
28
27
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
29
28
|
|
|
@@ -45,7 +44,9 @@ class ToyModelWithKV(torch.nn.Module):
|
|
|
45
44
|
)
|
|
46
45
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
47
46
|
size=config.max_seq_len,
|
|
48
|
-
dim=int(
|
|
47
|
+
dim=int(
|
|
48
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
49
|
+
),
|
|
49
50
|
base=10_000,
|
|
50
51
|
condense_ratio=1,
|
|
51
52
|
dtype=torch.float32,
|
|
@@ -72,13 +73,12 @@ class ToyModelWithKV(torch.nn.Module):
|
|
|
72
73
|
|
|
73
74
|
def _export_stablehlo_mlir(model, args):
|
|
74
75
|
ep = torch.export.export(model, args)
|
|
75
|
-
|
|
76
|
-
return stablehlo_gm.get_stablehlo_text()
|
|
76
|
+
return lowertools.exported_program_to_mlir_text(ep)
|
|
77
77
|
|
|
78
78
|
|
|
79
79
|
def get_model_config() -> cfg.ModelConfig:
|
|
80
80
|
attn_config = cfg.AttentionConfig(
|
|
81
|
-
num_heads=32, num_query_groups=4, rotary_percentage=1.0
|
|
81
|
+
num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
|
|
82
82
|
)
|
|
83
83
|
ff_config = cfg.FeedForwardConfig(
|
|
84
84
|
type=cfg.FeedForwardType.GATED,
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|