ai-edge-torch-nightly 0.3.0.dev20241218__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +43 -25
  3. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
  4. ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
  5. ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
  6. ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
  7. ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
  8. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  9. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  11. ai_edge_torch/generative/layers/attention.py +4 -29
  12. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
  13. ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
  14. ai_edge_torch/generative/utilities/model_builder.py +14 -14
  15. ai_edge_torch/generative/utilities/verifier.py +22 -22
  16. ai_edge_torch/odml_torch/export.py +6 -1
  17. ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
  18. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  19. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
  20. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  21. ai_edge_torch/version.py +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
  23. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +26 -23
  24. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20241218.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,72 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of PaliGemma2 3B model."""
17
+
18
+ import logging
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
26
+ import transformers
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
38
+
39
+
40
+ def main(_):
41
+ checkpoint = kagglehub.model_download(
42
+ "google/paligemma-2/transformers/paligemma2-3b-pt-224"
43
+ )
44
+ logging.info("Loading the original model from: %s", checkpoint)
45
+ original_full_model = (
46
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
47
+ )
48
+ original_language_model = original_full_model.eval().language_model
49
+
50
+ logging.info("Building the reauthored model from: %s", checkpoint)
51
+ reauthored_model = decoder2.build_decoder2(checkpoint)
52
+
53
+ logging.info("Loading the tokenizer from: %s", checkpoint)
54
+ # It works only when GemmaTokenizerFast is available. In some environments,
55
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
56
+ # sentencepiece model file properly.
57
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_language_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
65
+ generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
67
+ atol=1e-04,
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ app.run(main)
@@ -20,31 +20,48 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ import kagglehub
23
24
  from PIL import Image
24
25
  import requests
25
26
  import torch
26
27
  import transformers
27
28
 
29
+ _VERSION = flags.DEFINE_enum(
30
+ "version",
31
+ "2",
32
+ ["1", "2"],
33
+ "The version of PaliGemma vision model to verify.",
34
+ )
28
35
  _IMAGE_URL = flags.DEFINE_string(
29
36
  "image_url",
30
37
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
38
  "The image URI to encode.",
32
39
  )
33
40
 
41
+ _CHECKPOINT = {
42
+ "1": "google/paligemma-3b-mix-224",
43
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
44
+ }
45
+
34
46
 
35
47
  def main(_):
36
- checkpoint = "google/paligemma-3b-mix-224"
48
+ if _VERSION.value == "1":
49
+ checkpoint = _CHECKPOINT[_VERSION.value]
50
+ # Locate the cached dir.
51
+ cached_config_file = transformers.utils.cached_file(
52
+ checkpoint, transformers.utils.CONFIG_NAME
53
+ )
54
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
55
+ else:
56
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
57
+ reauthored_checkpoint = checkpoint
58
+
37
59
  logging.info("Loading the original model from: %s", checkpoint)
38
60
  original_full_model = (
39
61
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
62
  )
41
63
  original_vision_model = original_full_model.eval().vision_tower
42
64
 
43
- # Locate the cached dir.
44
- cached_config_file = transformers.utils.cached_file(
45
- checkpoint, transformers.utils.CONFIG_NAME
46
- )
47
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
65
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
66
  reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
67
 
@@ -69,7 +86,7 @@ def main(_):
69
86
 
70
87
  try:
71
88
  assert torch.allclose(
72
- outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
89
+ outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
73
90
  )
74
91
  except AssertionError as e:
75
92
  logging.error("*** FAILED *** verify with an image")
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entires = []
75
+ updated_kv_entries = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
80
+ updated_kv_entries.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -26,33 +26,6 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
 
29
- def _embed_rope(
30
- q: torch.Tensor,
31
- k: torch.Tensor,
32
- n_elem: int,
33
- rope: Tuple[torch.Tensor, torch.Tensor],
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """Embed rotary positional embedding for query and key.
36
-
37
- Args:
38
- q (torch.Tensor): query tensor.
39
- k (torch.Tensor): key tensor.
40
- n_elem (int): number of elements to embed rotarty positional embedding.
41
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
- """
43
- if n_elem > 0:
44
- cos, sin = rope
45
- q_roped = rotary_pos_emb.apply_rope(
46
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
- )
48
- k_roped = rotary_pos_emb.apply_rope(
49
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
- )
51
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
- return q, k
54
-
55
-
56
29
  class TransformerBlock(nn.Module):
57
30
 
58
31
  def __init__(
@@ -238,7 +211,8 @@ class CausalSelfAttention(nn.Module):
238
211
  if rope is not None:
239
212
  # Compute rotary positional embedding for query and key.
240
213
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
- q, k = _embed_rope(q, k, n_elem, rope)
214
+ cos, sin = rope
215
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
242
216
 
243
217
  if kv_cache is not None:
244
218
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -374,7 +348,8 @@ class CrossAttention(nn.Module):
374
348
  if rope is not None:
375
349
  # Compute rotary positional embedding for query and key.
376
350
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
351
+ cos, sin = rope
352
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
378
353
 
379
354
  if kv_cache is not None:
380
355
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -32,57 +32,64 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
- x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
- rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
- roped = (x * cos) + (rotated * sin)
35
+ x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
+ left = x1 * cos - x2 * sin
37
+ right = x2 * cos + x1 * sin
38
+ roped = torch.cat([left, right], dim=-1)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def apply_rope_inline(
43
- q: torch.Tensor,
44
- k: torch.Tensor,
42
+ def build_rope(
45
43
  input_pos: torch.Tensor,
46
44
  n_elem: int,
45
+ head_dim: int,
47
46
  base: int = 10_000,
48
47
  ) -> Tuple[torch.Tensor, torch.Tensor]:
49
- """Computes rotary positional embedding inline for a query and key.
48
+ """Computes rotary positional embedding cosine and sine tensors.
50
49
 
51
50
  Args:
52
- q: the query tensor.
53
- k: the key tensor.
54
51
  input_pos: the sequence indices for the query and key
55
52
  n_elem: number of elements of the head dimension for RoPE computation
53
+ base: the base of the exponentiated value for RoPE.
56
54
 
57
55
  Returns:
58
- output the RoPE'd query and key.
56
+ cos, sin tensors
59
57
  """
60
58
 
61
59
  if n_elem <= 0:
62
- return q, k
60
+ return None, None
63
61
 
64
62
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
63
  freq_exponents = (2.0 / n_elem) * torch.arange(
66
- q.shape[-1] // 2, dtype=torch.float32
64
+ head_dim // 2, dtype=torch.float32
67
65
  )
68
66
  timescale = float(base) ** freq_exponents
69
67
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
68
  0
71
69
  ).unsqueeze(0)
72
- cos = torch.cos(radians).type_as(q)
73
- sin = torch.sin(radians).type_as(q)
70
+ cos = torch.cos(radians)
71
+ sin = torch.sin(radians)
72
+ return cos, sin
73
+
74
74
 
75
- def apply(x, sin, cos):
76
- x = x.transpose(1, 2)
77
- b, h, s, d = x.shape
78
- ans = torch.split(x, d // 2, dim=-1)
79
- x1, x2 = ans
80
- left = x1 * cos - x2 * sin
81
- right = x2 * cos + x1 * sin
82
- res = torch.cat([left, right], dim=-1)
83
- res = res.transpose(1, 2)
84
- return res
75
+ def apply_rope_inline(
76
+ q: torch.Tensor,
77
+ k: torch.Tensor,
78
+ cos: torch.Tensor,
79
+ sin: torch.Tensor,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ """Computes rotary positional embedding inline for a query and key.
82
+
83
+ Args:
84
+ q: the query tensor.
85
+ k: the key tensor.
86
+ cos: the cosine tensor.
87
+ sin: the sine tensor.
88
+
89
+ Returns:
90
+ output the RoPE'd query and key.
91
+ """
85
92
 
86
- q_roped = apply(q, sin, cos)
87
- k_roped = apply(k, sin, cos)
93
+ q_roped = apply_rope(q, cos, sin)
94
+ k_roped = apply_rope(k, cos, sin)
88
95
  return q_roped, k_roped
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
+ from ai_edge_torch.generative.examples.paligemma import decoder
25
+ from ai_edge_torch.generative.examples.paligemma import decoder2
24
26
  from ai_edge_torch.generative.examples.paligemma import paligemma
25
27
  from ai_edge_torch.generative.examples.phi import phi2
26
28
  from ai_edge_torch.generative.examples.phi import phi3
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
171
173
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172
174
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
173
175
 
174
- @googletest.skipIf(
175
- ai_edge_torch.config.in_oss,
176
- reason="tests with custom ops are not supported in oss",
177
- )
178
- def disabled_test_paligemma(self):
179
- config = paligemma.get_fake_model_config()
180
- pytorch_model = paligemma.PaliGemma(config).eval()
176
+ def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
177
+ config = paligemma.get_fake_model_config(decoder_config)
178
+ pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
181
179
 
182
180
  image_embedding_config = config.image_encoder_config.image_embedding
183
181
  num_patches = (
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
215
213
  kv,
216
214
  pixel_values=pixel_values,
217
215
  signature_name="prefill_pixel",
218
- atol=1e-3,
219
- rtol=1e-5,
216
+ atol=atol,
217
+ rtol=rtol,
220
218
  )
221
219
  )
222
220
 
221
+ @googletest.skipIf(
222
+ ai_edge_torch.config.in_oss,
223
+ reason="tests with custom ops are not supported in oss",
224
+ )
225
+ def disabled_test_paligemma1(self):
226
+ self._test_paligemma_model(
227
+ decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
228
+ )
229
+
230
+ @googletest.skipIf(
231
+ ai_edge_torch.config.in_oss,
232
+ reason="tests with custom ops are not supported in oss",
233
+ )
234
+ def disabled_test_paligemma2(self):
235
+ self._test_paligemma_model(
236
+ decoder2.Decoder2,
237
+ decoder2.get_fake_decoder2_config,
238
+ atol=1e-3,
239
+ rtol=1e-5,
240
+ )
241
+
223
242
  @googletest.skipIf(
224
243
  ai_edge_torch.config.in_oss,
225
244
  reason="tests with custom ops are not supported in oss",
@@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
27
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
27
28
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
29
  import torch
29
30
  from torch import nn
@@ -85,13 +86,6 @@ class DecoderOnlyModel(nn.Module):
85
86
  config.embedding_dim,
86
87
  config.final_norm_config,
87
88
  )
88
- # ROPE parameters for all attn_configs are the same. Take the first one.
89
- attn_config = config.block_config(0).attn_config
90
- self.rope_cache = attn_utils.build_rope_cache(
91
- size=config.kv_cache_max,
92
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
- base=attn_config.rotary_base,
94
- )
95
89
  self.mask_cache = attn_utils.build_causal_mask_cache(
96
90
  size=config.kv_cache_max,
97
91
  )
@@ -113,16 +107,22 @@ class DecoderOnlyModel(nn.Module):
113
107
 
114
108
  # token embeddings of shape (b, t, n_embd)
115
109
  input_embeds = self.tok_embedding(tokens)
116
- cos, sin = self.rope_cache
117
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
110
+
111
+ # ROPE parameters for all attn_configs are the same. Take the first one.
112
+ attn_config = self.config.block_config(0).attn_config
113
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
114
+ rope = rotary_pos_emb.build_rope(
115
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
116
+ )
117
+
118
118
  mask = self.mask_cache.index_select(2, input_pos)
119
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
120
 
121
- return self.forward_with_embeds(
121
+ return self._forward_with_embeds(
122
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
123
123
  )
124
124
 
125
- def forward_with_embeds(
125
+ def _forward_with_embeds(
126
126
  self,
127
127
  input_embeds: torch.Tensor,
128
128
  rope: Tuple[torch.Tensor, torch.Tensor],
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
141
141
  if self.config.embedding_scale is not None:
142
142
  x = x * self.config.embedding_scale
143
143
 
144
- updated_kv_entires = []
144
+ updated_kv_entries = []
145
145
  for i, block in enumerate(self.transformer_blocks):
146
146
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
147
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148
148
  if kv_entry:
149
- updated_kv_entires.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
149
+ updated_kv_entries.append(kv_entry)
150
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
151
151
 
152
152
  if export_config is not None:
153
153
  if (
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List
19
+ from typing import Any,List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -87,6 +87,10 @@ class ReauthoredModelWrapper(ModelWrapper):
87
87
  """Returns an initialized KV cache."""
88
88
  return kv_utils.KVCache.from_model_config(self.model.config)
89
89
 
90
+ def _get_extra_args_for_forward(self) -> dict[str, Any]:
91
+ """Returns extra arguments for the forward() method."""
92
+ return {}
93
+
90
94
  def _forward_with_kv_cache(
91
95
  self,
92
96
  tokens: torch.Tensor,
@@ -105,26 +109,15 @@ class ReauthoredModelWrapper(ModelWrapper):
105
109
  Returns:
106
110
  The output logits and the updated KV cache.
107
111
  """
108
- # Verification requires logit outputs on prefill for comparison.
109
- if (
110
- self.export_config is not None
111
- and not self.export_config.output_logits_on_prefill
112
- ):
113
- raise ValueError("Verifier requires logit output on prefill.")
114
- # Since the reauthored model doesn't include keyword arguments, pass
115
- # pixel_values only when it is not None. Otherwise, it may raise an error.
116
- if pixel_values is None:
117
- output = self.model.forward(
118
- tokens, input_pos, kv_cache, export_config=self.export_config
119
- )
120
- else:
121
- output = self.model.forward(
122
- tokens,
123
- input_pos,
124
- kv_cache,
125
- pixel_values=pixel_values,
126
- export_config=self.export_config,
127
- )
112
+ extra_args = self._get_extra_args_for_forward()
113
+ if self.export_config is not None:
114
+ # Verification requires logit outputs on prefill for comparison.
115
+ if not self.export_config.output_logits_on_prefill:
116
+ raise ValueError("Verifier requires logit output on prefill.")
117
+ extra_args["export_config"] = self.export_config
118
+ if pixel_values is not None:
119
+ extra_args["pixel_values"] = pixel_values
120
+ output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
128
121
  return output["logits"], output["kv_cache"]
129
122
 
130
123
  def forward(
@@ -141,6 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
141
134
  prompts: torch.Tensor,
142
135
  max_new_tokens: int,
143
136
  pixel_values: torch.Tensor = None,
137
+ eos_token_id: int = 1,
144
138
  ) -> torch.IntTensor:
145
139
  input_ids = prompts[0].int().tolist()
146
140
  tokens = torch.tensor([input_ids])
@@ -152,6 +146,8 @@ class ReauthoredModelWrapper(ModelWrapper):
152
146
  )
153
147
  generated_token = logits[0][-1].argmax().item()
154
148
  input_ids.append(generated_token)
149
+ if generated_token == eos_token_id:
150
+ break
155
151
  tokens = torch.tensor([[generated_token]])
156
152
  input_pos = torch.tensor([len(input_ids) - 1])
157
153
  pixel_values = None # Pass only for the first time.
@@ -254,7 +250,11 @@ def verify_model_with_prompts(
254
250
  logging.info("outputs_from_original_model: [[%s]]", response_original)
255
251
 
256
252
  logging.info("Generating answer with the reauthored model...")
257
- outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
253
+ outputs_reauthored = reauthored_model.generate(
254
+ prompt_tokens,
255
+ max_new_tokens,
256
+ eos_token_id=tokenizer.tokenizer.eos_token_id,
257
+ )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
260
260
 
@@ -198,7 +198,12 @@ class MlirLowered:
198
198
  # build, which may not have the same StableHLO version as what used in
199
199
  # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
200
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
- target_version = stablehlo.get_minimum_version()
201
+ if stablehlo.get_api_version() < 9:
202
+ target_version = stablehlo.get_minimum_version()
203
+ else:
204
+ target_version = stablehlo.get_version_from_compatibility_requirement(
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_4
206
+ )
202
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
208
  self.module_bytecode, target_version
204
209
  )
@@ -12,4 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
15
+ from ai_edge_torch.odml_torch.jax_bridge import _wrap
16
+ from ai_edge_torch.odml_torch.jax_bridge import utils
17
+
18
+ wrap = _wrap.wrap
@@ -18,6 +18,7 @@ from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
20
  from . import _quantized_decomposed
21
+ from . import _rand
21
22
  from . import context
22
23
  from . import registry
23
24
  from . import utils
@@ -26,6 +26,7 @@ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
26
26
 
27
27
  LoweringContext = context.LoweringContext
28
28
 
29
+
29
30
  @functools.cache
30
31
  def _log_usage(op):
31
32
  logging.warning("Use jax lowering: %s", str(op))
@@ -184,8 +185,6 @@ lower_by_torch_xla2(torch.ops.aten.permute_copy)
184
185
  lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
185
186
  lower_by_torch_xla2(torch.ops.aten.pow)
186
187
  lower_by_torch_xla2(torch.ops.aten.prod)
187
- lower_by_torch_xla2(torch.ops.aten.rand)
188
- lower_by_torch_xla2(torch.ops.aten.randn)
189
188
  lower_by_torch_xla2(torch.ops.aten.reciprocal)
190
189
  lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
191
190
  lower_by_torch_xla2(torch.ops.aten.relu)