ai-edge-torch-nightly 0.3.0.dev20240817__py3-none-any.whl → 0.3.0.dev20240822__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -84,6 +84,9 @@ class TestConvertMultiSignature(googletest.TestCase):
84
84
  )
85
85
  )
86
86
 
87
+ @googletest.skip(
88
+ reason="Re-enable once the tflite converter issue is fixed.",
89
+ )
87
90
  def test_convert_mobilenet_v2_signature_helper(self):
88
91
  """Tests the ai_edge_torch.signature helper function works."""
89
92
  torch_module = torchvision.models.mobilenet_v2().eval()
@@ -203,7 +203,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
203
203
  final_norm_config=norm_config,
204
204
  parallel_residual=False,
205
205
  lm_head_use_bias=False,
206
- enable_hlfb=False,
206
+ enable_hlfb=True,
207
207
  final_logit_softcap=30.0,
208
208
  )
209
209
  return config
@@ -242,7 +242,6 @@ def define_and_run_2b() -> None:
242
242
  out = model.forward(tokens, input_pos)
243
243
  out_final = out[0, 8, :]
244
244
  assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
245
- print(out)
246
245
 
247
246
 
248
247
  if __name__ == "__main__":
@@ -99,14 +99,16 @@ def scaled_dot_product_attention_with_hlfb(
99
99
  The output tensor of scaled_dot_product_attention.
100
100
  """
101
101
 
102
- if softcap is not None:
103
- raise NotImplementedError("SDPA with HLFB not available with softcap.")
104
-
105
102
  if scale is None:
106
103
  scale = 1.0 / math.sqrt(head_size)
107
104
 
105
+ attrs = {"scale": scale}
106
+
107
+ if softcap is not None:
108
+ attrs["logit_cap"] = softcap
109
+
108
110
  builder = StableHLOCompositeBuilder(
109
- name="odml.scaled_dot_product_attention", attr={"scale": scale}
111
+ name="odml.scaled_dot_product_attention", attr=attrs
110
112
  )
111
113
  q, k, v, mask = builder.mark_inputs(q, k, v, mask)
112
114
 
@@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
72
72
  patterns = []
73
73
  if os.path.isdir(full_path):
74
74
  patterns.append(os.path.join(full_path, "*.bin"))
75
- patterns.append(os.path.join(full_path, "*.pt"))
75
+ patterns.append(os.path.join(full_path, "*pt"))
76
76
  else:
77
77
  patterns.append(full_path)
78
78
  for pattern in patterns:
@@ -149,6 +149,7 @@ class ModelLoader:
149
149
  enabled.
150
150
  """
151
151
  state = self._loader(self._file_name)
152
+ state = state["model_state_dict"] if "model_state_dict" in state else state
152
153
  converted_state = dict()
153
154
  if self._names.embedding is not None:
154
155
  converted_state["tok_embedding.weight"] = state.pop(
@@ -200,7 +201,7 @@ class ModelLoader:
200
201
  if glob.glob(os.path.join(self._file_name, "*.safetensors")):
201
202
  return load_safetensors
202
203
  if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
203
- os.path.join(self._file_name, "*.pt")
204
+ os.path.join(self._file_name, "*pt")
204
205
  ):
205
206
  return load_pytorch_statedict
206
207
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240817"
16
+ __version__ = "0.3.0.dev20240822"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240817
3
+ Version: 0.3.0.dev20240822
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
5
- ai_edge_torch/version.py,sha256=C4yfJq9TbtZBH5gwhPSUtBgiIe04GkxvCq5TImNopww,706
5
+ ai_edge_torch/version.py,sha256=DjujCBc63P1BCTPwFWV93QP5GPGWMkSuuTWTRIg5YNA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -28,7 +28,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
29
  ai_edge_torch/_convert/test/test_convert.py,sha256=y0ZRivdglGx217rnacze8N6nd7aafk28NkbBFUSa9DQ,13121
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=CBiOqq-m7QT2ggBI1jBl9MkTIT5d0nK1tA0BUga0LGs,7994
31
- ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=4jm5blAfzLMjvrJt0ntuG_Fgy4Ie3SoUOGBOy9tf6bg,4725
31
+ ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=cqXJK1YALC1huw87HJVjrdc9xe0LaahRY8tVm_RsKg4,4817
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=jLAmyHw5llT2ff8qA8mem3eVN57e_o5EpBnW72ZtP2I,3026
33
33
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
34
34
  ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
@@ -53,7 +53,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
53
53
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
55
  ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=j-zxJ-JNRnQ_kDzUESmsyy_a_4IxWZ510HmIImc0LDc,8240
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=q9Zil66EvRKrSpLVQHxKHu_8NL0HAgY2FbtThoTZVUY,8226
57
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
59
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
@@ -94,7 +94,7 @@ ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lf
94
94
  ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
95
95
  ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
96
96
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
97
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=x2bOmrTgOISXcb06IDP7X3xgftpPpxOjBXw_OxTMVns,3874
97
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
98
98
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
99
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
100
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -114,7 +114,7 @@ ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0W
114
114
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
115
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
116
116
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
117
- ai_edge_torch/generative/utilities/loader.py,sha256=bAWZ7FM4v_pPnX_AmEdGxHkDH65QdL-MjIP3PxscZmI,12649
117
+ ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
118
118
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
119
119
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
120
120
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -137,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
137
137
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
138
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
139
139
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
140
- ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
- ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/METADATA,sha256=GZgUf21m2RQYBvHxmeTujeniBxbbUTVpQQB9vjNSTaM,1885
142
- ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
- ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
- ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/RECORD,,
140
+ ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
+ ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/METADATA,sha256=U8torMOW4U2TfZG7j8mwQ3egNgda43VRAtleRwZmyKw,1885
142
+ ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
+ ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
+ ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/RECORD,,