ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250107__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
ai_edge_torch/_config.py CHANGED
@@ -22,6 +22,18 @@ import os
22
22
  __all__ = ["config"]
23
23
 
24
24
 
25
+ def _get_bool_env_var(name: str, default: bool) -> bool:
26
+ var = os.environ.get(name, "false")
27
+ var = var.lower().strip()
28
+ if var in ("y", "yes", "t", "true", "on", "1"):
29
+ return True
30
+ elif var in ("n", "no", "f", "false", "off", "0"):
31
+ return False
32
+ else:
33
+ logging.warning("Invalid %s value is ignored: %s.", name, var)
34
+ return default
35
+
36
+
25
37
  class _Config:
26
38
  """ai-edge-torch global configs."""
27
39
 
@@ -33,20 +45,25 @@ class _Config:
33
45
  To use torch_xla as the lowering backend, set environment variable
34
46
  `USE_TORCH_XLA` to "true".
35
47
  """
36
- var = os.environ.get("USE_TORCH_XLA", "false")
37
- var = var.lower().strip()
38
- if var in ("y", "yes", "t", "true", "on", "1"):
39
- return True
40
- elif var in ("n", "no", "f", "false", "off", "0"):
41
- return False
42
- else:
43
- logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
- return False
48
+ return _get_bool_env_var("USE_TORCH_XLA", default=False)
45
49
 
46
50
  @property
47
51
  def in_oss(self) -> bool:
48
52
  """True if the code is not running in google internal environment."""
49
53
  return True
50
54
 
55
+ @property
56
+ def enable_group_norm_composite(self) -> bool:
57
+ """True if lowering group norm in StableHLO composite.
58
+
59
+ Currently only supports NHWC group norm generated by
60
+ OptimizeLayoutTransposesPass.
61
+ """
62
+ return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63
+
64
+ @enable_group_norm_composite.setter
65
+ def enable_group_norm_composite(self, value: bool):
66
+ os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
+
51
68
 
52
69
  config = _Config()
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  import operator
19
19
 
20
+ import ai_edge_torch
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155
156
  @layout_sensitive_inputs_getters.register(
156
157
  aten._native_batch_norm_legit_no_training
157
158
  )
159
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
158
160
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
159
161
  def _first_arg_getter(node):
160
162
  return [node.args[0]]
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
188
190
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
191
 
190
192
 
193
+ @nhwcable_node_checkers.register(aten.group_norm)
194
+ def _aten_group_norm_checker(node):
195
+ val = node.meta.get("val")
196
+ if not hasattr(val, "shape"):
197
+ return NHWCable(can_be=False, must_be=False)
198
+
199
+ can_be = len(val.shape) == 4
200
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201
+ return NHWCable(can_be=can_be, must_be=must_be)
202
+
203
+
191
204
  @nhwcable_node_checkers.register(aten.native_group_norm)
192
205
  def _aten_native_group_norm_checker(node):
193
206
  val = node.meta.get("val")
@@ -16,6 +16,7 @@
16
16
 
17
17
  import operator
18
18
 
19
+ import ai_edge_torch
19
20
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -23,6 +24,7 @@ import torch
23
24
  import torch.utils._pytree as pytree
24
25
 
25
26
  aten = torch.ops.aten
27
+ StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
26
28
 
27
29
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
28
30
 
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
342
344
  node.target = batch_norm
343
345
 
344
346
 
347
+ @rewriters.register(aten.group_norm.default)
348
+ def _aten_group_norm(node):
349
+ def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
350
+ is_composite_supported = (
351
+ ai_edge_torch.config.enable_group_norm_composite
352
+ and weight is not None
353
+ and bias is not None
354
+ )
355
+
356
+ builder = None
357
+ if is_composite_supported:
358
+ builder = StableHLOCompositeBuilder(
359
+ name="odml.group_norm",
360
+ attr={
361
+ "num_groups": num_groups,
362
+ "epsilon": eps,
363
+ "reduction_axes": [3],
364
+ "channel_axis": 3,
365
+ },
366
+ )
367
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
368
+
369
+ input = utils.tensor_to_nchw(input)
370
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
371
+ output = utils.tensor_to_nhwc(output)
372
+
373
+ if builder is not None:
374
+ output = builder.mark_outputs(output)
375
+ return output
376
+
377
+ node.target = group_norm
378
+
379
+
345
380
  @rewriters.register(aten.native_group_norm.default)
346
381
  def _aten_native_group_norm(node):
347
382
 
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
354
389
  flattened_inner_size: int,
355
390
  num_groups: int,
356
391
  eps: float,
392
+ **kwargs,
357
393
  ):
358
394
  input_reshaped = torch.reshape(
359
395
  input,
@@ -15,14 +15,13 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
25
  from ai_edge_torch.generative.utilities import model_builder
27
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
27
  import torch
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
104
103
  config.embedding_dim,
105
104
  config.final_norm_config,
106
105
  )
107
- self.mask_cache = attn_utils.build_causal_mask_cache(
108
- size=config.kv_cache_max,
109
- )
110
106
  # Gemma2 has same hyper parameters for each layer except for attention
111
107
  # types. Use the first layer.
112
108
  attn_config = config.block_config(0).attn_config
109
+ self.rope_cache = attn_utils.build_rope_cache(
110
+ size=config.kv_cache_max,
111
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
+ base=attn_config.rotary_base,
113
+ )
114
+ self.mask_cache = attn_utils.build_causal_mask_cache(
115
+ size=config.kv_cache_max,
116
+ )
113
117
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
114
118
  size=config.kv_cache_max,
115
119
  window_size=attn_config.sliding_window_size,
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
136
140
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
141
  f" {self.config.max_seq_len}"
138
142
  )
139
-
140
- # token embeddings of shape (b, t, n_embd)
141
- input_embeds = self.tok_embedding(tokens)
142
- # RoPE parameters are the same for all blocks. Use the first layer.
143
- attn_config = self.config.block_config(0).attn_config
144
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
- rope = rotary_pos_emb.build_rope(
146
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
- )
148
- mask = [self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- ) for i in range(self.config.num_layers)]
151
-
152
- return self._forward_with_embeds(
153
- input_embeds, rope, mask, input_pos, kv_cache, export_config
154
- )
155
-
156
- def _forward_with_embeds(
157
- self,
158
- input_embeds: torch.Tensor,
159
- rope: Tuple[torch.Tensor, torch.Tensor],
160
- mask: List[torch.Tensor],
161
- input_pos: torch.Tensor,
162
- kv_cache: kv_utils.KVCache,
163
- export_config: Optional[model_builder.ExportConfig] = None,
164
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
- """Forwards the model with input embeddings."""
166
143
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
144
  "The number of transformer blocks and the number of KV cache entries"
168
145
  " must be the same."
169
146
  )
170
147
 
171
- if self.config.embedding_scale is not None:
172
- input_embeds = input_embeds * self.config.embedding_scale
173
- x = input_embeds
174
- updated_kv_entries = []
148
+ cos, sin = self.rope_cache
149
+ cos = cos.index_select(0, input_pos)
150
+ sin = sin.index_select(0, input_pos)
151
+
152
+ # token embeddings of shape (b, t, n_embd)
153
+ x = self.tok_embedding(tokens)
154
+ x = x * (self.config.embedding_dim**0.5)
155
+
156
+ updated_kv_entires = []
175
157
  for i, block in enumerate(self.transformer_blocks):
158
+ mask = self.get_attention_mask(
159
+ block.config.attn_config.attn_type, input_pos
160
+ )
176
161
  kv_entry = kv_cache.caches[i] if kv_cache else None
177
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
178
163
  if kv_entry:
179
- updated_kv_entries.append(kv_entry)
180
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
164
+ updated_kv_entires.append(kv_entry)
165
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
181
166
 
182
167
  if export_config is not None:
183
168
  if (
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
243
228
  )
244
229
 
245
230
  num_layers = 26
246
- embedding_dim = 2304
247
231
  config = cfg.ModelConfig(
248
232
  vocab_size=256000,
249
233
  num_layers=num_layers,
250
234
  max_seq_len=8192,
251
- embedding_dim=embedding_dim,
252
- embedding_scale=embedding_dim**0.5,
235
+ embedding_dim=2304,
253
236
  kv_cache_max_len=kv_cache_max_len,
254
237
  block_configs=[get_block_config(i) for i in range(num_layers)],
255
238
  final_norm_config=norm_config,
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
266
249
  config.num_layers = 2
267
250
  config.max_seq_len = 2 * kv_cache_max_len
268
251
  config.embedding_dim = 128
269
- config.embedding_scale = config.embedding_dim**0.5
270
252
  config.block_configs = config.block_configs[: config.num_layers]
271
253
  for block_config in config.block_configs:
272
254
  block_config.attn_config.num_heads = 4
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entries = []
75
+ updated_kv_entires = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entries.append(kv_entry)
80
+ updated_kv_entires.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -26,6 +26,33 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
 
29
+ def _embed_rope(
30
+ q: torch.Tensor,
31
+ k: torch.Tensor,
32
+ n_elem: int,
33
+ rope: Tuple[torch.Tensor, torch.Tensor],
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """Embed rotary positional embedding for query and key.
36
+
37
+ Args:
38
+ q (torch.Tensor): query tensor.
39
+ k (torch.Tensor): key tensor.
40
+ n_elem (int): number of elements to embed rotarty positional embedding.
41
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
+ """
43
+ if n_elem > 0:
44
+ cos, sin = rope
45
+ q_roped = rotary_pos_emb.apply_rope(
46
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
+ )
48
+ k_roped = rotary_pos_emb.apply_rope(
49
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
+ )
51
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
+ return q, k
54
+
55
+
29
56
  class TransformerBlock(nn.Module):
30
57
 
31
58
  def __init__(
@@ -211,8 +238,7 @@ class CausalSelfAttention(nn.Module):
211
238
  if rope is not None:
212
239
  # Compute rotary positional embedding for query and key.
213
240
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
214
- cos, sin = rope
215
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
241
+ q, k = _embed_rope(q, k, n_elem, rope)
216
242
 
217
243
  if kv_cache is not None:
218
244
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -348,8 +374,7 @@ class CrossAttention(nn.Module):
348
374
  if rope is not None:
349
375
  # Compute rotary positional embedding for query and key.
350
376
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
351
- cos, sin = rope
352
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
377
+ q, k = _embed_rope(q, k, n_elem, rope)
353
378
 
354
379
  if kv_cache is not None:
355
380
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -80,6 +80,7 @@ class RMSNorm(torch.nn.Module):
80
80
  output = self._norm(x.float()).type_as(x)
81
81
  return output * w
82
82
 
83
+
83
84
  class GroupNorm(torch.nn.Module):
84
85
 
85
86
  def __init__(
@@ -115,16 +116,7 @@ class GroupNorm(torch.nn.Module):
115
116
  Returns:
116
117
  torch.Tensor: output tensor after applying GroupNorm.
117
118
  """
118
- if self.enable_hlfb:
119
- return group_norm_with_hlfb(
120
- x,
121
- self.weight,
122
- self.bias,
123
- self.group_num,
124
- self.eps,
125
- )
126
- else:
127
- return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
119
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
128
120
 
129
121
 
130
122
  class LayerNorm(torch.nn.Module):
@@ -169,46 +161,6 @@ class LayerNorm(torch.nn.Module):
169
161
  )
170
162
 
171
163
 
172
- def group_norm_with_hlfb(
173
- x: torch.Tensor,
174
- w: torch.Tensor,
175
- b: torch.Tensor,
176
- num_groups: int,
177
- eps: float,
178
- ):
179
- """Group Normalization with high-level function boundary enabled.
180
-
181
- Args:
182
- x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
183
- w (torch.Tensor): The weight tensor for the normalization.
184
- b (torch.Tensor): The bias tensor for the normalization.
185
- num_groups (int): Number of groups to separate the channels into.
186
- eps (float): A small float value to ensure numerical stability.
187
-
188
- Returns:
189
- The output tensor of Group Normalization.
190
- """
191
- x = torch.permute(x, (0, 2, 3, 1))
192
-
193
- builder = StableHLOCompositeBuilder(
194
- name="odml.group_norm",
195
- attr={
196
- "num_groups": num_groups,
197
- "epsilon": eps,
198
- "reduction_axes": [3],
199
- "channel_axis": 3,
200
- },
201
- )
202
- x, w, b = builder.mark_inputs(x, w, b)
203
- x = torch.permute(x, (0, 3, 1, 2))
204
- y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
205
- y = torch.permute(y, (0, 2, 3, 1))
206
- y = builder.mark_outputs(y)
207
-
208
- y = torch.permute(y, (0, 3, 1, 2))
209
- return y
210
-
211
-
212
164
  def rms_norm_with_hlfb(
213
165
  x: torch.Tensor,
214
166
  w: torch.Tensor,
@@ -32,64 +32,57 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
- left = x1 * cos - x2 * sin
37
- right = x2 * cos + x1 * sin
38
- roped = torch.cat([left, right], dim=-1)
35
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
+ roped = (x * cos) + (rotated * sin)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def build_rope(
42
+ def apply_rope_inline(
43
+ q: torch.Tensor,
44
+ k: torch.Tensor,
43
45
  input_pos: torch.Tensor,
44
46
  n_elem: int,
45
- head_dim: int,
46
47
  base: int = 10_000,
47
48
  ) -> Tuple[torch.Tensor, torch.Tensor]:
48
- """Computes rotary positional embedding cosine and sine tensors.
49
+ """Computes rotary positional embedding inline for a query and key.
49
50
 
50
51
  Args:
52
+ q: the query tensor.
53
+ k: the key tensor.
51
54
  input_pos: the sequence indices for the query and key
52
55
  n_elem: number of elements of the head dimension for RoPE computation
53
- base: the base of the exponentiated value for RoPE.
54
56
 
55
57
  Returns:
56
- cos, sin tensors
58
+ output the RoPE'd query and key.
57
59
  """
58
60
 
59
61
  if n_elem <= 0:
60
- return None, None
62
+ return q, k
61
63
 
62
64
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
63
65
  freq_exponents = (2.0 / n_elem) * torch.arange(
64
- head_dim // 2, dtype=torch.float32
66
+ q.shape[-1] // 2, dtype=torch.float32
65
67
  )
66
68
  timescale = float(base) ** freq_exponents
67
69
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
68
70
  0
69
71
  ).unsqueeze(0)
70
- cos = torch.cos(radians)
71
- sin = torch.sin(radians)
72
- return cos, sin
73
-
72
+ cos = torch.cos(radians).type_as(q)
73
+ sin = torch.sin(radians).type_as(q)
74
74
 
75
- def apply_rope_inline(
76
- q: torch.Tensor,
77
- k: torch.Tensor,
78
- cos: torch.Tensor,
79
- sin: torch.Tensor,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- """Computes rotary positional embedding inline for a query and key.
82
-
83
- Args:
84
- q: the query tensor.
85
- k: the key tensor.
86
- cos: the cosine tensor.
87
- sin: the sine tensor.
88
-
89
- Returns:
90
- output the RoPE'd query and key.
91
- """
75
+ def apply(x, sin, cos):
76
+ x = x.transpose(1, 2)
77
+ b, h, s, d = x.shape
78
+ ans = torch.split(x, d // 2, dim=-1)
79
+ x1, x2 = ans
80
+ left = x1 * cos - x2 * sin
81
+ right = x2 * cos + x1 * sin
82
+ res = torch.cat([left, right], dim=-1)
83
+ res = res.transpose(1, 2)
84
+ return res
92
85
 
93
- q_roped = apply_rope(q, cos, sin)
94
- k_roped = apply_rope(k, cos, sin)
86
+ q_roped = apply(q, sin, cos)
87
+ k_roped = apply(k, sin, cos)
95
88
  return q_roped, k_roped
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
27
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
29
28
  import torch
30
29
  from torch import nn
@@ -86,6 +85,13 @@ class DecoderOnlyModel(nn.Module):
86
85
  config.embedding_dim,
87
86
  config.final_norm_config,
88
87
  )
88
+ # ROPE parameters for all attn_configs are the same. Take the first one.
89
+ attn_config = config.block_config(0).attn_config
90
+ self.rope_cache = attn_utils.build_rope_cache(
91
+ size=config.kv_cache_max,
92
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
+ base=attn_config.rotary_base,
94
+ )
89
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
90
96
  size=config.kv_cache_max,
91
97
  )
@@ -107,22 +113,16 @@ class DecoderOnlyModel(nn.Module):
107
113
 
108
114
  # token embeddings of shape (b, t, n_embd)
109
115
  input_embeds = self.tok_embedding(tokens)
110
-
111
- # ROPE parameters for all attn_configs are the same. Take the first one.
112
- attn_config = self.config.block_config(0).attn_config
113
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
114
- rope = rotary_pos_emb.build_rope(
115
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
116
- )
117
-
116
+ cos, sin = self.rope_cache
117
+ rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118
118
  mask = self.mask_cache.index_select(2, input_pos)
119
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
120
 
121
- return self._forward_with_embeds(
121
+ return self.forward_with_embeds(
122
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
123
123
  )
124
124
 
125
- def _forward_with_embeds(
125
+ def forward_with_embeds(
126
126
  self,
127
127
  input_embeds: torch.Tensor,
128
128
  rope: Tuple[torch.Tensor, torch.Tensor],
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
141
141
  if self.config.embedding_scale is not None:
142
142
  x = x * self.config.embedding_scale
143
143
 
144
- updated_kv_entries = []
144
+ updated_kv_entires = []
145
145
  for i, block in enumerate(self.transformer_blocks):
146
146
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
147
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148
148
  if kv_entry:
149
- updated_kv_entries.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
149
+ updated_kv_entires.append(kv_entry)
150
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
151
151
 
152
152
  if export_config is not None:
153
153
  if (
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import Any,List
19
+ from typing import Any, List, Optional
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
134
134
  prompts: torch.Tensor,
135
135
  max_new_tokens: int,
136
136
  pixel_values: torch.Tensor = None,
137
- eos_token_id: int = 1,
137
+ eos_token_id: Optional[int] = None,
138
138
  ) -> torch.IntTensor:
139
139
  input_ids = prompts[0].int().tolist()
140
140
  tokens = torch.tensor([input_ids])
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
146
146
  )
147
147
  generated_token = logits[0][-1].argmax().item()
148
148
  input_ids.append(generated_token)
149
- if generated_token == eos_token_id:
149
+ if eos_token_id is not None and generated_token == eos_token_id:
150
150
  break
151
151
  tokens = torch.tensor([[generated_token]])
152
152
  input_pos = torch.tensor([len(input_ids) - 1])
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
253
253
  outputs_reauthored = reauthored_model.generate(
254
254
  prompt_tokens,
255
255
  max_new_tokens,
256
- eos_token_id=tokenizer.tokenizer.eos_token_id,
256
+ eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
257
257
  )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250105"
16
+ __version__ = "0.3.0.dev20250107"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250105
3
+ Version: 0.3.0.dev20250107
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,9 +1,9 @@
1
1
  ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
2
- ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
2
+ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=rEruohWdKGtxlBLh9SF_NnC4pbAqrOU4MKG598yJRHY,706
6
+ ai_edge_torch/version.py,sha256=X0ZEB5T3xcR8MsIE8VOHDAdHnCZTzJLBQQ9j2xZ4_qA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4J
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=VTM2nO3TqK2d1DyEb2MiHc-Tyw2lMcUXyOhvg0H5ENY,10147
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -109,7 +109,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
109
109
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
110
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
111
111
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
112
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
112
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
113
113
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
114
114
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
115
115
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
@@ -117,14 +117,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
117
117
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
118
118
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
119
119
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
120
- ai_edge_torch/generative/layers/attention.py,sha256=_OmamS3f0m_JtW73ljwGLwFPeMLL837JCLY-dJ3iRUg,12453
120
+ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
121
121
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
122
122
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
123
123
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
124
124
  ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
125
125
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
126
- ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
127
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=zbFTNgQdOT-tcKK1QaIX6fG-50syYwQX_ZbLhg2C98c,2691
126
+ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
127
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
128
128
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
129
129
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
130
130
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -149,12 +149,12 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
149
149
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
150
150
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
151
151
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
152
- ai_edge_torch/generative/utilities/model_builder.py,sha256=S08WNqVKCmxd2QjtMlwETd7J97UnlME_bTKdz5LMkGU,6352
152
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
153
153
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
154
154
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
155
155
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
156
156
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
157
- ai_edge_torch/generative/utilities/verifier.py,sha256=awO-sQrEpsFxIkZw72ysWZenYEmkLOLOuj62o2c7XeQ,11994
157
+ ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
158
158
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
159
159
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
160
160
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/METADATA,sha256=d8fPEhT1HG6ZlbX2joNTeIpEQNqth8LduM_W6aQZQn8,1966
208
- ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20250105.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/METADATA,sha256=p2F-coQaq7CbpMOkQLVnpFB01cCKqftVRGZ4dCVu8Ck,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/RECORD,,