ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,15 @@
17
17
  import copy
18
18
  import os
19
19
  from pathlib import Path
20
- from typing import Optional, Tuple
21
-
22
- import numpy as np
23
- import torch
24
- import torch.nn as nn
20
+ from typing import Optional
25
21
 
26
22
  from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
27
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
28
24
  import ai_edge_torch.generative.layers.builder as builder
29
25
  import ai_edge_torch.generative.layers.model_config as cfg
30
26
  import ai_edge_torch.generative.utilities.t5_loader as loading_utils
27
+ import torch
28
+ import torch.nn as nn
31
29
 
32
30
  ENCDEC_TENSOR_NAMES = {
33
31
  "ff_up_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wi",
@@ -36,9 +34,11 @@ ENCDEC_TENSOR_NAMES = {
36
34
  "attn_key_proj": "{prefix}.block.{}.layer.0.SelfAttention.k",
37
35
  "attn_value_proj": "{prefix}.block.{}.layer.0.SelfAttention.v",
38
36
  "attn_output_proj": "{prefix}.block.{}.layer.0.SelfAttention.o",
39
- "relative_attn_bias": "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias",
37
+ "relative_attn_bias": (
38
+ "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
39
+ ),
40
40
  "pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
41
- "pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
41
+ "post_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
42
42
  "final_norm": "{prefix}.final_layer_norm",
43
43
  }
44
44
 
@@ -52,13 +52,19 @@ class T5Stack(nn.Module):
52
52
  self.config = config
53
53
  self.embed_tokens = embed_tokens
54
54
  self.is_decoder = config.is_decoder
55
- self.transformer_blocks = nn.ModuleList(
56
- [
57
- EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
58
- for i in range(config.num_layers)
59
- ]
55
+ # T5 has only one block config.
56
+ block_config = config.block_config(0)
57
+ self.transformer_blocks = nn.ModuleList([
58
+ EncoderDecoderBlock(
59
+ block_config,
60
+ config,
61
+ has_relative_attention_bias=bool(idx == 0),
62
+ )
63
+ for idx in range(config.num_layers)
64
+ ])
65
+ self.final_norm = builder.build_norm(
66
+ config.embedding_dim, config.final_norm_config
60
67
  )
61
- self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
62
68
 
63
69
  def forward(
64
70
  self,
@@ -73,23 +79,23 @@ class T5Stack(nn.Module):
73
79
  torch.Tensor
74
80
  ] = None, # should be for decoder case
75
81
  ):
76
- input_shape = input_ids.size()
77
82
  inputs_embeds = self.embed_tokens(input_ids)
78
- batch_size, seq_length = input_shape
79
83
  hidden_states = inputs_embeds
80
84
  position_bias = None
81
85
  encoder_decoder_position_bias = None
82
- for i, layer_module in enumerate(self.transformer_blocks):
86
+ for _, layer_module in enumerate(self.transformer_blocks):
83
87
  # EncoderDecoderBlock.forward
84
- hidden_states, position_bias, encoder_decoder_position_bias = layer_module(
85
- hidden_states,
86
- input_pos,
87
- mask=attention_mask,
88
- relative_position=relative_position,
89
- position_bias=position_bias,
90
- encoder_hidden_states=encoder_hidden_states,
91
- encoder_attention_mask=encoder_attention_mask,
92
- encoder_decoder_position_bias=encoder_decoder_position_bias,
88
+ hidden_states, position_bias, encoder_decoder_position_bias = (
89
+ layer_module(
90
+ hidden_states,
91
+ input_pos,
92
+ mask=attention_mask,
93
+ relative_position=relative_position,
94
+ position_bias=position_bias,
95
+ encoder_hidden_states=encoder_hidden_states,
96
+ encoder_attention_mask=encoder_attention_mask,
97
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
98
+ )
93
99
  )
94
100
 
95
101
  hidden_states = self.final_norm(hidden_states)
@@ -109,7 +115,8 @@ class T5(nn.Module):
109
115
 
110
116
  encoder_config = copy.deepcopy(config)
111
117
  encoder_config.is_decoder = False
112
- encoder_config.attn_config.enable_kv_cache = False
118
+ # T5 has only one block config.
119
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
113
120
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
114
121
 
115
122
  decoder_config = copy.deepcopy(config)
@@ -130,23 +137,27 @@ class T5(nn.Module):
130
137
  )
131
138
 
132
139
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
133
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
140
+ size=config.kv_cache_max,
141
+ dtype=torch.float32,
142
+ device=torch.device("cpu"),
134
143
  )
135
144
 
145
+ # T5 has only one block config.
146
+ attn_config = config.block_config(0).attn_config
136
147
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
137
148
  bidirectional=True,
138
149
  query_length=config.kv_cache_max,
139
150
  key_length=config.kv_cache_max,
140
- num_buckets=config.attn_config.relative_attention_num_buckets,
141
- max_distance=config.attn_config.relative_attention_max_distance,
151
+ num_buckets=attn_config.relative_attention_num_buckets,
152
+ max_distance=attn_config.relative_attention_max_distance,
142
153
  )
143
154
 
144
155
  self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
145
156
  bidirectional=False,
146
157
  query_length=config.kv_cache_max,
147
158
  key_length=config.kv_cache_max,
148
- num_buckets=config.attn_config.relative_attention_num_buckets,
149
- max_distance=config.attn_config.relative_attention_max_distance,
159
+ num_buckets=attn_config.relative_attention_num_buckets,
160
+ max_distance=attn_config.relative_attention_max_distance,
150
161
  )
151
162
 
152
163
  @torch.inference_mode
@@ -159,9 +170,10 @@ class T5(nn.Module):
159
170
  pad_mask: torch.Tensor,
160
171
  ) -> torch.Tensor:
161
172
  B, T = input_ids.size()
162
- assert (
163
- self.config.max_seq_len >= T
164
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
173
+ assert self.config.max_seq_len >= T, (
174
+ f"Cannot forward sequence of length {T}, max seq length is only"
175
+ f" {self.config.max_seq_len}"
176
+ )
165
177
 
166
178
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
167
179
  enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
@@ -170,10 +182,18 @@ class T5(nn.Module):
170
182
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
171
183
  dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
172
184
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
173
- enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
174
- dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
175
- dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
176
- enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
185
+ enc_relative_position = enc_relative_position[
186
+ :, :, :, : self.config.kv_cache_max
187
+ ]
188
+ dec_relative_position = self.enc_rel_pos_mask.index_select(
189
+ 2, decoder_input_pos
190
+ )
191
+ dec_relative_position = dec_relative_position[
192
+ :, :, :, : self.config.kv_cache_max
193
+ ]
194
+ enc_attention_mask = self.enc_attn_mask_cache.index_select(
195
+ 2, decoder_input_pos
196
+ )
177
197
  # Mask off any "pad" tokens that shouldn't contribute to cross attention
178
198
  enc_attention_mask[:, :, :, :] += pad_mask
179
199
 
@@ -210,12 +230,15 @@ class T5Encoder(nn.Module):
210
230
 
211
231
  self.config = config
212
232
  # Construct model layers.
213
- assert embedding_layer != None, "Passed in embedding layer should not be None!"
233
+ assert (
234
+ embedding_layer != None
235
+ ), "Passed in embedding layer should not be None!"
214
236
  self.tok_embedding = embedding_layer
215
237
 
216
238
  encoder_config = copy.deepcopy(config)
217
239
  encoder_config.is_decoder = False
218
- encoder_config.attn_config.enable_kv_cache = False
240
+ # T5 has only one block config.
241
+ encoder_config.block_config(0).attn_config.enable_kv_cache = False
219
242
  self.encoder = T5Stack(encoder_config, self.tok_embedding)
220
243
 
221
244
  self.enc_attn_mask_cache = (
@@ -228,12 +251,14 @@ class T5Encoder(nn.Module):
228
251
  .unsqueeze(0)
229
252
  )
230
253
 
254
+ # T5 has only one block config.
255
+ attn_config = config.block_config(0).attn_config
231
256
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
232
257
  bidirectional=True,
233
258
  query_length=config.kv_cache_max,
234
259
  key_length=config.kv_cache_max,
235
- num_buckets=config.attn_config.relative_attention_num_buckets,
236
- max_distance=config.attn_config.relative_attention_max_distance,
260
+ num_buckets=attn_config.relative_attention_num_buckets,
261
+ max_distance=attn_config.relative_attention_max_distance,
237
262
  )
238
263
 
239
264
  @torch.inference_mode
@@ -244,16 +269,19 @@ class T5Encoder(nn.Module):
244
269
  pad_mask: torch.Tensor,
245
270
  ) -> torch.Tensor:
246
271
  B, T = input_ids.size()
247
- assert (
248
- self.config.max_seq_len >= T
249
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
272
+ assert self.config.max_seq_len >= T, (
273
+ f"Cannot forward sequence of length {T}, max seq length is only"
274
+ f" {self.config.max_seq_len}"
275
+ )
250
276
 
251
277
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
252
278
  enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
253
279
  # Mask off any "pad" tokens that shouldn't contribute to self-attention
254
280
  enc_mask[:, :, :, :] += pad_mask
255
281
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
256
- enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
282
+ enc_relative_position = enc_relative_position[
283
+ :, :, :, : self.config.kv_cache_max
284
+ ]
257
285
 
258
286
  # Convert encoder inputs in embeddings if needed
259
287
  encoder_hidden_states = self.encoder(
@@ -273,7 +301,9 @@ class T5Decoder(nn.Module):
273
301
 
274
302
  self.config = config
275
303
  # Construct model layers.
276
- assert embedding_layer != None, "Passed in embedding layer should not be None!"
304
+ assert (
305
+ embedding_layer != None
306
+ ), "Passed in embedding layer should not be None!"
277
307
  self.tok_embedding = embedding_layer
278
308
 
279
309
  decoder_config = copy.deepcopy(config)
@@ -293,16 +323,20 @@ class T5Decoder(nn.Module):
293
323
  .unsqueeze(0)
294
324
  )
295
325
 
326
+ # T5 has only one block config.
327
+ attn_config = config.block_config(0).attn_config
296
328
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
297
329
  bidirectional=True,
298
330
  query_length=config.kv_cache_max,
299
331
  key_length=config.kv_cache_max,
300
- num_buckets=config.attn_config.relative_attention_num_buckets,
301
- max_distance=config.attn_config.relative_attention_max_distance,
332
+ num_buckets=attn_config.relative_attention_num_buckets,
333
+ max_distance=attn_config.relative_attention_max_distance,
302
334
  )
303
335
 
304
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
305
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
337
+ size=config.kv_cache_max,
338
+ dtype=torch.float32,
339
+ device=torch.device("cpu"),
306
340
  )
307
341
 
308
342
  @torch.inference_mode
@@ -315,9 +349,15 @@ class T5Decoder(nn.Module):
315
349
  ) -> torch.Tensor:
316
350
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
317
351
  dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
318
- dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
319
- dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
320
- enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
352
+ dec_relative_position = self.enc_rel_pos_mask.index_select(
353
+ 2, decoder_input_pos
354
+ )
355
+ dec_relative_position = dec_relative_position[
356
+ :, :, :, : self.config.kv_cache_max
357
+ ]
358
+ enc_attention_mask = self.enc_attn_mask_cache.index_select(
359
+ 2, decoder_input_pos
360
+ )
321
361
  # Mask off any "pad" tokens that shouldn't contribute to cross attention
322
362
  enc_attention_mask[:, :, :, :] += pad_mask
323
363
 
@@ -342,6 +382,7 @@ class T5Decoder(nn.Module):
342
382
  def get_model_config_t5() -> cfg.ModelConfig:
343
383
  attn_config = cfg.AttentionConfig(
344
384
  num_heads=12,
385
+ head_dim=64,
345
386
  num_query_groups=12,
346
387
  qkv_use_bias=False,
347
388
  relative_attention_num_buckets=32,
@@ -357,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
357
398
  type=cfg.NormalizationType.RMS_NORM,
358
399
  epsilon=1e-6,
359
400
  )
360
-
401
+ block_config = cfg.TransformerBlockConfig(
402
+ attn_config=attn_config,
403
+ relative_attention=True,
404
+ ff_config=ff_config,
405
+ pre_attention_norm_config=norm_config,
406
+ post_attention_norm_config=norm_config,
407
+ )
361
408
  config = cfg.ModelConfig(
362
409
  vocab_size=32128,
363
410
  num_layers=12,
364
411
  max_seq_len=512,
365
412
  embedding_dim=768,
366
- attn_config=attn_config,
367
- relative_attention=True,
368
- ff_config=ff_config,
369
- pre_attention_norm_config=norm_config,
370
- pre_ff_norm_config=norm_config,
413
+ block_configs=block_config,
371
414
  final_norm_config=norm_config,
372
- parallel_residual=False,
373
415
  lm_head_use_bias=False,
374
416
  enable_hlfb=True,
375
417
  )
@@ -390,7 +432,7 @@ def build_t5_model(checkpoint_path: str) -> nn.Module:
390
432
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
391
433
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
392
434
  # In the decoder, the FF is layer 2 in the Transformer block
393
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
435
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
394
436
  # In the decoder, the cross attention is layer 1 in the Transformer block
395
437
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
396
438
  }
@@ -446,7 +488,7 @@ def build_t5_decoder_model(
446
488
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
447
489
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
448
490
  # In the decoder, the FF is layer 2 in the Transformer block
449
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
491
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
450
492
  # In the decoder, the cross attention is layer 1 in the Transformer block
451
493
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
452
494
  }
@@ -470,104 +512,101 @@ def build_t5_decoder_model(
470
512
 
471
513
 
472
514
  def get_sample_encoder_input_ids() -> torch.Tensor:
473
- idx = torch.tensor(
474
- [
475
- [
476
- 3856,
477
- 27111,
478
- 10,
479
- 4425,
480
- 51,
481
- 4008,
482
- 31,
483
- 7,
484
- 2306,
485
- 16576,
486
- 47,
487
- 4381,
488
- 16,
489
- 8,
490
- 3414,
491
- 13,
492
- 1410,
493
- 16,
494
- 932,
495
- 11,
496
- 1515,
497
- 2766,
498
- 6,
499
- 11,
500
- 4838,
501
- 16,
502
- 23964,
503
- 16,
504
- 1797,
505
- 13,
506
- 24,
507
- 215,
508
- 5,
509
- 94,
510
- 47,
511
- 2017,
512
- 168,
513
- 1204,
514
- 57,
515
- 6800,
516
- 7,
517
- 11,
518
- 9443,
519
- 38,
520
- 3673,
521
- 8,
522
- 4016,
523
- 13,
524
- 66,
525
- 70,
526
- 14234,
527
- 5,
528
- 2449,
529
- 1215,
530
- 83,
531
- 17,
532
- 16,
533
- 8782,
534
- 70,
535
- 723,
536
- 30,
537
- 8,
538
- 6162,
539
- 13,
540
- 1410,
541
- 12,
542
- 48,
543
- 833,
544
- 250,
545
- 13,
546
- 149,
547
- 231,
548
- 79,
549
- 1858,
550
- 16576,
551
- 5,
552
- 1,
553
- ]
554
- ]
555
- )
515
+ idx = torch.tensor([[
516
+ 3856,
517
+ 27111,
518
+ 10,
519
+ 4425,
520
+ 51,
521
+ 4008,
522
+ 31,
523
+ 7,
524
+ 2306,
525
+ 16576,
526
+ 47,
527
+ 4381,
528
+ 16,
529
+ 8,
530
+ 3414,
531
+ 13,
532
+ 1410,
533
+ 16,
534
+ 932,
535
+ 11,
536
+ 1515,
537
+ 2766,
538
+ 6,
539
+ 11,
540
+ 4838,
541
+ 16,
542
+ 23964,
543
+ 16,
544
+ 1797,
545
+ 13,
546
+ 24,
547
+ 215,
548
+ 5,
549
+ 94,
550
+ 47,
551
+ 2017,
552
+ 168,
553
+ 1204,
554
+ 57,
555
+ 6800,
556
+ 7,
557
+ 11,
558
+ 9443,
559
+ 38,
560
+ 3673,
561
+ 8,
562
+ 4016,
563
+ 13,
564
+ 66,
565
+ 70,
566
+ 14234,
567
+ 5,
568
+ 2449,
569
+ 1215,
570
+ 83,
571
+ 17,
572
+ 16,
573
+ 8782,
574
+ 70,
575
+ 723,
576
+ 30,
577
+ 8,
578
+ 6162,
579
+ 13,
580
+ 1410,
581
+ 12,
582
+ 48,
583
+ 833,
584
+ 250,
585
+ 13,
586
+ 149,
587
+ 231,
588
+ 79,
589
+ 1858,
590
+ 16576,
591
+ 5,
592
+ 1,
593
+ ]])
556
594
  return idx
557
595
 
558
596
 
559
597
  def define_and_run_t5(checkpoint_path: str) -> None:
560
- t5_goldens = torch.load("t5_lm_logits.pt")
598
+ current_dir = Path(__file__).parent.resolve()
599
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
561
600
 
562
601
  model = build_t5_model(checkpoint_path)
563
602
 
564
603
  idx = get_sample_encoder_input_ids()
565
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
604
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
566
605
  tokens[0, :77] = idx
567
- input_pos = torch.arange(0, 512)
606
+ input_pos = torch.arange(0, 512, dtype=torch.int)
568
607
 
569
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
570
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
608
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
609
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
571
610
  pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
572
611
  pad_mask[77:] = float("-inf")
573
612
  lm_logits = model.forward(
@@ -579,20 +618,30 @@ def define_and_run_t5(checkpoint_path: str) -> None:
579
618
 
580
619
  # TODO(haoliang): Move those tests.
581
620
  def define_and_run_t5_split(checkpoint_path: str) -> None:
582
- t5_goldens = torch.load("t5_lm_logits.pt")
621
+ current_dir = Path(__file__).parent.resolve()
622
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
623
+
583
624
  config = get_model_config_t5()
584
- embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
585
- t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
586
- t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
625
+ embedding_layer = nn.Embedding(
626
+ config.vocab_size, config.embedding_dim, padding_idx=0
627
+ )
628
+ t5_encoder_model = build_t5_encoder_model(
629
+ config, embedding_layer, checkpoint_path
630
+ )
631
+ t5_decoder_model = build_t5_decoder_model(
632
+ config, embedding_layer, checkpoint_path
633
+ )
587
634
  idx = get_sample_encoder_input_ids()
588
635
 
589
- tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
636
+ tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
590
637
  tokens[0, :77] = idx
591
- input_pos = torch.arange(0, 512)
638
+ input_pos = torch.arange(0, 512, dtype=torch.int)
592
639
 
593
- decode_d_token = torch.tensor([[0]], dtype=torch.int64)
594
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
595
- pad_mask = torch.zeros([t5_encoder_model.config.kv_cache_max], dtype=torch.float32)
640
+ decode_d_token = torch.tensor([[0]], dtype=torch.int)
641
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
642
+ pad_mask = torch.zeros(
643
+ [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
644
+ )
596
645
  pad_mask[77:] = float("-inf")
597
646
  hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
598
647
  lm_logits = t5_decoder_model.forward(