ai-edge-torch-nightly 0.3.0.dev20250205__py3-none-any.whl → 0.3.0.dev20250207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +49 -4
 - ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +8 -5
 - ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +211 -0
 - ai_edge_torch/generative/examples/qwen_vl/verify.py +143 -0
 - ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py +1 -1
 - ai_edge_torch/odml_torch/debuginfo/__init__.py +1 -1
 - ai_edge_torch/odml_torch/debuginfo/_build.py +24 -0
 - ai_edge_torch/odml_torch/export.py +6 -1
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/METADATA +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/RECORD +14 -12
 - {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/WHEEL +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20250205.dist-info → ai_edge_torch_nightly-0.3.0.dev20250207.dist-info}/top_level.txt +0 -0
 
| 
         @@ -15,16 +15,61 @@ 
     | 
|
| 
       15 
15 
     | 
    
         | 
| 
       16 
16 
     | 
    
         
             
            """Example of building decoder for Qwen 2.5 VL models."""
         
     | 
| 
       17 
17 
     | 
    
         | 
| 
      
 18 
     | 
    
         
            +
            from typing import Optional, Tuple
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         
     | 
| 
       18 
21 
     | 
    
         
             
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       19 
22 
     | 
    
         
             
            from ai_edge_torch.generative.utilities import model_builder
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
      
 23 
     | 
    
         
            +
            import torch
         
     | 
| 
       21 
24 
     | 
    
         | 
| 
       22 
25 
     | 
    
         
             
            TENSOR_NAMES = model_builder.TENSOR_NAMES
         
     | 
| 
       23 
26 
     | 
    
         | 
| 
       24 
27 
     | 
    
         | 
| 
       25 
28 
     | 
    
         
             
            class Decoder(model_builder.DecoderOnlyModel):
         
     | 
| 
       26 
     | 
    
         
            -
              """A decoder for Qwen-VL model built from the Edge Generative API layers. 
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
      
 29 
     | 
    
         
            +
              """A decoder for Qwen-VL model built from the Edge Generative API layers.
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
              Besides a tensor of text token IDs, forward() can also take a tensor of
         
     | 
| 
      
 32 
     | 
    
         
            +
              embeddings which may include text or image or both.
         
     | 
| 
      
 33 
     | 
    
         
            +
              """
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
              @torch.inference_mode
         
     | 
| 
      
 36 
     | 
    
         
            +
              def forward(
         
     | 
| 
      
 37 
     | 
    
         
            +
                  self,
         
     | 
| 
      
 38 
     | 
    
         
            +
                  tokens: torch.Tensor,
         
     | 
| 
      
 39 
     | 
    
         
            +
                  input_pos: torch.Tensor,
         
     | 
| 
      
 40 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCache,
         
     | 
| 
      
 41 
     | 
    
         
            +
                  input_embeds: torch.Tensor = None,
         
     | 
| 
      
 42 
     | 
    
         
            +
                  rope: Tuple[torch.Tensor, torch.Tensor] = None,
         
     | 
| 
      
 43 
     | 
    
         
            +
                  mask: Optional[torch.Tensor] = None,
         
     | 
| 
      
 44 
     | 
    
         
            +
                  export_config: Optional[model_builder.ExportConfig] = None,
         
     | 
| 
      
 45 
     | 
    
         
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         
     | 
| 
      
 46 
     | 
    
         
            +
                if input_embeds is None:
         
     | 
| 
      
 47 
     | 
    
         
            +
                  _, seq_len = tokens.size()
         
     | 
| 
      
 48 
     | 
    
         
            +
                  assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
      
 49 
     | 
    
         
            +
                      f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
      
 50 
     | 
    
         
            +
                      f" {self.config.max_seq_len}"
         
     | 
| 
      
 51 
     | 
    
         
            +
                  )
         
     | 
| 
      
 52 
     | 
    
         
            +
                  # token embeddings of shape (b, t, n_embd)
         
     | 
| 
      
 53 
     | 
    
         
            +
                  input_embeds = self.tok_embedding(tokens)
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                if rope is None:
         
     | 
| 
      
 56 
     | 
    
         
            +
                  # ROPE parameters for all attn_configs are the same. Take the first one.
         
     | 
| 
      
 57 
     | 
    
         
            +
                  attn_config = self.config.block_config(0).attn_config
         
     | 
| 
      
 58 
     | 
    
         
            +
                  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
         
     | 
| 
      
 59 
     | 
    
         
            +
                  rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                if mask is None:
         
     | 
| 
      
 62 
     | 
    
         
            +
                  mask = self.mask_cache.index_select(2, input_pos)
         
     | 
| 
      
 63 
     | 
    
         
            +
                  mask = mask[:, :, :, : self.config.kv_cache_max]
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                return self._forward_with_embeds(
         
     | 
| 
      
 66 
     | 
    
         
            +
                    input_embeds,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    rope,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    mask,
         
     | 
| 
      
 69 
     | 
    
         
            +
                    input_pos,
         
     | 
| 
      
 70 
     | 
    
         
            +
                    kv_cache,
         
     | 
| 
      
 71 
     | 
    
         
            +
                    export_config=export_config,
         
     | 
| 
      
 72 
     | 
    
         
            +
                )
         
     | 
| 
       28 
73 
     | 
    
         | 
| 
       29 
74 
     | 
    
         | 
| 
       30 
75 
     | 
    
         
             
            def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
         @@ -82,7 +127,7 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig: 
     | 
|
| 
       82 
127 
     | 
    
         
             
              return config
         
     | 
| 
       83 
128 
     | 
    
         | 
| 
       84 
129 
     | 
    
         | 
| 
       85 
     | 
    
         
            -
            def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
         
     | 
| 
      
 130 
     | 
    
         
            +
            def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
         
     | 
| 
       86 
131 
     | 
    
         
             
              return model_builder.build_decoder_only_model(
         
     | 
| 
       87 
132 
     | 
    
         
             
                  checkpoint_path=checkpoint_path,
         
     | 
| 
       88 
133 
     | 
    
         
             
                  config=get_decoder_config(**kwargs),
         
     | 
| 
         @@ -356,6 +356,12 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig: 
     | 
|
| 
       356 
356 
     | 
    
         
             
            def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
         
     | 
| 
       357 
357 
     | 
    
         
             
              config = get_image_encoder_config()
         
     | 
| 
       358 
358 
     | 
    
         
             
              encoder = QwenVLImageEncoder(config)
         
     | 
| 
      
 359 
     | 
    
         
            +
              load_image_encoder(checkpoint_path, encoder)
         
     | 
| 
      
 360 
     | 
    
         
            +
              encoder.eval()
         
     | 
| 
      
 361 
     | 
    
         
            +
              return encoder
         
     | 
| 
      
 362 
     | 
    
         
            +
             
     | 
| 
      
 363 
     | 
    
         
            +
             
     | 
| 
      
 364 
     | 
    
         
            +
            def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
         
     | 
| 
       359 
365 
     | 
    
         
             
              loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         
     | 
| 
       360 
366 
     | 
    
         
             
              # Loose the strictness because only image encoder is being loaded.
         
     | 
| 
       361 
367 
     | 
    
         
             
              loader.load(encoder, strict=False)
         
     | 
| 
         @@ -365,15 +371,12 @@ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder: 
     | 
|
| 
       365 
371 
     | 
    
         
             
              state = merger_loader.get_state()
         
     | 
| 
       366 
372 
     | 
    
         
             
              w1_state = dict()
         
     | 
| 
       367 
373 
     | 
    
         
             
              w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
         
     | 
| 
       368 
     | 
    
         
            -
              if config.merger_config.use_bias:
         
     | 
| 
      
 374 
     | 
    
         
            +
              if encoder.config.merger_config.use_bias:
         
     | 
| 
       369 
375 
     | 
    
         
             
                w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
         
     | 
| 
       370 
376 
     | 
    
         
             
              encoder.merger.w1.load_state_dict(w1_state)
         
     | 
| 
       371 
377 
     | 
    
         | 
| 
       372 
378 
     | 
    
         
             
              w2_state = dict()
         
     | 
| 
       373 
379 
     | 
    
         
             
              w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
         
     | 
| 
       374 
     | 
    
         
            -
              if config.merger_config.use_bias:
         
     | 
| 
      
 380 
     | 
    
         
            +
              if encoder.config.merger_config.use_bias:
         
     | 
| 
       375 
381 
     | 
    
         
             
                w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
         
     | 
| 
       376 
382 
     | 
    
         
             
              encoder.merger.w2.load_state_dict(w2_state)
         
     | 
| 
       377 
     | 
    
         
            -
             
     | 
| 
       378 
     | 
    
         
            -
              encoder.eval()
         
     | 
| 
       379 
     | 
    
         
            -
              return encoder
         
     | 
| 
         @@ -0,0 +1,211 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Example of building a full-stack of Qwen 2.5 VL model."""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            import dataclasses
         
     | 
| 
      
 19 
     | 
    
         
            +
            from typing import List, Optional, Tuple
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            from ai_edge_torch.generative.examples.qwen_vl import decoder
         
     | 
| 
      
 22 
     | 
    
         
            +
            from ai_edge_torch.generative.examples.qwen_vl import image_encoder
         
     | 
| 
      
 23 
     | 
    
         
            +
            import ai_edge_torch.generative.layers.kv_cache as kv_utils
         
     | 
| 
      
 24 
     | 
    
         
            +
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
      
 25 
     | 
    
         
            +
            from ai_edge_torch.generative.utilities import model_builder
         
     | 
| 
      
 26 
     | 
    
         
            +
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
      
 27 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 28 
     | 
    
         
            +
            from torch import nn
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
            @dataclasses.dataclass
         
     | 
| 
      
 32 
     | 
    
         
            +
            class QwenVLConfig:
         
     | 
| 
      
 33 
     | 
    
         
            +
              """Qwen VL model configurations."""
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
              image_encoder_config: image_encoder.QwenVLImageConfig
         
     | 
| 
      
 36 
     | 
    
         
            +
              decoder_config: cfg.ModelConfig
         
     | 
| 
      
 37 
     | 
    
         
            +
              image_token_id: int
         
     | 
| 
      
 38 
     | 
    
         
            +
              mrope_section: List[int]
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
            class QwenVL(nn.Module):
         
     | 
| 
      
 42 
     | 
    
         
            +
              """Qwen VL model from the Edge Generative API."""
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
              def __init__(self, config: QwenVLConfig):
         
     | 
| 
      
 45 
     | 
    
         
            +
                super().__init__()
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                self.image_encoder = image_encoder.QwenVLImageEncoder(
         
     | 
| 
      
 48 
     | 
    
         
            +
                    config.image_encoder_config
         
     | 
| 
      
 49 
     | 
    
         
            +
                )
         
     | 
| 
      
 50 
     | 
    
         
            +
                self.decoder = decoder.Decoder(config.decoder_config)
         
     | 
| 
      
 51 
     | 
    
         
            +
                # The amount of adjustment in input_pos to calculate RoPE properly in
         
     | 
| 
      
 52 
     | 
    
         
            +
                # forward() calls after image is handled.
         
     | 
| 
      
 53 
     | 
    
         
            +
                self.rope_pos_adjust = 0
         
     | 
| 
      
 54 
     | 
    
         
            +
                self.config = config
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
              @torch.inference_mode
         
     | 
| 
      
 57 
     | 
    
         
            +
              def forward(
         
     | 
| 
      
 58 
     | 
    
         
            +
                  self,
         
     | 
| 
      
 59 
     | 
    
         
            +
                  tokens: torch.Tensor,
         
     | 
| 
      
 60 
     | 
    
         
            +
                  input_pos: torch.Tensor,
         
     | 
| 
      
 61 
     | 
    
         
            +
                  kv_cache: kv_utils.KVCache,
         
     | 
| 
      
 62 
     | 
    
         
            +
                  mask: Optional[torch.Tensor] = None,
         
     | 
| 
      
 63 
     | 
    
         
            +
                  pixel_values: torch.Tensor = None,
         
     | 
| 
      
 64 
     | 
    
         
            +
                  grid_thw: torch.Tensor = None,
         
     | 
| 
      
 65 
     | 
    
         
            +
                  export_config: Optional[model_builder.ExportConfig] = None,
         
     | 
| 
      
 66 
     | 
    
         
            +
              ) -> dict[torch.Tensor, kv_utils.KVCache]:
         
     | 
| 
      
 67 
     | 
    
         
            +
                if pixel_values is None:
         
     | 
| 
      
 68 
     | 
    
         
            +
                  return self.decoder(
         
     | 
| 
      
 69 
     | 
    
         
            +
                      tokens=tokens,
         
     | 
| 
      
 70 
     | 
    
         
            +
                      input_pos=input_pos,
         
     | 
| 
      
 71 
     | 
    
         
            +
                      kv_cache=kv_cache,
         
     | 
| 
      
 72 
     | 
    
         
            +
                      mask=mask,
         
     | 
| 
      
 73 
     | 
    
         
            +
                      rope=self._build_text_rope(input_pos),
         
     | 
| 
      
 74 
     | 
    
         
            +
                      input_embeds=None,
         
     | 
| 
      
 75 
     | 
    
         
            +
                      export_config=export_config,
         
     | 
| 
      
 76 
     | 
    
         
            +
                  )
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                input_embeds = self.decoder.tok_embedding(tokens)
         
     | 
| 
      
 79 
     | 
    
         
            +
                image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
         
     | 
| 
      
 82 
     | 
    
         
            +
                # can be done like:
         
     | 
| 
      
 83 
     | 
    
         
            +
                #
         
     | 
| 
      
 84 
     | 
    
         
            +
                #   image_mask = tokens == self.config.image_token_id
         
     | 
| 
      
 85 
     | 
    
         
            +
                #   image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
         
     | 
| 
      
 86 
     | 
    
         
            +
                #   input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
         
     | 
| 
      
 87 
     | 
    
         
            +
                #
         
     | 
| 
      
 88 
     | 
    
         
            +
                # Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
         
     | 
| 
      
 89 
     | 
    
         
            +
                # Assume that image is put at the beginning of the input sequence wrapped
         
     | 
| 
      
 90 
     | 
    
         
            +
                # with vision_start and vision_end tokens.
         
     | 
| 
      
 91 
     | 
    
         
            +
                input_embeds = torch.cat(
         
     | 
| 
      
 92 
     | 
    
         
            +
                    (
         
     | 
| 
      
 93 
     | 
    
         
            +
                        input_embeds[:, :1, :],
         
     | 
| 
      
 94 
     | 
    
         
            +
                        image_embeds,
         
     | 
| 
      
 95 
     | 
    
         
            +
                        input_embeds[:, image_embeds.shape[1] + 1 :, :],
         
     | 
| 
      
 96 
     | 
    
         
            +
                    ),
         
     | 
| 
      
 97 
     | 
    
         
            +
                    dim=1,
         
     | 
| 
      
 98 
     | 
    
         
            +
                )
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                return self.decoder(
         
     | 
| 
      
 101 
     | 
    
         
            +
                    tokens=None,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    input_pos=input_pos,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    kv_cache=kv_cache,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    mask=mask,
         
     | 
| 
      
 105 
     | 
    
         
            +
                    input_embeds=input_embeds,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    rope=self._build_multimodal_rope(input_pos, grid_thw),
         
     | 
| 
      
 107 
     | 
    
         
            +
                    export_config=export_config,
         
     | 
| 
      
 108 
     | 
    
         
            +
                )
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
              def _build_rope(
         
     | 
| 
      
 111 
     | 
    
         
            +
                  self, rope_pos: torch.Tensor
         
     | 
| 
      
 112 
     | 
    
         
            +
              ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 113 
     | 
    
         
            +
                # ROPE parameters for all attn_configs are the same. Take the first one.
         
     | 
| 
      
 114 
     | 
    
         
            +
                attn_config = self.config.decoder_config.block_config(0).attn_config
         
     | 
| 
      
 115 
     | 
    
         
            +
                n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
         
     | 
| 
      
 116 
     | 
    
         
            +
                return self.config.decoder_config.build_rope(
         
     | 
| 
      
 117 
     | 
    
         
            +
                    rope_pos, n_elem, attn_config.rotary_base
         
     | 
| 
      
 118 
     | 
    
         
            +
                )
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
              def _build_text_rope(
         
     | 
| 
      
 121 
     | 
    
         
            +
                  self, input_pos: torch.Tensor
         
     | 
| 
      
 122 
     | 
    
         
            +
              ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 123 
     | 
    
         
            +
                # Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
         
     | 
| 
      
 124 
     | 
    
         
            +
                # input_pos[0] = 0.
         
     | 
| 
      
 125 
     | 
    
         
            +
                if input_pos[0] == 0:
         
     | 
| 
      
 126 
     | 
    
         
            +
                  self.rope_pos_adjust = 0
         
     | 
| 
      
 127 
     | 
    
         
            +
                return self._build_rope(input_pos + self.rope_pos_adjust)
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
              def _build_multimodal_rope(
         
     | 
| 
      
 130 
     | 
    
         
            +
                  self, input_pos: torch.Tensor, grid_thw: torch.Tensor
         
     | 
| 
      
 131 
     | 
    
         
            +
              ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 132 
     | 
    
         
            +
                """Builds RoPE of multimodal input for the Qwen VL model.
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and
         
     | 
| 
      
 135 
     | 
    
         
            +
                simplified based on the assumption that an image is put at the beginning of
         
     | 
| 
      
 136 
     | 
    
         
            +
                the input sequence with vision start and vision end tokens.
         
     | 
| 
      
 137 
     | 
    
         
            +
                """
         
     | 
| 
      
 138 
     | 
    
         
            +
                spatial_merge_size = self.config.image_encoder_config.spatial_merge_size
         
     | 
| 
      
 139 
     | 
    
         
            +
                height = grid_thw[0][1] // spatial_merge_size
         
     | 
| 
      
 140 
     | 
    
         
            +
                width = grid_thw[0][2] // spatial_merge_size
         
     | 
| 
      
 141 
     | 
    
         
            +
                image_pos_max = max(height, width)
         
     | 
| 
      
 142 
     | 
    
         
            +
                image_pos_count = height * width
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
                # The position of vision end tokek and text tokens and after the image.
         
     | 
| 
      
 145 
     | 
    
         
            +
                text_pos_start = image_pos_max + 1
         
     | 
| 
      
 146 
     | 
    
         
            +
                text_pos_count = len(input_pos) - image_pos_count - 1
         
     | 
| 
      
 147 
     | 
    
         
            +
                text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count)
         
     | 
| 
      
 148 
     | 
    
         
            +
                # Set input_pos_adjust since text_pos_start has changed.
         
     | 
| 
      
 149 
     | 
    
         
            +
                self.rope_pos_adjust = image_pos_max - image_pos_count
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                temporal_rope = self._build_image_text_rope(
         
     | 
| 
      
 152 
     | 
    
         
            +
                    torch.ones(image_pos_count, dtype=torch.int), text_pos
         
     | 
| 
      
 153 
     | 
    
         
            +
                )
         
     | 
| 
      
 154 
     | 
    
         
            +
                height_rope = self._build_image_text_rope(
         
     | 
| 
      
 155 
     | 
    
         
            +
                    torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(),
         
     | 
| 
      
 156 
     | 
    
         
            +
                    text_pos,
         
     | 
| 
      
 157 
     | 
    
         
            +
                )
         
     | 
| 
      
 158 
     | 
    
         
            +
                width_rope = self._build_image_text_rope(
         
     | 
| 
      
 159 
     | 
    
         
            +
                    torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(),
         
     | 
| 
      
 160 
     | 
    
         
            +
                    text_pos,
         
     | 
| 
      
 161 
     | 
    
         
            +
                )
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                return (
         
     | 
| 
      
 164 
     | 
    
         
            +
                    self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]),
         
     | 
| 
      
 165 
     | 
    
         
            +
                    self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]),
         
     | 
| 
      
 166 
     | 
    
         
            +
                )
         
     | 
| 
      
 167 
     | 
    
         
            +
             
     | 
| 
      
 168 
     | 
    
         
            +
              def _build_image_text_rope(
         
     | 
| 
      
 169 
     | 
    
         
            +
                  self, image_pos: torch.Tensor, text_pos: torch.Tensor
         
     | 
| 
      
 170 
     | 
    
         
            +
              ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 171 
     | 
    
         
            +
                return self._build_rope(
         
     | 
| 
      
 172 
     | 
    
         
            +
                    torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos))
         
     | 
| 
      
 173 
     | 
    
         
            +
                )
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
              def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
         
     | 
| 
      
 176 
     | 
    
         
            +
                """Merges RoPE tensors based on apply_multimodal_rotary_pos_emb()."""
         
     | 
| 
      
 177 
     | 
    
         
            +
                split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1)
         
     | 
| 
      
 178 
     | 
    
         
            +
                return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
            def get_model_config(**kwargs) -> QwenVLConfig:
         
     | 
| 
      
 182 
     | 
    
         
            +
              """Returns the model config for a PaliGemma 3B-224 model.
         
     | 
| 
      
 183 
     | 
    
         
            +
             
     | 
| 
      
 184 
     | 
    
         
            +
              Returns:
         
     | 
| 
      
 185 
     | 
    
         
            +
                The model config for a PaliGemma 3B model.
         
     | 
| 
      
 186 
     | 
    
         
            +
              """
         
     | 
| 
      
 187 
     | 
    
         
            +
              return QwenVLConfig(
         
     | 
| 
      
 188 
     | 
    
         
            +
                  image_encoder_config=image_encoder.get_image_encoder_config(),
         
     | 
| 
      
 189 
     | 
    
         
            +
                  decoder_config=decoder.get_decoder_config(**kwargs),
         
     | 
| 
      
 190 
     | 
    
         
            +
                  image_token_id=151655,
         
     | 
| 
      
 191 
     | 
    
         
            +
                  mrope_section=[16, 24, 24],
         
     | 
| 
      
 192 
     | 
    
         
            +
              )
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
            def get_fake_model_config(**kwargs) -> QwenVLConfig:
         
     | 
| 
      
 196 
     | 
    
         
            +
              return QwenVLConfig(
         
     | 
| 
      
 197 
     | 
    
         
            +
                  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
         
     | 
| 
      
 198 
     | 
    
         
            +
                  decoder_config=decoder.get_fake_decoder_config(**kwargs),
         
     | 
| 
      
 199 
     | 
    
         
            +
                  image_token_id=127,
         
     | 
| 
      
 200 
     | 
    
         
            +
              )
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
             
     | 
| 
      
 203 
     | 
    
         
            +
            def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
         
     | 
| 
      
 204 
     | 
    
         
            +
              config = get_model_config(**kwargs)
         
     | 
| 
      
 205 
     | 
    
         
            +
              model = QwenVL(config)
         
     | 
| 
      
 206 
     | 
    
         
            +
              image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
         
     | 
| 
      
 207 
     | 
    
         
            +
              # Load the parameters of decoder.
         
     | 
| 
      
 208 
     | 
    
         
            +
              loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
         
     | 
| 
      
 209 
     | 
    
         
            +
              loader.load(model.decoder, strict=False)
         
     | 
| 
      
 210 
     | 
    
         
            +
              model.eval()
         
     | 
| 
      
 211 
     | 
    
         
            +
              return model
         
     | 
| 
         @@ -0,0 +1,143 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """Verifies the reauthored Qwen 2.5 VL model."""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 19 
     | 
    
         
            +
            import pathlib
         
     | 
| 
      
 20 
     | 
    
         
            +
            from absl import app
         
     | 
| 
      
 21 
     | 
    
         
            +
            from absl import flags
         
     | 
| 
      
 22 
     | 
    
         
            +
            from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
         
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.generative.layers import kv_cache
         
     | 
| 
      
 24 
     | 
    
         
            +
            from ai_edge_torch.generative.utilities import verifier
         
     | 
| 
      
 25 
     | 
    
         
            +
            from PIL import Image
         
     | 
| 
      
 26 
     | 
    
         
            +
            import requests
         
     | 
| 
      
 27 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 28 
     | 
    
         
            +
            import transformers
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
            _IMAGE_URL = flags.DEFINE_string(
         
     | 
| 
      
 31 
     | 
    
         
            +
                "image_url",
         
     | 
| 
      
 32 
     | 
    
         
            +
                "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
         
     | 
| 
      
 33 
     | 
    
         
            +
                "The image URI to encode.",
         
     | 
| 
      
 34 
     | 
    
         
            +
            )
         
     | 
| 
      
 35 
     | 
    
         
            +
            _PROMPTS = flags.DEFINE_string(
         
     | 
| 
      
 36 
     | 
    
         
            +
                "prompts",
         
     | 
| 
      
 37 
     | 
    
         
            +
                "<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
         
     | 
| 
      
 38 
     | 
    
         
            +
                "The input prompts to generate answers.",
         
     | 
| 
      
 39 
     | 
    
         
            +
            )
         
     | 
| 
      
 40 
     | 
    
         
            +
            _MAX_NEW_TOKENS = flags.DEFINE_integer(
         
     | 
| 
      
 41 
     | 
    
         
            +
                "max_new_tokens",
         
     | 
| 
      
 42 
     | 
    
         
            +
                30,
         
     | 
| 
      
 43 
     | 
    
         
            +
                "The maximum size of the generated tokens.",
         
     | 
| 
      
 44 
     | 
    
         
            +
            )
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
            class ReauthoredQwenVLWrapper(verifier.ReauthoredModelWrapper):
         
     | 
| 
      
 48 
     | 
    
         
            +
              """Reauthored Qwen VL model wrapper."""
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
              def __init__(self, model: torch.nn.Module):
         
     | 
| 
      
 51 
     | 
    
         
            +
                super().__init__(model)
         
     | 
| 
      
 52 
     | 
    
         
            +
                self.grid_thw = None
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
              def _init_kv_cache(self):
         
     | 
| 
      
 55 
     | 
    
         
            +
                return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
              def _get_extra_args_for_forward(self):
         
     | 
| 
      
 58 
     | 
    
         
            +
                return {"grid_thw": self.grid_thw}
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
            def main(_):
         
     | 
| 
      
 62 
     | 
    
         
            +
              checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
         
     | 
| 
      
 63 
     | 
    
         
            +
              logging.info("Loading the original model from: %s", checkpoint)
         
     | 
| 
      
 64 
     | 
    
         
            +
              original_model = (
         
     | 
| 
      
 65 
     | 
    
         
            +
                  transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
         
     | 
| 
      
 66 
     | 
    
         
            +
                      checkpoint
         
     | 
| 
      
 67 
     | 
    
         
            +
                  )
         
     | 
| 
      
 68 
     | 
    
         
            +
              )
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
              # Locate the cached dir.
         
     | 
| 
      
 71 
     | 
    
         
            +
              cached_config_file = transformers.utils.cached_file(
         
     | 
| 
      
 72 
     | 
    
         
            +
                  checkpoint, transformers.utils.CONFIG_NAME
         
     | 
| 
      
 73 
     | 
    
         
            +
              )
         
     | 
| 
      
 74 
     | 
    
         
            +
              reauthored_checkpoint = pathlib.Path(cached_config_file).parent
         
     | 
| 
      
 75 
     | 
    
         
            +
              logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
         
     | 
| 
      
 76 
     | 
    
         
            +
              reauthored_model = qwen_vl.build_model(reauthored_checkpoint)
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
              logging.info("Loading the processor from: %s", checkpoint)
         
     | 
| 
      
 79 
     | 
    
         
            +
              processor = transformers.AutoProcessor.from_pretrained(checkpoint)
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
              logging.info("Loading the image from: %s", _IMAGE_URL.value)
         
     | 
| 
      
 82 
     | 
    
         
            +
              image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
         
     | 
| 
      
 83 
     | 
    
         
            +
              inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
              logging.info("Verifying the reauthored model with model.forward()...")
         
     | 
| 
      
 86 
     | 
    
         
            +
              logging.info("Forwarding the original model...")
         
     | 
| 
      
 87 
     | 
    
         
            +
              outputs_original = original_model.forward(
         
     | 
| 
      
 88 
     | 
    
         
            +
                  input_ids=inputs["input_ids"],
         
     | 
| 
      
 89 
     | 
    
         
            +
                  pixel_values=inputs["pixel_values"],
         
     | 
| 
      
 90 
     | 
    
         
            +
                  image_grid_thw=inputs["image_grid_thw"],
         
     | 
| 
      
 91 
     | 
    
         
            +
              )
         
     | 
| 
      
 92 
     | 
    
         
            +
              outputs_original = outputs_original.logits
         
     | 
| 
      
 93 
     | 
    
         
            +
              logging.info("outputs_original: %s", outputs_original)
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
              logging.info("Forwarding the reauthored model...")
         
     | 
| 
      
 96 
     | 
    
         
            +
              wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
         
     | 
| 
      
 97 
     | 
    
         
            +
              wrapped_reauthored_model.grid_thw = inputs["image_grid_thw"]
         
     | 
| 
      
 98 
     | 
    
         
            +
              outputs_reauthored = wrapped_reauthored_model.forward(
         
     | 
| 
      
 99 
     | 
    
         
            +
                  tokens=inputs["input_ids"],
         
     | 
| 
      
 100 
     | 
    
         
            +
                  pixel_values=inputs["pixel_values"],
         
     | 
| 
      
 101 
     | 
    
         
            +
              )
         
     | 
| 
      
 102 
     | 
    
         
            +
              logging.info("outputs_reauthored: %s", outputs_reauthored)
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
              try:
         
     | 
| 
      
 105 
     | 
    
         
            +
                assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-01)
         
     | 
| 
      
 106 
     | 
    
         
            +
              except AssertionError as e:
         
     | 
| 
      
 107 
     | 
    
         
            +
                logging.error("*** FAILED *** verify with forward()")
         
     | 
| 
      
 108 
     | 
    
         
            +
                raise e
         
     | 
| 
      
 109 
     | 
    
         
            +
              else:
         
     | 
| 
      
 110 
     | 
    
         
            +
                logging.info("*** PASSED *** verify with forward()")
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
              logging.info("Verifying the reauthored model with model.generate()...")
         
     | 
| 
      
 113 
     | 
    
         
            +
              logging.info("Generating answer with the original model...")
         
     | 
| 
      
 114 
     | 
    
         
            +
              outputs_original = original_model.generate(
         
     | 
| 
      
 115 
     | 
    
         
            +
                  **inputs, max_new_tokens=_MAX_NEW_TOKENS.value
         
     | 
| 
      
 116 
     | 
    
         
            +
              )
         
     | 
| 
      
 117 
     | 
    
         
            +
              response_original = processor.decode(
         
     | 
| 
      
 118 
     | 
    
         
            +
                  outputs_original[0], skip_special_tokens=True
         
     | 
| 
      
 119 
     | 
    
         
            +
              )
         
     | 
| 
      
 120 
     | 
    
         
            +
              logging.info("outputs_from_original_model: [[%s]]", response_original)
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
              logging.info("Generating answer with the reauthored model...")
         
     | 
| 
      
 123 
     | 
    
         
            +
              outputs_reauthored = wrapped_reauthored_model.generate(
         
     | 
| 
      
 124 
     | 
    
         
            +
                  prompts=inputs["input_ids"],
         
     | 
| 
      
 125 
     | 
    
         
            +
                  pixel_values=inputs["pixel_values"],
         
     | 
| 
      
 126 
     | 
    
         
            +
                  max_new_tokens=_MAX_NEW_TOKENS.value,
         
     | 
| 
      
 127 
     | 
    
         
            +
              )
         
     | 
| 
      
 128 
     | 
    
         
            +
              response_reauthored = processor.decode(
         
     | 
| 
      
 129 
     | 
    
         
            +
                  outputs_reauthored[0], skip_special_tokens=True
         
     | 
| 
      
 130 
     | 
    
         
            +
              )
         
     | 
| 
      
 131 
     | 
    
         
            +
              logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
              try:
         
     | 
| 
      
 134 
     | 
    
         
            +
                assert response_original == response_reauthored
         
     | 
| 
      
 135 
     | 
    
         
            +
              except AssertionError as e:
         
     | 
| 
      
 136 
     | 
    
         
            +
                logging.error("*** FAILED *** verify with generate()")
         
     | 
| 
      
 137 
     | 
    
         
            +
                raise e
         
     | 
| 
      
 138 
     | 
    
         
            +
              else:
         
     | 
| 
      
 139 
     | 
    
         
            +
                logging.info("*** PASSED *** verify with generate()")
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
            if __name__ == "__main__":
         
     | 
| 
      
 143 
     | 
    
         
            +
              app.run(main)
         
     | 
| 
         @@ -12,5 +12,5 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            from ._build import build_mlir_debuginfo
         
     | 
| 
      
 15 
     | 
    
         
            +
            from ._build import build_mlir_debuginfo, build_mlir_file_debuginfo
         
     | 
| 
       16 
16 
     | 
    
         
             
            from ._op_polyfill import write_mlir_debuginfo_op
         
     | 
| 
         @@ -13,6 +13,7 @@ 
     | 
|
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         
             
            import torch
         
     | 
| 
      
 16 
     | 
    
         
            +
            import re
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
       17 
18 
     | 
    
         | 
| 
       18 
19 
     | 
    
         
             
            def _class_fullname(cls):
         
     | 
| 
         @@ -34,6 +35,29 @@ def _get_hierarchy(node: torch.fx.Node): 
     | 
|
| 
       34 
35 
     | 
    
         
             
              return hierachy_str
         
     | 
| 
       35 
36 
     | 
    
         | 
| 
       36 
37 
     | 
    
         | 
| 
      
 38 
     | 
    
         
            +
            def _get_canonical_filename(filename):
         
     | 
| 
      
 39 
     | 
    
         
            +
              """Remove unnecessary path prefix to make the filename more readable.
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
              This should be factored out so that pattern is a global option that a user
         
     | 
| 
      
 42 
     | 
    
         
            +
              can override.
         
     | 
| 
      
 43 
     | 
    
         
            +
              """
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
              # TODO: We should add a config option to provide a regex to strip from the
         
     | 
| 
      
 46 
     | 
    
         
            +
              # debug info. Currently absolute path is used.
         
     | 
| 
      
 47 
     | 
    
         
            +
              return filename
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
            def build_mlir_file_debuginfo(node: torch.fx.Node):
         
     | 
| 
      
 51 
     | 
    
         
            +
              """Build the file and line info for the given node's lowerings in MLIR."""
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
              if not node.stack_trace:
         
     | 
| 
      
 54 
     | 
    
         
            +
                return None, None
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
              # Note: This uses internal APIs and may break in the future.
         
     | 
| 
      
 57 
     | 
    
         
            +
              pt_trace = torch.fx.graph._parse_stack_trace(node.stack_trace)
         
     | 
| 
      
 58 
     | 
    
         
            +
              return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
       37 
61 
     | 
    
         
             
            def build_mlir_debuginfo(node: torch.fx.Node):
         
     | 
| 
       38 
62 
     | 
    
         
             
              """Build the debuginfo string for the given node's lowerings in MLIR."""
         
     | 
| 
       39 
63 
     | 
    
         | 
| 
         @@ -93,7 +93,12 @@ class LoweringInterpreter(torch.fx.Interpreter): 
     | 
|
| 
       93 
93 
     | 
    
         
             
                if info is None:
         
     | 
| 
       94 
94 
     | 
    
         
             
                  return ir.Location.unknown()
         
     | 
| 
       95 
95 
     | 
    
         | 
| 
       96 
     | 
    
         
            -
                 
     | 
| 
      
 96 
     | 
    
         
            +
                (file, line) = debuginfo.build_mlir_file_debuginfo(node)
         
     | 
| 
      
 97 
     | 
    
         
            +
                fileinfo = None
         
     | 
| 
      
 98 
     | 
    
         
            +
                if file is not None:
         
     | 
| 
      
 99 
     | 
    
         
            +
                  fileinfo = ir.Location.file(filename=file, line=line, col=0)
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                return ir.Location.name(name=info, childLoc=fileinfo)
         
     | 
| 
       97 
102 
     | 
    
         | 
| 
       98 
103 
     | 
    
         
             
              def run_node(self, node: torch.fx.Node):
         
     | 
| 
       99 
104 
     | 
    
         
             
                loc = self._build_loc(node)
         
     | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.1
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ai-edge-torch-nightly
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.3.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.3.0.dev20250207
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         
     | 
| 
       6 
6 
     | 
    
         
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         
     | 
| 
         @@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120 
     | 
|
| 
       2 
2 
     | 
    
         
             
            ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
         
     | 
| 
       3 
3 
     | 
    
         
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         
     | 
| 
       4 
4 
     | 
    
         
             
            ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
         
     | 
| 
       5 
     | 
    
         
            -
            ai_edge_torch/version.py,sha256= 
     | 
| 
      
 5 
     | 
    
         
            +
            ai_edge_torch/version.py,sha256=9V9FbxtqLT70Tzmv_G0qlbqixmVc0pPPJs22C_iBlHE,706
         
     | 
| 
       6 
6 
     | 
    
         
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       7 
7 
     | 
    
         
             
            ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
         
     | 
| 
       8 
8 
     | 
    
         
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         
     | 
| 
         @@ -94,9 +94,11 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehda 
     | 
|
| 
       94 
94 
     | 
    
         
             
            ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
         
     | 
| 
       95 
95 
     | 
    
         
             
            ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
         
     | 
| 
       96 
96 
     | 
    
         
             
            ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
         
     | 
| 
       97 
     | 
    
         
            -
            ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256= 
     | 
| 
       98 
     | 
    
         
            -
            ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256= 
     | 
| 
       99 
     | 
    
         
            -
            ai_edge_torch/generative/examples/qwen_vl/ 
     | 
| 
      
 97 
     | 
    
         
            +
            ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=0x4iDg2cBe3PFnjVce3nj7g2rjagGHcKqRCfbASNxA8,4402
         
     | 
| 
      
 98 
     | 
    
         
            +
            ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=OYyF0bLVYJno9azmKDqX3gT8ojYYWEyp_F8nLtltPWs,13544
         
     | 
| 
      
 99 
     | 
    
         
            +
            ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=Uzl1ZPkdYIaHN9QxezqxNwagZiGOHf1VreWnqgRQwf8,7627
         
     | 
| 
      
 100 
     | 
    
         
            +
            ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=2GPi0Vay4a69EwBSOfPMCMjE9PTwPOQus5j2KN7HE7I,5031
         
     | 
| 
      
 101 
     | 
    
         
            +
            ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
         
     | 
| 
       100 
102 
     | 
    
         
             
            ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=lQR8p6Zp7PxDN_erMf-FKLIn_Rv4BGyQHjDbModFkeY,2946
         
     | 
| 
       101 
103 
     | 
    
         
             
            ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       102 
104 
     | 
    
         
             
            ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
         
     | 
| 
         @@ -195,14 +197,14 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1 
     | 
|
| 
       195 
197 
     | 
    
         
             
            ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
         
     | 
| 
       196 
198 
     | 
    
         
             
            ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
         
     | 
| 
       197 
199 
     | 
    
         
             
            ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
         
     | 
| 
       198 
     | 
    
         
            -
            ai_edge_torch/odml_torch/export.py,sha256= 
     | 
| 
      
 200 
     | 
    
         
            +
            ai_edge_torch/odml_torch/export.py,sha256=LDyZUehM1lmT3y2bGeA94rMGRUTLxzIUm4DTlCA8tQc,13640
         
     | 
| 
       199 
201 
     | 
    
         
             
            ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
         
     | 
| 
       200 
202 
     | 
    
         
             
            ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
         
     | 
| 
       201 
203 
     | 
    
         
             
            ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
         
     | 
| 
       202 
204 
     | 
    
         
             
            ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
         
     | 
| 
       203 
205 
     | 
    
         
             
            ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
         
     | 
| 
       204 
     | 
    
         
            -
            ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256= 
     | 
| 
       205 
     | 
    
         
            -
            ai_edge_torch/odml_torch/debuginfo/_build.py,sha256= 
     | 
| 
      
 206 
     | 
    
         
            +
            ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=DoE3HgAtV_GNKGBDGzH2Lb7JUHvyH7TUqWbDZIObr34,789
         
     | 
| 
      
 207 
     | 
    
         
            +
            ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=sjpYeqgdbDmD7lhp80yc8jfWq-HxX3xuQ58ND8ZeU-I,2213
         
     | 
| 
       206 
208 
     | 
    
         
             
            ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
         
     | 
| 
       207 
209 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
         
     | 
| 
       208 
210 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
         
     | 
| 
         @@ -227,8 +229,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 
     | 
|
| 
       227 
229 
     | 
    
         
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       228 
230 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         
     | 
| 
       229 
231 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         
     | 
| 
       230 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       231 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       232 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       233 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       234 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
      
 232 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         
     | 
| 
      
 233 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/METADATA,sha256=pvcJfgIOezx3rNegfvMIVrkFXmZuqnnE_zMzC9Wt37k,1966
         
     | 
| 
      
 234 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
         
     | 
| 
      
 235 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         
     | 
| 
      
 236 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250207.dist-info/RECORD,,
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |