ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -30,18 +30,27 @@ class SequentialFeedForward(nn.Module):
30
30
  hidden_dim: int,
31
31
  activation: Callable[[torch.Tensor], torch.Tensor],
32
32
  use_bias=False,
33
+ use_glu=False,
33
34
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
35
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
35
36
  ):
36
37
  """Init function for feedforward layer.
37
38
 
38
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
39
- feedforward layer. activation(Callable): activation function used in this
40
- block. use_bias(Boolean): whether to use bias. Default is false.
39
+ Args:
40
+ dim (int): embedding size.
41
+ hidden_dim (int): hidden dim size of the feedforward layer.
42
+ activation (Callable): activation function used in this block.
43
+ use_bias (Boolean): whether to use bias. Default is false.
44
+ use_glu (Boolean): whether to use glu in activation. Default is false.
45
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
46
+ post_ff_norm (Callable): post feedforward norm. Default is None.
41
47
  """
42
48
  super().__init__()
43
49
  self.act = activation
44
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
50
+ if use_glu:
51
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
52
+ else:
53
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
45
54
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
46
55
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
47
56
  self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
@@ -72,18 +81,27 @@ class GatedFeedForward(nn.Module):
72
81
  hidden_dim: int,
73
82
  activation: Callable[[torch.Tensor], torch.Tensor],
74
83
  use_bias=False,
84
+ use_glu=False,
75
85
  pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
76
86
  post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
77
87
  ):
78
88
  """Init function for feedforward layer.
79
89
 
80
- Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
81
- feedforward layer. activation(Callable): activation function used in this
82
- block. use_bias(Boolean): whether to use bias. Default is false.
90
+ Args:
91
+ dim (int): embedding size.
92
+ hidden_dim (int): hidden dim size of the feedforward layer.
93
+ activation (Callable): activation function used in this block.
94
+ use_bias (Boolean): whether to use bias. Default is false.
95
+ use_glu (Boolean): whether to use glu in activation. Default is false.
96
+ pre_ff_norm (Callable): pre feedforward norm. Default is None.
97
+ post_ff_norm (Callable): post feedforward norm. Default is None.
83
98
  """
84
99
  super().__init__()
85
100
  self.act = activation
86
- self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
101
+ if use_glu:
102
+ self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
103
+ else:
104
+ self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
87
105
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
88
106
  self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
89
107
  self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
@@ -12,72 +12,184 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # `nn.Module` which implements a KV cache.
16
15
 
17
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
16
+ """Utility functions for externalized KV Cache."""
17
+
18
+ import dataclasses
19
+ from typing import List, Tuple
20
+
21
+ from ai_edge_torch import hlfb
22
+ from ai_edge_torch.generative.layers import model_config
18
23
  import torch
19
- from torch import nn
24
+ import torch.utils._pytree as pytree
20
25
 
26
+ BATCH_SIZE = 1
21
27
 
22
- class KVCache(nn.Module):
23
28
 
24
- def __init__(
25
- self, batch_size, kv_cache_max, n_heads, head_dim, enable_hlfb=False
26
- ):
27
- """Initializes the KVCache layer.
29
+ @dataclasses.dataclass
30
+ class KVCacheEntry:
31
+ """A single cache entry that includes K and V caches.
28
32
 
29
- Args:
30
- batch_size (int): batch size. Currently only batch size 1 is supported.
31
- kv_cache_max (int): the max length of KV cache.
32
- n_heads (int): number of kv heads.
33
- head_dim (int): the head dimension size.
34
- enable_hlfb (bool): whether hlfb is enabled or not.
35
- """
36
- super().__init__()
37
- cache_shape = (batch_size, kv_cache_max, n_heads, head_dim)
38
- self.register_buffer("k_cache", torch.zeros(cache_shape), persistent=False)
39
- self.register_buffer("v_cache", torch.zeros(cache_shape), persistent=False)
40
- self.enable_hlfb = enable_hlfb
41
- self.kv_cache_max = kv_cache_max
33
+ The chaches are built based on the provided config with the shape of
34
+ (batch_size=1, kv_cache_max, num_query_groups, head_dim).
35
+ """
42
36
 
43
- def update_cache(self, input_pos, k_val, v_val):
44
- """Update an entry in the KV cache.
37
+ k_cache: torch.Tensor
38
+ v_cache: torch.Tensor
45
39
 
46
- Args:
47
- input_pos (torch.Tensor): the input position.
48
- k_val (torch.Tensor): the new `key` value.
49
- v_val (torch.Tensor): the new `value` value.
40
+ @classmethod
41
+ def from_model_config(
42
+ cls,
43
+ kv_cache_max: int,
44
+ config: model_config.AttentionConfig,
45
+ dtype: torch.dtype = torch.float32,
46
+ device: torch.device = None,
47
+ ) -> "KVCacheEntry":
48
+ """Build an instance of the class based on model config."""
49
+ shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
50
+ k = torch.zeros(shape, dtype=dtype, device=device)
51
+ v = torch.zeros(shape, dtype=dtype, device=device)
52
+ obj = cls(k_cache=k, v_cache=v)
53
+ return obj
50
54
 
51
- Returns:
52
- The updated key and value tensor.
53
- """
54
- if self.enable_hlfb:
55
- return self.update_cache_with_hlfb(input_pos, k_val, v_val)
56
55
 
57
- updated_k = self.k_cache.index_copy_(1, input_pos, k_val)
58
- updated_v = self.v_cache.index_copy_(1, input_pos, v_val)
59
- # Here we need a clone otherwise dynamo export will fail.
60
- return torch.clone(updated_k), torch.clone(updated_v)
56
+ @dataclasses.dataclass
57
+ class KVCache:
58
+ """A utility class for holding KV cache entries per layer."""
61
59
 
62
- def update_cache_with_hlfb(self, input_pos, k_val, v_val):
63
- """Update an entry in the KV cache and enable high-level function boundary.
60
+ caches: Tuple[KVCacheEntry, ...]
61
+
62
+ @classmethod
63
+ def from_model_config(
64
+ cls,
65
+ config: model_config.ModelConfig,
66
+ dtype: torch.dtype = torch.float32,
67
+ device: torch.device = None,
68
+ ) -> "KVCache":
69
+ """Build an instance of the class based on model config.
64
70
 
65
71
  Args:
66
- input_pos (torch.Tensor): the input position.
67
- k_val (torch.Tensor): the new `key` value.
68
- v_val (torch.Tensor): the new `value` value.
72
+ config (ModelConfig): Model config used for building the cache.
73
+ dtype (torch.dtype, optional): The data type of the cache tensor.
74
+ Defaults to torch.float32.
75
+ device (torch.device, optional): The device placement of the cache
76
+ tensors. Defaults to None.
69
77
 
70
78
  Returns:
71
- The updated key and value tensor.
79
+ KVCache: The created cache object.
72
80
  """
81
+ caches = [
82
+ KVCacheEntry.from_model_config(
83
+ config.kv_cache_max,
84
+ config.block_config(idx).attn_config,
85
+ dtype,
86
+ device,
87
+ )
88
+ for idx in range(config.num_layers)
89
+ ]
90
+ obj = cls(caches=tuple(caches))
91
+ return obj
73
92
 
74
- builder = StableHLOCompositeBuilder(
75
- name="odml.update_kv_cache", attr={"kv_cache_max": self.kv_cache_max}
76
- )
77
- k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
78
- self.k_cache, self.v_cache, input_pos, k_val, v_val
93
+ def flatten(self) -> List[torch.Tensor]:
94
+ """Flatten the cache entries into a list of tensors with order k_i, v_i."""
95
+ flattened, _ = _flatten_kvc(self)
96
+ return flattened
97
+
98
+
99
+ def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
100
+ flattened = []
101
+ flat_names = []
102
+ none_names = []
103
+ for i, kv_entry in enumerate(kvc.caches):
104
+ flattened.append(kv_entry.k_cache)
105
+ flat_names.append(f"k_{i}")
106
+ flattened.append(kv_entry.v_cache)
107
+ flat_names.append(f"v_{i}")
108
+ return flattened, [flat_names, none_names]
109
+
110
+
111
+ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
112
+ flattened, (flat_names, none_names) = _flatten_kvc(kvc)
113
+ return [
114
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
115
+ ], flat_names
116
+
117
+
118
+ def _unflatten_kvc(
119
+ values: List[torch.Tensor], context: Tuple[List, List]
120
+ ) -> KVCache:
121
+ assert len(values) % 2 == 0, "Found odd number of K and V entries."
122
+ num_layers = len(values) // 2
123
+ flat_names = context[0]
124
+ kv_entries = []
125
+ for i in range(num_layers):
126
+ k_cache_idx = flat_names.index(f"k_{i}")
127
+ v_cache_idx = flat_names.index(f"v_{i}")
128
+ kv_entries.append(
129
+ KVCacheEntry(k_cache=values[k_cache_idx], v_cache=values[v_cache_idx])
79
130
  )
80
- updated_k = k_cache.index_copy_(1, input_pos, k_val)
81
- updated_v = v_cache.index_copy_(1, input_pos, v_val)
82
- updated_k, updated_v = builder.mark_outputs(updated_k, updated_v)
83
- return updated_k, updated_v
131
+ obj = KVCache(tuple(kv_entries))
132
+ return obj
133
+
134
+
135
+ pytree.register_pytree_node(
136
+ KVCache,
137
+ _flatten_kvc,
138
+ _unflatten_kvc,
139
+ flatten_with_keys_fn=_flatten_kvc_with_keys,
140
+ serialized_type_name="",
141
+ )
142
+
143
+
144
+ def update(
145
+ cache: KVCacheEntry,
146
+ input_pos: torch.Tensor,
147
+ k_slice: torch.Tensor,
148
+ v_slice: torch.Tensor,
149
+ enable_hlfb: bool = True,
150
+ ) -> KVCacheEntry:
151
+ """Out of place update of Cache buffer.
152
+
153
+ Args:
154
+ cache (KVCacheEntry): The original cache buffer.
155
+ input_pos (torch.Tensor): The update slice positions.
156
+ k_slice (torch.Tensor): The K slice to be updated in the new cache.
157
+ v_slice (torch.Tensor): The V slice to be updated in the new cache.
158
+ enable_hlfb (bool, optional): Whether the op is annotated for export with
159
+ High Level Function Boundary. Defaults to True.
160
+
161
+ Returns:
162
+ KVCacheEntry: The updated KVCache entry based on the passed inputs.
163
+ """
164
+ update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
165
+ return update_func(cache, input_pos, k_slice, v_slice)
166
+
167
+
168
+ def _update_kv_base_impl(
169
+ cache: KVCacheEntry,
170
+ input_pos: torch.Tensor,
171
+ k_slice: torch.Tensor,
172
+ v_slice: torch.Tensor,
173
+ ) -> KVCacheEntry:
174
+ """Update the cache buffer without High Level Function Boundary annotation."""
175
+ k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
176
+ v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
177
+ updated_cache = KVCacheEntry(k, v)
178
+ return updated_cache
179
+
180
+
181
+ def _update_kv_hlfb_impl(
182
+ cache: KVCacheEntry,
183
+ input_pos: torch.Tensor,
184
+ k_slice: torch.Tensor,
185
+ v_slice: torch.Tensor,
186
+ ) -> KVCacheEntry:
187
+ """Update the cache buffer with High Level Function Boundary annotation."""
188
+ builder = hlfb.StableHLOCompositeBuilder(name="odml.update_external_kv_cache")
189
+ k_cache, v_cache, input_pos, k_slice, v_slice = builder.mark_inputs(
190
+ cache.k_cache, cache.v_cache, input_pos, k_slice, v_slice
191
+ )
192
+ k = k_cache.index_copy(1, input_pos.to(torch.long), k_slice)
193
+ v = v_cache.index_copy(1, input_pos.to(torch.long), v_slice)
194
+ k, v = builder.mark_outputs(k, v)
195
+ return KVCacheEntry(k, v)
@@ -16,7 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from dataclasses import field
18
18
  import enum
19
- from typing import Optional, Sequence
19
+ from typing import Optional, Sequence, Union
20
20
 
21
21
 
22
22
  @enum.unique
@@ -30,6 +30,7 @@ class ActivationType(enum.Enum):
30
30
  GELU_QUICK = enum.auto()
31
31
  GE_GLU = enum.auto()
32
32
  RELU = enum.auto()
33
+ SILU_GLU = enum.auto()
33
34
 
34
35
 
35
36
  @enum.unique
@@ -58,6 +59,18 @@ class AttentionType(enum.Enum):
58
59
  LOCAL_SLIDING = enum.auto()
59
60
 
60
61
 
62
+ @dataclass
63
+ class NormalizationConfig:
64
+ """Normalizater parameters."""
65
+
66
+ type: NormalizationType = NormalizationType.NONE
67
+ enable_hlfb: bool = False
68
+ epsilon: float = 1e-5
69
+ zero_centered: bool = False
70
+ # Number of groups used in group normalization.
71
+ group_num: Optional[float] = None
72
+
73
+
61
74
  @dataclass
62
75
  class AttentionConfig:
63
76
  """Attention model's parameters."""
@@ -81,12 +94,20 @@ class AttentionConfig:
81
94
  # Whether to use bias with attention output projection.
82
95
  output_proj_use_bias: bool = False
83
96
  enable_kv_cache: bool = True
97
+ # The normalization applied to query projection's output.
98
+ query_norm_config: NormalizationConfig = field(
99
+ default_factory=NormalizationConfig
100
+ )
101
+ # The normalization applied to key projection's output.
102
+ key_norm_config: NormalizationConfig = field(
103
+ default_factory=NormalizationConfig
104
+ )
84
105
  relative_attention_num_buckets: int = 0
85
106
  relative_attention_max_distance: int = 0
86
107
  # Softcap on the output logits.
87
108
  logit_softcap: Optional[float] = None
88
- # The types of attention used in the layers of the model.
89
- attn_types: Optional[Sequence[AttentionType]] = None
109
+ # The type of attention.
110
+ attn_type: Optional[AttentionType] = None
90
111
  # The size of the sliding window used for local attention.
91
112
  sliding_window_size: Optional[int] = None
92
113
 
@@ -94,20 +115,9 @@ class AttentionConfig:
94
115
  @dataclass
95
116
  class ActivationConfig:
96
117
  type: ActivationType = ActivationType.LINEAR
97
- # Dimension of input and output, used in GeGLU.
98
- dim_in: Optional[int] = None
99
- dim_out: Optional[int] = None
100
-
101
-
102
- @dataclass
103
- class NormalizationConfig:
104
- """Normalizater parameters."""
105
-
106
- type: NormalizationType = NormalizationType.NONE
107
- epsilon: float = 1e-5
108
- zero_centered: bool = False
109
- # Number of groups used in group normalization.
110
- group_num: Optional[float] = None
118
+ # Whether to GLU gate is the front part instead of the back part of input
119
+ # when ActivationType is `GE_GLU` or `SILU_GLU`.
120
+ gate_is_front: bool = False
111
121
 
112
122
 
113
123
  @dataclass
@@ -129,13 +139,8 @@ class FeedForwardConfig:
129
139
 
130
140
 
131
141
  @dataclass
132
- class ModelConfig:
133
- """Base configurations for building a transformer architecture."""
134
-
135
- vocab_size: int
136
- num_layers: int
137
- max_seq_len: int
138
- embedding_dim: int
142
+ class TransformerBlockConfig:
143
+ """TransformerBlock module's parameters."""
139
144
 
140
145
  attn_config: AttentionConfig
141
146
  ff_config: FeedForwardConfig
@@ -147,15 +152,33 @@ class ModelConfig:
147
152
  post_attention_norm_config: NormalizationConfig = field(
148
153
  default_factory=NormalizationConfig
149
154
  )
155
+ # If set to True, only attn_config.pre_attention_norm is applied to the input
156
+ # and the decode's output is computed as `output = input + attn_out + ff_out`
157
+ # where attention and feed forward are called with pre_attention_norm's
158
+ # output.
159
+ parallel_residual: bool = False
160
+ # The Attention computation will include relative positional bias.
161
+ relative_attention: bool = False
162
+
163
+
164
+ @dataclass
165
+ class ModelConfig:
166
+ """Base configurations for building a transformer architecture."""
167
+
168
+ vocab_size: int
169
+ num_layers: int
170
+ max_seq_len: int
171
+ embedding_dim: int
172
+
173
+ # TransformerBlockConfig for each layer block. If a single
174
+ # TransformerBlockConfig is provided, it will be used for all layers.
175
+ block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
176
+
150
177
  # The normalization applied before LM head.
151
178
  final_norm_config: NormalizationConfig = field(
152
179
  default_factory=NormalizationConfig
153
180
  )
154
181
 
155
- # If set to True, only pre_attention_norm is applied to the input and the
156
- # decode's output is computed as `output = input + attn_out + ff_out` where
157
- # attention and feed forward are called with pre_attention_norm's output.
158
- parallel_residual: bool = False
159
182
  # Use bias term within LLM's HEAD.
160
183
  lm_head_use_bias: bool = False
161
184
  # Whether to turn on high-level function boundary.
@@ -164,9 +187,6 @@ class ModelConfig:
164
187
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
165
188
  kv_cache_max_len: int = 0
166
189
 
167
- # The Attention computation will include relative positional bias.
168
- relative_attention: bool = False
169
-
170
190
  # Default batch size of the exported model. Default value is 1.
171
191
  batch_size: int = 1
172
192
 
@@ -177,5 +197,13 @@ class ModelConfig:
177
197
  def kv_cache_max(self) -> int:
178
198
  if self.kv_cache_max_len > 0:
179
199
  return self.kv_cache_max_len
180
- else:
181
- return self.max_seq_len
200
+ return self.max_seq_len
201
+
202
+ def block_config(self, idx: int) -> TransformerBlockConfig:
203
+ if isinstance(self.block_configs, TransformerBlockConfig):
204
+ return self.block_configs
205
+ if idx < 0 or idx >= len(self.block_configs):
206
+ raise ValueError(
207
+ f"Index {idx} is out of range for layer configs: {self.block_configs}"
208
+ )
209
+ return self.block_configs[idx]
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
18
21
 
19
22
 
20
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
58
61
  return output * (1 + self.weight)
59
62
  else:
60
63
  return output * self.weight
64
+
65
+
66
+ class GroupNorm(torch.nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ group_num: int,
71
+ dim: int,
72
+ eps: float = 1e-5,
73
+ enable_hlfb: bool = False,
74
+ ):
75
+ """Initialize the GroupNorm layer.
76
+
77
+ Args:
78
+ group_num (int): Number of groups to separate the channels into.
79
+ dim (int): Dimension of the input tensor.
80
+ eps (float): A small float value to ensure numerical stability (default:
81
+ 1e-6).
82
+ enable_hlfb (bool): Whether to convert this normalization into a single
83
+ op.
84
+ """
85
+ super().__init__()
86
+ self.enable_hlfb = enable_hlfb
87
+ self.group_num = group_num
88
+ self.eps = eps
89
+ self.weight = torch.nn.Parameter(torch.ones(dim))
90
+ self.bias = torch.nn.Parameter(torch.ones(dim))
91
+
92
+ def forward(self, x):
93
+ """Running the forward pass of GroupNorm layer.
94
+
95
+ Args:
96
+ x (torch.Tensor): input tensor.
97
+
98
+ Returns:
99
+ torch.Tensor: output tensor after applying GroupNorm.
100
+ """
101
+ if self.enable_hlfb:
102
+ return group_norm_with_hlfb(
103
+ x,
104
+ self.weight,
105
+ self.bias,
106
+ self.group_num,
107
+ self.eps,
108
+ )
109
+ else:
110
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
111
+
112
+
113
+ class LayerNorm(torch.nn.Module):
114
+
115
+ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
116
+ """Initialize the LayerNorm layer.
117
+
118
+ Args:
119
+ dim (int): dimension of the input tensor.
120
+ eps (float): A small float value to ensure numerical stability (default:
121
+ 1e-6).
122
+ enable_hlfb (bool): Whether to convert this normalization into a single
123
+ op.
124
+ """
125
+ super().__init__()
126
+ self.enable_hlfb = enable_hlfb
127
+ self.eps = eps
128
+ self.weight = torch.nn.Parameter(torch.ones(dim))
129
+ self.bias = torch.nn.Parameter(torch.ones(dim))
130
+
131
+ def forward(self, x):
132
+ """Running the forward pass of LayerNorm layer.
133
+
134
+ Args:
135
+ x (torch.Tensor): input tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: output tensor after applying LayerNorm.
139
+ """
140
+ if self.enable_hlfb:
141
+ return layer_norm_with_hlfb(
142
+ x,
143
+ self.weight,
144
+ self.bias,
145
+ self.eps,
146
+ )
147
+ else:
148
+ return F.layer_norm(
149
+ x,
150
+ x.shape,
151
+ self.weight.broadcast_to(x.shape),
152
+ self.bias.broadcast_to(x.shape),
153
+ self.eps,
154
+ )
155
+
156
+
157
+ def group_norm_with_hlfb(
158
+ x: torch.Tensor,
159
+ w: torch.Tensor,
160
+ b: torch.Tensor,
161
+ num_groups: int,
162
+ eps: float,
163
+ ):
164
+ """Group Normalization with high-level function boundary enabled.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
168
+ w (torch.Tensor): The weight tensor for the normalization.
169
+ b (torch.Tensor): The bias tensor for the normalization.
170
+ num_groups (int): Number of groups to separate the channels into.
171
+ eps (float): A small float value to ensure numerical stability.
172
+
173
+ Returns:
174
+ The output tensor of Group Normalization.
175
+ """
176
+ x = torch.permute(x, (0, 2, 3, 1))
177
+
178
+ builder = StableHLOCompositeBuilder(
179
+ name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
180
+ )
181
+ x, w, b = builder.mark_inputs(x, w, b)
182
+ x = torch.permute(x, (0, 3, 1, 2))
183
+ y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
184
+ y = torch.permute(y, (0, 2, 3, 1))
185
+ y = builder.mark_outputs(y)
186
+
187
+ y = torch.permute(y, (0, 3, 1, 2))
188
+ return y
189
+
190
+
191
+ def layer_norm_with_hlfb(
192
+ x: torch.Tensor,
193
+ w: torch.Tensor,
194
+ b: torch.Tensor,
195
+ eps: float,
196
+ ):
197
+ """Layer Normalization with high-level function boundary enabled.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor for Layer Normalization.
201
+ w (torch.Tensor): The weight tensor for the normalization.
202
+ b (torch.Tensor): The bias tensor for the normalization.
203
+ eps (float): A small float value to ensure numerical stability.
204
+
205
+ Returns:
206
+ The output tensor of Layer Normalization.
207
+ """
208
+ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
+ x, w, b = builder.mark_inputs(x, w, b)
210
+ y = F.layer_norm(
211
+ x,
212
+ x.shape,
213
+ weight=w.broadcast_to(x.shape),
214
+ bias=b.broadcast_to(x.shape),
215
+ eps=eps,
216
+ )
217
+ y = builder.mark_outputs(y)
218
+ return y
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  config.attention_batch_size,
123
123
  config.dim,
124
124
  config.attention_config,
125
- 0,
126
125
  enable_hlfb=config.enable_hlfb,
127
126
  )
128
127
 
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
180
179
  config.query_dim,
181
180
  config.cross_dim,
182
181
  config.attention_config,
183
- 0,
184
182
  enable_hlfb=config.enable_hlfb,
185
183
  )
186
184
 
@@ -25,9 +25,9 @@ def main():
25
25
  config = gemma.get_fake_model_config()
26
26
  model = gemma.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
28
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
30
- input_pos = torch.arange(0, 10)
30
+ input_pos = torch.arange(0, 10, dtype=torch.int)
31
31
 
32
32
  # Create a quantization recipe to be applied to the model
33
33
  quant_config = quant_recipes.full_int8_dynamic_recipe()