ai-edge-torch-nightly 0.3.0.dev20240921__py3-none-any.whl → 0.3.0.dev20240924__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
161
161
  ),
162
162
  ff_config=cfg.FeedForwardConfig(
163
163
  type=cfg.FeedForwardType.SEQUENTIAL,
164
- activation=cfg.ActivationConfig(
165
- cfg.ActivationType.SILU_GLU, gate_is_front=True
166
- ),
164
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
167
165
  intermediate_size=get_intermediate_size(idx),
168
166
  pre_ff_norm_config=norm_config,
169
167
  ),
@@ -34,6 +34,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
34
  "The maximum size of the generated tokens.",
35
35
  )
36
36
 
37
+
37
38
  def main(_):
38
39
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
40
  verifier.log_msg("Loading the original model from", checkpoint)
@@ -23,34 +23,35 @@ from torch import nn
23
23
  import torch.nn.functional as F
24
24
 
25
25
 
26
- def build_glu(
27
- act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
- ) -> Callable[[torch.Tensor], torch.Tensor]:
29
- """Builds an activation function with GLU (Gated Linear Unit).
26
+ class GeGLU(nn.Module):
27
+ """GeGLU is an activation function which is a variant of GELU.
30
28
 
31
- If gate_is_front is True,
32
- f(x) = act(x) * y
33
- otherwise,
34
- f(x) = x * act(y),
35
- where x is the first half of the input and y is the second half of the input.
29
+ GeGLU(x) = (xW+b) * GELU(xV+c)
30
+ See: https://arxiv.org/abs/2002.05202v1
31
+ """
36
32
 
37
- Args:
38
- act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
- to the gate.
40
- gate_is_front: whether the gate is in front half of the input. Other part is
41
- the output in GLU.
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
42
36
 
43
- Returns:
44
- A callable activation function with GLU.
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class SwiGLU(nn.Module):
43
+ """SwiGLU is an activation function which is a variant of GLU.
44
+
45
+ SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
46
+ swish function.
47
+
48
+ SwiGLU(x) = Swish(xW+b) * (xV+c)
49
+ See: https://paperswithcode.com/method/swiglu
45
50
  """
46
51
 
47
- def _glu(x):
52
+ def forward(self, x: torch.Tensor):
48
53
  x, y = x.chunk(2, dim=-1)
49
- if gate_is_front:
50
- return act(x) * y
51
- return x * act(y)
52
-
53
- return _glu
54
+ return F.silu(x) * y
54
55
 
55
56
 
56
57
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
151
152
  # See: https://github.com/hendrycks/GELUs
152
153
  return lambda x: x * F.sigmoid(1.702 * x)
153
154
  elif config.type == cfg.ActivationType.GE_GLU:
154
- return build_glu(F.gelu, config.gate_is_front)
155
+ return GeGLU(config.dim_in, config.dim_out)
155
156
  elif config.type == cfg.ActivationType.RELU:
156
157
  return F.relu
157
158
  elif config.type == cfg.ActivationType.SILU_GLU:
158
- return build_glu(F.silu, config.gate_is_front)
159
+ return SwiGLU()
159
160
  else:
160
161
  raise ValueError("Unsupported activation type.")
@@ -118,9 +118,9 @@ class AttentionConfig:
118
118
  @dataclass
119
119
  class ActivationConfig:
120
120
  type: ActivationType = ActivationType.LINEAR
121
- # Whether to GLU gate is the front part instead of the back part of input
122
- # when ActivationType is `GE_GLU` or `SILU_GLU`.
123
- gate_is_front: bool = False
121
+ # Dimension of input and output, used in GeGLU.
122
+ dim_in: Optional[int] = None
123
+ dim_out: Optional[int] = None
124
124
 
125
125
 
126
126
  @dataclass
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
183
183
  """
184
184
  x = torch.permute(x, (0, 2, 3, 1))
185
185
 
186
+ # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
187
+ # int32 when the bug is fixed.
186
188
  builder = StableHLOCompositeBuilder(
187
- name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
189
+ name="odml.group_norm",
190
+ attr={
191
+ "num_groups": num_groups,
192
+ "eps": eps,
193
+ "reduction_axes": 3,
194
+ "channel_axis": 3,
195
+ },
188
196
  )
189
197
  x, w, b = builder.mark_inputs(x, w, b)
190
198
  x = torch.permute(x, (0, 3, 1, 2))
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
206
214
  """Layer Normalization with high-level function boundary enabled.
207
215
 
208
216
  Args:
209
- x (torch.Tensor): Input tensor for Layer Normalization.
217
+ x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
210
218
  w (torch.Tensor): The weight tensor for the normalization.
211
219
  b (torch.Tensor): The bias tensor for the normalization.
212
220
  eps (float): A small float value to ensure numerical stability.
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
216
224
  Returns:
217
225
  The output tensor of Layer Normalization.
218
226
  """
219
- builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
227
+ builder = StableHLOCompositeBuilder(
228
+ name="odml.group_norm",
229
+ attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
230
+ )
220
231
  x, w, b = builder.mark_inputs(x, w, b)
221
232
  if use_input_shape:
222
233
  normalized_shape = x.shape
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240921"
16
+ __version__ = "0.3.0.dev20240924"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240921
3
+ Version: 0.3.0.dev20240924
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-litert-nightly
33
34
  Requires-Dist: ai-edge-quantizer-nightly
34
35
 
35
36
  Library that supports converting PyTorch models into a .tflite format, which can
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=t9zajdsiowClI2fG0RkKVonPF-SUx9UBuUDOEZFU9y4,706
6
+ ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -48,12 +48,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
49
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
51
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
52
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
53
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
54
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
55
55
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
56
- ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
56
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
57
57
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
59
59
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
@@ -89,11 +89,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
89
89
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
90
90
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
91
91
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
92
- ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
92
+ ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
93
93
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
94
94
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
95
- ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
96
- ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
95
+ ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
96
+ ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
97
97
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
98
98
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
99
99
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -166,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
166
166
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
167
167
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
168
168
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
169
- ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
- ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/METADATA,sha256=SWy7BhOQDe0_SBF17deNndzt1bEYy7iXUxy0KznIPYM,1859
171
- ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
- ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
- ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/RECORD,,
169
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
171
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,