ai-edge-torch-nightly 0.3.0.dev20240921__py3-none-any.whl → 0.3.0.dev20240924__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
- ai_edge_torch/generative/examples/phi/verify.py +1 -0
- ai_edge_torch/generative/layers/builder.py +25 -24
- ai_edge_torch/generative/layers/model_config.py +3 -3
- ai_edge_torch/generative/layers/normalization.py +14 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/top_level.txt +0 -0
| @@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | |
| 161 161 | 
             
                    ),
         | 
| 162 162 | 
             
                    ff_config=cfg.FeedForwardConfig(
         | 
| 163 163 | 
             
                        type=cfg.FeedForwardType.SEQUENTIAL,
         | 
| 164 | 
            -
                        activation=cfg.ActivationConfig(
         | 
| 165 | 
            -
                            cfg.ActivationType.SILU_GLU, gate_is_front=True
         | 
| 166 | 
            -
                        ),
         | 
| 164 | 
            +
                        activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
         | 
| 167 165 | 
             
                        intermediate_size=get_intermediate_size(idx),
         | 
| 168 166 | 
             
                        pre_ff_norm_config=norm_config,
         | 
| 169 167 | 
             
                    ),
         | 
| @@ -23,34 +23,35 @@ from torch import nn | |
| 23 23 | 
             
            import torch.nn.functional as F
         | 
| 24 24 |  | 
| 25 25 |  | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            ) -> Callable[[torch.Tensor], torch.Tensor]:
         | 
| 29 | 
            -
              """Builds an activation function with GLU (Gated Linear Unit).
         | 
| 26 | 
            +
            class GeGLU(nn.Module):
         | 
| 27 | 
            +
              """GeGLU is an activation function which is a variant of GELU.
         | 
| 30 28 |  | 
| 31 | 
            -
               | 
| 32 | 
            -
             | 
| 33 | 
            -
               | 
| 34 | 
            -
                f(x) = x * act(y),
         | 
| 35 | 
            -
              where x is the first half of the input and y is the second half of the input.
         | 
| 29 | 
            +
              GeGLU(x) = (xW+b) * GELU(xV+c)
         | 
| 30 | 
            +
              See: https://arxiv.org/abs/2002.05202v1
         | 
| 31 | 
            +
              """
         | 
| 36 32 |  | 
| 37 | 
            -
               | 
| 38 | 
            -
                 | 
| 39 | 
            -
             | 
| 40 | 
            -
                gate_is_front: whether the gate is in front half of the input. Other part is
         | 
| 41 | 
            -
                  the output in GLU.
         | 
| 33 | 
            +
              def __init__(self, d_in: int, d_out: int):
         | 
| 34 | 
            +
                super().__init__()
         | 
| 35 | 
            +
                self.proj = nn.Linear(d_in, d_out * 2)
         | 
| 42 36 |  | 
| 43 | 
            -
               | 
| 44 | 
            -
                 | 
| 37 | 
            +
              def forward(self, x: torch.Tensor):
         | 
| 38 | 
            +
                x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 39 | 
            +
                return x * F.gelu(gate)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class SwiGLU(nn.Module):
         | 
| 43 | 
            +
              """SwiGLU is an activation function which is a variant of GLU.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
              SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
         | 
| 46 | 
            +
              swish function.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
              SwiGLU(x) = Swish(xW+b) * (xV+c)
         | 
| 49 | 
            +
              See: https://paperswithcode.com/method/swiglu
         | 
| 45 50 | 
             
              """
         | 
| 46 51 |  | 
| 47 | 
            -
              def  | 
| 52 | 
            +
              def forward(self, x: torch.Tensor):
         | 
| 48 53 | 
             
                x, y = x.chunk(2, dim=-1)
         | 
| 49 | 
            -
                 | 
| 50 | 
            -
                  return act(x) * y
         | 
| 51 | 
            -
                return x * act(y)
         | 
| 52 | 
            -
             | 
| 53 | 
            -
              return _glu
         | 
| 54 | 
            +
                return F.silu(x) * y
         | 
| 54 55 |  | 
| 55 56 |  | 
| 56 57 | 
             
            def build_norm(dim: int, config: cfg.NormalizationConfig):
         | 
| @@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig): | |
| 151 152 | 
             
                # See: https://github.com/hendrycks/GELUs
         | 
| 152 153 | 
             
                return lambda x: x * F.sigmoid(1.702 * x)
         | 
| 153 154 | 
             
              elif config.type == cfg.ActivationType.GE_GLU:
         | 
| 154 | 
            -
                return  | 
| 155 | 
            +
                return GeGLU(config.dim_in, config.dim_out)
         | 
| 155 156 | 
             
              elif config.type == cfg.ActivationType.RELU:
         | 
| 156 157 | 
             
                return F.relu
         | 
| 157 158 | 
             
              elif config.type == cfg.ActivationType.SILU_GLU:
         | 
| 158 | 
            -
                return  | 
| 159 | 
            +
                return SwiGLU()
         | 
| 159 160 | 
             
              else:
         | 
| 160 161 | 
             
                raise ValueError("Unsupported activation type.")
         | 
| @@ -118,9 +118,9 @@ class AttentionConfig: | |
| 118 118 | 
             
            @dataclass
         | 
| 119 119 | 
             
            class ActivationConfig:
         | 
| 120 120 | 
             
              type: ActivationType = ActivationType.LINEAR
         | 
| 121 | 
            -
              #  | 
| 122 | 
            -
               | 
| 123 | 
            -
               | 
| 121 | 
            +
              # Dimension of input and output, used in GeGLU.
         | 
| 122 | 
            +
              dim_in: Optional[int] = None
         | 
| 123 | 
            +
              dim_out: Optional[int] = None
         | 
| 124 124 |  | 
| 125 125 |  | 
| 126 126 | 
             
            @dataclass
         | 
| @@ -183,8 +183,16 @@ def group_norm_with_hlfb( | |
| 183 183 | 
             
              """
         | 
| 184 184 | 
             
              x = torch.permute(x, (0, 2, 3, 1))
         | 
| 185 185 |  | 
| 186 | 
            +
              # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
         | 
| 187 | 
            +
              # int32 when the bug is fixed.
         | 
| 186 188 | 
             
              builder = StableHLOCompositeBuilder(
         | 
| 187 | 
            -
                  name="odml.group_norm", | 
| 189 | 
            +
                  name="odml.group_norm",
         | 
| 190 | 
            +
                  attr={
         | 
| 191 | 
            +
                      "num_groups": num_groups,
         | 
| 192 | 
            +
                      "eps": eps,
         | 
| 193 | 
            +
                      "reduction_axes": 3,
         | 
| 194 | 
            +
                      "channel_axis": 3,
         | 
| 195 | 
            +
                  },
         | 
| 188 196 | 
             
              )
         | 
| 189 197 | 
             
              x, w, b = builder.mark_inputs(x, w, b)
         | 
| 190 198 | 
             
              x = torch.permute(x, (0, 3, 1, 2))
         | 
| @@ -206,7 +214,7 @@ def layer_norm_with_hlfb( | |
| 206 214 | 
             
              """Layer Normalization with high-level function boundary enabled.
         | 
| 207 215 |  | 
| 208 216 | 
             
              Args:
         | 
| 209 | 
            -
                x (torch.Tensor): Input tensor for Layer Normalization.
         | 
| 217 | 
            +
                x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
         | 
| 210 218 | 
             
                w (torch.Tensor): The weight tensor for the normalization.
         | 
| 211 219 | 
             
                b (torch.Tensor): The bias tensor for the normalization.
         | 
| 212 220 | 
             
                eps (float): A small float value to ensure numerical stability.
         | 
| @@ -216,7 +224,10 @@ def layer_norm_with_hlfb( | |
| 216 224 | 
             
              Returns:
         | 
| 217 225 | 
             
                The output tensor of Layer Normalization.
         | 
| 218 226 | 
             
              """
         | 
| 219 | 
            -
              builder = StableHLOCompositeBuilder( | 
| 227 | 
            +
              builder = StableHLOCompositeBuilder(
         | 
| 228 | 
            +
                  name="odml.group_norm",
         | 
| 229 | 
            +
                  attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
         | 
| 230 | 
            +
              )
         | 
| 220 231 | 
             
              x, w, b = builder.mark_inputs(x, w, b)
         | 
| 221 232 | 
             
              if use_input_shape:
         | 
| 222 233 | 
             
                normalized_shape = x.shape
         | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: ai-edge-torch-nightly
         | 
| 3 | 
            -
            Version: 0.3.0. | 
| 3 | 
            +
            Version: 0.3.0.dev20240924
         | 
| 4 4 | 
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         | 
| 5 5 | 
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         | 
| 6 6 | 
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         | 
| @@ -30,6 +30,7 @@ Requires-Dist: tabulate | |
| 30 30 | 
             
            Requires-Dist: torch>=2.4.0
         | 
| 31 31 | 
             
            Requires-Dist: torch-xla>=2.4.0
         | 
| 32 32 | 
             
            Requires-Dist: tf-nightly>=2.18.0.dev20240722
         | 
| 33 | 
            +
            Requires-Dist: ai-edge-litert-nightly
         | 
| 33 34 | 
             
            Requires-Dist: ai-edge-quantizer-nightly
         | 
| 34 35 |  | 
| 35 36 | 
             
            Library that supports converting PyTorch models into a .tflite format, which can
         | 
| @@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909 | |
| 3 3 | 
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         | 
| 4 4 | 
             
            ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
         | 
| 5 5 | 
             
            ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
         | 
| 6 | 
            -
            ai_edge_torch/version.py,sha256= | 
| 6 | 
            +
            ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
         | 
| 7 7 | 
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 8 8 | 
             
            ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
         | 
| 9 9 | 
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         | 
| @@ -48,12 +48,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax | |
| 48 48 | 
             
            ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
         | 
| 49 49 | 
             
            ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 50 50 | 
             
            ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
         | 
| 51 | 
            -
            ai_edge_torch/generative/examples/openelm/openelm.py,sha256= | 
| 51 | 
            +
            ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
         | 
| 52 52 | 
             
            ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
         | 
| 53 53 | 
             
            ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 54 54 | 
             
            ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
         | 
| 55 55 | 
             
            ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
         | 
| 56 | 
            -
            ai_edge_torch/generative/examples/phi/verify.py,sha256= | 
| 56 | 
            +
            ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
         | 
| 57 57 | 
             
            ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 58 58 | 
             
            ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
         | 
| 59 59 | 
             
            ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
         | 
| @@ -89,11 +89,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD | |
| 89 89 | 
             
            ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 90 90 | 
             
            ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
         | 
| 91 91 | 
             
            ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
         | 
| 92 | 
            -
            ai_edge_torch/generative/layers/builder.py,sha256= | 
| 92 | 
            +
            ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
         | 
| 93 93 | 
             
            ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
         | 
| 94 94 | 
             
            ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
         | 
| 95 | 
            -
            ai_edge_torch/generative/layers/model_config.py,sha256= | 
| 96 | 
            -
            ai_edge_torch/generative/layers/normalization.py,sha256= | 
| 95 | 
            +
            ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
         | 
| 96 | 
            +
            ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
         | 
| 97 97 | 
             
            ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
         | 
| 98 98 | 
             
            ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
         | 
| 99 99 | 
             
            ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| @@ -166,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 | |
| 166 166 | 
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 167 167 | 
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         | 
| 168 168 | 
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         | 
| 169 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 170 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 171 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 172 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 173 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 169 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         | 
| 170 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
         | 
| 171 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
         | 
| 172 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         | 
| 173 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |