ai-edge-torch-nightly 0.3.0.dev20240923__py3-none-any.whl → 0.3.0.dev20240924__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
161
161
  ),
162
162
  ff_config=cfg.FeedForwardConfig(
163
163
  type=cfg.FeedForwardType.SEQUENTIAL,
164
- activation=cfg.ActivationConfig(
165
- cfg.ActivationType.SILU_GLU, gate_is_front=True
166
- ),
164
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
167
165
  intermediate_size=get_intermediate_size(idx),
168
166
  pre_ff_norm_config=norm_config,
169
167
  ),
@@ -34,6 +34,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
34
  "The maximum size of the generated tokens.",
35
35
  )
36
36
 
37
+
37
38
  def main(_):
38
39
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
40
  verifier.log_msg("Loading the original model from", checkpoint)
@@ -23,34 +23,35 @@ from torch import nn
23
23
  import torch.nn.functional as F
24
24
 
25
25
 
26
- def build_glu(
27
- act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
- ) -> Callable[[torch.Tensor], torch.Tensor]:
29
- """Builds an activation function with GLU (Gated Linear Unit).
26
+ class GeGLU(nn.Module):
27
+ """GeGLU is an activation function which is a variant of GELU.
30
28
 
31
- If gate_is_front is True,
32
- f(x) = act(x) * y
33
- otherwise,
34
- f(x) = x * act(y),
35
- where x is the first half of the input and y is the second half of the input.
29
+ GeGLU(x) = (xW+b) * GELU(xV+c)
30
+ See: https://arxiv.org/abs/2002.05202v1
31
+ """
36
32
 
37
- Args:
38
- act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
- to the gate.
40
- gate_is_front: whether the gate is in front half of the input. Other part is
41
- the output in GLU.
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
42
36
 
43
- Returns:
44
- A callable activation function with GLU.
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class SwiGLU(nn.Module):
43
+ """SwiGLU is an activation function which is a variant of GLU.
44
+
45
+ SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
46
+ swish function.
47
+
48
+ SwiGLU(x) = Swish(xW+b) * (xV+c)
49
+ See: https://paperswithcode.com/method/swiglu
45
50
  """
46
51
 
47
- def _glu(x):
52
+ def forward(self, x: torch.Tensor):
48
53
  x, y = x.chunk(2, dim=-1)
49
- if gate_is_front:
50
- return act(x) * y
51
- return x * act(y)
52
-
53
- return _glu
54
+ return F.silu(x) * y
54
55
 
55
56
 
56
57
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
151
152
  # See: https://github.com/hendrycks/GELUs
152
153
  return lambda x: x * F.sigmoid(1.702 * x)
153
154
  elif config.type == cfg.ActivationType.GE_GLU:
154
- return build_glu(F.gelu, config.gate_is_front)
155
+ return GeGLU(config.dim_in, config.dim_out)
155
156
  elif config.type == cfg.ActivationType.RELU:
156
157
  return F.relu
157
158
  elif config.type == cfg.ActivationType.SILU_GLU:
158
- return build_glu(F.silu, config.gate_is_front)
159
+ return SwiGLU()
159
160
  else:
160
161
  raise ValueError("Unsupported activation type.")
@@ -118,9 +118,9 @@ class AttentionConfig:
118
118
  @dataclass
119
119
  class ActivationConfig:
120
120
  type: ActivationType = ActivationType.LINEAR
121
- # Whether to GLU gate is the front part instead of the back part of input
122
- # when ActivationType is `GE_GLU` or `SILU_GLU`.
123
- gate_is_front: bool = False
121
+ # Dimension of input and output, used in GeGLU.
122
+ dim_in: Optional[int] = None
123
+ dim_out: Optional[int] = None
124
124
 
125
125
 
126
126
  @dataclass
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
183
183
  """
184
184
  x = torch.permute(x, (0, 2, 3, 1))
185
185
 
186
+ # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
187
+ # int32 when the bug is fixed.
186
188
  builder = StableHLOCompositeBuilder(
187
- name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
189
+ name="odml.group_norm",
190
+ attr={
191
+ "num_groups": num_groups,
192
+ "eps": eps,
193
+ "reduction_axes": 3,
194
+ "channel_axis": 3,
195
+ },
188
196
  )
189
197
  x, w, b = builder.mark_inputs(x, w, b)
190
198
  x = torch.permute(x, (0, 3, 1, 2))
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
206
214
  """Layer Normalization with high-level function boundary enabled.
207
215
 
208
216
  Args:
209
- x (torch.Tensor): Input tensor for Layer Normalization.
217
+ x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
210
218
  w (torch.Tensor): The weight tensor for the normalization.
211
219
  b (torch.Tensor): The bias tensor for the normalization.
212
220
  eps (float): A small float value to ensure numerical stability.
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
216
224
  Returns:
217
225
  The output tensor of Layer Normalization.
218
226
  """
219
- builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
227
+ builder = StableHLOCompositeBuilder(
228
+ name="odml.group_norm",
229
+ attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
230
+ )
220
231
  x, w, b = builder.mark_inputs(x, w, b)
221
232
  if use_input_shape:
222
233
  normalized_shape = x.shape
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240923"
16
+ __version__ = "0.3.0.dev20240924"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240923
3
+ Version: 0.3.0.dev20240924
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-litert-nightly
33
34
  Requires-Dist: ai-edge-quantizer-nightly
34
35
 
35
36
  Library that supports converting PyTorch models into a .tflite format, which can
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=oxtOOEY9LJkV5vRrgr1EoSjAjuetYVNq7WQqMuauRkc,706
6
+ ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -48,12 +48,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
49
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
51
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
52
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
53
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
54
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
55
55
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
56
- ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
56
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
57
57
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
59
59
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
@@ -89,11 +89,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
89
89
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
90
90
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
91
91
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
92
- ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
92
+ ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
93
93
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
94
94
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
95
- ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
96
- ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
95
+ ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
96
+ ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
97
97
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
98
98
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
99
99
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -166,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
166
166
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
167
167
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
168
168
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
169
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/METADATA,sha256=BgwLxDJ3AOPVn0fkngAQpf3YdmShufhMt3bANFevtiQ,1859
171
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/RECORD,,
169
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
171
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
+ ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,