ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.3.0.dev20240809__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.3.0.dev20240809.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/top_level.txt +0 -0
@@ -28,17 +28,17 @@ def convert_tiny_llama_to_tflite(
28
28
  kv_cache_max_len: int = 1024,
29
29
  quantize: bool = True,
30
30
  ):
31
- """An example method for converting TinyLlama model to multi-signature
32
- tflite model.
31
+ """Converts TinyLlama model to multi-signature tflite model.
33
32
 
34
33
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
34
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
35
+ holding the checkpoint.
36
36
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
37
  Defaults to 512.
38
38
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
39
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
40
+ quantize (bool, optional): Whether the model should be quanized. Defaults
41
+ to True.
42
42
  """
43
43
  pytorch_model = tiny_llama.build_model(
44
44
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -64,7 +64,9 @@ class TinyLLamma(nn.Module):
64
64
  )
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
67
+ dim=int(
68
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
69
+ ),
68
70
  base=10_000,
69
71
  condense_ratio=1,
70
72
  dtype=torch.float32,
@@ -109,6 +111,7 @@ class TinyLLamma(nn.Module):
109
111
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110
112
  attn_config = cfg.AttentionConfig(
111
113
  num_heads=32,
114
+ head_dim=64,
112
115
  num_query_groups=4,
113
116
  rotary_percentage=1.0,
114
117
  )
@@ -12,8 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.convert.fx_passes import CanonicalizePass
16
- from ai_edge_torch.convert.fx_passes import run_passes
15
+ from ai_edge_torch._convert.fx_passes import CanonicalizePass
16
+ from ai_edge_torch._convert.fx_passes import run_passes
17
17
  from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
18
18
  import torch
19
19
 
@@ -12,8 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
16
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
15
+ from ai_edge_torch import lowertools
16
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
17
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
17
18
  import torch
18
19
 
19
20
 
@@ -27,7 +28,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
27
28
  for node in graph.nodes:
28
29
  if not (
29
30
  node.op == "call_function"
30
- and node.target == torch.ops.xla.mark_tensor.default
31
+ and node.target == lowertools.mark_tensor_op
31
32
  ):
32
33
  continue
33
34
 
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
24
24
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
- import torch.nn.functional as F
28
27
 
29
28
 
30
29
  def _embed_rope(
@@ -60,8 +59,8 @@ class TransformerBlock(nn.Module):
60
59
  """Initialize an instance of the TransformerBlock.
61
60
 
62
61
  Args:
63
- config (cfg.ModelConfig): the configuration object
64
- for this transformer block.
62
+ config (cfg.ModelConfig): the configuration object for this transformer
63
+ block.
65
64
  """
66
65
 
67
66
  super().__init__()
@@ -131,20 +130,23 @@ class CausalSelfAttention(nn.Module):
131
130
  batch_size (int): batch size of the input tensor.
132
131
  dim (int): causal attention's input/output dimmension.
133
132
  config (cfg.AttentionConfig): attention specific configurations.
134
- kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
133
+ kv_cache_max (int): determines the size of the KV Cache buffer, if
134
+ enabled.
135
135
  enable_hlfb (bool): whether hlfb is enabled or not.
136
136
  """
137
137
  super().__init__()
138
- self.head_dim = dim // config.num_heads
139
- shape = (config.num_heads + 2 * config.num_query_groups) * self.head_dim
140
- # Key, query, value projections for all heads.
141
- self.qkv_projection = nn.Linear(dim, shape, bias=config.qkv_use_bias)
142
- self.output_projection = nn.Linear(
143
- dim, dim, bias=config.output_proj_use_bias
144
- )
145
138
  self.config = config
146
139
  self.kv_cache = None
147
140
  self.batch_size = batch_size
141
+ qkv_shape = (
142
+ config.num_heads + 2 * config.num_query_groups
143
+ ) * config.head_dim
144
+ output_shape = config.num_heads * config.head_dim
145
+ # Key, query, value projections for all heads.
146
+ self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
147
+ self.output_projection = nn.Linear(
148
+ output_shape, dim, bias=config.output_proj_use_bias
149
+ )
148
150
 
149
151
  # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
150
152
  if config.enable_kv_cache:
@@ -152,7 +154,7 @@ class CausalSelfAttention(nn.Module):
152
154
  batch_size,
153
155
  kv_cache_max,
154
156
  config.num_query_groups,
155
- self.head_dim,
157
+ config.head_dim,
156
158
  enable_hlfb,
157
159
  )
158
160
 
@@ -169,6 +171,7 @@ class CausalSelfAttention(nn.Module):
169
171
  input_pos: Optional[torch.Tensor] = None,
170
172
  ) -> torch.Tensor:
171
173
  """Forward function of the CausalSelfAttention layer, which can support
174
+
172
175
  MQA, GQA and MHA.
173
176
 
174
177
  Args:
@@ -193,7 +196,7 @@ class CausalSelfAttention(nn.Module):
193
196
  q_per_kv = self.config.num_heads // self.config.num_query_groups
194
197
  # Each group has >=1 queries, 1 key, and 1 value.
195
198
  if self.config.qkv_transpose_before_split:
196
- qkv = qkv.view(B, T, -1, self.head_dim)
199
+ qkv = qkv.view(B, T, -1, self.config.head_dim)
197
200
  q, k, v = qkv.split(
198
201
  (
199
202
  q_per_kv * self.config.num_query_groups,
@@ -205,22 +208,27 @@ class CausalSelfAttention(nn.Module):
205
208
  else:
206
209
  qkv = qkv.view(B, T, self.config.num_query_groups, -1)
207
210
  q, k, v = qkv.split(
208
- (q_per_kv * self.head_dim, self.head_dim, self.head_dim), dim=-1
211
+ (
212
+ q_per_kv * self.config.head_dim,
213
+ self.config.head_dim,
214
+ self.config.head_dim,
215
+ ),
216
+ dim=-1,
209
217
  )
210
218
 
211
- q = q.reshape(B, T, -1, self.head_dim)
212
- k = k.reshape(B, T, -1, self.head_dim)
213
- v = v.reshape(B, T, -1, self.head_dim)
219
+ q = q.reshape(B, T, -1, self.config.head_dim)
220
+ k = k.reshape(B, T, -1, self.config.head_dim)
221
+ v = v.reshape(B, T, -1, self.config.head_dim)
214
222
 
215
223
  # Compute rotary positional embedding for query and key.
216
- n_elem = int(self.config.rotary_percentage * self.head_dim)
224
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
217
225
  q, k = _embed_rope(q, k, n_elem, rope)
218
226
 
219
227
  if self.kv_cache is not None:
220
228
  # TODO(haoliang): Handle when execeeding max sequence length.
221
229
  k, v = self.kv_cache.update_cache(input_pos, k, v)
222
230
 
223
- y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
231
+ y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
224
232
  y = y.reshape(B, T, E)
225
233
 
226
234
  # Compute the output projection.
@@ -274,12 +282,12 @@ class CrossAttention(nn.Module):
274
282
  query_dim (int): query tensor's dimension.
275
283
  cross_dim (int): cross attention's dimensions, for key and value tensors.
276
284
  config (cfg.AttentionConfig): attention specific configurations.
277
- kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
285
+ kv_cache_max (int): determines the size of the KV Cache buffer, if
286
+ enabled.
278
287
  enable_hlfb (bool): whether hlfb is enabled or not.
279
288
  """
280
289
  super().__init__()
281
290
  self.config = config
282
- self.head_dim = query_dim // config.num_heads
283
291
  self.n_heads = config.num_heads
284
292
  self.q_projection = nn.Linear(
285
293
  query_dim, query_dim, bias=config.qkv_use_bias
@@ -301,7 +309,7 @@ class CrossAttention(nn.Module):
301
309
  batch_size,
302
310
  kv_cache_max,
303
311
  config.num_query_groups,
304
- self.head_dim,
312
+ self.config.head_dim,
305
313
  enable_hlfb,
306
314
  )
307
315
 
@@ -324,7 +332,8 @@ class CrossAttention(nn.Module):
324
332
  x (torch.Tensor): the target tensor, with shape [B, target_seq_len, ...].
325
333
  y (torch.Tensor): the source tensor, with shape [B, source_seq_len, ...].
326
334
  rope (Tuple[torch.Tensor, torch.Tensor]): the optional input rope tensor.
327
- mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape [B, n_heads, target_seq_len, source_seq_len].
335
+ mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
336
+ [B, n_heads, target_seq_len, source_seq_len].
328
337
  input_pos (torch.Tensor): the optional input position tensor.
329
338
 
330
339
  Returns:
@@ -338,13 +347,13 @@ class CrossAttention(nn.Module):
338
347
  k = self.k_projection(y)
339
348
  v = self.v_projection(y)
340
349
 
341
- interim_shape = (batch_size, -1, self.n_heads, self.head_dim)
350
+ interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
342
351
  q = q.view(interim_shape)
343
352
  k = k.view(interim_shape)
344
353
  v = v.view(interim_shape)
345
354
 
346
355
  # Compute rotary positional embedding for query and key.
347
- n_elem = int(self.config.rotary_percentage * self.head_dim)
356
+ n_elem = int(self.config.rotary_percentage * self.config.head_dim)
348
357
  q, k = _embed_rope(q, k, n_elem, rope)
349
358
 
350
359
  if self.kv_cache is not None:
@@ -354,7 +363,7 @@ class CrossAttention(nn.Module):
354
363
  mask = torch.zeros(
355
364
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
356
365
  )
357
- y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
366
+ y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
358
367
  y = y.reshape(batch_size, target_seq_len, -1)
359
368
 
360
369
  # Compute the output projection.
@@ -28,7 +28,9 @@ def build_rope_cache(
28
28
  dtype: torch.dtype = torch.float32,
29
29
  device: torch.device = None,
30
30
  ) -> Tuple[torch.Tensor, torch.Tensor]:
31
- """Precompute Rotary Positional Embedding Sin and Cos values for quick lookups
31
+ """Precomputes Rotary Positional Embeddings.
32
+
33
+ Precompute Rotary Positional Embedding Sin and Cos values for quick lookup
32
34
  during the inference.
33
35
 
34
36
  Args:
@@ -84,16 +86,22 @@ def relative_position_bucket(
84
86
  num_buckets: int,
85
87
  max_distance: int,
86
88
  ) -> torch.Tensor:
87
- """
88
- Adapted from Mesh Tensorflow:
89
+ """Adapted from Mesh Tensorflow:
90
+
89
91
  https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
90
92
 
91
- Translate relative position to a bucket number for relative attention. The relative position is defined as
92
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
93
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
94
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
95
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
96
- This should allow for more graceful generalization to longer sequences than the model has been trained on
93
+ Translate relative position to a bucket number for relative attention. The
94
+ relative position is defined as
95
+ memory_position - query_position, i.e. the distance in tokens from the
96
+ attending position to the attended-to
97
+ position. If bidirectional=False, then positive relative positions are
98
+ invalid. We use smaller buckets for
99
+ small absolute relative_position and larger buckets for larger absolute
100
+ relative_positions. All relative
101
+ positions >=max_distance map to the same bucket. All relative positions
102
+ <=-max_distance map to the same bucket.
103
+ This should allow for more graceful generalization to longer sequences than
104
+ the model has been trained on
97
105
 
98
106
  Args:
99
107
  relative_position: an int32 Tensor
@@ -102,7 +110,8 @@ def relative_position_bucket(
102
110
  max_distance: an integer for max distance.
103
111
 
104
112
  Returns:
105
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
113
+ a Tensor with the same shape as relative_position, containing int32 values
114
+ in the range [0, num_buckets)
106
115
  """
107
116
  relative_buckets = 0
108
117
  if bidirectional:
@@ -119,7 +128,8 @@ def relative_position_bucket(
119
128
  max_exact = num_buckets // 2
120
129
  is_small = relative_position < max_exact
121
130
 
122
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
131
+ # The other half of the buckets are for logarithmically bigger bins in
132
+ # positions up to max_distance
123
133
  relative_position_if_large = max_exact + (
124
134
  torch.log(relative_position.float() / max_exact)
125
135
  / math.log(max_distance / max_exact)
@@ -148,7 +158,8 @@ def build_relative_position_buckets(
148
158
  Args:
149
159
  query_length: an integer of length of current query tensor.
150
160
  key_length: an integer of length of current key tensor.
151
- bidirectional: a boolean - whether the attention is bidirectional, default is True.
161
+ bidirectional: a boolean - whether the attention is bidirectional, default
162
+ is True.
152
163
  num_buckets: an integer for number of buckets, default is 32.
153
164
  max_distance: an integer for max distance, default is 128.
154
165
 
@@ -26,7 +26,6 @@ class GeGLU(nn.Module):
26
26
 
27
27
  GeGLU(x) = (xW+b) * GELU(xV+c)
28
28
  See: https://arxiv.org/abs/2002.05202v1
29
-
30
29
  """
31
30
 
32
31
  def __init__(self, d_in: int, d_out: int):
@@ -33,11 +33,9 @@ class SequentialFeedForward(nn.Module):
33
33
  ):
34
34
  """Init function for feedforward layer.
35
35
 
36
- Args:
37
- dim(int): embedding size.
38
- hidden_dim(int): hidden dim size of the feedforward layer.
39
- activation(Callable): activation function used in this block.
40
- use_bias(Boolean): whether to use bias. Default is false.
36
+ Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
37
+ feedforward layer. activation(Callable): activation function used in this
38
+ block. use_bias(Boolean): whether to use bias. Default is false.
41
39
  """
42
40
  super().__init__()
43
41
  self.act = activation
@@ -71,11 +69,9 @@ class GatedFeedForward(nn.Module):
71
69
  ):
72
70
  """Init function for feedforward layer.
73
71
 
74
- Args:
75
- dim(int): embedding size.
76
- hidden_dim(int): hidden dim size of the feedforward layer.
77
- activation(Callable): activation function used in this block.
78
- use_bias(Boolean): whether to use bias. Default is false.
72
+ Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
73
+ feedforward layer. activation(Callable): activation function used in this
74
+ block. use_bias(Boolean): whether to use bias. Default is false.
79
75
  """
80
76
  super().__init__()
81
77
  self.act = activation
@@ -17,7 +17,6 @@
17
17
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
18
18
  import torch
19
19
  from torch import nn
20
- import torch_xla
21
20
 
22
21
 
23
22
  class KVCache(nn.Module):
@@ -55,9 +55,10 @@ class FeedForwardType(enum.Enum):
55
55
 
56
56
  @dataclass
57
57
  class AttentionConfig:
58
- """Attention moduel's parameters."""
58
+ """Attention model's parameters."""
59
59
 
60
60
  num_heads: int
61
+ head_dim: int
61
62
  # Used to determine number of groups in grouped query attention (GQA)
62
63
  # https://arxiv.org/pdf/2305.13245.pdf
63
64
  num_query_groups: Optional[int]
@@ -156,7 +157,3 @@ class ModelConfig:
156
157
  return self.kv_cache_max_len
157
158
  else:
158
159
  return self.max_seq_len
159
-
160
- @property
161
- def head_dim(self) -> int:
162
- return self.embedding_dim // self.attn_config.num_heads
@@ -21,12 +21,12 @@ import torch
21
21
  class RMSNorm(torch.nn.Module):
22
22
 
23
23
  def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False):
24
- """
25
- Initialize the RMSNorm layer.
24
+ """Initialize the RMSNorm layer.
26
25
 
27
26
  Args:
28
27
  dim (int): dimension of the input tensor.
29
- eps (float): A small float value to ensure numerical stability (default: 1e-6).
28
+ eps (float): A small float value to ensure numerical stability (default:
29
+ 1e-6).
30
30
  """
31
31
  super().__init__()
32
32
  self.eps = eps
@@ -34,8 +34,7 @@ class RMSNorm(torch.nn.Module):
34
34
  self.zero_centered_gamma = zero_centered_gamma
35
35
 
36
36
  def _norm(self, x):
37
- """
38
- Apply RMSNorm normalization.
37
+ """Apply RMSNorm normalization.
39
38
 
40
39
  Args:
41
40
  x (torch.Tensor): input tensor.
@@ -46,8 +45,7 @@ class RMSNorm(torch.nn.Module):
46
45
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
47
46
 
48
47
  def forward(self, x):
49
- """
50
- Running the forward pass of RMSNorm layer.
48
+ """Running the forward pass of RMSNorm layer.
51
49
 
52
50
  Args:
53
51
  x (torch.Tensor): input tensor.
@@ -22,9 +22,9 @@ def apply_rope(
22
22
  """Computes rotary positional embedding.
23
23
 
24
24
  Args:
25
- x(torch.Tensor): the input tensor.
26
- cos(torch.Tensor): cosine value for the rope.
27
- sin(torch.Tensor): sin value for the rope.
25
+ x: the input tensor.
26
+ cos: cosine value for the rope.
27
+ sin: sin value for the rope.
28
28
 
29
29
  Returns:
30
30
  output tensor of RoPE.
@@ -105,7 +105,6 @@ class AttentionBlock2D(nn.Module):
105
105
  """2D self attention block
106
106
 
107
107
  x = SelfAttention(Norm(input_tensor)) + x
108
-
109
108
  """
110
109
 
111
110
  def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
@@ -161,14 +160,14 @@ class CrossAttentionBlock2D(nn.Module):
161
160
  """2D cross attention block
162
161
 
163
162
  x = CrossAttention(Norm(input_tensor), context) + x
164
-
165
163
  """
166
164
 
167
165
  def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
168
166
  """Initialize an instance of the AttentionBlock2D.
169
167
 
170
168
  Args:
171
- config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this block.
169
+ config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this
170
+ block.
172
171
  """
173
172
  super().__init__()
174
173
  self.config = config
@@ -191,7 +190,8 @@ class CrossAttentionBlock2D(nn.Module):
191
190
 
192
191
  Args:
193
192
  input_tensor (torch.Tensor): the input tensor.
194
- context_tensor (torch.Tensor): the context tensor to apply cross attention on.
193
+ context_tensor (torch.Tensor): the context tensor to apply cross attention
194
+ on.
195
195
 
196
196
  Returns:
197
197
  output activation tensor after cross attention.
@@ -220,7 +220,6 @@ class FeedForwardBlock2D(nn.Module):
220
220
  """2D feed forward block
221
221
 
222
222
  x = w2(Activation(w1(Norm(x)))) + x
223
-
224
223
  """
225
224
 
226
225
  def __init__(
@@ -291,15 +290,14 @@ class TransformerBlock2D(nn.Module):
291
290
  └─────────┬─────────┘
292
291
 
293
292
  hidden_states
294
-
295
-
296
293
  """
297
294
 
298
295
  def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
299
296
  """Initialize an instance of the TransformerBlock2D.
300
297
 
301
298
  Args:
302
- config (unet_cfg.TransformerBlock2Dconfig): the configuration of this block.
299
+ config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
300
+ block.
303
301
  """
304
302
  super().__init__()
305
303
  self.config = config
@@ -329,7 +327,8 @@ class TransformerBlock2D(nn.Module):
329
327
 
330
328
  Args:
331
329
  input_tensor (torch.Tensor): the input tensor.
332
- context_tensor (torch.Tensor): the context tensor to apply cross attention on.
330
+ context_tensor (torch.Tensor): the context tensor to apply cross attention
331
+ on.
333
332
 
334
333
  Returns:
335
334
  output activation tensor after transformer block.
@@ -377,7 +376,8 @@ class DownEncoderBlock2D(nn.Module):
377
376
  """Initialize an instance of the DownEncoderBlock2D.
378
377
 
379
378
  Args:
380
- config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this block.
379
+ config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this
380
+ block.
381
381
  """
382
382
  super().__init__()
383
383
  self.config = config
@@ -418,10 +418,13 @@ class DownEncoderBlock2D(nn.Module):
418
418
 
419
419
  Args:
420
420
  input_tensor (torch.Tensor): the input tensor.
421
- time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
422
- time embedding.
423
- context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
424
- output_hidden_states (bool): whether to output hidden states, usually for skip connections.
421
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is
422
+ configured to accept time embedding.
423
+ context_tensor (torch.Tensor): optional context tensor, if the block if
424
+ configured to use transofrmer block.
425
+ output_hidden_states (bool): whether to output hidden states, usually for
426
+ skip connections.
427
+
425
428
  Returns:
426
429
  output hidden_states tensor after DownEncoderBlock2D.
427
430
  """
@@ -523,9 +526,10 @@ class UpDecoderBlock2D(nn.Module):
523
526
 
524
527
  Args:
525
528
  input_tensor (torch.Tensor): the input tensor.
526
- time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
527
- time embedding.
528
- context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
529
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is
530
+ configured to accept time embedding.
531
+ context_tensor (torch.Tensor): optional context tensor, if the block if
532
+ configured to use transofrmer block.
529
533
 
530
534
  Returns:
531
535
  output hidden_states tensor after UpDecoderBlock2D.
@@ -576,7 +580,8 @@ class SkipUpDecoderBlock2D(nn.Module):
576
580
  """Initialize an instance of the SkipUpDecoderBlock2D.
577
581
 
578
582
  Args:
579
- config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this block.
583
+ config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this
584
+ block.
580
585
  """
581
586
  super().__init__()
582
587
  self.config = config
@@ -632,10 +637,12 @@ class SkipUpDecoderBlock2D(nn.Module):
632
637
 
633
638
  Args:
634
639
  input_tensor (torch.Tensor): the input tensor.
635
- skip_connection_tensors (List[torch.Tensor]): the skip connection tensors from encoder blocks.
636
- time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
637
- time embedding.
638
- context_tensor (torch.Tensor): optional context tensor, if the block if configured to use transofrmer block.
640
+ skip_connection_tensors (List[torch.Tensor]): the skip connection tensors
641
+ from encoder blocks.
642
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is
643
+ configured to accept time embedding.
644
+ context_tensor (torch.Tensor): optional context tensor, if the block if
645
+ configured to use transofrmer block.
639
646
 
640
647
  Returns:
641
648
  output hidden_states tensor after SkipUpDecoderBlock2D.
@@ -738,10 +745,10 @@ class MidBlock2D(nn.Module):
738
745
 
739
746
  Args:
740
747
  input_tensor (torch.Tensor): the input tensor.
741
- time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
742
- time embedding.
743
- context_tensor (torch.Tensor): optional context tensor, if the block if configured to use
744
- transofrmer block.
748
+ time_emb (torch.Tensor): optional time embedding tensor, if the block is
749
+ configured to accept time embedding.
750
+ context_tensor (torch.Tensor): optional context tensor, if the block if
751
+ configured to use transofrmer block.
745
752
 
746
753
  Returns:
747
754
  output hidden_states tensor after MidBlock2D.