ai-edge-torch-nightly 0.3.0.dev20241216__py3-none-any.whl → 0.3.0.dev20241220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +5 -1
- ai_edge_torch/_convert/converter.py +8 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +14 -15
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
- ai_edge_torch/generative/utilities/model_builder.py +11 -12
- ai_edge_torch/lowertools/_shim.py +4 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +4 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +6 -0
- ai_edge_torch/odml_torch/export.py +4 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +5 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/RECORD +18 -18
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/top_level.txt +0 -0
| @@ -78,7 +78,8 @@ def convert_signatures( | |
| 78 78 | 
             
                *,
         | 
| 79 79 | 
             
                strict_export: Union[Literal["auto"], bool] = True,
         | 
| 80 80 | 
             
                quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 81 | 
            -
                _tfl_converter_flags: Optional[dict[str, Any]],
         | 
| 81 | 
            +
                _tfl_converter_flags: Optional[dict[str, Any]] = None,
         | 
| 82 | 
            +
                _saved_model_dir: Optional[str] = None,
         | 
| 82 83 | 
             
            ) -> model.TfLiteModel:
         | 
| 83 84 | 
             
              """Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
         | 
| 84 85 |  | 
| @@ -93,6 +94,8 @@ def convert_signatures( | |
| 93 94 | 
             
                  quant_config: User-defined quantization method and scheme of the model.
         | 
| 94 95 | 
             
                  _tfl_converter_flags: A nested dictionary allowing setting flags for the
         | 
| 95 96 | 
             
                    underlying tflite converter.
         | 
| 97 | 
            +
                  _saved_model_dir: Directory for the intermediate saved model. If not
         | 
| 98 | 
            +
                    specified, a random temporary directory would be used.
         | 
| 96 99 |  | 
| 97 100 | 
             
              Returns:
         | 
| 98 101 | 
             
                The converted `model.TfLiteModel` object.
         | 
| @@ -140,6 +143,7 @@ def convert_signatures( | |
| 140 143 | 
             
                  signatures,
         | 
| 141 144 | 
             
                  quant_config=quant_config,
         | 
| 142 145 | 
             
                  _tfl_converter_flags=_tfl_converter_flags,
         | 
| 146 | 
            +
                  _saved_model_dir=_saved_model_dir,
         | 
| 143 147 | 
             
              )
         | 
| 144 148 |  | 
| 145 149 | 
             
              return model.TfLiteModel(tflite_model)
         | 
| @@ -106,6 +106,7 @@ class Converter: | |
| 106 106 | 
             
                  quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 107 107 | 
             
                  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
         | 
| 108 108 | 
             
                  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
         | 
| 109 | 
            +
                  _saved_model_dir: Optional[str] = None,
         | 
| 109 110 | 
             
              ) -> model.TfLiteModel:
         | 
| 110 111 | 
             
                """Finalizes the conversion and produces an edge model.
         | 
| 111 112 |  | 
| @@ -139,6 +140,8 @@ class Converter: | |
| 139 140 | 
             
                    of this function and so needs to be treated as such. Please do not rely
         | 
| 140 141 | 
             
                    on this parameter except for local debugging as this can be removed in a
         | 
| 141 142 | 
             
                    future release.
         | 
| 143 | 
            +
                  _saved_model_dir: Directory for the intermediate saved model. If not
         | 
| 144 | 
            +
                    specified, a random temporary directory would be used.
         | 
| 142 145 |  | 
| 143 146 | 
             
                Returns:
         | 
| 144 147 | 
             
                  The converted edge model.
         | 
| @@ -171,6 +174,7 @@ class Converter: | |
| 171 174 | 
             
                    strict_export=strict_export,
         | 
| 172 175 | 
             
                    quant_config=quant_config,
         | 
| 173 176 | 
             
                    _tfl_converter_flags=_ai_edge_converter_flags,
         | 
| 177 | 
            +
                    _saved_model_dir=_saved_model_dir,
         | 
| 174 178 | 
             
                )
         | 
| 175 179 |  | 
| 176 180 |  | 
| @@ -216,6 +220,7 @@ def convert( | |
| 216 220 | 
             
                quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 217 221 | 
             
                dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
         | 
| 218 222 | 
             
                _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
         | 
| 223 | 
            +
                _saved_model_dir: Optional[str] = None,
         | 
| 219 224 | 
             
            ) -> model.TfLiteModel:
         | 
| 220 225 | 
             
              """Converts a PyTorch model to an edge model with a default signature.
         | 
| 221 226 |  | 
| @@ -240,6 +245,8 @@ def convert( | |
| 240 245 | 
             
                  this function and so needs to be treated as such. Please do not rely on
         | 
| 241 246 | 
             
                  this parameter except for local debugging as this can be removed in a
         | 
| 242 247 | 
             
                  future release.
         | 
| 248 | 
            +
                _saved_model_dir: Directory for the intermediate saved model. If not
         | 
| 249 | 
            +
                  specified, a random temporary directory would be used.
         | 
| 243 250 |  | 
| 244 251 | 
             
              Returns:
         | 
| 245 252 | 
             
                The converted edge model.
         | 
| @@ -259,4 +266,5 @@ def convert( | |
| 259 266 | 
             
                  quant_config=quant_config,
         | 
| 260 267 | 
             
                  dynamic_shapes=dynamic_shapes,
         | 
| 261 268 | 
             
                  _ai_edge_converter_flags=_ai_edge_converter_flags,
         | 
| 269 | 
            +
                  _saved_model_dir=_saved_model_dir,
         | 
| 262 270 | 
             
              )
         | 
| @@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder | |
| 22 22 | 
             
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         | 
| 23 23 | 
             
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         | 
| 24 24 | 
             
            import ai_edge_torch.generative.layers.model_config as cfg
         | 
| 25 | 
            +
            import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
         | 
| 25 26 | 
             
            from ai_edge_torch.generative.utilities import model_builder
         | 
| 26 27 | 
             
            import ai_edge_torch.generative.utilities.loader as loading_utils
         | 
| 27 28 | 
             
            import torch
         | 
| @@ -103,17 +104,12 @@ class Gemma2(nn.Module): | |
| 103 104 | 
             
                    config.embedding_dim,
         | 
| 104 105 | 
             
                    config.final_norm_config,
         | 
| 105 106 | 
             
                )
         | 
| 106 | 
            -
                # Gemma2 has same hyper parameters for each layer except for attention
         | 
| 107 | 
            -
                # types. Use the first layer.
         | 
| 108 | 
            -
                attn_config = config.block_config(0).attn_config
         | 
| 109 | 
            -
                self.rope_cache = attn_utils.build_rope_cache(
         | 
| 110 | 
            -
                    size=config.kv_cache_max,
         | 
| 111 | 
            -
                    dim=int(attn_config.rotary_percentage * attn_config.head_dim),
         | 
| 112 | 
            -
                    base=attn_config.rotary_base,
         | 
| 113 | 
            -
                )
         | 
| 114 107 | 
             
                self.mask_cache = attn_utils.build_causal_mask_cache(
         | 
| 115 108 | 
             
                    size=config.kv_cache_max,
         | 
| 116 109 | 
             
                )
         | 
| 110 | 
            +
                # Gemma2 has same hyper parameters for each layer except for attention
         | 
| 111 | 
            +
                # types. Use the first layer.
         | 
| 112 | 
            +
                attn_config = config.block_config(0).attn_config
         | 
| 117 113 | 
             
                self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
         | 
| 118 114 | 
             
                    size=config.kv_cache_max,
         | 
| 119 115 | 
             
                    window_size=attn_config.sliding_window_size,
         | 
| @@ -145,24 +141,27 @@ class Gemma2(nn.Module): | |
| 145 141 | 
             
                    " must be the same."
         | 
| 146 142 | 
             
                )
         | 
| 147 143 |  | 
| 148 | 
            -
                 | 
| 149 | 
            -
                 | 
| 150 | 
            -
                 | 
| 144 | 
            +
                # RoPE parameters are the same for all blocks. Use the first layer.
         | 
| 145 | 
            +
                attn_config = self.config.block_config(0).attn_config
         | 
| 146 | 
            +
                n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
         | 
| 147 | 
            +
                rope = rotary_pos_emb.build_rope(
         | 
| 148 | 
            +
                    input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
         | 
| 149 | 
            +
                )
         | 
| 151 150 |  | 
| 152 151 | 
             
                # token embeddings of shape (b, t, n_embd)
         | 
| 153 152 | 
             
                x = self.tok_embedding(tokens)
         | 
| 154 153 | 
             
                x = x * (self.config.embedding_dim**0.5)
         | 
| 155 154 |  | 
| 156 | 
            -
                 | 
| 155 | 
            +
                updated_kv_entries = []
         | 
| 157 156 | 
             
                for i, block in enumerate(self.transformer_blocks):
         | 
| 158 157 | 
             
                  mask = self.get_attention_mask(
         | 
| 159 158 | 
             
                      block.config.attn_config.attn_type, input_pos
         | 
| 160 159 | 
             
                  )
         | 
| 161 160 | 
             
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         | 
| 162 | 
            -
                  x, kv_entry = block(x,  | 
| 161 | 
            +
                  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
         | 
| 163 162 | 
             
                  if kv_entry:
         | 
| 164 | 
            -
                     | 
| 165 | 
            -
                updated_kv_cache = kv_utils.KVCache(tuple( | 
| 163 | 
            +
                    updated_kv_entries.append(kv_entry)
         | 
| 164 | 
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
         | 
| 166 165 |  | 
| 167 166 | 
             
                if export_config is not None:
         | 
| 168 167 | 
             
                  if (
         | 
| @@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module): | |
| 72 72 | 
             
                mask = self.mask_cache.index_select(2, input_pos)
         | 
| 73 73 | 
             
                mask = mask[:, :, :, : self.config.max_seq_len]
         | 
| 74 74 |  | 
| 75 | 
            -
                 | 
| 75 | 
            +
                updated_kv_entries = []
         | 
| 76 76 | 
             
                for i, block in enumerate(self.transformer_blocks):
         | 
| 77 77 | 
             
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         | 
| 78 78 | 
             
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         | 
| 79 79 | 
             
                  if kv_entry:
         | 
| 80 | 
            -
                     | 
| 80 | 
            +
                    updated_kv_entries.append(kv_entry)
         | 
| 81 81 |  | 
| 82 | 
            -
                updated_kv_cache = kv_utils.KVCache(tuple( | 
| 82 | 
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
         | 
| 83 83 |  | 
| 84 84 | 
             
                if export_config is not None:
         | 
| 85 85 | 
             
                  if (
         | 
| @@ -26,33 +26,6 @@ import torch | |
| 26 26 | 
             
            from torch import nn
         | 
| 27 27 |  | 
| 28 28 |  | 
| 29 | 
            -
            def _embed_rope(
         | 
| 30 | 
            -
                q: torch.Tensor,
         | 
| 31 | 
            -
                k: torch.Tensor,
         | 
| 32 | 
            -
                n_elem: int,
         | 
| 33 | 
            -
                rope: Tuple[torch.Tensor, torch.Tensor],
         | 
| 34 | 
            -
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 35 | 
            -
              """Embed rotary positional embedding for query and key.
         | 
| 36 | 
            -
             | 
| 37 | 
            -
              Args:
         | 
| 38 | 
            -
                q (torch.Tensor): query tensor.
         | 
| 39 | 
            -
                k (torch.Tensor): key tensor.
         | 
| 40 | 
            -
                n_elem (int): number of elements to embed rotarty positional embedding.
         | 
| 41 | 
            -
                rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
         | 
| 42 | 
            -
              """
         | 
| 43 | 
            -
              if n_elem > 0:
         | 
| 44 | 
            -
                cos, sin = rope
         | 
| 45 | 
            -
                q_roped = rotary_pos_emb.apply_rope(
         | 
| 46 | 
            -
                    q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
         | 
| 47 | 
            -
                )
         | 
| 48 | 
            -
                k_roped = rotary_pos_emb.apply_rope(
         | 
| 49 | 
            -
                    k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
         | 
| 50 | 
            -
                )
         | 
| 51 | 
            -
                q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
         | 
| 52 | 
            -
                k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
         | 
| 53 | 
            -
              return q, k
         | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 29 | 
             
            class TransformerBlock(nn.Module):
         | 
| 57 30 |  | 
| 58 31 | 
             
              def __init__(
         | 
| @@ -238,7 +211,8 @@ class CausalSelfAttention(nn.Module): | |
| 238 211 | 
             
                if rope is not None:
         | 
| 239 212 | 
             
                  # Compute rotary positional embedding for query and key.
         | 
| 240 213 | 
             
                  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
         | 
| 241 | 
            -
                   | 
| 214 | 
            +
                  cos, sin = rope
         | 
| 215 | 
            +
                  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
         | 
| 242 216 |  | 
| 243 217 | 
             
                if kv_cache is not None:
         | 
| 244 218 | 
             
                  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
         | 
| @@ -374,7 +348,8 @@ class CrossAttention(nn.Module): | |
| 374 348 | 
             
                if rope is not None:
         | 
| 375 349 | 
             
                  # Compute rotary positional embedding for query and key.
         | 
| 376 350 | 
             
                  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
         | 
| 377 | 
            -
                   | 
| 351 | 
            +
                  cos, sin = rope
         | 
| 352 | 
            +
                  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
         | 
| 378 353 |  | 
| 379 354 | 
             
                if kv_cache is not None:
         | 
| 380 355 | 
             
                  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
         | 
| @@ -32,57 +32,64 @@ def apply_rope( | |
| 32 32 | 
             
              """
         | 
| 33 33 | 
             
              x = x.transpose(1, 2)
         | 
| 34 34 | 
             
              head_size = x.size(-1)
         | 
| 35 | 
            -
              x1 = x | 
| 36 | 
            -
               | 
| 37 | 
            -
               | 
| 38 | 
            -
              roped = ( | 
| 35 | 
            +
              x1, x2 = torch.split(x, head_size // 2, dim=-1)
         | 
| 36 | 
            +
              left = x1 * cos - x2 * sin
         | 
| 37 | 
            +
              right = x2 * cos + x1 * sin
         | 
| 38 | 
            +
              roped = torch.cat([left, right], dim=-1)
         | 
| 39 39 | 
             
              return roped.transpose(1, 2).type_as(x)
         | 
| 40 40 |  | 
| 41 41 |  | 
| 42 | 
            -
            def  | 
| 43 | 
            -
                q: torch.Tensor,
         | 
| 44 | 
            -
                k: torch.Tensor,
         | 
| 42 | 
            +
            def build_rope(
         | 
| 45 43 | 
             
                input_pos: torch.Tensor,
         | 
| 46 44 | 
             
                n_elem: int,
         | 
| 45 | 
            +
                head_dim: int,
         | 
| 47 46 | 
             
                base: int = 10_000,
         | 
| 48 47 | 
             
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 49 | 
            -
              """Computes rotary positional embedding  | 
| 48 | 
            +
              """Computes rotary positional embedding cosine and sine tensors.
         | 
| 50 49 |  | 
| 51 50 | 
             
              Args:
         | 
| 52 | 
            -
                q: the query tensor.
         | 
| 53 | 
            -
                k: the key tensor.
         | 
| 54 51 | 
             
                input_pos: the sequence indices for the query and key
         | 
| 55 52 | 
             
                n_elem: number of elements of the head dimension for RoPE computation
         | 
| 53 | 
            +
                base: the base of the exponentiated value for RoPE.
         | 
| 56 54 |  | 
| 57 55 | 
             
              Returns:
         | 
| 58 | 
            -
                 | 
| 56 | 
            +
                cos, sin tensors
         | 
| 59 57 | 
             
              """
         | 
| 60 58 |  | 
| 61 59 | 
             
              if n_elem <= 0:
         | 
| 62 | 
            -
                return  | 
| 60 | 
            +
                return None, None
         | 
| 63 61 |  | 
| 64 62 | 
             
              theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
         | 
| 65 63 | 
             
              freq_exponents = (2.0 / n_elem) * torch.arange(
         | 
| 66 | 
            -
                   | 
| 64 | 
            +
                  head_dim // 2, dtype=torch.float32
         | 
| 67 65 | 
             
              )
         | 
| 68 66 | 
             
              timescale = float(base) ** freq_exponents
         | 
| 69 67 | 
             
              radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
         | 
| 70 68 | 
             
                  0
         | 
| 71 69 | 
             
              ).unsqueeze(0)
         | 
| 72 | 
            -
              cos = torch.cos(radians) | 
| 73 | 
            -
              sin = torch.sin(radians) | 
| 70 | 
            +
              cos = torch.cos(radians)
         | 
| 71 | 
            +
              sin = torch.sin(radians)
         | 
| 72 | 
            +
              return cos, sin
         | 
| 73 | 
            +
             | 
| 74 74 |  | 
| 75 | 
            -
             | 
| 76 | 
            -
                 | 
| 77 | 
            -
                 | 
| 78 | 
            -
                 | 
| 79 | 
            -
                 | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
                 | 
| 75 | 
            +
            def apply_rope_inline(
         | 
| 76 | 
            +
                q: torch.Tensor,
         | 
| 77 | 
            +
                k: torch.Tensor,
         | 
| 78 | 
            +
                cos: torch.Tensor,
         | 
| 79 | 
            +
                sin: torch.Tensor,
         | 
| 80 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 81 | 
            +
              """Computes rotary positional embedding inline for a query and key.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
              Args:
         | 
| 84 | 
            +
                q: the query tensor.
         | 
| 85 | 
            +
                k: the key tensor.
         | 
| 86 | 
            +
                cos: the cosine tensor.
         | 
| 87 | 
            +
                sin: the sine tensor.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
              Returns:
         | 
| 90 | 
            +
                output the RoPE'd query and key.
         | 
| 91 | 
            +
              """
         | 
| 85 92 |  | 
| 86 | 
            -
              q_roped =  | 
| 87 | 
            -
              k_roped =  | 
| 93 | 
            +
              q_roped = apply_rope(q, cos, sin)
         | 
| 94 | 
            +
              k_roped = apply_rope(k, cos, sin)
         | 
| 88 95 | 
             
              return q_roped, k_roped
         | 
| @@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import builder | |
| 24 24 | 
             
            from ai_edge_torch.generative.layers import kv_cache as kv_utils
         | 
| 25 25 | 
             
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         | 
| 26 26 | 
             
            import ai_edge_torch.generative.layers.model_config as cfg
         | 
| 27 | 
            +
            import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
         | 
| 27 28 | 
             
            import ai_edge_torch.generative.utilities.loader as loading_utils
         | 
| 28 29 | 
             
            import torch
         | 
| 29 30 | 
             
            from torch import nn
         | 
| @@ -85,13 +86,6 @@ class DecoderOnlyModel(nn.Module): | |
| 85 86 | 
             
                    config.embedding_dim,
         | 
| 86 87 | 
             
                    config.final_norm_config,
         | 
| 87 88 | 
             
                )
         | 
| 88 | 
            -
                # ROPE parameters for all attn_configs are the same. Take the first one.
         | 
| 89 | 
            -
                attn_config = config.block_config(0).attn_config
         | 
| 90 | 
            -
                self.rope_cache = attn_utils.build_rope_cache(
         | 
| 91 | 
            -
                    size=config.kv_cache_max,
         | 
| 92 | 
            -
                    dim=int(attn_config.rotary_percentage * attn_config.head_dim),
         | 
| 93 | 
            -
                    base=attn_config.rotary_base,
         | 
| 94 | 
            -
                )
         | 
| 95 89 | 
             
                self.mask_cache = attn_utils.build_causal_mask_cache(
         | 
| 96 90 | 
             
                    size=config.kv_cache_max,
         | 
| 97 91 | 
             
                )
         | 
| @@ -113,11 +107,16 @@ class DecoderOnlyModel(nn.Module): | |
| 113 107 |  | 
| 114 108 | 
             
                # token embeddings of shape (b, t, n_embd)
         | 
| 115 109 | 
             
                input_embeds = self.tok_embedding(tokens)
         | 
| 116 | 
            -
                cos, sin = self.rope_cache
         | 
| 117 | 
            -
                rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
         | 
| 118 110 | 
             
                mask = self.mask_cache.index_select(2, input_pos)
         | 
| 119 111 | 
             
                mask = mask[:, :, :, : self.config.kv_cache_max]
         | 
| 120 112 |  | 
| 113 | 
            +
                # ROPE parameters for all attn_configs are the same. Take the first one.
         | 
| 114 | 
            +
                attn_config = self.config.block_config(0).attn_config
         | 
| 115 | 
            +
                n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
         | 
| 116 | 
            +
                rope = rotary_pos_emb.build_rope(
         | 
| 117 | 
            +
                    input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 121 120 | 
             
                return self.forward_with_embeds(
         | 
| 122 121 | 
             
                    input_embeds, rope, mask, input_pos, kv_cache, export_config
         | 
| 123 122 | 
             
                )
         | 
| @@ -141,13 +140,13 @@ class DecoderOnlyModel(nn.Module): | |
| 141 140 | 
             
                if self.config.embedding_scale is not None:
         | 
| 142 141 | 
             
                  x = x * self.config.embedding_scale
         | 
| 143 142 |  | 
| 144 | 
            -
                 | 
| 143 | 
            +
                updated_kv_entries = []
         | 
| 145 144 | 
             
                for i, block in enumerate(self.transformer_blocks):
         | 
| 146 145 | 
             
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         | 
| 147 146 | 
             
                  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
         | 
| 148 147 | 
             
                  if kv_entry:
         | 
| 149 | 
            -
                     | 
| 150 | 
            -
                updated_kv_cache = kv_utils.KVCache(tuple( | 
| 148 | 
            +
                    updated_kv_entries.append(kv_entry)
         | 
| 149 | 
            +
                updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
         | 
| 151 150 |  | 
| 152 151 | 
             
                if export_config is not None:
         | 
| 153 152 | 
             
                  if (
         | 
| @@ -50,6 +50,7 @@ def exported_programs_to_tflite( | |
| 50 50 | 
             
                *,
         | 
| 51 51 | 
             
                quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 52 52 | 
             
                _tfl_converter_flags: Optional[dict[str, Any]] = None,
         | 
| 53 | 
            +
                _saved_model_dir: Optional[str] = None
         | 
| 53 54 | 
             
            ):
         | 
| 54 55 | 
             
              """Converts a list of ExportedProgram to a TFLite model.
         | 
| 55 56 |  | 
| @@ -57,6 +58,8 @@ def exported_programs_to_tflite( | |
| 57 58 | 
             
                exported_programs: A list of ExportedProgram.
         | 
| 58 59 | 
             
                signatures: A list of Signature.
         | 
| 59 60 | 
             
                quant_config: A QuantConfig.
         | 
| 61 | 
            +
                _saved_model_dir: Directory for the intermediate saved model. If not
         | 
| 62 | 
            +
                  specified, a random temporary directory would be used.
         | 
| 60 63 | 
             
                _tfl_converter_flags: A dict of flags for TFLiteConverter.
         | 
| 61 64 |  | 
| 62 65 | 
             
              Returns:
         | 
| @@ -79,4 +82,5 @@ def exported_programs_to_tflite( | |
| 79 82 | 
             
                  signatures,
         | 
| 80 83 | 
             
                  quant_config=quant_config,
         | 
| 81 84 | 
             
                  _tfl_converter_flags=_tfl_converter_flags,
         | 
| 85 | 
            +
                  _saved_model_dir=_saved_model_dir,
         | 
| 82 86 | 
             
              )
         | 
| @@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model( | |
| 138 138 | 
             
                *,
         | 
| 139 139 | 
             
                quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 140 140 | 
             
                _tfl_converter_flags: dict = {},
         | 
| 141 | 
            +
                _saved_model_dir: Optional[str] = None,
         | 
| 141 142 | 
             
            ):
         | 
| 142 143 | 
             
              tf_state_dict = merged_bundle.bundles[0].state_dict
         | 
| 143 144 |  | 
| @@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model( | |
| 173 174 | 
             
              # We need to temporarily save since TFLite's from_concrete_functions does not
         | 
| 174 175 | 
             
              # allow providing names for each of the concrete functions.
         | 
| 175 176 | 
             
              with tempfile.TemporaryDirectory() as temp_dir_path:
         | 
| 177 | 
            +
                if _saved_model_dir is not None:
         | 
| 178 | 
            +
                  temp_dir_path = _saved_model_dir
         | 
| 179 | 
            +
             | 
| 176 180 | 
             
                tf.saved_model.save(
         | 
| 177 181 | 
             
                    tf_module,
         | 
| 178 182 | 
             
                    temp_dir_path,
         | 
| @@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model( | |
| 192 192 | 
             
                *,
         | 
| 193 193 | 
             
                quant_config: Optional[qcfg.QuantConfig] = None,
         | 
| 194 194 | 
             
                _tfl_converter_flags: dict = {},
         | 
| 195 | 
            +
                _saved_model_dir: Optional[str] = None,
         | 
| 195 196 | 
             
            ) -> None:
         | 
| 196 197 | 
             
              """Converts a StableHLOGraphModule to a tflite model.
         | 
| 197 198 |  | 
| @@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model( | |
| 200 201 | 
             
                signatures: List of signatures from which names of the signatures is
         | 
| 201 202 | 
             
                extracted.
         | 
| 202 203 | 
             
                quant_config: User-defined quantization method and scheme of the model.
         | 
| 204 | 
            +
                _saved_model_dir: Directory for the intermediate saved model. If not
         | 
| 205 | 
            +
                  specified, a random temporary directory would be used.
         | 
| 203 206 | 
             
                _tfl_converter_flags: A nested dictionary allowing setting flags for the
         | 
| 204 207 | 
             
                underlying tflite converter.
         | 
| 205 208 | 
             
              """
         | 
| @@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model( | |
| 246 249 | 
             
              # We need to temporarily save since TFLite's from_concrete_functions does not
         | 
| 247 250 | 
             
              # allow providing names for each of the concrete functions.
         | 
| 248 251 | 
             
              with tempfile.TemporaryDirectory() as temp_dir_path:
         | 
| 252 | 
            +
                if _saved_model_dir is not None:
         | 
| 253 | 
            +
                  temp_dir_path = _saved_model_dir
         | 
| 254 | 
            +
             | 
| 249 255 | 
             
                tf.saved_model.save(
         | 
| 250 256 | 
             
                    tf_module,
         | 
| 251 257 | 
             
                    temp_dir_path,
         | 
| @@ -304,9 +304,13 @@ def exported_program_to_mlir( | |
| 304 304 | 
             
              )
         | 
| 305 305 |  | 
| 306 306 | 
             
              _convert_i64_to_i32(exported_program)
         | 
| 307 | 
            +
             | 
| 307 308 | 
             
              exported_program = _torch_future.safe_run_decompositions(
         | 
| 308 309 | 
             
                  exported_program, lowerings.decompositions()
         | 
| 309 310 | 
             
              )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
              # Passes below mutate the exported program to a state not executable by torch.
         | 
| 313 | 
            +
              # Do not call run_decompositions after applying the passes.
         | 
| 310 314 | 
             
              _convert_q_dq_per_channel_args_to_list(exported_program)
         | 
| 311 315 |  | 
| 312 316 | 
             
              with export_utils.create_ir_context() as context, ir.Location.unknown():
         | 
| @@ -52,10 +52,13 @@ def _uniform_quantized_type( | |
| 52 52 | 
             
                assert isinstance(scale, (list, tuple))
         | 
| 53 53 | 
             
                assert isinstance(zero_point, (list, tuple))
         | 
| 54 54 |  | 
| 55 | 
            +
                scale = list(scale)
         | 
| 56 | 
            +
                zero_point = list(zero_point)
         | 
| 57 | 
            +
             | 
| 55 58 | 
             
                if len(scale) == 1:
         | 
| 56 | 
            -
                  scale  | 
| 59 | 
            +
                  scale = scale * channel_axis_size
         | 
| 57 60 | 
             
                if len(zero_point) == 1:
         | 
| 58 | 
            -
                  zero_point  | 
| 61 | 
            +
                  zero_point = zero_point * channel_axis_size
         | 
| 59 62 |  | 
| 60 63 | 
             
                assert len(scale) == len(zero_point) == channel_axis_size
         | 
| 61 64 | 
             
                scale_zp_strs = []
         | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: ai-edge-torch-nightly
         | 
| 3 | 
            -
            Version: 0.3.0. | 
| 3 | 
            +
            Version: 0.3.0.dev20241220
         | 
| 4 4 | 
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         | 
| 5 5 | 
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         | 
| 6 6 | 
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         | 
| @@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614 | |
| 3 3 | 
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         | 
| 4 4 | 
             
            ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
         | 
| 5 5 | 
             
            ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
         | 
| 6 | 
            -
            ai_edge_torch/version.py,sha256= | 
| 6 | 
            +
            ai_edge_torch/version.py,sha256=xD-MWAEa1ROHhyF3rY7MaL28xsuON0aJwaiXbJ04qfc,706
         | 
| 7 7 | 
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 8 | 
            -
            ai_edge_torch/_convert/conversion.py,sha256= | 
| 8 | 
            +
            ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
         | 
| 9 9 | 
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         | 
| 10 | 
            -
            ai_edge_torch/_convert/converter.py,sha256= | 
| 10 | 
            +
            ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
         | 
| 11 11 | 
             
            ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
         | 
| 12 12 | 
             
            ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
         | 
| 13 13 | 
             
            ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
         | 
| @@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX | |
| 47 47 | 
             
            ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
         | 
| 48 48 | 
             
            ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
         | 
| 49 49 | 
             
            ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
         | 
| 50 | 
            -
            ai_edge_torch/generative/examples/gemma/gemma2.py,sha256= | 
| 50 | 
            +
            ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=roEwWVXASbk5BFj7jojjEJpHui6gCelT51l-TtN_ZaQ,9367
         | 
| 51 51 | 
             
            ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
         | 
| 52 52 | 
             
            ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
         | 
| 53 53 | 
             
            ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
         | 
| @@ -107,7 +107,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX | |
| 107 107 | 
             
            ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 108 108 | 
             
            ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
         | 
| 109 109 | 
             
            ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
         | 
| 110 | 
            -
            ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256= | 
| 110 | 
            +
            ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
         | 
| 111 111 | 
             
            ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 112 112 | 
             
            ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
         | 
| 113 113 | 
             
            ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
         | 
| @@ -115,14 +115,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f | |
| 115 115 | 
             
            ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
         | 
| 116 116 | 
             
            ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
         | 
| 117 117 | 
             
            ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 118 | 
            -
            ai_edge_torch/generative/layers/attention.py,sha256= | 
| 118 | 
            +
            ai_edge_torch/generative/layers/attention.py,sha256=_OmamS3f0m_JtW73ljwGLwFPeMLL837JCLY-dJ3iRUg,12453
         | 
| 119 119 | 
             
            ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
         | 
| 120 120 | 
             
            ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
         | 
| 121 121 | 
             
            ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
         | 
| 122 122 | 
             
            ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
         | 
| 123 123 | 
             
            ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
         | 
| 124 124 | 
             
            ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
         | 
| 125 | 
            -
            ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256= | 
| 125 | 
            +
            ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=zbFTNgQdOT-tcKK1QaIX6fG-50syYwQX_ZbLhg2C98c,2691
         | 
| 126 126 | 
             
            ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
         | 
| 127 127 | 
             
            ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 128 128 | 
             
            ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
         | 
| @@ -147,7 +147,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l | |
| 147 147 | 
             
            ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
         | 
| 148 148 | 
             
            ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
         | 
| 149 149 | 
             
            ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
         | 
| 150 | 
            -
            ai_edge_torch/generative/utilities/model_builder.py,sha256= | 
| 150 | 
            +
            ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
         | 
| 151 151 | 
             
            ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
         | 
| 152 152 | 
             
            ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
         | 
| 153 153 | 
             
            ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
         | 
| @@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh | |
| 160 160 | 
             
            ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 161 161 | 
             
            ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
         | 
| 162 162 | 
             
            ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
         | 
| 163 | 
            -
            ai_edge_torch/lowertools/_shim.py,sha256= | 
| 163 | 
            +
            ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
         | 
| 164 164 | 
             
            ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
         | 
| 165 | 
            -
            ai_edge_torch/lowertools/odml_torch_utils.py,sha256= | 
| 165 | 
            +
            ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
         | 
| 166 166 | 
             
            ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
         | 
| 167 | 
            -
            ai_edge_torch/lowertools/torch_xla_utils.py,sha256= | 
| 167 | 
            +
            ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
         | 
| 168 168 | 
             
            ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
         | 
| 169 169 | 
             
            ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
         | 
| 170 170 | 
             
            ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
         | 
| 171 171 | 
             
            ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
         | 
| 172 | 
            -
            ai_edge_torch/odml_torch/export.py,sha256= | 
| 172 | 
            +
            ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
         | 
| 173 173 | 
             
            ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
         | 
| 174 174 | 
             
            ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
         | 
| 175 175 | 
             
            ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
         | 
| @@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_ | |
| 187 187 | 
             
            ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
         | 
| 188 188 | 
             
            ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
         | 
| 189 189 | 
             
            ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
         | 
| 190 | 
            -
            ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256= | 
| 190 | 
            +
            ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
         | 
| 191 191 | 
             
            ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
         | 
| 192 192 | 
             
            ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
         | 
| 193 193 | 
             
            ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
         | 
| @@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 | |
| 200 200 | 
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 201 201 | 
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         | 
| 202 202 | 
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         | 
| 203 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 204 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 205 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 206 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 207 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 203 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         | 
| 204 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/METADATA,sha256=PfyYhqbf7VEibw2TEDRb8tBOIPG9dfXhT9tNNou_iZg,1966
         | 
| 205 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
         | 
| 206 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         | 
| 207 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |