ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,231 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Attention modules for the T5 encoder-decoder model family.
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+
23
+ from ai_edge_torch.generative.layers.attention import CrossAttention
24
+ import ai_edge_torch.generative.layers.builder as builder
25
+ from ai_edge_torch.generative.layers.kv_cache import KVCache
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
29
+
30
+ BATCH_SIZE = 1
31
+
32
+
33
+ class EncoderDecoderBlock(nn.Module):
34
+
35
+ def __init__(
36
+ self, config: cfg.ModelConfig, has_relative_attention_bias: bool = False
37
+ ) -> None:
38
+ """Initialize an instance of the EncoderDecoderBlock.
39
+
40
+ Args:
41
+ config (cfg.ModelConfig): the configuration object
42
+ for this transformer block.
43
+ has_relative_attention_bias (bool): whether the
44
+ self attention block has relative bias.
45
+ """
46
+
47
+ super().__init__()
48
+ self.atten_func = T5Attention(
49
+ BATCH_SIZE,
50
+ config.embedding_dim,
51
+ config.attn_config,
52
+ config.pre_attention_norm_config,
53
+ config.kv_cache_max,
54
+ config.enable_hlfb,
55
+ has_relative_attention_bias=has_relative_attention_bias,
56
+ )
57
+ # For a decoder, we add a cross attention.
58
+ if config.is_decoder:
59
+ self.cross_atten_func = T5Attention(
60
+ BATCH_SIZE,
61
+ config.embedding_dim,
62
+ config.attn_config,
63
+ config.pre_attention_norm_config,
64
+ config.kv_cache_max,
65
+ config.enable_hlfb,
66
+ # Cross Attention does not have relative attention bias.
67
+ has_relative_attention_bias=False,
68
+ )
69
+ else:
70
+ self.cross_atten_func = None
71
+
72
+ self.pre_ff_norm = builder.build_norm(
73
+ config.embedding_dim, config.pre_ff_norm_config
74
+ )
75
+ self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ x: torch.Tensor,
81
+ input_pos: Optional[torch.Tensor] = None,
82
+ mask: Optional[torch.Tensor] = None,
83
+ relative_position: Optional[torch.Tensor] = None,
84
+ position_bias: Optional[torch.Tensor] = None,
85
+ encoder_hidden_states: Optional[torch.Tensor] = None,
86
+ encoder_attention_mask: Optional[torch.Tensor] = None,
87
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
88
+ ) -> torch.Tensor:
89
+ """Forward function of the EncoderDecoderBlock.
90
+
91
+ Args:
92
+ x (torch.Tensor): the input tensor.
93
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
94
+ mask (torch.Tensor): the optional mask tensor.
95
+ input_pos (torch.Tensor): the optional input position tensor.
96
+
97
+ Returns:
98
+ output activation from this transformer block.
99
+ """
100
+
101
+ hidden_states, position_bias = self.atten_func(
102
+ x,
103
+ input_pos=input_pos,
104
+ mask=mask,
105
+ relative_position=relative_position,
106
+ position_bias=position_bias,
107
+ )
108
+
109
+ attn_out = hidden_states + x
110
+
111
+ if self.cross_atten_func:
112
+ hidden_states, encoder_decoder_position_bias = self.cross_atten_func(
113
+ attn_out,
114
+ input_pos=input_pos,
115
+ key_value_states=encoder_hidden_states,
116
+ mask=encoder_attention_mask,
117
+ relative_position=relative_position,
118
+ position_bias=encoder_decoder_position_bias,
119
+ )
120
+ attn_out = hidden_states + attn_out
121
+
122
+ forwarded = self.pre_ff_norm(attn_out)
123
+ forwarded = self.ff(forwarded)
124
+ hidden_states = attn_out + forwarded
125
+
126
+ # encoder_deocder_position_bias is from CrossAttention
127
+ return hidden_states, position_bias, encoder_decoder_position_bias
128
+
129
+
130
+ class T5Attention(CrossAttention):
131
+
132
+ def __init__(
133
+ self,
134
+ batch: int,
135
+ dim: int,
136
+ config: cfg.AttentionConfig,
137
+ norm_config: cfg.NormalizationConfig,
138
+ kv_cache_max: int,
139
+ enable_hlfb: bool,
140
+ has_relative_attention_bias=False,
141
+ ) -> None:
142
+ """Initialize an instance of T5Attention.
143
+
144
+ Args:
145
+ dim (int): causal attention's input/output dimmension.
146
+ config (cfg.AttentionConfig): attention specific configurations.
147
+ norm_config (cfg.NormalizationConfig): normalization configure before attention.
148
+ kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
149
+ enable_hlfb (bool): whether hlfb is enabled or not.
150
+ has_relative_attention_bias (bool): whether we compute relative bias.
151
+ """
152
+ super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb)
153
+ self.pre_atten_norm = builder.build_norm(dim, norm_config)
154
+
155
+ self.has_relative_attention_bias = has_relative_attention_bias
156
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
157
+ if self.has_relative_attention_bias:
158
+ self.relative_attention_bias = nn.Embedding(
159
+ self.relative_attention_num_buckets, self.n_heads
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ x: torch.Tensor,
165
+ input_pos: Optional[torch.Tensor] = None,
166
+ key_value_states: Optional[torch.Tensor] = None,
167
+ mask: Optional[torch.Tensor] = None,
168
+ relative_position: Optional[torch.Tensor] = None,
169
+ position_bias: Optional[torch.Tensor] = None,
170
+ ) -> torch.Tensor:
171
+ """Forward function of the T5Attention layer.
172
+
173
+ Args:
174
+ x (torch.Tensor): the input tensor.
175
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
176
+ mask (torch.Tensor): the optional mask tensor.
177
+ input_pos (torch.Tensor): the optional input position tensor.
178
+
179
+ Returns:
180
+ output activation from this self attention layer.
181
+ """
182
+
183
+ x = self.pre_atten_norm(x)
184
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
185
+ query_states = self.q_projection(x)
186
+ query_states = query_states.reshape(B, T, -1, self.head_dim) # (B, T, nh_q, hs)
187
+
188
+ if key_value_states is not None:
189
+ (
190
+ kvB,
191
+ kvT,
192
+ kvC,
193
+ ) = (
194
+ key_value_states.size()
195
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
196
+ key_states = self.k_projection(key_value_states)
197
+ value_states = self.v_projection(key_value_states)
198
+ key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
199
+ value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
200
+ else:
201
+ key_states = self.k_projection(x)
202
+ value_states = self.v_projection(x)
203
+ key_states = key_states.reshape(B, T, -1, self.head_dim)
204
+ value_states = value_states.reshape(B, T, -1, self.head_dim)
205
+
206
+ if key_value_states is None and self.kv_cache is not None:
207
+ key_states, value_states = self.kv_cache.update_cache(
208
+ input_pos, key_states, value_states
209
+ )
210
+
211
+ if position_bias is None:
212
+ # handle the encoder case first
213
+ if self.has_relative_attention_bias:
214
+ position_bias = self.relative_attention_bias(
215
+ relative_position
216
+ ) # shape (query_length, key_length, num_heads)
217
+ position_bias = position_bias.permute([0, 1, 4, 2, 3]).squeeze(
218
+ 0
219
+ ) # shape (1, num_heads, query_length, key_length)
220
+ else:
221
+ # position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
222
+ position_bias = torch.zeros_like(mask, dtype=torch.float32)
223
+
224
+ mask = mask + position_bias
225
+ y = self.sdpa_func(
226
+ query_states, key_states, value_states, self.head_dim, mask=mask, scale=1.0
227
+ )
228
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
229
+ # output projection
230
+ y = self.output_projection(y)
231
+ return y, position_bias
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,122 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from typing import Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ import ai_edge_torch
23
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.builder as builder
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+
28
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
29
+ KV_CACHE_MAX_LEN = 100
30
+
31
+
32
+ class ToySingleLayerModel(torch.nn.Module):
33
+
34
+ def __init__(self, config: cfg.ModelConfig) -> None:
35
+ super().__init__()
36
+ self.lm_head = nn.Linear(
37
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
38
+ )
39
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
40
+ self.transformer_block = TransformerBlock(config)
41
+ self.final_norm = builder.build_norm(
42
+ config.embedding_dim,
43
+ config.final_norm_config,
44
+ )
45
+ self.rope_cache = attn_utils.build_rope_cache(
46
+ size=config.max_seq_len,
47
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
48
+ base=10_000,
49
+ condense_ratio=1,
50
+ dtype=torch.float32,
51
+ device=torch.device('cpu'),
52
+ )
53
+ self.mask_cache = attn_utils.build_causal_mask_cache(
54
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
55
+ )
56
+ self.config = config
57
+
58
+ @torch.inference_mode
59
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
60
+ x = self.tok_embedding(idx)
61
+ cos, sin = self.rope_cache
62
+
63
+ cos = cos.index_select(0, input_pos)
64
+ sin = sin.index_select(0, input_pos)
65
+ mask = self.mask_cache.index_select(2, input_pos)
66
+ mask = mask[:, :, :, : self.config.max_seq_len]
67
+
68
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
69
+ x = self.final_norm(x)
70
+ return self.lm_head(x)
71
+
72
+
73
+ def get_model_config() -> cfg.ModelConfig:
74
+ attn_config = cfg.AttentionConfig(
75
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False
76
+ )
77
+ ff_config = cfg.FeedForwardConfig(
78
+ type=cfg.FeedForwardType.GATED,
79
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
80
+ intermediate_size=256,
81
+ )
82
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
83
+ config = cfg.ModelConfig(
84
+ vocab_size=400,
85
+ num_layers=1,
86
+ max_seq_len=KV_CACHE_MAX_LEN,
87
+ embedding_dim=128,
88
+ attn_config=attn_config,
89
+ ff_config=ff_config,
90
+ pre_attention_norm_config=norm_config,
91
+ pre_ff_norm_config=norm_config,
92
+ final_norm_config=norm_config,
93
+ )
94
+ return config
95
+
96
+
97
+ def define_and_run() -> None:
98
+ model = ToySingleLayerModel(get_model_config())
99
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
100
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
101
+ print('running an inference')
102
+ print(
103
+ model.forward(
104
+ idx,
105
+ input_pos,
106
+ )
107
+ )
108
+
109
+ # Convert model to tflite.
110
+ print('converting model to tflite')
111
+ edge_model = ai_edge_torch.convert(
112
+ model,
113
+ (
114
+ idx,
115
+ input_pos,
116
+ ),
117
+ )
118
+ edge_model.export('/tmp/toy_model.tflite')
119
+
120
+
121
+ if __name__ == '__main__':
122
+ define_and_run()
@@ -0,0 +1,161 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
+
17
+ from typing import List, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch_xla
23
+
24
+ import ai_edge_torch
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+
30
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
+
32
+
33
+ class ToyModelWithExternalKV(torch.nn.Module):
34
+
35
+ def __init__(self, config: cfg.ModelConfig) -> None:
36
+ super().__init__()
37
+ self.lm_head = nn.Linear(
38
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
39
+ )
40
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ self.transformer_blocks = nn.ModuleList(
42
+ TransformerBlock(config) for _ in range(config.num_layers)
43
+ )
44
+ self.final_norm = builder.build_norm(
45
+ config.embedding_dim,
46
+ config.final_norm_config,
47
+ )
48
+ self.rope_cache = attn_utils.build_rope_cache(
49
+ size=config.max_seq_len,
50
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
51
+ base=10_000,
52
+ condense_ratio=1,
53
+ dtype=torch.float32,
54
+ device=torch.device('cpu'),
55
+ )
56
+ self.mask_cache = attn_utils.build_causal_mask_cache(
57
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
58
+ )
59
+ self.config = config
60
+
61
+ def forward(
62
+ self,
63
+ idx: torch.Tensor,
64
+ input_pos: torch.Tensor,
65
+ k_caches: torch.Tensor,
66
+ v_caches: torch.Tensor,
67
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
68
+ x = self.tok_embedding(idx)
69
+ cos, sin = self.rope_cache
70
+ cos = cos.index_select(0, input_pos)
71
+ sin = sin.index_select(0, input_pos)
72
+ mask = self.mask_cache.index_select(2, input_pos)
73
+ mask = mask[:, :, :, : self.config.max_seq_len]
74
+
75
+ for i, block in enumerate(self.transformer_blocks):
76
+ input_k, input_v = k_caches[i], v_caches[i]
77
+ x, (updated_k, updated_v) = block(
78
+ x, (cos, sin), mask, input_pos, (input_k, input_v)
79
+ )
80
+ k_caches[i], v_caches[i] = updated_k, updated_v
81
+
82
+ x = self.final_norm(x)
83
+ return self.lm_head(x), k_caches, v_caches
84
+
85
+
86
+ def _export_stablehlo_mlir(model, args):
87
+ ep = torch.export.export(model, args)
88
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
89
+ return stablehlo_gm.get_stablehlo_text()
90
+
91
+
92
+ def get_model_config() -> cfg.ModelConfig:
93
+ attn_config = cfg.AttentionConfig(
94
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0
95
+ )
96
+ ff_config = cfg.FeedForwardConfig(
97
+ type=cfg.FeedForwardType.GATED,
98
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
99
+ intermediate_size=256,
100
+ )
101
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
102
+ config = cfg.ModelConfig(
103
+ vocab_size=150,
104
+ num_layers=2,
105
+ max_seq_len=100,
106
+ embedding_dim=128,
107
+ attn_config=attn_config,
108
+ ff_config=ff_config,
109
+ pre_attention_norm_config=norm_config,
110
+ pre_ff_norm_config=norm_config,
111
+ final_norm_config=norm_config,
112
+ enable_hlfb=True,
113
+ )
114
+ return config
115
+
116
+
117
+ def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ input_pos = torch.arange(0, 100)
120
+ return idx, input_pos
121
+
122
+
123
+ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
+ idx = torch.tensor([[1]], dtype=torch.long)
125
+ input_pos = torch.tensor([10])
126
+ return idx, input_pos
127
+
128
+
129
+ def define_and_run() -> None:
130
+ dump_mlir = False
131
+
132
+ config = get_model_config()
133
+ model = ToyModelWithExternalKV(config)
134
+ print('running an inference')
135
+ k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
+ v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
137
+
138
+ idx, input_pos = get_sample_prefill_inputs()
139
+ decode_idx, decode_input_pos = get_sample_decode_inputs()
140
+ print(model.forward(idx, input_pos, k_caches, v_caches))
141
+
142
+ if dump_mlir:
143
+ mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
144
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
+ f.write(mlir_text)
146
+
147
+ # Convert model to tflite with 2 signatures (prefill + decode).
148
+ # TODO(b/344014416): currently conversion will fail, because we generate int64 index
149
+ # in dynamic update slice op.
150
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
+ edge_model = (
152
+ ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
153
+ .signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
154
+ .convert()
155
+ )
156
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
157
+
158
+
159
+ if __name__ == '__main__':
160
+ with torch.inference_mode():
161
+ define_and_run()
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has basic transformer block (w/ KV-Cache).
16
+ from typing import List, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch_xla
22
+
23
+ import ai_edge_torch
24
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ import ai_edge_torch.generative.layers.model_config as cfg
28
+
29
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
30
+
31
+
32
+ class ToyModelWithKV(torch.nn.Module):
33
+
34
+ def __init__(self, config: cfg.ModelConfig) -> None:
35
+ super().__init__()
36
+ self.lm_head = nn.Linear(
37
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
38
+ )
39
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
40
+ self.transformer_blocks = nn.ModuleList(
41
+ TransformerBlock(config) for _ in range(config.num_layers)
42
+ )
43
+ self.final_norm = builder.build_norm(
44
+ config.embedding_dim,
45
+ config.final_norm_config,
46
+ )
47
+ self.rope_cache = attn_utils.build_rope_cache(
48
+ size=config.max_seq_len,
49
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
50
+ base=10_000,
51
+ condense_ratio=1,
52
+ dtype=torch.float32,
53
+ device=torch.device('cpu'),
54
+ )
55
+ self.mask_cache = attn_utils.build_causal_mask_cache(
56
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ )
58
+ self.config = config
59
+
60
+ @torch.inference_mode
61
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
+ x = self.tok_embedding(idx)
63
+ cos, sin = self.rope_cache
64
+ cos = cos.index_select(0, input_pos)
65
+ sin = sin.index_select(0, input_pos)
66
+ mask = self.mask_cache.index_select(2, input_pos)
67
+ mask = mask[:, :, :, : self.config.max_seq_len]
68
+ for i, block in enumerate(self.transformer_blocks):
69
+ x = block(x, (cos, sin), mask, input_pos)
70
+ x = self.final_norm(x)
71
+ return self.lm_head(x)
72
+
73
+
74
+ def _export_stablehlo_mlir(model, args):
75
+ ep = torch.export.export(model, args)
76
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
77
+ return stablehlo_gm.get_stablehlo_text()
78
+
79
+
80
+ def get_model_config() -> cfg.ModelConfig:
81
+ attn_config = cfg.AttentionConfig(
82
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0
83
+ )
84
+ ff_config = cfg.FeedForwardConfig(
85
+ type=cfg.FeedForwardType.GATED,
86
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
87
+ intermediate_size=256,
88
+ )
89
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
90
+ config = cfg.ModelConfig(
91
+ vocab_size=150,
92
+ num_layers=2,
93
+ max_seq_len=500,
94
+ embedding_dim=128,
95
+ attn_config=attn_config,
96
+ ff_config=ff_config,
97
+ pre_attention_norm_config=norm_config,
98
+ pre_ff_norm_config=norm_config,
99
+ final_norm_config=norm_config,
100
+ enable_hlfb=True,
101
+ )
102
+ return config
103
+
104
+
105
+ def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
106
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
107
+ input_pos = torch.arange(0, 100)
108
+ return idx, input_pos
109
+
110
+
111
+ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
112
+ idx = torch.tensor([[1]], dtype=torch.long)
113
+ input_pos = torch.tensor([10], dtype=torch.int64)
114
+ return idx, input_pos
115
+
116
+
117
+ def define_and_run() -> None:
118
+ dump_mlir = False
119
+
120
+ config = get_model_config()
121
+ model = ToyModelWithKV(config)
122
+ print('running an inference')
123
+ idx, input_pos = get_sample_prefill_inputs()
124
+ decode_idx, decode_input_pos = get_sample_decode_inputs()
125
+ print(model.forward(idx, input_pos))
126
+
127
+ if dump_mlir:
128
+ mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
129
+ with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
130
+ f.write(mlir_text)
131
+
132
+ # Convert model to tflite with 2 signatures (prefill + decode).
133
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
134
+ edge_model = (
135
+ ai_edge_torch.signature('prefill', model, (idx, input_pos))
136
+ .signature('decode', model, (decode_idx, decode_input_pos))
137
+ .convert()
138
+ )
139
+ edge_model.export('/tmp/toy_kv_cache.tflite')
140
+
141
+
142
+ if __name__ == '__main__':
143
+ define_and_run()