ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  9. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  11. ai_edge_torch/generative/layers/attention.py +60 -63
  12. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  13. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  14. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  16. ai_edge_torch/generative/test/utils.py +54 -0
  17. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  18. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
  22. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  24. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  25. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  26. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  29. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  30. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  31. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  32. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  33. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=r0y6crIySNGhJqtljkzyHxb1XMvLji2VLajLfUjW8b4,706
5
+ ai_edge_torch/version.py,sha256=vCTKdj1Lc6r2UbJhIZpLdXauJSS0KfBLzgy9e3D16AA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=pUYSXuqFg8CAeJ8JkoYf7S0RDLRPVuZUwVOd0xObM6w,14411
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=FSufFZEeTLBpUnzE1Iy-LvNN0mhDynWMNg7Mei8RpLQ,14973
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -39,24 +39,14 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
46
- ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
- ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
49
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
52
42
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
57
- ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
43
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
44
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
45
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=oVV1lXgi9cMPES6JmiV8fJOgBQruRdHpyJL7MmXU09M,7283
46
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKtlDXcYLzy-n2moQPLJA_U,9769
47
+ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
49
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
60
50
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
51
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
52
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -78,19 +68,18 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8v
78
68
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
69
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
70
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
71
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
83
72
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
73
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
74
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=RK7oisSwIPqUWwwE1P-hDJlEnRJJ_V29UjUCxt4xETE,6780
86
75
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
76
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
77
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
- ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
78
+ ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
90
79
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
91
80
  ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
92
81
  ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
93
- ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
82
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
94
83
  ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
95
84
  ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
96
85
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
@@ -107,11 +96,12 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
107
96
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
97
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
98
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
99
+ ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
111
100
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
101
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=OmAHSGkxTNzDX5_kYjK7pxlPk0YZLqL9YiVIJQfuvPc,5889
102
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
114
103
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
104
+ ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
115
105
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
116
106
  ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
117
107
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
@@ -148,8 +138,8 @@ ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7i
148
138
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
149
139
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
150
140
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
151
- ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
152
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
141
+ ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
142
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
153
143
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
144
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
145
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -161,8 +151,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
151
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
152
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
153
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/METADATA,sha256=s7SAIUvFciy8peNKMHvyhoNQWYx67Jerz4foeV7KiE0,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/RECORD,,
154
+ ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
155
+ ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/METADATA,sha256=caHeAQX6pxEskue_BvgwkTfZEsG55rXHFwPDcV9oCN8,1859
156
+ ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
157
+ ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
158
+ ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,88 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Note: This is an experimental version of Gemma with external KV cache.
17
- # Please use with caution.
18
-
19
-
20
- import os
21
- from pathlib import Path
22
-
23
- import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.gemma import gemma
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
- from ai_edge_torch.generative.quantize import quant_recipes
27
- import torch
28
-
29
-
30
- def convert_gemma_to_tflite(
31
- checkpoint_path: str,
32
- prefill_seq_len: int = 512,
33
- kv_cache_max_len: int = 1024,
34
- quantize: bool = True,
35
- ):
36
- """An example method for converting a Gemma 2B model to multi-signature
37
-
38
- tflite model.
39
- Args:
40
- checkpoint_path (str): The filepath to the model checkpoint, or directory
41
- holding the checkpoint.
42
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
43
- Defaults to 512.
44
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
45
- including both prefill and decode. Defaults to 1024.
46
- quantize (bool, optional): Whether the model should be quanized. Defaults
47
- to True.
48
- """
49
- pytorch_model = gemma.build_2b_model(
50
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
51
- )
52
- # Tensors used to trace the model graph during conversion.
53
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
54
- prefill_input_pos = torch.arange(0, prefill_seq_len)
55
- decode_token = torch.tensor([[0]], dtype=torch.long)
56
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
57
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
58
-
59
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
60
- edge_model = (
61
- ai_edge_torch.signature(
62
- 'prefill',
63
- pytorch_model,
64
- sample_kwargs={
65
- 'tokens': prefill_tokens,
66
- 'input_pos': prefill_input_pos,
67
- 'kv_cache': kv,
68
- },
69
- )
70
- .signature(
71
- 'decode',
72
- pytorch_model,
73
- sample_kwargs={
74
- 'tokens': decode_token,
75
- 'input_pos': decode_input_pos,
76
- 'kv_cache': kv,
77
- },
78
- )
79
- .convert(quant_config=quant_config)
80
- )
81
- edge_model.export(
82
- f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
83
- )
84
-
85
-
86
- if __name__ == '__main__':
87
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
88
- convert_gemma_to_tflite(checkpoint_path)
@@ -1,219 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Example of building a Gemma model.
16
- #
17
- # Note: This is an experimental version of Gemma with external KV cache.
18
- # Please use with caution.
19
-
20
- import os
21
- from pathlib import Path
22
- from typing import Tuple
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.utilities.loader as loading_utils
30
- import numpy as np
31
- import torch
32
- from torch import nn
33
-
34
-
35
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
- ff_up_proj="model.layers.{}.mlp.up_proj",
37
- ff_down_proj="model.layers.{}.mlp.down_proj",
38
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
- attn_query_proj="model.layers.{}.self_attn.q_proj",
40
- attn_key_proj="model.layers.{}.self_attn.k_proj",
41
- attn_value_proj="model.layers.{}.self_attn.v_proj",
42
- attn_output_proj="model.layers.{}.self_attn.o_proj",
43
- pre_attn_norm="model.layers.{}.input_layernorm",
44
- post_attn_norm="model.layers.{}.post_attention_layernorm",
45
- embedding="model.embed_tokens",
46
- final_norm="model.norm",
47
- lm_head=None,
48
- )
49
-
50
-
51
- class Gemma(nn.Module):
52
- """A Gemma model built from the Edge Generative API layers."""
53
-
54
- def __init__(self, config: cfg.ModelConfig):
55
- super().__init__()
56
-
57
- self.config = config
58
- # Construct model layers.
59
- self.tok_embedding = nn.Embedding(
60
- config.vocab_size, config.embedding_dim, padding_idx=0
61
- )
62
- self.lm_head = nn.Linear(
63
- config.embedding_dim,
64
- config.vocab_size,
65
- bias=config.lm_head_use_bias,
66
- )
67
- # Gemma re-uses the embedding as the head projection layer.
68
- self.lm_head.weight.data = self.tok_embedding.weight.data
69
- self.transformer_blocks = nn.ModuleList(
70
- attention.TransformerBlock(config) for _ in range(config.num_layers)
71
- )
72
- self.final_norm = builder.build_norm(
73
- config.embedding_dim,
74
- config.final_norm_config,
75
- )
76
- self.rope_cache = attn_utils.build_rope_cache(
77
- size=config.kv_cache_max,
78
- dim=int(
79
- config.attn_config.rotary_percentage * config.attn_config.head_dim
80
- ),
81
- base=10_000,
82
- condense_ratio=1,
83
- dtype=torch.float32,
84
- device=torch.device("cpu"),
85
- )
86
- self.mask_cache = attn_utils.build_causal_mask_cache(
87
- size=config.kv_cache_max,
88
- dtype=torch.float32,
89
- device=torch.device("cpu"),
90
- )
91
- self.config = config
92
-
93
- @torch.inference_mode
94
- def forward(
95
- self,
96
- tokens: torch.Tensor,
97
- input_pos: torch.Tensor,
98
- kv_cache: kv_utils.EKVCache,
99
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
100
- _, seq_len = tokens.size()
101
- assert self.config.max_seq_len >= seq_len, (
102
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
103
- f" {self.config.max_seq_len}"
104
- )
105
-
106
- cos, sin = self.rope_cache
107
- cos = cos.index_select(0, input_pos)
108
- sin = sin.index_select(0, input_pos)
109
- mask = self.mask_cache.index_select(2, input_pos)
110
- mask = mask[:, :, :, : self.config.kv_cache_max]
111
-
112
- # token embeddings of shape (b, t, n_embd)
113
- x = self.tok_embedding(tokens)
114
- x = x * (self.config.embedding_dim**0.5)
115
-
116
- updated_kv_entires = []
117
- for i, block in enumerate(self.transformer_blocks):
118
- kv_entry = kv_cache.caches[i] if kv_cache else None
119
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
- if kv_entry:
121
- updated_kv_entires.append(kv_entry)
122
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
123
-
124
- x = self.final_norm(x)
125
- res = self.lm_head(x) # (b, t, vocab_size)
126
- return res, updated_kv_cache
127
-
128
-
129
- def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
- """Returns the model config for a Gemma 2B model.
131
-
132
- Args:
133
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
134
- is 1024.
135
-
136
- Returns:
137
- The model config for a Gemma 2B model.
138
- """
139
- attn_config = cfg.AttentionConfig(
140
- num_heads=8,
141
- head_dim=256,
142
- num_query_groups=1,
143
- rotary_percentage=1.0,
144
- )
145
- ff_config = cfg.FeedForwardConfig(
146
- type=cfg.FeedForwardType.GATED,
147
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
148
- intermediate_size=16384,
149
- )
150
- norm_config = cfg.NormalizationConfig(
151
- type=cfg.NormalizationType.RMS_NORM,
152
- epsilon=1e-6,
153
- zero_centered=True,
154
- )
155
- config = cfg.ModelConfig(
156
- vocab_size=256000,
157
- num_layers=18,
158
- max_seq_len=8192,
159
- embedding_dim=2048,
160
- kv_cache_max_len=kv_cache_max_len,
161
- attn_config=attn_config,
162
- ff_config=ff_config,
163
- pre_attention_norm_config=norm_config,
164
- post_attention_norm_config=norm_config,
165
- final_norm_config=norm_config,
166
- parallel_residual=False,
167
- lm_head_use_bias=False,
168
- enable_hlfb=True,
169
- )
170
- return config
171
-
172
-
173
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
174
- config = get_model_config_2b(kv_cache_max_len)
175
- config.ff_config.intermediate_size = 128
176
- config.vocab_size = 128
177
- config.num_layers = 2
178
- config.max_seq_len = 2 * kv_cache_max_len
179
- return config
180
-
181
-
182
- def build_2b_model(
183
- checkpoint_path: str, test_model: bool = False, **kwargs
184
- ) -> nn.Module:
185
- """Instantiates the model instance and load checkpoint if provided."""
186
- config = (
187
- get_fake_model_config(**kwargs)
188
- if test_model
189
- else get_model_config_2b(**kwargs)
190
- )
191
- model = Gemma(config)
192
- if checkpoint_path is not None:
193
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
194
- # since embedding and lm-head use the same weight, we need to set strict
195
- # to False.
196
- loader.load(model, strict=False)
197
- model.eval()
198
- return model
199
-
200
-
201
- def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
202
- """Instantiates and runs a Gemma 2B model."""
203
-
204
- kv_cache_max_len = 1024
205
- model = build_2b_model(
206
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
207
- )
208
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
209
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
210
- tokens[0, :4] = idx
211
- input_pos = torch.arange(0, kv_cache_max_len)
212
- kv = kv_utils.EKVCache.from_model_config(model.config)
213
- print("running an inference")
214
- print(model.forward(tokens, input_pos, kv))
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
219
- define_and_run_2b(input_checkpoint_path)
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,87 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Note: This is an experimental version of TinyLlama with external KV cache.
17
- # Please use with caution.
18
-
19
-
20
- import os
21
- from pathlib import Path
22
-
23
- import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
- from ai_edge_torch.generative.quantize import quant_recipes
27
- import torch
28
-
29
-
30
- def convert_tiny_llama_to_tflite(
31
- checkpoint_path: str,
32
- prefill_seq_len: int = 512,
33
- kv_cache_max_len: int = 1024,
34
- quantize: bool = True,
35
- ):
36
- """An example for converting TinyLlama model to multi-signature tflite model.
37
-
38
- Args:
39
- checkpoint_path (str): The filepath to the model checkpoint, or directory
40
- holding the checkpoint.
41
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
42
- Defaults to 512.
43
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
44
- including both prefill and decode. Defaults to 1024.
45
- quantize (bool, optional): Whether the model should be quanized. Defaults
46
- to True.
47
- """
48
- pytorch_model = tiny_llama.build_model(
49
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
50
- )
51
- # Tensors used to trace the model graph during conversion.
52
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
53
- prefill_input_pos = torch.arange(0, prefill_seq_len)
54
- decode_token = torch.tensor([[0]], dtype=torch.long)
55
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
57
-
58
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
- edge_model = (
60
- ai_edge_torch.signature(
61
- 'prefill',
62
- pytorch_model,
63
- sample_kwargs={
64
- 'tokens': prefill_tokens,
65
- 'input_pos': prefill_input_pos,
66
- 'kv_cache': kv,
67
- },
68
- )
69
- .signature(
70
- 'decode',
71
- pytorch_model,
72
- sample_kwargs={
73
- 'tokens': decode_token,
74
- 'input_pos': decode_input_pos,
75
- 'kv_cache': kv,
76
- },
77
- )
78
- .convert(quant_config=quant_config)
79
- )
80
- edge_model.export(
81
- f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
- )
83
-
84
-
85
- if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
87
- convert_tiny_llama_to_tflite(checkpoint_path)